Spaces:
Runtime error
Runtime error
| """Voice model wrapper for HuggingFace models.""" | |
| import torch | |
| import torch.nn as nn | |
| import logging | |
| from typing import Optional, Iterator, Dict, Any, Tuple | |
| from pathlib import Path | |
| from transformers import AutoModel, AutoConfig, AutoProcessor | |
| import json | |
| from .policy_wrapper import RLVoiceModel | |
| logger = logging.getLogger(__name__) | |
| class VoiceModelWrapper: | |
| """ | |
| Wrapper for HuggingFace voice models with RL training support. | |
| Provides a consistent interface for model loading, inference, | |
| checkpointing, and license verification. | |
| """ | |
| # List of known commercial-use licenses | |
| COMMERCIAL_LICENSES = [ | |
| "apache-2.0", | |
| "mit", | |
| "bsd", | |
| "bsd-3-clause", | |
| "cc-by-4.0", | |
| "cc-by-sa-4.0", | |
| "openrail", | |
| ] | |
| def __init__( | |
| self, | |
| model_name: str, | |
| device: str = "cuda", | |
| cache_dir: Optional[str] = None, | |
| enable_rl: bool = True, | |
| action_dim: int = 256 | |
| ): | |
| """ | |
| Initialize the voice model wrapper. | |
| Args: | |
| model_name: HuggingFace model identifier | |
| device: Device to load model on ('cuda', 'cpu', 'mps') | |
| cache_dir: Optional cache directory for model files | |
| enable_rl: Whether to add RL policy/value heads | |
| action_dim: Dimensionality of action space for RL | |
| """ | |
| self.model_name = model_name | |
| self.device = device | |
| self.cache_dir = cache_dir | |
| self.enable_rl = enable_rl | |
| self.action_dim = action_dim | |
| self.model = None | |
| self.rl_model = None | |
| self.processor = None | |
| self.config = None | |
| logger.info(f"Initialized VoiceModelWrapper for {model_name} on {device} (RL: {enable_rl})") | |
| def load_model(self) -> None: | |
| """ | |
| Load the voice model from HuggingFace. | |
| Performs license verification and architecture compatibility checks. | |
| Raises: | |
| ValueError: If model has incompatible license or architecture | |
| RuntimeError: If model loading fails | |
| """ | |
| try: | |
| logger.info(f"Loading model: {self.model_name}") | |
| # Load configuration first | |
| self.config = AutoConfig.from_pretrained( | |
| self.model_name, | |
| cache_dir=self.cache_dir | |
| ) | |
| # Verify license | |
| self._verify_license() | |
| # Verify architecture compatibility | |
| self._verify_architecture() | |
| # Load model | |
| self.model = AutoModel.from_pretrained( | |
| self.model_name, | |
| cache_dir=self.cache_dir | |
| ) | |
| self.model.to(self.device) | |
| self.model.train() # Set to training mode for RL | |
| # Wrap with RL policy/value heads if enabled | |
| if self.enable_rl: | |
| hidden_size = self.config.hidden_size if hasattr(self.config, 'hidden_size') else 768 | |
| self.rl_model = RLVoiceModel( | |
| base_model=self.model, | |
| hidden_size=hidden_size, | |
| action_dim=self.action_dim | |
| ) | |
| self.rl_model.to(self.device) | |
| logger.info(f"Added RL policy/value heads (action_dim={self.action_dim})") | |
| # Load processor if available | |
| try: | |
| self.processor = AutoProcessor.from_pretrained( | |
| self.model_name, | |
| cache_dir=self.cache_dir | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Could not load processor: {e}") | |
| self.processor = None | |
| logger.info(f"Successfully loaded model: {self.model_name}") | |
| logger.info(f"Model parameters: {self.count_parameters():,}") | |
| except Exception as e: | |
| error_msg = f"Failed to load model {self.model_name}: {str(e)}" | |
| logger.error(error_msg) | |
| raise RuntimeError(error_msg) from e | |
| def _verify_license(self) -> None: | |
| """ | |
| Verify that the model has a commercial-use license. | |
| Raises: | |
| ValueError: If license is not suitable for commercial use | |
| """ | |
| # Try to get license from config | |
| license_info = getattr(self.config, 'license', None) | |
| if license_info is None: | |
| logger.warning( | |
| f"No license information found for {self.model_name}. " | |
| "Please verify license manually." | |
| ) | |
| return | |
| license_lower = license_info.lower() | |
| # Check if license is in approved list | |
| is_commercial = any( | |
| approved in license_lower | |
| for approved in self.COMMERCIAL_LICENSES | |
| ) | |
| if not is_commercial: | |
| raise ValueError( | |
| f"Model {self.model_name} has license '{license_info}' " | |
| f"which may not be suitable for commercial use. " | |
| f"Approved licenses: {', '.join(self.COMMERCIAL_LICENSES)}" | |
| ) | |
| logger.info(f"License verified: {license_info}") | |
| def _verify_architecture(self) -> None: | |
| """ | |
| Verify that the model architecture is compatible with RL training. | |
| Checks for required attributes and methods. | |
| Raises: | |
| ValueError: If architecture is incompatible | |
| """ | |
| # Check if model has required architecture attributes | |
| required_attrs = ['config'] | |
| for attr in required_attrs: | |
| if not hasattr(self.config, attr.replace('config.', '')): | |
| logger.warning(f"Model may be missing attribute: {attr}") | |
| # Check model type | |
| model_type = getattr(self.config, 'model_type', 'unknown') | |
| logger.info(f"Model type: {model_type}") | |
| # Verify model can be put in training mode | |
| if self.model is not None and not hasattr(self.model, 'train'): | |
| raise ValueError("Model does not support training mode") | |
| logger.info("Architecture compatibility verified") | |
| def generate( | |
| self, | |
| input_features: torch.Tensor, | |
| training: bool = False, | |
| **kwargs | |
| ) -> torch.Tensor: | |
| """ | |
| Generate output from the model. | |
| Args: | |
| input_features: Input tensor | |
| training: If True, compute with gradients (for RL training) | |
| **kwargs: Additional generation parameters | |
| Returns: | |
| Generated output tensor | |
| Raises: | |
| RuntimeError: If model is not loaded | |
| """ | |
| if self.model is None: | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| if training: | |
| # During training, keep gradients for backprop | |
| outputs = self.model(input_features, **kwargs) | |
| else: | |
| # During inference, no gradients needed | |
| with torch.no_grad(): | |
| outputs = self.model(input_features, **kwargs) | |
| # Handle different output types | |
| if hasattr(outputs, 'last_hidden_state'): | |
| return outputs.last_hidden_state | |
| elif isinstance(outputs, torch.Tensor): | |
| return outputs | |
| else: | |
| return outputs[0] | |
| def get_logits(self, input_features: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Get model logits for input features. | |
| Args: | |
| input_features: Input tensor | |
| Returns: | |
| Logits tensor | |
| Raises: | |
| RuntimeError: If model is not loaded | |
| """ | |
| if self.model is None: | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| outputs = self.model(input_features) | |
| if hasattr(outputs, 'logits'): | |
| return outputs.logits | |
| elif hasattr(outputs, 'last_hidden_state'): | |
| return outputs.last_hidden_state | |
| else: | |
| return outputs[0] | |
| def forward(self, input_features: torch.Tensor, **kwargs) -> Any: | |
| """ | |
| Forward pass through the model. | |
| Args: | |
| input_features: Input tensor | |
| **kwargs: Additional forward parameters | |
| Returns: | |
| Model outputs (RL-compatible if RL enabled) | |
| """ | |
| if self.model is None: | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| # Use RL model if available (returns log_probs, values) | |
| if self.rl_model is not None: | |
| return self.rl_model(input_features, **kwargs) | |
| else: | |
| return self.model(input_features, **kwargs) | |
| def sample_action( | |
| self, | |
| input_features: torch.Tensor, | |
| deterministic: bool = False | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Sample action from the policy (RL training). | |
| Args: | |
| input_features: Input audio features | |
| deterministic: If True, take most likely action | |
| Returns: | |
| Tuple of (actions, log_probs, values) | |
| Raises: | |
| RuntimeError: If RL model is not enabled | |
| """ | |
| if self.rl_model is None: | |
| raise RuntimeError("RL model not enabled. Set enable_rl=True when initializing.") | |
| return self.rl_model.sample_action(input_features, deterministic) | |
| def evaluate_actions( | |
| self, | |
| input_features: torch.Tensor, | |
| actions: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Evaluate actions (for PPO training). | |
| Args: | |
| input_features: Input audio features | |
| actions: Actions to evaluate | |
| Returns: | |
| Tuple of (log_probs, values, entropy) | |
| Raises: | |
| RuntimeError: If RL model is not enabled | |
| """ | |
| if self.rl_model is None: | |
| raise RuntimeError("RL model not enabled. Set enable_rl=True when initializing.") | |
| return self.rl_model.evaluate_actions(input_features, actions) | |
| def save_checkpoint(self, path: str, metadata: Optional[Dict] = None) -> None: | |
| """ | |
| Save model checkpoint. | |
| Args: | |
| path: Path to save checkpoint | |
| metadata: Optional metadata to save with checkpoint | |
| Raises: | |
| RuntimeError: If model is not loaded | |
| """ | |
| if self.model is None: | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| checkpoint_path = Path(path) | |
| checkpoint_path.parent.mkdir(parents=True, exist_ok=True) | |
| checkpoint = { | |
| 'model_state_dict': self.model.state_dict(), | |
| 'model_name': self.model_name, | |
| 'config': self.config.to_dict() if self.config else None, | |
| 'enable_rl': self.enable_rl, | |
| 'action_dim': self.action_dim, | |
| } | |
| # Save RL model state if present | |
| if self.rl_model is not None: | |
| checkpoint['rl_model_state_dict'] = self.rl_model.state_dict() | |
| if metadata: | |
| checkpoint['metadata'] = metadata | |
| torch.save(checkpoint, checkpoint_path) | |
| logger.info(f"Checkpoint saved to {checkpoint_path}") | |
| def load_checkpoint(self, path: str) -> Dict: | |
| """ | |
| Load model checkpoint. | |
| Args: | |
| path: Path to checkpoint file | |
| Returns: | |
| Checkpoint metadata | |
| Raises: | |
| RuntimeError: If model is not loaded | |
| FileNotFoundError: If checkpoint file doesn't exist | |
| """ | |
| if self.model is None: | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| checkpoint_path = Path(path) | |
| if not checkpoint_path.exists(): | |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| # Load RL model state if present | |
| if 'rl_model_state_dict' in checkpoint and self.rl_model is not None: | |
| self.rl_model.load_state_dict(checkpoint['rl_model_state_dict']) | |
| logger.info("Loaded RL model state") | |
| logger.info(f"Checkpoint loaded from {checkpoint_path}") | |
| return checkpoint.get('metadata', {}) | |
| def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]: | |
| """ | |
| Get iterator over trainable parameters. | |
| Returns: | |
| Iterator over trainable parameters | |
| Raises: | |
| RuntimeError: If model is not loaded | |
| """ | |
| if self.model is None: | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| return (p for p in self.model.parameters() if p.requires_grad) | |
| def count_parameters(self, trainable_only: bool = False) -> int: | |
| """ | |
| Count model parameters. | |
| Args: | |
| trainable_only: If True, count only trainable parameters | |
| Returns: | |
| Number of parameters | |
| """ | |
| if self.model is None: | |
| return 0 | |
| # Count RL model params if available, otherwise base model | |
| model_to_count = self.rl_model if self.rl_model is not None else self.model | |
| if trainable_only: | |
| return sum(p.numel() for p in model_to_count.parameters() if p.requires_grad) | |
| else: | |
| return sum(p.numel() for p in model_to_count.parameters()) | |
| def set_training_mode(self, mode: bool = True) -> None: | |
| """ | |
| Set model training mode. | |
| Args: | |
| mode: If True, set to training mode; otherwise evaluation mode | |
| """ | |
| if self.model is None: | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| if mode: | |
| self.model.train() | |
| if self.rl_model is not None: | |
| self.rl_model.train() | |
| else: | |
| self.model.eval() | |
| if self.rl_model is not None: | |
| self.rl_model.eval() | |
| def to(self, device: str) -> None: | |
| """ | |
| Move model to specified device. | |
| Args: | |
| device: Target device | |
| """ | |
| if self.model is None: | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| self.device = device | |
| self.model.to(device) | |
| if self.rl_model is not None: | |
| self.rl_model.to(device) | |
| logger.info(f"Model moved to {device}") | |
| def get_rl_model(self) -> Optional[nn.Module]: | |
| """ | |
| Get the RL-wrapped model. | |
| Returns: | |
| RLVoiceModel if RL is enabled, None otherwise | |
| """ | |
| return self.rl_model | |