File size: 13,392 Bytes
c3efd49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
"""Training orchestrator for RL voice model training."""
import torch
import logging
from typing import Dict, Any, Optional, List
from pathlib import Path
import time

from src.models.voice_model_wrapper import VoiceModelWrapper
from src.rl.algorithm_base import RLAlgorithm
from src.rl.reward_function import RewardFunction
from src.data.dataset import VoiceDataset

logger = logging.getLogger(__name__)


class TrainingOrchestrator:
    """
    Orchestrates the RL training process.
    
    Coordinates model, algorithm, data, and reward computation.
    """
    
    def __init__(
        self,
        model: VoiceModelWrapper,
        algorithm: RLAlgorithm,
        reward_function: RewardFunction,
        train_dataset: VoiceDataset,
        val_dataset: Optional[VoiceDataset] = None,
        metrics_tracker: Optional[Any] = None,
        visualizer: Optional[Any] = None,
        config: Optional[Dict[str, Any]] = None
    ):
        """
        Initialize training orchestrator.

        Args:
            model: Voice model wrapper
            algorithm: RL algorithm
            reward_function: Reward function
            train_dataset: Training dataset
            val_dataset: Optional validation dataset
            metrics_tracker: Optional metrics tracker
            visualizer: Optional visualizer
            config: Training configuration
        """
        self.model = model
        self.algorithm = algorithm
        self.reward_function = reward_function
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.metrics_tracker = metrics_tracker
        self.visualizer = visualizer
        
        # Default configuration
        self.config = {
            'num_episodes': 100,
            'episode_length': 10,
            'batch_size': 32,
            'log_interval': 10,
            'checkpoint_interval': 50,
            'checkpoint_dir': 'checkpoints',
            'max_checkpoints': 5,
        }
        if config:
            self.config.update(config)
        
        # Training state
        self.current_episode = 0
        self.training_history = []
        self.best_reward = float('-inf')
        
        # Log configuration
        logger.info("Initialized TrainingOrchestrator")
        logger.info(f"Configuration: {self.config}")
        logger.info(f"Algorithm: {type(self.algorithm).__name__}")
        logger.info(f"Training samples: {len(self.train_dataset)}")
    
    def initialize_training(self) -> None:
        """Initialize training state and prepare for training."""
        self.current_episode = 0
        self.training_history = []
        self.best_reward = float('-inf')
        
        # Ensure checkpoint directory exists
        Path(self.config['checkpoint_dir']).mkdir(parents=True, exist_ok=True)
        
        # Set model to training mode
        self.model.set_training_mode(True)
        
        logger.info("Training initialized")
    
    def train_episode(self) -> Dict[str, Any]:
        """
        Execute one training episode.

        Returns:
            Dictionary with episode metrics
        """
        episode_start = time.time()

        # Sample batch from dataset
        batch_indices = torch.randint(0, len(self.train_dataset), (self.config['batch_size'],))
        batch_samples = [self.train_dataset[int(idx)] for idx in batch_indices]

        # Collect states, actions, rewards, log probs, values
        states = []
        actions = []
        old_log_probs = []
        old_values = []
        rewards = []

        total_reward = 0.0

        for sample in batch_samples:
            # Get input audio and move to model device
            input_audio = sample['audio'].to(self.model.device)

            # Sample action from policy (with gradients for training)
            action, log_prob, value = self.model.sample_action(
                input_audio.unsqueeze(0),
                deterministic=False
            )

            # Generate output representation for reward computation
            # (In practice, you'd decode action to audio, here we use a placeholder)
            output_audio = self.model.generate(input_audio.unsqueeze(0), training=True)

            # Compute reward
            reference_audio = input_audio  # In real scenario, would have separate reference
            reward = self.reward_function.compute_reward(
                output_audio.squeeze(0),
                reference_audio
            )

            total_reward += reward

            # Store for RL update
            states.append(input_audio)
            actions.append(action.squeeze(0))
            old_log_probs.append(log_prob.squeeze(0))
            old_values.append(value.squeeze())  # Fully squeeze to scalar
            rewards.append(reward)

        # Convert to tensors
        # Handle variable-length audio by padding to max length
        max_length = max(s.shape[0] for s in states)

        # Pad states to same length
        states_padded = []
        for s in states:
            if len(s.shape) == 1:
                # Pad 1D tensor
                pad_length = max_length - s.shape[0]
                if pad_length > 0:
                    s_padded = torch.nn.functional.pad(s, (0, pad_length))
                else:
                    s_padded = s
            else:
                # Shouldn't happen but handle it
                s_padded = s
            states_padded.append(s_padded)

        states_tensor = torch.stack(states_padded)
        actions_tensor = torch.stack(actions)
        old_log_probs_tensor = torch.stack(old_log_probs)
        old_values_tensor = torch.stack(old_values)
        rewards_tensor = torch.tensor(rewards, dtype=torch.float32, device=self.model.device)

        # Dones (all False for continuous training)
        dones = torch.zeros_like(rewards_tensor)

        # Compute loss using RL algorithm
        loss = self.algorithm.compute_loss(
            states_tensor,
            actions_tensor,
            rewards_tensor,
            states_tensor,  # next_states = current states (simplified)
            old_log_probs=old_log_probs_tensor,
            values=old_values_tensor,
            dones=dones
        )

        # Update policy
        update_metrics = self.algorithm.update_policy(loss)

        # Compute episode metrics
        episode_time = time.time() - episode_start
        avg_reward = total_reward / len(batch_samples)

        metrics = {
            'episode': self.current_episode,
            'total_reward': total_reward,
            'average_reward': avg_reward,
            'loss': loss.item(),
            'episode_time': episode_time,
            **update_metrics
        }

        # Update best reward
        if avg_reward > self.best_reward:
            self.best_reward = avg_reward
            metrics['is_best'] = True
        else:
            metrics['is_best'] = False

        # Log metrics to tracker if available
        if self.metrics_tracker:
            self.metrics_tracker.log_metrics({
                'reward': avg_reward,
                'total_reward': total_reward,
                'loss': loss.item(),
                'episode_time': episode_time,
                **{k: v for k, v in update_metrics.items() if isinstance(v, (int, float))}
            }, step=self.current_episode)

        self.training_history.append(metrics)
        self.current_episode += 1

        return metrics
    
    def should_checkpoint(self) -> bool:
        """
        Check if checkpoint should be saved.
        
        Returns:
            True if checkpoint should be saved
        """
        if self.current_episode == 0:
            return False
        
        return self.current_episode % self.config['checkpoint_interval'] == 0
    
    def should_log(self) -> bool:
        """
        Check if metrics should be logged.
        
        Returns:
            True if should log
        """
        if self.current_episode == 0:
            return True
        
        return self.current_episode % self.config['log_interval'] == 0
    
    def train(self) -> Dict[str, Any]:
        """
        Run full training loop.
        
        Returns:
            Training summary
        """
        self.initialize_training()
        
        logger.info(f"Starting training for {self.config['num_episodes']} episodes")
        
        for episode in range(self.config['num_episodes']):
            # Train one episode
            metrics = self.train_episode()
            
            # Log if needed
            if self.should_log():
                logger.info(
                    f"Episode {metrics['episode']}: "
                    f"reward={metrics['average_reward']:.4f}, "
                    f"loss={metrics['loss']:.4f}, "
                    f"time={metrics['episode_time']:.2f}s"
                )
            
            # Checkpoint if needed
            if self.should_checkpoint():
                self.save_checkpoint()

            # Generate visualizations periodically
            if self.visualizer and (episode + 1) % max(1, self.config['num_episodes'] // 5) == 0:
                self.visualizer.plot_training_curves(
                    self.metrics_tracker.get_all_metrics(),
                    title=f"Training Progress (Episode {episode})"
                )

        # Finalize training
        summary = self.finalize_training()

        # Save final metrics
        self.metrics_tracker.save_metrics()

        # Generate final visualizations
        if self.visualizer:
            self.visualizer.plot_training_curves(
                self.metrics_tracker.get_all_metrics(),
                title="Final Training Results"
            )

        return summary
    
    def save_checkpoint(self, path: Optional[str] = None) -> None:
        """
        Save training checkpoint.
        
        Args:
            path: Optional custom checkpoint path
        """
        if path is None:
            checkpoint_dir = Path(self.config['checkpoint_dir'])
            path = checkpoint_dir / f"checkpoint_episode_{self.current_episode}.pt"
        
        metadata = {
            'episode': self.current_episode,
            'best_reward': self.best_reward,
            'config': self.config,
            'algorithm_hyperparameters': self.algorithm.get_hyperparameters()
        }
        
        self.model.save_checkpoint(str(path), metadata=metadata)
        logger.info(f"Checkpoint saved: {path}")
        
        # Cleanup old checkpoints
        self._cleanup_old_checkpoints()
    
    def _cleanup_old_checkpoints(self) -> None:
        """Remove old checkpoints, keeping only the most recent N."""
        checkpoint_dir = Path(self.config['checkpoint_dir'])
        checkpoints = sorted(checkpoint_dir.glob("checkpoint_episode_*.pt"))
        
        max_checkpoints = self.config.get('max_checkpoints', 5)
        
        if len(checkpoints) > max_checkpoints:
            for old_checkpoint in checkpoints[:-max_checkpoints]:
                old_checkpoint.unlink()
                logger.debug(f"Removed old checkpoint: {old_checkpoint}")
    
    def load_checkpoint(self, path: str) -> None:
        """
        Load training checkpoint.
        
        Args:
            path: Path to checkpoint file
        """
        metadata = self.model.load_checkpoint(path)
        
        self.current_episode = metadata.get('episode', 0)
        self.best_reward = metadata.get('best_reward', float('-inf'))
        
        logger.info(f"Checkpoint loaded from {path}")
        logger.info(f"Resuming from episode {self.current_episode}")
    
    def finalize_training(self) -> Dict[str, Any]:
        """
        Finalize training and generate summary.
        
        Returns:
            Training summary dictionary
        """
        # Save final checkpoint
        final_path = Path(self.config['checkpoint_dir']) / "final_model.pt"
        self.save_checkpoint(str(final_path))
        
        # Compute summary statistics
        if self.training_history:
            rewards = [m['average_reward'] for m in self.training_history]
            losses = [m['loss'] for m in self.training_history]
            
            summary = {
                'total_episodes': self.current_episode,
                'best_reward': self.best_reward,
                'final_reward': rewards[-1] if rewards else 0.0,
                'mean_reward': sum(rewards) / len(rewards),
                'mean_loss': sum(losses) / len(losses),
                'config': self.config,
                'training_history': self.training_history
            }
        else:
            summary = {
                'total_episodes': 0,
                'best_reward': 0.0,
                'final_reward': 0.0,
                'mean_reward': 0.0,
                'mean_loss': 0.0,
                'config': self.config,
                'training_history': []
            }
        
        logger.info("Training finalized")
        logger.info(f"Best reward: {summary['best_reward']:.4f}")
        logger.info(f"Mean reward: {summary['mean_reward']:.4f}")
        
        return summary
    
    def get_training_history(self) -> List[Dict[str, Any]]:
        """
        Get training history.
        
        Returns:
            List of episode metrics
        """
        return self.training_history