File size: 504 Bytes
c3efd49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
"""Model configuration classes."""
from dataclasses import dataclass
from typing import Optional


@dataclass
class ModelConfig:
    """Configuration for voice model."""
    name: str
    device: str = "cuda"
    checkpoint: Optional[str] = None
    cache_dir: Optional[str] = None
    
    def __post_init__(self):
        """Validate configuration."""
        if self.device not in ["cuda", "cpu", "mps"]:
            raise ValueError(f"Invalid device: {self.device}. Must be 'cuda', 'cpu', or 'mps'")