mbellan's picture
Initial deployment
c3efd49
"""Checkpoint management for training."""
import torch
import json
from pathlib import Path
from typing import Dict, Any, Optional, List
from datetime import datetime
import logging
logger = logging.getLogger(__name__)
class CheckpointManager:
"""
Manages model checkpoints during training.
Handles saving, loading, and cleanup of checkpoints.
"""
def __init__(
self,
checkpoint_dir: str = "checkpoints",
max_checkpoints: int = 5,
save_interval: int = 10
):
"""
Initialize checkpoint manager.
Args:
checkpoint_dir: Directory to save checkpoints
max_checkpoints: Maximum number of checkpoints to keep
save_interval: Save checkpoint every N episodes
"""
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.max_checkpoints = max_checkpoints
self.save_interval = save_interval
self.checkpoint_history = []
logger.info(f"CheckpointManager initialized: dir={checkpoint_dir}, max={max_checkpoints}, interval={save_interval}")
def should_save(self, episode: int) -> bool:
"""
Check if checkpoint should be saved at this episode.
Args:
episode: Current episode number
Returns:
True if should save checkpoint
"""
if episode == 0:
return False
return episode % self.save_interval == 0
def save_checkpoint(
self,
model,
episode: int,
metrics: Optional[Dict[str, Any]] = None,
is_best: bool = False
) -> str:
"""
Save a checkpoint.
Args:
model: Model to save
episode: Current episode number
metrics: Optional training metrics
is_best: Whether this is the best model so far
Returns:
Path to saved checkpoint
"""
# Create checkpoint filename
if is_best:
filename = "best_model.pt"
else:
filename = f"checkpoint_episode_{episode}.pt"
checkpoint_path = self.checkpoint_dir / filename
# Prepare metadata
metadata = {
'episode': episode,
'timestamp': datetime.now().isoformat(),
'is_best': is_best
}
if metrics:
metadata['metrics'] = metrics
# Save checkpoint
model.save_checkpoint(str(checkpoint_path), metadata=metadata)
# Record in history
self.checkpoint_history.append({
'path': str(checkpoint_path),
'episode': episode,
'timestamp': metadata['timestamp'],
'is_best': is_best
})
logger.info(f"Checkpoint saved: {checkpoint_path}")
# Cleanup old checkpoints
if not is_best:
self._cleanup_old_checkpoints()
return str(checkpoint_path)
def load_checkpoint(
self,
model,
checkpoint_path: Optional[str] = None,
load_best: bool = False
) -> Dict[str, Any]:
"""
Load a checkpoint.
Args:
model: Model to load checkpoint into
checkpoint_path: Optional specific checkpoint path
load_best: If True, load best model
Returns:
Checkpoint metadata
"""
if load_best:
checkpoint_path = str(self.checkpoint_dir / "best_model.pt")
elif checkpoint_path is None:
# Load most recent checkpoint
checkpoint_path = self._get_latest_checkpoint()
if checkpoint_path is None:
raise FileNotFoundError("No checkpoints found")
metadata = model.load_checkpoint(checkpoint_path)
logger.info(f"Checkpoint loaded: {checkpoint_path}")
logger.info(f"Episode: {metadata.get('episode', 'unknown')}")
return metadata
def _get_latest_checkpoint(self) -> Optional[str]:
"""
Get path to most recent checkpoint.
Returns:
Path to latest checkpoint or None
"""
checkpoints = sorted(
self.checkpoint_dir.glob("checkpoint_episode_*.pt"),
key=lambda p: p.stat().st_mtime,
reverse=True
)
if checkpoints:
return str(checkpoints[0])
return None
def _cleanup_old_checkpoints(self) -> None:
"""Remove old checkpoints, keeping only the most recent N."""
# Get all episode checkpoints (not best model)
checkpoints = sorted(
self.checkpoint_dir.glob("checkpoint_episode_*.pt"),
key=lambda p: p.stat().st_mtime,
reverse=True
)
# Remove old checkpoints
if len(checkpoints) > self.max_checkpoints:
for old_checkpoint in checkpoints[self.max_checkpoints:]:
old_checkpoint.unlink()
logger.debug(f"Removed old checkpoint: {old_checkpoint}")
def list_checkpoints(self) -> List[Dict[str, Any]]:
"""
List all available checkpoints.
Returns:
List of checkpoint information
"""
checkpoints = []
for checkpoint_file in self.checkpoint_dir.glob("*.pt"):
stat = checkpoint_file.stat()
checkpoints.append({
'path': str(checkpoint_file),
'name': checkpoint_file.name,
'size_mb': stat.st_size / (1024 * 1024),
'modified': datetime.fromtimestamp(stat.st_mtime).isoformat()
})
return sorted(checkpoints, key=lambda x: x['modified'], reverse=True)
def get_checkpoint_history(self) -> List[Dict[str, Any]]:
"""
Get checkpoint history.
Returns:
List of checkpoint records
"""
return self.checkpoint_history
def save_training_state(
self,
state: Dict[str, Any],
filename: str = "training_state.json"
) -> None:
"""
Save training state to JSON.
Args:
state: Training state dictionary
filename: Output filename
"""
state_path = self.checkpoint_dir / filename
with open(state_path, 'w') as f:
json.dump(state, f, indent=2)
logger.info(f"Training state saved: {state_path}")
def load_training_state(
self,
filename: str = "training_state.json"
) -> Dict[str, Any]:
"""
Load training state from JSON.
Args:
filename: State filename
Returns:
Training state dictionary
"""
state_path = self.checkpoint_dir / filename
if not state_path.exists():
raise FileNotFoundError(f"Training state not found: {state_path}")
with open(state_path, 'r') as f:
state = json.load(f)
logger.info(f"Training state loaded: {state_path}")
return state