"""Training orchestrator for RL voice model training.""" import torch import logging from typing import Dict, Any, Optional, List from pathlib import Path import time from src.models.voice_model_wrapper import VoiceModelWrapper from src.rl.algorithm_base import RLAlgorithm from src.rl.reward_function import RewardFunction from src.data.dataset import VoiceDataset logger = logging.getLogger(__name__) class TrainingOrchestrator: """ Orchestrates the RL training process. Coordinates model, algorithm, data, and reward computation. """ def __init__( self, model: VoiceModelWrapper, algorithm: RLAlgorithm, reward_function: RewardFunction, train_dataset: VoiceDataset, val_dataset: Optional[VoiceDataset] = None, metrics_tracker: Optional[Any] = None, visualizer: Optional[Any] = None, config: Optional[Dict[str, Any]] = None ): """ Initialize training orchestrator. Args: model: Voice model wrapper algorithm: RL algorithm reward_function: Reward function train_dataset: Training dataset val_dataset: Optional validation dataset metrics_tracker: Optional metrics tracker visualizer: Optional visualizer config: Training configuration """ self.model = model self.algorithm = algorithm self.reward_function = reward_function self.train_dataset = train_dataset self.val_dataset = val_dataset self.metrics_tracker = metrics_tracker self.visualizer = visualizer # Default configuration self.config = { 'num_episodes': 100, 'episode_length': 10, 'batch_size': 32, 'log_interval': 10, 'checkpoint_interval': 50, 'checkpoint_dir': 'checkpoints', 'max_checkpoints': 5, } if config: self.config.update(config) # Training state self.current_episode = 0 self.training_history = [] self.best_reward = float('-inf') # Log configuration logger.info("Initialized TrainingOrchestrator") logger.info(f"Configuration: {self.config}") logger.info(f"Algorithm: {type(self.algorithm).__name__}") logger.info(f"Training samples: {len(self.train_dataset)}") def initialize_training(self) -> None: """Initialize training state and prepare for training.""" self.current_episode = 0 self.training_history = [] self.best_reward = float('-inf') # Ensure checkpoint directory exists Path(self.config['checkpoint_dir']).mkdir(parents=True, exist_ok=True) # Set model to training mode self.model.set_training_mode(True) logger.info("Training initialized") def train_episode(self) -> Dict[str, Any]: """ Execute one training episode. Returns: Dictionary with episode metrics """ episode_start = time.time() # Sample batch from dataset batch_indices = torch.randint(0, len(self.train_dataset), (self.config['batch_size'],)) batch_samples = [self.train_dataset[int(idx)] for idx in batch_indices] # Collect states, actions, rewards, log probs, values states = [] actions = [] old_log_probs = [] old_values = [] rewards = [] total_reward = 0.0 for sample in batch_samples: # Get input audio and move to model device input_audio = sample['audio'].to(self.model.device) # Sample action from policy (with gradients for training) action, log_prob, value = self.model.sample_action( input_audio.unsqueeze(0), deterministic=False ) # Generate output representation for reward computation # (In practice, you'd decode action to audio, here we use a placeholder) output_audio = self.model.generate(input_audio.unsqueeze(0), training=True) # Compute reward reference_audio = input_audio # In real scenario, would have separate reference reward = self.reward_function.compute_reward( output_audio.squeeze(0), reference_audio ) total_reward += reward # Store for RL update states.append(input_audio) actions.append(action.squeeze(0)) old_log_probs.append(log_prob.squeeze(0)) old_values.append(value.squeeze()) # Fully squeeze to scalar rewards.append(reward) # Convert to tensors # Handle variable-length audio by padding to max length max_length = max(s.shape[0] for s in states) # Pad states to same length states_padded = [] for s in states: if len(s.shape) == 1: # Pad 1D tensor pad_length = max_length - s.shape[0] if pad_length > 0: s_padded = torch.nn.functional.pad(s, (0, pad_length)) else: s_padded = s else: # Shouldn't happen but handle it s_padded = s states_padded.append(s_padded) states_tensor = torch.stack(states_padded) actions_tensor = torch.stack(actions) old_log_probs_tensor = torch.stack(old_log_probs) old_values_tensor = torch.stack(old_values) rewards_tensor = torch.tensor(rewards, dtype=torch.float32, device=self.model.device) # Dones (all False for continuous training) dones = torch.zeros_like(rewards_tensor) # Compute loss using RL algorithm loss = self.algorithm.compute_loss( states_tensor, actions_tensor, rewards_tensor, states_tensor, # next_states = current states (simplified) old_log_probs=old_log_probs_tensor, values=old_values_tensor, dones=dones ) # Update policy update_metrics = self.algorithm.update_policy(loss) # Compute episode metrics episode_time = time.time() - episode_start avg_reward = total_reward / len(batch_samples) metrics = { 'episode': self.current_episode, 'total_reward': total_reward, 'average_reward': avg_reward, 'loss': loss.item(), 'episode_time': episode_time, **update_metrics } # Update best reward if avg_reward > self.best_reward: self.best_reward = avg_reward metrics['is_best'] = True else: metrics['is_best'] = False # Log metrics to tracker if available if self.metrics_tracker: self.metrics_tracker.log_metrics({ 'reward': avg_reward, 'total_reward': total_reward, 'loss': loss.item(), 'episode_time': episode_time, **{k: v for k, v in update_metrics.items() if isinstance(v, (int, float))} }, step=self.current_episode) self.training_history.append(metrics) self.current_episode += 1 return metrics def should_checkpoint(self) -> bool: """ Check if checkpoint should be saved. Returns: True if checkpoint should be saved """ if self.current_episode == 0: return False return self.current_episode % self.config['checkpoint_interval'] == 0 def should_log(self) -> bool: """ Check if metrics should be logged. Returns: True if should log """ if self.current_episode == 0: return True return self.current_episode % self.config['log_interval'] == 0 def train(self) -> Dict[str, Any]: """ Run full training loop. Returns: Training summary """ self.initialize_training() logger.info(f"Starting training for {self.config['num_episodes']} episodes") for episode in range(self.config['num_episodes']): # Train one episode metrics = self.train_episode() # Log if needed if self.should_log(): logger.info( f"Episode {metrics['episode']}: " f"reward={metrics['average_reward']:.4f}, " f"loss={metrics['loss']:.4f}, " f"time={metrics['episode_time']:.2f}s" ) # Checkpoint if needed if self.should_checkpoint(): self.save_checkpoint() # Generate visualizations periodically if self.visualizer and (episode + 1) % max(1, self.config['num_episodes'] // 5) == 0: self.visualizer.plot_training_curves( self.metrics_tracker.get_all_metrics(), title=f"Training Progress (Episode {episode})" ) # Finalize training summary = self.finalize_training() # Save final metrics self.metrics_tracker.save_metrics() # Generate final visualizations if self.visualizer: self.visualizer.plot_training_curves( self.metrics_tracker.get_all_metrics(), title="Final Training Results" ) return summary def save_checkpoint(self, path: Optional[str] = None) -> None: """ Save training checkpoint. Args: path: Optional custom checkpoint path """ if path is None: checkpoint_dir = Path(self.config['checkpoint_dir']) path = checkpoint_dir / f"checkpoint_episode_{self.current_episode}.pt" metadata = { 'episode': self.current_episode, 'best_reward': self.best_reward, 'config': self.config, 'algorithm_hyperparameters': self.algorithm.get_hyperparameters() } self.model.save_checkpoint(str(path), metadata=metadata) logger.info(f"Checkpoint saved: {path}") # Cleanup old checkpoints self._cleanup_old_checkpoints() def _cleanup_old_checkpoints(self) -> None: """Remove old checkpoints, keeping only the most recent N.""" checkpoint_dir = Path(self.config['checkpoint_dir']) checkpoints = sorted(checkpoint_dir.glob("checkpoint_episode_*.pt")) max_checkpoints = self.config.get('max_checkpoints', 5) if len(checkpoints) > max_checkpoints: for old_checkpoint in checkpoints[:-max_checkpoints]: old_checkpoint.unlink() logger.debug(f"Removed old checkpoint: {old_checkpoint}") def load_checkpoint(self, path: str) -> None: """ Load training checkpoint. Args: path: Path to checkpoint file """ metadata = self.model.load_checkpoint(path) self.current_episode = metadata.get('episode', 0) self.best_reward = metadata.get('best_reward', float('-inf')) logger.info(f"Checkpoint loaded from {path}") logger.info(f"Resuming from episode {self.current_episode}") def finalize_training(self) -> Dict[str, Any]: """ Finalize training and generate summary. Returns: Training summary dictionary """ # Save final checkpoint final_path = Path(self.config['checkpoint_dir']) / "final_model.pt" self.save_checkpoint(str(final_path)) # Compute summary statistics if self.training_history: rewards = [m['average_reward'] for m in self.training_history] losses = [m['loss'] for m in self.training_history] summary = { 'total_episodes': self.current_episode, 'best_reward': self.best_reward, 'final_reward': rewards[-1] if rewards else 0.0, 'mean_reward': sum(rewards) / len(rewards), 'mean_loss': sum(losses) / len(losses), 'config': self.config, 'training_history': self.training_history } else: summary = { 'total_episodes': 0, 'best_reward': 0.0, 'final_reward': 0.0, 'mean_reward': 0.0, 'mean_loss': 0.0, 'config': self.config, 'training_history': [] } logger.info("Training finalized") logger.info(f"Best reward: {summary['best_reward']:.4f}") logger.info(f"Mean reward: {summary['mean_reward']:.4f}") return summary def get_training_history(self) -> List[Dict[str, Any]]: """ Get training history. Returns: List of episode metrics """ return self.training_history