Spaces:
Runtime error
Runtime error
| """Metrics tracking for training monitoring.""" | |
| import torch | |
| import numpy as np | |
| from typing import Dict, Any, List, Optional | |
| from collections import defaultdict | |
| import logging | |
| import json | |
| from pathlib import Path | |
| logger = logging.getLogger(__name__) | |
| class MetricsTracker: | |
| """ | |
| Tracks and aggregates training metrics. | |
| Logs rewards, losses, learning rates, GPU memory, and custom metrics. | |
| """ | |
| def __init__(self, log_dir: str = "logs"): | |
| """ | |
| Initialize metrics tracker. | |
| Args: | |
| log_dir: Directory to save metric logs | |
| """ | |
| self.log_dir = Path(log_dir) | |
| self.log_dir.mkdir(parents=True, exist_ok=True) | |
| # Storage for metrics | |
| self.metrics = defaultdict(list) | |
| self.step_counter = 0 | |
| logger.info(f"MetricsTracker initialized: log_dir={log_dir}") | |
| def log_metric( | |
| self, | |
| name: str, | |
| value: float, | |
| step: Optional[int] = None | |
| ) -> None: | |
| """ | |
| Log a single metric value. | |
| Args: | |
| name: Metric name | |
| value: Metric value | |
| step: Optional step number (uses internal counter if not provided) | |
| """ | |
| if step is None: | |
| step = self.step_counter | |
| self.metrics[name].append({ | |
| 'step': step, | |
| 'value': float(value) | |
| }) | |
| def log_metrics( | |
| self, | |
| metrics: Dict[str, float], | |
| step: Optional[int] = None | |
| ) -> None: | |
| """ | |
| Log multiple metrics at once. | |
| Args: | |
| metrics: Dictionary of metric names and values | |
| step: Optional step number | |
| """ | |
| if step is None: | |
| step = self.step_counter | |
| for name, value in metrics.items(): | |
| self.log_metric(name, value, step) | |
| self.step_counter += 1 | |
| def log_training_metrics( | |
| self, | |
| episode: int, | |
| reward: float, | |
| loss: float, | |
| learning_rate: float, | |
| **kwargs | |
| ) -> None: | |
| """ | |
| Log standard training metrics. | |
| Args: | |
| episode: Episode number | |
| reward: Episode reward | |
| loss: Training loss | |
| learning_rate: Current learning rate | |
| **kwargs: Additional metrics | |
| """ | |
| metrics = { | |
| 'reward': reward, | |
| 'loss': loss, | |
| 'learning_rate': learning_rate, | |
| **kwargs | |
| } | |
| self.log_metrics(metrics, step=episode) | |
| def log_gpu_memory(self, step: Optional[int] = None) -> None: | |
| """ | |
| Log GPU memory usage. | |
| Args: | |
| step: Optional step number | |
| """ | |
| if torch.cuda.is_available(): | |
| allocated = torch.cuda.memory_allocated() / (1024 ** 2) # MB | |
| reserved = torch.cuda.memory_reserved() / (1024 ** 2) # MB | |
| self.log_metric('gpu_memory_allocated_mb', allocated, step) | |
| self.log_metric('gpu_memory_reserved_mb', reserved, step) | |
| def get_metric(self, name: str) -> List[Dict[str, Any]]: | |
| """ | |
| Get all values for a specific metric. | |
| Args: | |
| name: Metric name | |
| Returns: | |
| List of {step, value} dictionaries | |
| """ | |
| return self.metrics.get(name, []) | |
| def get_latest_value(self, name: str) -> Optional[float]: | |
| """ | |
| Get the most recent value for a metric. | |
| Args: | |
| name: Metric name | |
| Returns: | |
| Latest value or None | |
| """ | |
| values = self.metrics.get(name, []) | |
| if values: | |
| return values[-1]['value'] | |
| return None | |
| def get_metric_statistics(self, name: str) -> Dict[str, float]: | |
| """ | |
| Get statistics for a metric. | |
| Args: | |
| name: Metric name | |
| Returns: | |
| Dictionary with mean, std, min, max | |
| """ | |
| values = [entry['value'] for entry in self.metrics.get(name, [])] | |
| if not values: | |
| return { | |
| 'count': 0, | |
| 'mean': 0.0, | |
| 'std': 0.0, | |
| 'min': 0.0, | |
| 'max': 0.0 | |
| } | |
| return { | |
| 'count': len(values), | |
| 'mean': float(np.mean(values)), | |
| 'std': float(np.std(values)), | |
| 'min': float(np.min(values)), | |
| 'max': float(np.max(values)) | |
| } | |
| def get_all_metrics(self) -> Dict[str, List[Dict[str, Any]]]: | |
| """ | |
| Get all tracked metrics. | |
| Returns: | |
| Dictionary of all metrics | |
| """ | |
| return dict(self.metrics) | |
| def get_metric_names(self) -> List[str]: | |
| """ | |
| Get names of all tracked metrics. | |
| Returns: | |
| List of metric names | |
| """ | |
| return list(self.metrics.keys()) | |
| def aggregate_metrics( | |
| self, | |
| window_size: int = 10 | |
| ) -> Dict[str, Dict[str, float]]: | |
| """ | |
| Aggregate metrics over a sliding window. | |
| Args: | |
| window_size: Size of sliding window | |
| Returns: | |
| Dictionary of aggregated metrics | |
| """ | |
| aggregated = {} | |
| for name, values in self.metrics.items(): | |
| if len(values) >= window_size: | |
| recent_values = [v['value'] for v in values[-window_size:]] | |
| aggregated[name] = { | |
| 'mean': float(np.mean(recent_values)), | |
| 'std': float(np.std(recent_values)), | |
| 'min': float(np.min(recent_values)), | |
| 'max': float(np.max(recent_values)) | |
| } | |
| return aggregated | |
| def save_metrics(self, filename: str = "metrics.json") -> None: | |
| """ | |
| Save metrics to JSON file. | |
| Args: | |
| filename: Output filename | |
| """ | |
| output_path = self.log_dir / filename | |
| with open(output_path, 'w') as f: | |
| json.dump(dict(self.metrics), f, indent=2) | |
| logger.info(f"Metrics saved to {output_path}") | |
| def load_metrics(self, filename: str = "metrics.json") -> None: | |
| """ | |
| Load metrics from JSON file. | |
| Args: | |
| filename: Input filename | |
| """ | |
| input_path = self.log_dir / filename | |
| if not input_path.exists(): | |
| raise FileNotFoundError(f"Metrics file not found: {input_path}") | |
| with open(input_path, 'r') as f: | |
| loaded_metrics = json.load(f) | |
| self.metrics = defaultdict(list, loaded_metrics) | |
| logger.info(f"Metrics loaded from {input_path}") | |
| def reset(self) -> None: | |
| """Reset all metrics.""" | |
| self.metrics.clear() | |
| self.step_counter = 0 | |
| logger.info("Metrics reset") | |
| def summary(self) -> Dict[str, Any]: | |
| """ | |
| Generate summary of all metrics. | |
| Returns: | |
| Summary dictionary | |
| """ | |
| summary = { | |
| 'total_steps': self.step_counter, | |
| 'num_metrics': len(self.metrics), | |
| 'metrics': {} | |
| } | |
| for name in self.metrics.keys(): | |
| summary['metrics'][name] = self.get_metric_statistics(name) | |
| return summary | |