Spaces:
Runtime error
Runtime error
| """Checkpoint management for training.""" | |
| import torch | |
| import json | |
| from pathlib import Path | |
| from typing import Dict, Any, Optional, List | |
| from datetime import datetime | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class CheckpointManager: | |
| """ | |
| Manages model checkpoints during training. | |
| Handles saving, loading, and cleanup of checkpoints. | |
| """ | |
| def __init__( | |
| self, | |
| checkpoint_dir: str = "checkpoints", | |
| max_checkpoints: int = 5, | |
| save_interval: int = 10 | |
| ): | |
| """ | |
| Initialize checkpoint manager. | |
| Args: | |
| checkpoint_dir: Directory to save checkpoints | |
| max_checkpoints: Maximum number of checkpoints to keep | |
| save_interval: Save checkpoint every N episodes | |
| """ | |
| self.checkpoint_dir = Path(checkpoint_dir) | |
| self.checkpoint_dir.mkdir(parents=True, exist_ok=True) | |
| self.max_checkpoints = max_checkpoints | |
| self.save_interval = save_interval | |
| self.checkpoint_history = [] | |
| logger.info(f"CheckpointManager initialized: dir={checkpoint_dir}, max={max_checkpoints}, interval={save_interval}") | |
| def should_save(self, episode: int) -> bool: | |
| """ | |
| Check if checkpoint should be saved at this episode. | |
| Args: | |
| episode: Current episode number | |
| Returns: | |
| True if should save checkpoint | |
| """ | |
| if episode == 0: | |
| return False | |
| return episode % self.save_interval == 0 | |
| def save_checkpoint( | |
| self, | |
| model, | |
| episode: int, | |
| metrics: Optional[Dict[str, Any]] = None, | |
| is_best: bool = False | |
| ) -> str: | |
| """ | |
| Save a checkpoint. | |
| Args: | |
| model: Model to save | |
| episode: Current episode number | |
| metrics: Optional training metrics | |
| is_best: Whether this is the best model so far | |
| Returns: | |
| Path to saved checkpoint | |
| """ | |
| # Create checkpoint filename | |
| if is_best: | |
| filename = "best_model.pt" | |
| else: | |
| filename = f"checkpoint_episode_{episode}.pt" | |
| checkpoint_path = self.checkpoint_dir / filename | |
| # Prepare metadata | |
| metadata = { | |
| 'episode': episode, | |
| 'timestamp': datetime.now().isoformat(), | |
| 'is_best': is_best | |
| } | |
| if metrics: | |
| metadata['metrics'] = metrics | |
| # Save checkpoint | |
| model.save_checkpoint(str(checkpoint_path), metadata=metadata) | |
| # Record in history | |
| self.checkpoint_history.append({ | |
| 'path': str(checkpoint_path), | |
| 'episode': episode, | |
| 'timestamp': metadata['timestamp'], | |
| 'is_best': is_best | |
| }) | |
| logger.info(f"Checkpoint saved: {checkpoint_path}") | |
| # Cleanup old checkpoints | |
| if not is_best: | |
| self._cleanup_old_checkpoints() | |
| return str(checkpoint_path) | |
| def load_checkpoint( | |
| self, | |
| model, | |
| checkpoint_path: Optional[str] = None, | |
| load_best: bool = False | |
| ) -> Dict[str, Any]: | |
| """ | |
| Load a checkpoint. | |
| Args: | |
| model: Model to load checkpoint into | |
| checkpoint_path: Optional specific checkpoint path | |
| load_best: If True, load best model | |
| Returns: | |
| Checkpoint metadata | |
| """ | |
| if load_best: | |
| checkpoint_path = str(self.checkpoint_dir / "best_model.pt") | |
| elif checkpoint_path is None: | |
| # Load most recent checkpoint | |
| checkpoint_path = self._get_latest_checkpoint() | |
| if checkpoint_path is None: | |
| raise FileNotFoundError("No checkpoints found") | |
| metadata = model.load_checkpoint(checkpoint_path) | |
| logger.info(f"Checkpoint loaded: {checkpoint_path}") | |
| logger.info(f"Episode: {metadata.get('episode', 'unknown')}") | |
| return metadata | |
| def _get_latest_checkpoint(self) -> Optional[str]: | |
| """ | |
| Get path to most recent checkpoint. | |
| Returns: | |
| Path to latest checkpoint or None | |
| """ | |
| checkpoints = sorted( | |
| self.checkpoint_dir.glob("checkpoint_episode_*.pt"), | |
| key=lambda p: p.stat().st_mtime, | |
| reverse=True | |
| ) | |
| if checkpoints: | |
| return str(checkpoints[0]) | |
| return None | |
| def _cleanup_old_checkpoints(self) -> None: | |
| """Remove old checkpoints, keeping only the most recent N.""" | |
| # Get all episode checkpoints (not best model) | |
| checkpoints = sorted( | |
| self.checkpoint_dir.glob("checkpoint_episode_*.pt"), | |
| key=lambda p: p.stat().st_mtime, | |
| reverse=True | |
| ) | |
| # Remove old checkpoints | |
| if len(checkpoints) > self.max_checkpoints: | |
| for old_checkpoint in checkpoints[self.max_checkpoints:]: | |
| old_checkpoint.unlink() | |
| logger.debug(f"Removed old checkpoint: {old_checkpoint}") | |
| def list_checkpoints(self) -> List[Dict[str, Any]]: | |
| """ | |
| List all available checkpoints. | |
| Returns: | |
| List of checkpoint information | |
| """ | |
| checkpoints = [] | |
| for checkpoint_file in self.checkpoint_dir.glob("*.pt"): | |
| stat = checkpoint_file.stat() | |
| checkpoints.append({ | |
| 'path': str(checkpoint_file), | |
| 'name': checkpoint_file.name, | |
| 'size_mb': stat.st_size / (1024 * 1024), | |
| 'modified': datetime.fromtimestamp(stat.st_mtime).isoformat() | |
| }) | |
| return sorted(checkpoints, key=lambda x: x['modified'], reverse=True) | |
| def get_checkpoint_history(self) -> List[Dict[str, Any]]: | |
| """ | |
| Get checkpoint history. | |
| Returns: | |
| List of checkpoint records | |
| """ | |
| return self.checkpoint_history | |
| def save_training_state( | |
| self, | |
| state: Dict[str, Any], | |
| filename: str = "training_state.json" | |
| ) -> None: | |
| """ | |
| Save training state to JSON. | |
| Args: | |
| state: Training state dictionary | |
| filename: Output filename | |
| """ | |
| state_path = self.checkpoint_dir / filename | |
| with open(state_path, 'w') as f: | |
| json.dump(state, f, indent=2) | |
| logger.info(f"Training state saved: {state_path}") | |
| def load_training_state( | |
| self, | |
| filename: str = "training_state.json" | |
| ) -> Dict[str, Any]: | |
| """ | |
| Load training state from JSON. | |
| Args: | |
| filename: State filename | |
| Returns: | |
| Training state dictionary | |
| """ | |
| state_path = self.checkpoint_dir / filename | |
| if not state_path.exists(): | |
| raise FileNotFoundError(f"Training state not found: {state_path}") | |
| with open(state_path, 'r') as f: | |
| state = json.load(f) | |
| logger.info(f"Training state loaded: {state_path}") | |
| return state | |