mbellan's picture
Initial deployment
c3efd49
"""Training orchestrator for RL voice model training."""
import torch
import logging
from typing import Dict, Any, Optional, List
from pathlib import Path
import time
from src.models.voice_model_wrapper import VoiceModelWrapper
from src.rl.algorithm_base import RLAlgorithm
from src.rl.reward_function import RewardFunction
from src.data.dataset import VoiceDataset
logger = logging.getLogger(__name__)
class TrainingOrchestrator:
"""
Orchestrates the RL training process.
Coordinates model, algorithm, data, and reward computation.
"""
def __init__(
self,
model: VoiceModelWrapper,
algorithm: RLAlgorithm,
reward_function: RewardFunction,
train_dataset: VoiceDataset,
val_dataset: Optional[VoiceDataset] = None,
metrics_tracker: Optional[Any] = None,
visualizer: Optional[Any] = None,
config: Optional[Dict[str, Any]] = None
):
"""
Initialize training orchestrator.
Args:
model: Voice model wrapper
algorithm: RL algorithm
reward_function: Reward function
train_dataset: Training dataset
val_dataset: Optional validation dataset
metrics_tracker: Optional metrics tracker
visualizer: Optional visualizer
config: Training configuration
"""
self.model = model
self.algorithm = algorithm
self.reward_function = reward_function
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.metrics_tracker = metrics_tracker
self.visualizer = visualizer
# Default configuration
self.config = {
'num_episodes': 100,
'episode_length': 10,
'batch_size': 32,
'log_interval': 10,
'checkpoint_interval': 50,
'checkpoint_dir': 'checkpoints',
'max_checkpoints': 5,
}
if config:
self.config.update(config)
# Training state
self.current_episode = 0
self.training_history = []
self.best_reward = float('-inf')
# Log configuration
logger.info("Initialized TrainingOrchestrator")
logger.info(f"Configuration: {self.config}")
logger.info(f"Algorithm: {type(self.algorithm).__name__}")
logger.info(f"Training samples: {len(self.train_dataset)}")
def initialize_training(self) -> None:
"""Initialize training state and prepare for training."""
self.current_episode = 0
self.training_history = []
self.best_reward = float('-inf')
# Ensure checkpoint directory exists
Path(self.config['checkpoint_dir']).mkdir(parents=True, exist_ok=True)
# Set model to training mode
self.model.set_training_mode(True)
logger.info("Training initialized")
def train_episode(self) -> Dict[str, Any]:
"""
Execute one training episode.
Returns:
Dictionary with episode metrics
"""
episode_start = time.time()
# Sample batch from dataset
batch_indices = torch.randint(0, len(self.train_dataset), (self.config['batch_size'],))
batch_samples = [self.train_dataset[int(idx)] for idx in batch_indices]
# Collect states, actions, rewards, log probs, values
states = []
actions = []
old_log_probs = []
old_values = []
rewards = []
total_reward = 0.0
for sample in batch_samples:
# Get input audio and move to model device
input_audio = sample['audio'].to(self.model.device)
# Sample action from policy (with gradients for training)
action, log_prob, value = self.model.sample_action(
input_audio.unsqueeze(0),
deterministic=False
)
# Generate output representation for reward computation
# (In practice, you'd decode action to audio, here we use a placeholder)
output_audio = self.model.generate(input_audio.unsqueeze(0), training=True)
# Compute reward
reference_audio = input_audio # In real scenario, would have separate reference
reward = self.reward_function.compute_reward(
output_audio.squeeze(0),
reference_audio
)
total_reward += reward
# Store for RL update
states.append(input_audio)
actions.append(action.squeeze(0))
old_log_probs.append(log_prob.squeeze(0))
old_values.append(value.squeeze()) # Fully squeeze to scalar
rewards.append(reward)
# Convert to tensors
# Handle variable-length audio by padding to max length
max_length = max(s.shape[0] for s in states)
# Pad states to same length
states_padded = []
for s in states:
if len(s.shape) == 1:
# Pad 1D tensor
pad_length = max_length - s.shape[0]
if pad_length > 0:
s_padded = torch.nn.functional.pad(s, (0, pad_length))
else:
s_padded = s
else:
# Shouldn't happen but handle it
s_padded = s
states_padded.append(s_padded)
states_tensor = torch.stack(states_padded)
actions_tensor = torch.stack(actions)
old_log_probs_tensor = torch.stack(old_log_probs)
old_values_tensor = torch.stack(old_values)
rewards_tensor = torch.tensor(rewards, dtype=torch.float32, device=self.model.device)
# Dones (all False for continuous training)
dones = torch.zeros_like(rewards_tensor)
# Compute loss using RL algorithm
loss = self.algorithm.compute_loss(
states_tensor,
actions_tensor,
rewards_tensor,
states_tensor, # next_states = current states (simplified)
old_log_probs=old_log_probs_tensor,
values=old_values_tensor,
dones=dones
)
# Update policy
update_metrics = self.algorithm.update_policy(loss)
# Compute episode metrics
episode_time = time.time() - episode_start
avg_reward = total_reward / len(batch_samples)
metrics = {
'episode': self.current_episode,
'total_reward': total_reward,
'average_reward': avg_reward,
'loss': loss.item(),
'episode_time': episode_time,
**update_metrics
}
# Update best reward
if avg_reward > self.best_reward:
self.best_reward = avg_reward
metrics['is_best'] = True
else:
metrics['is_best'] = False
# Log metrics to tracker if available
if self.metrics_tracker:
self.metrics_tracker.log_metrics({
'reward': avg_reward,
'total_reward': total_reward,
'loss': loss.item(),
'episode_time': episode_time,
**{k: v for k, v in update_metrics.items() if isinstance(v, (int, float))}
}, step=self.current_episode)
self.training_history.append(metrics)
self.current_episode += 1
return metrics
def should_checkpoint(self) -> bool:
"""
Check if checkpoint should be saved.
Returns:
True if checkpoint should be saved
"""
if self.current_episode == 0:
return False
return self.current_episode % self.config['checkpoint_interval'] == 0
def should_log(self) -> bool:
"""
Check if metrics should be logged.
Returns:
True if should log
"""
if self.current_episode == 0:
return True
return self.current_episode % self.config['log_interval'] == 0
def train(self) -> Dict[str, Any]:
"""
Run full training loop.
Returns:
Training summary
"""
self.initialize_training()
logger.info(f"Starting training for {self.config['num_episodes']} episodes")
for episode in range(self.config['num_episodes']):
# Train one episode
metrics = self.train_episode()
# Log if needed
if self.should_log():
logger.info(
f"Episode {metrics['episode']}: "
f"reward={metrics['average_reward']:.4f}, "
f"loss={metrics['loss']:.4f}, "
f"time={metrics['episode_time']:.2f}s"
)
# Checkpoint if needed
if self.should_checkpoint():
self.save_checkpoint()
# Generate visualizations periodically
if self.visualizer and (episode + 1) % max(1, self.config['num_episodes'] // 5) == 0:
self.visualizer.plot_training_curves(
self.metrics_tracker.get_all_metrics(),
title=f"Training Progress (Episode {episode})"
)
# Finalize training
summary = self.finalize_training()
# Save final metrics
self.metrics_tracker.save_metrics()
# Generate final visualizations
if self.visualizer:
self.visualizer.plot_training_curves(
self.metrics_tracker.get_all_metrics(),
title="Final Training Results"
)
return summary
def save_checkpoint(self, path: Optional[str] = None) -> None:
"""
Save training checkpoint.
Args:
path: Optional custom checkpoint path
"""
if path is None:
checkpoint_dir = Path(self.config['checkpoint_dir'])
path = checkpoint_dir / f"checkpoint_episode_{self.current_episode}.pt"
metadata = {
'episode': self.current_episode,
'best_reward': self.best_reward,
'config': self.config,
'algorithm_hyperparameters': self.algorithm.get_hyperparameters()
}
self.model.save_checkpoint(str(path), metadata=metadata)
logger.info(f"Checkpoint saved: {path}")
# Cleanup old checkpoints
self._cleanup_old_checkpoints()
def _cleanup_old_checkpoints(self) -> None:
"""Remove old checkpoints, keeping only the most recent N."""
checkpoint_dir = Path(self.config['checkpoint_dir'])
checkpoints = sorted(checkpoint_dir.glob("checkpoint_episode_*.pt"))
max_checkpoints = self.config.get('max_checkpoints', 5)
if len(checkpoints) > max_checkpoints:
for old_checkpoint in checkpoints[:-max_checkpoints]:
old_checkpoint.unlink()
logger.debug(f"Removed old checkpoint: {old_checkpoint}")
def load_checkpoint(self, path: str) -> None:
"""
Load training checkpoint.
Args:
path: Path to checkpoint file
"""
metadata = self.model.load_checkpoint(path)
self.current_episode = metadata.get('episode', 0)
self.best_reward = metadata.get('best_reward', float('-inf'))
logger.info(f"Checkpoint loaded from {path}")
logger.info(f"Resuming from episode {self.current_episode}")
def finalize_training(self) -> Dict[str, Any]:
"""
Finalize training and generate summary.
Returns:
Training summary dictionary
"""
# Save final checkpoint
final_path = Path(self.config['checkpoint_dir']) / "final_model.pt"
self.save_checkpoint(str(final_path))
# Compute summary statistics
if self.training_history:
rewards = [m['average_reward'] for m in self.training_history]
losses = [m['loss'] for m in self.training_history]
summary = {
'total_episodes': self.current_episode,
'best_reward': self.best_reward,
'final_reward': rewards[-1] if rewards else 0.0,
'mean_reward': sum(rewards) / len(rewards),
'mean_loss': sum(losses) / len(losses),
'config': self.config,
'training_history': self.training_history
}
else:
summary = {
'total_episodes': 0,
'best_reward': 0.0,
'final_reward': 0.0,
'mean_reward': 0.0,
'mean_loss': 0.0,
'config': self.config,
'training_history': []
}
logger.info("Training finalized")
logger.info(f"Best reward: {summary['best_reward']:.4f}")
logger.info(f"Mean reward: {summary['mean_reward']:.4f}")
return summary
def get_training_history(self) -> List[Dict[str, Any]]:
"""
Get training history.
Returns:
List of episode metrics
"""
return self.training_history