"""Visualization tools for training monitoring.""" import matplotlib.pyplot as plt import numpy as np from typing import Dict, List, Optional, Any from pathlib import Path import logging logger = logging.getLogger(__name__) class Visualizer: """ Creates visualizations for training metrics. Supports TensorBoard integration and static plots. """ def __init__(self, output_dir: str = "visualizations"): """ Initialize visualizer. Args: output_dir: Directory to save visualizations """ self.output_dir = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) # Try to import tensorboard self.tensorboard_available = False try: from torch.utils.tensorboard import SummaryWriter self.SummaryWriter = SummaryWriter self.tensorboard_available = True logger.info("TensorBoard available") except ImportError: logger.warning("TensorBoard not available") self.writer = None logger.info(f"Visualizer initialized: output_dir={output_dir}") def initialize_tensorboard(self, log_dir: Optional[str] = None) -> None: """ Initialize TensorBoard writer. Args: log_dir: Optional TensorBoard log directory """ if not self.tensorboard_available: logger.warning("TensorBoard not available, skipping initialization") return if log_dir is None: log_dir = str(self.output_dir / "tensorboard") self.writer = self.SummaryWriter(log_dir) logger.info(f"TensorBoard initialized: {log_dir}") def log_scalar_to_tensorboard( self, tag: str, value: float, step: int ) -> None: """ Log scalar value to TensorBoard. Args: tag: Metric name value: Metric value step: Step number """ if self.writer is not None: self.writer.add_scalar(tag, value, step) def plot_training_curve( self, metrics: Dict[str, List[Dict[str, Any]]], metric_name: str, title: Optional[str] = None, filename: Optional[str] = None ) -> str: """ Plot training curve for a metric. Args: metrics: Dictionary of metrics metric_name: Name of metric to plot title: Optional plot title filename: Optional output filename Returns: Path to saved plot """ if metric_name not in metrics: raise ValueError(f"Metric '{metric_name}' not found") data = metrics[metric_name] steps = [entry['step'] for entry in data] values = [entry['value'] for entry in data] plt.figure(figsize=(10, 6)) plt.plot(steps, values, linewidth=2) plt.xlabel('Step') plt.ylabel(metric_name.replace('_', ' ').title()) plt.title(title or f'{metric_name.replace("_", " ").title()} Over Time') plt.grid(True, alpha=0.3) if filename is None: filename = f"{metric_name}_curve.png" output_path = self.output_dir / filename plt.savefig(output_path, dpi=150, bbox_inches='tight') plt.close() logger.info(f"Training curve saved: {output_path}") return str(output_path) def plot_multiple_metrics( self, metrics: Dict[str, List[Dict[str, Any]]], metric_names: List[str], title: Optional[str] = None, filename: Optional[str] = None ) -> str: """ Plot multiple metrics on the same figure. Args: metrics: Dictionary of metrics metric_names: List of metric names to plot title: Optional plot title filename: Optional output filename Returns: Path to saved plot """ plt.figure(figsize=(12, 6)) for metric_name in metric_names: if metric_name in metrics: data = metrics[metric_name] steps = [entry['step'] for entry in data] values = [entry['value'] for entry in data] plt.plot(steps, values, label=metric_name, linewidth=2) plt.xlabel('Step') plt.ylabel('Value') plt.title(title or 'Training Metrics') plt.legend() plt.grid(True, alpha=0.3) if filename is None: filename = "multiple_metrics.png" output_path = self.output_dir / filename plt.savefig(output_path, dpi=150, bbox_inches='tight') plt.close() logger.info(f"Multi-metric plot saved: {output_path}") return str(output_path) def plot_training_curves( self, metrics: Dict[str, List[Dict[str, Any]]], title: str = "Training Progress", filename: Optional[str] = None ) -> str: """ Plot comprehensive training curves with subplots. Args: metrics: Dictionary of all metrics title: Main title for the figure filename: Optional output filename Returns: Path to saved plot """ if not metrics: logger.warning("No metrics to plot") return "" # Determine which metrics to plot metric_names = list(metrics.keys()) num_metrics = len(metric_names) if num_metrics == 0: return "" # Create subplots fig, axes = plt.subplots(2, 2, figsize=(15, 10)) fig.suptitle(title, fontsize=16, fontweight='bold') axes = axes.flatten() # Plot up to 4 key metrics key_metrics = ['reward', 'loss', 'total_reward', 'episode_time'] plot_idx = 0 for metric_name in key_metrics: if metric_name in metrics and plot_idx < 4: data = metrics[metric_name] steps = [entry['step'] for entry in data] values = [entry['value'] for entry in data] ax = axes[plot_idx] ax.plot(steps, values, linewidth=2, marker='o', markersize=4) ax.set_xlabel('Episode') ax.set_ylabel(metric_name.replace('_', ' ').title()) ax.set_title(f'{metric_name.replace("_", " ").title()}') ax.grid(True, alpha=0.3) # Add trend line if len(steps) > 1: z = np.polyfit(steps, values, 1) p = np.poly1d(z) ax.plot(steps, p(steps), "--", alpha=0.5, color='red', label='Trend') ax.legend() plot_idx += 1 # Hide unused subplots for idx in range(plot_idx, 4): axes[idx].axis('off') plt.tight_layout() if filename is None: filename = f"training_curves_{len(steps)}_episodes.png" output_path = self.output_dir / filename plt.savefig(output_path, dpi=150, bbox_inches='tight') plt.close() logger.info(f"Training curves saved: {output_path}") return str(output_path) def plot_reward_distribution( self, rewards: List[float], title: Optional[str] = None, filename: Optional[str] = None ) -> str: """ Plot reward distribution histogram. Args: rewards: List of reward values title: Optional plot title filename: Optional output filename Returns: Path to saved plot """ plt.figure(figsize=(10, 6)) plt.hist(rewards, bins=30, alpha=0.7, edgecolor='black') plt.xlabel('Reward') plt.ylabel('Frequency') plt.title(title or 'Reward Distribution') plt.grid(True, alpha=0.3, axis='y') # Add statistics mean_reward = np.mean(rewards) std_reward = np.std(rewards) plt.axvline(mean_reward, color='red', linestyle='--', label=f'Mean: {mean_reward:.3f}') plt.axvline(mean_reward + std_reward, color='orange', linestyle=':', alpha=0.7, label=f'±1 Std') plt.axvline(mean_reward - std_reward, color='orange', linestyle=':', alpha=0.7) plt.legend() if filename is None: filename = "reward_distribution.png" output_path = self.output_dir / filename plt.savefig(output_path, dpi=150, bbox_inches='tight') plt.close() logger.info(f"Reward distribution saved: {output_path}") return str(output_path) def generate_summary_report( self, metrics: Dict[str, List[Dict[str, Any]]], statistics: Dict[str, Dict[str, float]], output_filename: str = "training_summary.txt" ) -> str: """ Generate text summary report. Args: metrics: Dictionary of metrics statistics: Dictionary of metric statistics output_filename: Output filename Returns: Path to saved report """ lines = [] lines.append("=" * 60) lines.append("TRAINING SUMMARY REPORT") lines.append("=" * 60) lines.append("") # Overall statistics lines.append("METRIC STATISTICS:") lines.append("-" * 60) for metric_name, stats in statistics.items(): lines.append(f"\n{metric_name}:") lines.append(f" Count: {stats['count']}") lines.append(f" Mean: {stats['mean']:.6f}") lines.append(f" Std: {stats['std']:.6f}") lines.append(f" Min: {stats['min']:.6f}") lines.append(f" Max: {stats['max']:.6f}") lines.append("") lines.append("=" * 60) report_text = "\n".join(lines) output_path = self.output_dir / output_filename with open(output_path, 'w') as f: f.write(report_text) logger.info(f"Summary report saved: {output_path}") return str(output_path) def close(self) -> None: """Close TensorBoard writer if open.""" if self.writer is not None: self.writer.close() logger.info("TensorBoard writer closed")