mbellan's picture
Initial deployment
c3efd49
"""Configuration management utilities."""
from dataclasses import dataclass, field, asdict
from typing import Optional, Dict, Any
import yaml
from pathlib import Path
@dataclass
class ModelConfig:
"""Model configuration."""
name: str = "facebook/wav2vec2-base"
device: str = "cuda"
checkpoint: Optional[str] = None
@dataclass
class RLConfig:
"""Reinforcement learning configuration."""
algorithm: str = "ppo"
learning_rate: float = 3.0e-4
batch_size: int = 32
num_episodes: int = 1000
episode_length: int = 100
gamma: float = 0.99
clip_epsilon: float = 0.2 # PPO specific
max_grad_norm: float = 1.0
@dataclass
class DataConfig:
"""Data configuration."""
dataset_path: str = "data/processed"
train_split: float = 0.7
val_split: float = 0.15
test_split: float = 0.15
sample_rate: int = 16000
@dataclass
class CurriculumConfig:
"""Curriculum learning configuration."""
enabled: bool = True
levels: int = 5
advancement_threshold: float = 0.8
@dataclass
class OptimizationConfig:
"""Optimization configuration."""
mixed_precision: bool = True
gradient_checkpointing: bool = False
@dataclass
class CheckpointConfig:
"""Checkpointing configuration."""
interval: int = 50 # episodes
save_dir: str = "checkpoints"
keep_last_n: int = 5
@dataclass
class MonitoringConfig:
"""Monitoring configuration."""
log_interval: int = 10
visualization_interval: int = 50
tensorboard_dir: str = "runs"
@dataclass
class ReproducibilityConfig:
"""Reproducibility configuration."""
random_seed: int = 42
@dataclass
class TrainingConfig:
"""Complete training configuration."""
model: ModelConfig = field(default_factory=ModelConfig)
rl: RLConfig = field(default_factory=RLConfig)
data: DataConfig = field(default_factory=DataConfig)
curriculum: CurriculumConfig = field(default_factory=CurriculumConfig)
optimization: OptimizationConfig = field(default_factory=OptimizationConfig)
checkpointing: CheckpointConfig = field(default_factory=CheckpointConfig)
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)
reproducibility: ReproducibilityConfig = field(default_factory=ReproducibilityConfig)
@classmethod
def from_yaml(cls, path: str) -> "TrainingConfig":
"""Load configuration from YAML file."""
with open(path, 'r') as f:
config_dict = yaml.safe_load(f)
return cls(
model=ModelConfig(**config_dict.get('model', {})),
rl=RLConfig(**config_dict.get('rl', {})),
data=DataConfig(**config_dict.get('data', {})),
curriculum=CurriculumConfig(**config_dict.get('curriculum', {})),
optimization=OptimizationConfig(**config_dict.get('optimization', {})),
checkpointing=CheckpointConfig(**config_dict.get('checkpointing', {})),
monitoring=MonitoringConfig(**config_dict.get('monitoring', {})),
reproducibility=ReproducibilityConfig(**config_dict.get('reproducibility', {}))
)
def to_yaml(self, path: str) -> None:
"""Save configuration to YAML file."""
config_dict = {
'model': asdict(self.model),
'rl': asdict(self.rl),
'data': asdict(self.data),
'curriculum': asdict(self.curriculum),
'optimization': asdict(self.optimization),
'checkpointing': asdict(self.checkpointing),
'monitoring': asdict(self.monitoring),
'reproducibility': asdict(self.reproducibility)
}
Path(path).parent.mkdir(parents=True, exist_ok=True)
with open(path, 'w') as f:
yaml.dump(config_dict, f, default_flow_style=False)
def to_dict(self) -> Dict[str, Any]:
"""Convert configuration to dictionary."""
return {
'model': asdict(self.model),
'rl': asdict(self.rl),
'data': asdict(self.data),
'curriculum': asdict(self.curriculum),
'optimization': asdict(self.optimization),
'checkpointing': asdict(self.checkpointing),
'monitoring': asdict(self.monitoring),
'reproducibility': asdict(self.reproducibility)
}