Spaces:
Runtime error
Runtime error
File size: 4,317 Bytes
c3efd49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
"""Configuration management utilities."""
from dataclasses import dataclass, field, asdict
from typing import Optional, Dict, Any
import yaml
from pathlib import Path
@dataclass
class ModelConfig:
"""Model configuration."""
name: str = "facebook/wav2vec2-base"
device: str = "cuda"
checkpoint: Optional[str] = None
@dataclass
class RLConfig:
"""Reinforcement learning configuration."""
algorithm: str = "ppo"
learning_rate: float = 3.0e-4
batch_size: int = 32
num_episodes: int = 1000
episode_length: int = 100
gamma: float = 0.99
clip_epsilon: float = 0.2 # PPO specific
max_grad_norm: float = 1.0
@dataclass
class DataConfig:
"""Data configuration."""
dataset_path: str = "data/processed"
train_split: float = 0.7
val_split: float = 0.15
test_split: float = 0.15
sample_rate: int = 16000
@dataclass
class CurriculumConfig:
"""Curriculum learning configuration."""
enabled: bool = True
levels: int = 5
advancement_threshold: float = 0.8
@dataclass
class OptimizationConfig:
"""Optimization configuration."""
mixed_precision: bool = True
gradient_checkpointing: bool = False
@dataclass
class CheckpointConfig:
"""Checkpointing configuration."""
interval: int = 50 # episodes
save_dir: str = "checkpoints"
keep_last_n: int = 5
@dataclass
class MonitoringConfig:
"""Monitoring configuration."""
log_interval: int = 10
visualization_interval: int = 50
tensorboard_dir: str = "runs"
@dataclass
class ReproducibilityConfig:
"""Reproducibility configuration."""
random_seed: int = 42
@dataclass
class TrainingConfig:
"""Complete training configuration."""
model: ModelConfig = field(default_factory=ModelConfig)
rl: RLConfig = field(default_factory=RLConfig)
data: DataConfig = field(default_factory=DataConfig)
curriculum: CurriculumConfig = field(default_factory=CurriculumConfig)
optimization: OptimizationConfig = field(default_factory=OptimizationConfig)
checkpointing: CheckpointConfig = field(default_factory=CheckpointConfig)
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)
reproducibility: ReproducibilityConfig = field(default_factory=ReproducibilityConfig)
@classmethod
def from_yaml(cls, path: str) -> "TrainingConfig":
"""Load configuration from YAML file."""
with open(path, 'r') as f:
config_dict = yaml.safe_load(f)
return cls(
model=ModelConfig(**config_dict.get('model', {})),
rl=RLConfig(**config_dict.get('rl', {})),
data=DataConfig(**config_dict.get('data', {})),
curriculum=CurriculumConfig(**config_dict.get('curriculum', {})),
optimization=OptimizationConfig(**config_dict.get('optimization', {})),
checkpointing=CheckpointConfig(**config_dict.get('checkpointing', {})),
monitoring=MonitoringConfig(**config_dict.get('monitoring', {})),
reproducibility=ReproducibilityConfig(**config_dict.get('reproducibility', {}))
)
def to_yaml(self, path: str) -> None:
"""Save configuration to YAML file."""
config_dict = {
'model': asdict(self.model),
'rl': asdict(self.rl),
'data': asdict(self.data),
'curriculum': asdict(self.curriculum),
'optimization': asdict(self.optimization),
'checkpointing': asdict(self.checkpointing),
'monitoring': asdict(self.monitoring),
'reproducibility': asdict(self.reproducibility)
}
Path(path).parent.mkdir(parents=True, exist_ok=True)
with open(path, 'w') as f:
yaml.dump(config_dict, f, default_flow_style=False)
def to_dict(self) -> Dict[str, Any]:
"""Convert configuration to dictionary."""
return {
'model': asdict(self.model),
'rl': asdict(self.rl),
'data': asdict(self.data),
'curriculum': asdict(self.curriculum),
'optimization': asdict(self.optimization),
'checkpointing': asdict(self.checkpointing),
'monitoring': asdict(self.monitoring),
'reproducibility': asdict(self.reproducibility)
}
|