Spaces:
Runtime error
Runtime error
| """Training orchestrator for RL voice model training.""" | |
| import torch | |
| import logging | |
| from typing import Dict, Any, Optional, List | |
| from pathlib import Path | |
| import time | |
| from src.models.voice_model_wrapper import VoiceModelWrapper | |
| from src.rl.algorithm_base import RLAlgorithm | |
| from src.rl.reward_function import RewardFunction | |
| from src.data.dataset import VoiceDataset | |
| logger = logging.getLogger(__name__) | |
| class TrainingOrchestrator: | |
| """ | |
| Orchestrates the RL training process. | |
| Coordinates model, algorithm, data, and reward computation. | |
| """ | |
| def __init__( | |
| self, | |
| model: VoiceModelWrapper, | |
| algorithm: RLAlgorithm, | |
| reward_function: RewardFunction, | |
| train_dataset: VoiceDataset, | |
| val_dataset: Optional[VoiceDataset] = None, | |
| metrics_tracker: Optional[Any] = None, | |
| visualizer: Optional[Any] = None, | |
| config: Optional[Dict[str, Any]] = None | |
| ): | |
| """ | |
| Initialize training orchestrator. | |
| Args: | |
| model: Voice model wrapper | |
| algorithm: RL algorithm | |
| reward_function: Reward function | |
| train_dataset: Training dataset | |
| val_dataset: Optional validation dataset | |
| metrics_tracker: Optional metrics tracker | |
| visualizer: Optional visualizer | |
| config: Training configuration | |
| """ | |
| self.model = model | |
| self.algorithm = algorithm | |
| self.reward_function = reward_function | |
| self.train_dataset = train_dataset | |
| self.val_dataset = val_dataset | |
| self.metrics_tracker = metrics_tracker | |
| self.visualizer = visualizer | |
| # Default configuration | |
| self.config = { | |
| 'num_episodes': 100, | |
| 'episode_length': 10, | |
| 'batch_size': 32, | |
| 'log_interval': 10, | |
| 'checkpoint_interval': 50, | |
| 'checkpoint_dir': 'checkpoints', | |
| 'max_checkpoints': 5, | |
| } | |
| if config: | |
| self.config.update(config) | |
| # Training state | |
| self.current_episode = 0 | |
| self.training_history = [] | |
| self.best_reward = float('-inf') | |
| # Log configuration | |
| logger.info("Initialized TrainingOrchestrator") | |
| logger.info(f"Configuration: {self.config}") | |
| logger.info(f"Algorithm: {type(self.algorithm).__name__}") | |
| logger.info(f"Training samples: {len(self.train_dataset)}") | |
| def initialize_training(self) -> None: | |
| """Initialize training state and prepare for training.""" | |
| self.current_episode = 0 | |
| self.training_history = [] | |
| self.best_reward = float('-inf') | |
| # Ensure checkpoint directory exists | |
| Path(self.config['checkpoint_dir']).mkdir(parents=True, exist_ok=True) | |
| # Set model to training mode | |
| self.model.set_training_mode(True) | |
| logger.info("Training initialized") | |
| def train_episode(self) -> Dict[str, Any]: | |
| """ | |
| Execute one training episode. | |
| Returns: | |
| Dictionary with episode metrics | |
| """ | |
| episode_start = time.time() | |
| # Sample batch from dataset | |
| batch_indices = torch.randint(0, len(self.train_dataset), (self.config['batch_size'],)) | |
| batch_samples = [self.train_dataset[int(idx)] for idx in batch_indices] | |
| # Collect states, actions, rewards, log probs, values | |
| states = [] | |
| actions = [] | |
| old_log_probs = [] | |
| old_values = [] | |
| rewards = [] | |
| total_reward = 0.0 | |
| for sample in batch_samples: | |
| # Get input audio and move to model device | |
| input_audio = sample['audio'].to(self.model.device) | |
| # Sample action from policy (with gradients for training) | |
| action, log_prob, value = self.model.sample_action( | |
| input_audio.unsqueeze(0), | |
| deterministic=False | |
| ) | |
| # Generate output representation for reward computation | |
| # (In practice, you'd decode action to audio, here we use a placeholder) | |
| output_audio = self.model.generate(input_audio.unsqueeze(0), training=True) | |
| # Compute reward | |
| reference_audio = input_audio # In real scenario, would have separate reference | |
| reward = self.reward_function.compute_reward( | |
| output_audio.squeeze(0), | |
| reference_audio | |
| ) | |
| total_reward += reward | |
| # Store for RL update | |
| states.append(input_audio) | |
| actions.append(action.squeeze(0)) | |
| old_log_probs.append(log_prob.squeeze(0)) | |
| old_values.append(value.squeeze()) # Fully squeeze to scalar | |
| rewards.append(reward) | |
| # Convert to tensors | |
| # Handle variable-length audio by padding to max length | |
| max_length = max(s.shape[0] for s in states) | |
| # Pad states to same length | |
| states_padded = [] | |
| for s in states: | |
| if len(s.shape) == 1: | |
| # Pad 1D tensor | |
| pad_length = max_length - s.shape[0] | |
| if pad_length > 0: | |
| s_padded = torch.nn.functional.pad(s, (0, pad_length)) | |
| else: | |
| s_padded = s | |
| else: | |
| # Shouldn't happen but handle it | |
| s_padded = s | |
| states_padded.append(s_padded) | |
| states_tensor = torch.stack(states_padded) | |
| actions_tensor = torch.stack(actions) | |
| old_log_probs_tensor = torch.stack(old_log_probs) | |
| old_values_tensor = torch.stack(old_values) | |
| rewards_tensor = torch.tensor(rewards, dtype=torch.float32, device=self.model.device) | |
| # Dones (all False for continuous training) | |
| dones = torch.zeros_like(rewards_tensor) | |
| # Compute loss using RL algorithm | |
| loss = self.algorithm.compute_loss( | |
| states_tensor, | |
| actions_tensor, | |
| rewards_tensor, | |
| states_tensor, # next_states = current states (simplified) | |
| old_log_probs=old_log_probs_tensor, | |
| values=old_values_tensor, | |
| dones=dones | |
| ) | |
| # Update policy | |
| update_metrics = self.algorithm.update_policy(loss) | |
| # Compute episode metrics | |
| episode_time = time.time() - episode_start | |
| avg_reward = total_reward / len(batch_samples) | |
| metrics = { | |
| 'episode': self.current_episode, | |
| 'total_reward': total_reward, | |
| 'average_reward': avg_reward, | |
| 'loss': loss.item(), | |
| 'episode_time': episode_time, | |
| **update_metrics | |
| } | |
| # Update best reward | |
| if avg_reward > self.best_reward: | |
| self.best_reward = avg_reward | |
| metrics['is_best'] = True | |
| else: | |
| metrics['is_best'] = False | |
| # Log metrics to tracker if available | |
| if self.metrics_tracker: | |
| self.metrics_tracker.log_metrics({ | |
| 'reward': avg_reward, | |
| 'total_reward': total_reward, | |
| 'loss': loss.item(), | |
| 'episode_time': episode_time, | |
| **{k: v for k, v in update_metrics.items() if isinstance(v, (int, float))} | |
| }, step=self.current_episode) | |
| self.training_history.append(metrics) | |
| self.current_episode += 1 | |
| return metrics | |
| def should_checkpoint(self) -> bool: | |
| """ | |
| Check if checkpoint should be saved. | |
| Returns: | |
| True if checkpoint should be saved | |
| """ | |
| if self.current_episode == 0: | |
| return False | |
| return self.current_episode % self.config['checkpoint_interval'] == 0 | |
| def should_log(self) -> bool: | |
| """ | |
| Check if metrics should be logged. | |
| Returns: | |
| True if should log | |
| """ | |
| if self.current_episode == 0: | |
| return True | |
| return self.current_episode % self.config['log_interval'] == 0 | |
| def train(self) -> Dict[str, Any]: | |
| """ | |
| Run full training loop. | |
| Returns: | |
| Training summary | |
| """ | |
| self.initialize_training() | |
| logger.info(f"Starting training for {self.config['num_episodes']} episodes") | |
| for episode in range(self.config['num_episodes']): | |
| # Train one episode | |
| metrics = self.train_episode() | |
| # Log if needed | |
| if self.should_log(): | |
| logger.info( | |
| f"Episode {metrics['episode']}: " | |
| f"reward={metrics['average_reward']:.4f}, " | |
| f"loss={metrics['loss']:.4f}, " | |
| f"time={metrics['episode_time']:.2f}s" | |
| ) | |
| # Checkpoint if needed | |
| if self.should_checkpoint(): | |
| self.save_checkpoint() | |
| # Generate visualizations periodically | |
| if self.visualizer and (episode + 1) % max(1, self.config['num_episodes'] // 5) == 0: | |
| self.visualizer.plot_training_curves( | |
| self.metrics_tracker.get_all_metrics(), | |
| title=f"Training Progress (Episode {episode})" | |
| ) | |
| # Finalize training | |
| summary = self.finalize_training() | |
| # Save final metrics | |
| self.metrics_tracker.save_metrics() | |
| # Generate final visualizations | |
| if self.visualizer: | |
| self.visualizer.plot_training_curves( | |
| self.metrics_tracker.get_all_metrics(), | |
| title="Final Training Results" | |
| ) | |
| return summary | |
| def save_checkpoint(self, path: Optional[str] = None) -> None: | |
| """ | |
| Save training checkpoint. | |
| Args: | |
| path: Optional custom checkpoint path | |
| """ | |
| if path is None: | |
| checkpoint_dir = Path(self.config['checkpoint_dir']) | |
| path = checkpoint_dir / f"checkpoint_episode_{self.current_episode}.pt" | |
| metadata = { | |
| 'episode': self.current_episode, | |
| 'best_reward': self.best_reward, | |
| 'config': self.config, | |
| 'algorithm_hyperparameters': self.algorithm.get_hyperparameters() | |
| } | |
| self.model.save_checkpoint(str(path), metadata=metadata) | |
| logger.info(f"Checkpoint saved: {path}") | |
| # Cleanup old checkpoints | |
| self._cleanup_old_checkpoints() | |
| def _cleanup_old_checkpoints(self) -> None: | |
| """Remove old checkpoints, keeping only the most recent N.""" | |
| checkpoint_dir = Path(self.config['checkpoint_dir']) | |
| checkpoints = sorted(checkpoint_dir.glob("checkpoint_episode_*.pt")) | |
| max_checkpoints = self.config.get('max_checkpoints', 5) | |
| if len(checkpoints) > max_checkpoints: | |
| for old_checkpoint in checkpoints[:-max_checkpoints]: | |
| old_checkpoint.unlink() | |
| logger.debug(f"Removed old checkpoint: {old_checkpoint}") | |
| def load_checkpoint(self, path: str) -> None: | |
| """ | |
| Load training checkpoint. | |
| Args: | |
| path: Path to checkpoint file | |
| """ | |
| metadata = self.model.load_checkpoint(path) | |
| self.current_episode = metadata.get('episode', 0) | |
| self.best_reward = metadata.get('best_reward', float('-inf')) | |
| logger.info(f"Checkpoint loaded from {path}") | |
| logger.info(f"Resuming from episode {self.current_episode}") | |
| def finalize_training(self) -> Dict[str, Any]: | |
| """ | |
| Finalize training and generate summary. | |
| Returns: | |
| Training summary dictionary | |
| """ | |
| # Save final checkpoint | |
| final_path = Path(self.config['checkpoint_dir']) / "final_model.pt" | |
| self.save_checkpoint(str(final_path)) | |
| # Compute summary statistics | |
| if self.training_history: | |
| rewards = [m['average_reward'] for m in self.training_history] | |
| losses = [m['loss'] for m in self.training_history] | |
| summary = { | |
| 'total_episodes': self.current_episode, | |
| 'best_reward': self.best_reward, | |
| 'final_reward': rewards[-1] if rewards else 0.0, | |
| 'mean_reward': sum(rewards) / len(rewards), | |
| 'mean_loss': sum(losses) / len(losses), | |
| 'config': self.config, | |
| 'training_history': self.training_history | |
| } | |
| else: | |
| summary = { | |
| 'total_episodes': 0, | |
| 'best_reward': 0.0, | |
| 'final_reward': 0.0, | |
| 'mean_reward': 0.0, | |
| 'mean_loss': 0.0, | |
| 'config': self.config, | |
| 'training_history': [] | |
| } | |
| logger.info("Training finalized") | |
| logger.info(f"Best reward: {summary['best_reward']:.4f}") | |
| logger.info(f"Mean reward: {summary['mean_reward']:.4f}") | |
| return summary | |
| def get_training_history(self) -> List[Dict[str, Any]]: | |
| """ | |
| Get training history. | |
| Returns: | |
| List of episode metrics | |
| """ | |
| return self.training_history | |