"""Checkpoint management for training.""" import torch import json from pathlib import Path from typing import Dict, Any, Optional, List from datetime import datetime import logging logger = logging.getLogger(__name__) class CheckpointManager: """ Manages model checkpoints during training. Handles saving, loading, and cleanup of checkpoints. """ def __init__( self, checkpoint_dir: str = "checkpoints", max_checkpoints: int = 5, save_interval: int = 10 ): """ Initialize checkpoint manager. Args: checkpoint_dir: Directory to save checkpoints max_checkpoints: Maximum number of checkpoints to keep save_interval: Save checkpoint every N episodes """ self.checkpoint_dir = Path(checkpoint_dir) self.checkpoint_dir.mkdir(parents=True, exist_ok=True) self.max_checkpoints = max_checkpoints self.save_interval = save_interval self.checkpoint_history = [] logger.info(f"CheckpointManager initialized: dir={checkpoint_dir}, max={max_checkpoints}, interval={save_interval}") def should_save(self, episode: int) -> bool: """ Check if checkpoint should be saved at this episode. Args: episode: Current episode number Returns: True if should save checkpoint """ if episode == 0: return False return episode % self.save_interval == 0 def save_checkpoint( self, model, episode: int, metrics: Optional[Dict[str, Any]] = None, is_best: bool = False ) -> str: """ Save a checkpoint. Args: model: Model to save episode: Current episode number metrics: Optional training metrics is_best: Whether this is the best model so far Returns: Path to saved checkpoint """ # Create checkpoint filename if is_best: filename = "best_model.pt" else: filename = f"checkpoint_episode_{episode}.pt" checkpoint_path = self.checkpoint_dir / filename # Prepare metadata metadata = { 'episode': episode, 'timestamp': datetime.now().isoformat(), 'is_best': is_best } if metrics: metadata['metrics'] = metrics # Save checkpoint model.save_checkpoint(str(checkpoint_path), metadata=metadata) # Record in history self.checkpoint_history.append({ 'path': str(checkpoint_path), 'episode': episode, 'timestamp': metadata['timestamp'], 'is_best': is_best }) logger.info(f"Checkpoint saved: {checkpoint_path}") # Cleanup old checkpoints if not is_best: self._cleanup_old_checkpoints() return str(checkpoint_path) def load_checkpoint( self, model, checkpoint_path: Optional[str] = None, load_best: bool = False ) -> Dict[str, Any]: """ Load a checkpoint. Args: model: Model to load checkpoint into checkpoint_path: Optional specific checkpoint path load_best: If True, load best model Returns: Checkpoint metadata """ if load_best: checkpoint_path = str(self.checkpoint_dir / "best_model.pt") elif checkpoint_path is None: # Load most recent checkpoint checkpoint_path = self._get_latest_checkpoint() if checkpoint_path is None: raise FileNotFoundError("No checkpoints found") metadata = model.load_checkpoint(checkpoint_path) logger.info(f"Checkpoint loaded: {checkpoint_path}") logger.info(f"Episode: {metadata.get('episode', 'unknown')}") return metadata def _get_latest_checkpoint(self) -> Optional[str]: """ Get path to most recent checkpoint. Returns: Path to latest checkpoint or None """ checkpoints = sorted( self.checkpoint_dir.glob("checkpoint_episode_*.pt"), key=lambda p: p.stat().st_mtime, reverse=True ) if checkpoints: return str(checkpoints[0]) return None def _cleanup_old_checkpoints(self) -> None: """Remove old checkpoints, keeping only the most recent N.""" # Get all episode checkpoints (not best model) checkpoints = sorted( self.checkpoint_dir.glob("checkpoint_episode_*.pt"), key=lambda p: p.stat().st_mtime, reverse=True ) # Remove old checkpoints if len(checkpoints) > self.max_checkpoints: for old_checkpoint in checkpoints[self.max_checkpoints:]: old_checkpoint.unlink() logger.debug(f"Removed old checkpoint: {old_checkpoint}") def list_checkpoints(self) -> List[Dict[str, Any]]: """ List all available checkpoints. Returns: List of checkpoint information """ checkpoints = [] for checkpoint_file in self.checkpoint_dir.glob("*.pt"): stat = checkpoint_file.stat() checkpoints.append({ 'path': str(checkpoint_file), 'name': checkpoint_file.name, 'size_mb': stat.st_size / (1024 * 1024), 'modified': datetime.fromtimestamp(stat.st_mtime).isoformat() }) return sorted(checkpoints, key=lambda x: x['modified'], reverse=True) def get_checkpoint_history(self) -> List[Dict[str, Any]]: """ Get checkpoint history. Returns: List of checkpoint records """ return self.checkpoint_history def save_training_state( self, state: Dict[str, Any], filename: str = "training_state.json" ) -> None: """ Save training state to JSON. Args: state: Training state dictionary filename: Output filename """ state_path = self.checkpoint_dir / filename with open(state_path, 'w') as f: json.dump(state, f, indent=2) logger.info(f"Training state saved: {state_path}") def load_training_state( self, filename: str = "training_state.json" ) -> Dict[str, Any]: """ Load training state from JSON. Args: filename: State filename Returns: Training state dictionary """ state_path = self.checkpoint_dir / filename if not state_path.exists(): raise FileNotFoundError(f"Training state not found: {state_path}") with open(state_path, 'r') as f: state = json.load(f) logger.info(f"Training state loaded: {state_path}") return state