Spaces:
Runtime error
Runtime error
| """Anomaly detection for training monitoring.""" | |
| import numpy as np | |
| from typing import List, Dict, Optional, Callable | |
| from collections import deque | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class AnomalyDetector: | |
| """ | |
| Detects anomalies during training. | |
| Monitors for reward collapse, gradient explosion, and other issues. | |
| """ | |
| def __init__( | |
| self, | |
| window_size: int = 10, | |
| alert_callback: Optional[Callable] = None | |
| ): | |
| """ | |
| Initialize anomaly detector. | |
| Args: | |
| window_size: Size of sliding window for detection | |
| alert_callback: Optional callback function for alerts | |
| """ | |
| self.window_size = window_size | |
| self.alert_callback = alert_callback or self._default_alert | |
| # Sliding windows for metrics | |
| self.reward_window = deque(maxlen=window_size) | |
| self.loss_window = deque(maxlen=window_size) | |
| self.gradient_window = deque(maxlen=window_size) | |
| # Alert history | |
| self.alerts = [] | |
| logger.info(f"AnomalyDetector initialized: window_size={window_size}") | |
| def _default_alert(self, alert_type: str, message: str, severity: str) -> None: | |
| """ | |
| Default alert handler. | |
| Args: | |
| alert_type: Type of alert | |
| message: Alert message | |
| severity: Severity level | |
| """ | |
| log_func = { | |
| 'critical': logger.critical, | |
| 'warning': logger.warning, | |
| 'info': logger.info | |
| }.get(severity, logger.warning) | |
| log_func(f"[{alert_type}] {message}") | |
| def update( | |
| self, | |
| reward: Optional[float] = None, | |
| loss: Optional[float] = None, | |
| gradient_norm: Optional[float] = None | |
| ) -> List[Dict[str, str]]: | |
| """ | |
| Update detector with new metrics and check for anomalies. | |
| Args: | |
| reward: Current reward value | |
| loss: Current loss value | |
| gradient_norm: Current gradient norm | |
| Returns: | |
| List of detected anomalies | |
| """ | |
| anomalies = [] | |
| # Update windows | |
| if reward is not None: | |
| self.reward_window.append(reward) | |
| if loss is not None: | |
| self.loss_window.append(loss) | |
| if gradient_norm is not None: | |
| self.gradient_window.append(gradient_norm) | |
| # Check for anomalies | |
| if len(self.reward_window) >= self.window_size: | |
| reward_anomaly = self.detect_reward_collapse() | |
| if reward_anomaly: | |
| anomalies.append(reward_anomaly) | |
| if len(self.gradient_window) >= 3: # Need fewer samples for gradient check | |
| gradient_anomaly = self.detect_gradient_explosion() | |
| if gradient_anomaly: | |
| anomalies.append(gradient_anomaly) | |
| if len(self.loss_window) >= self.window_size: | |
| loss_anomaly = self.detect_loss_divergence() | |
| if loss_anomaly: | |
| anomalies.append(loss_anomaly) | |
| # Store and alert | |
| for anomaly in anomalies: | |
| self.alerts.append(anomaly) | |
| self.alert_callback( | |
| anomaly['type'], | |
| anomaly['message'], | |
| anomaly['severity'] | |
| ) | |
| return anomalies | |
| def detect_reward_collapse(self) -> Optional[Dict[str, str]]: | |
| """ | |
| Detect reward collapse (rewards stop changing). | |
| Returns: | |
| Anomaly dictionary if detected, None otherwise | |
| """ | |
| if len(self.reward_window) < self.window_size: | |
| return None | |
| rewards = list(self.reward_window) | |
| # Check if variance is very low | |
| variance = np.var(rewards) | |
| if variance < 1e-6: | |
| return { | |
| 'type': 'reward_collapse', | |
| 'message': f'Reward collapse detected: variance={variance:.2e}', | |
| 'severity': 'critical', | |
| 'details': { | |
| 'variance': variance, | |
| 'mean_reward': np.mean(rewards) | |
| } | |
| } | |
| # Check if rewards are consistently decreasing | |
| if len(rewards) >= 5: | |
| recent_trend = np.polyfit(range(len(rewards)), rewards, 1)[0] | |
| if recent_trend < -0.01: # Significant negative trend | |
| return { | |
| 'type': 'reward_decline', | |
| 'message': f'Reward declining: trend={recent_trend:.4f}', | |
| 'severity': 'warning', | |
| 'details': { | |
| 'trend': recent_trend, | |
| 'mean_reward': np.mean(rewards) | |
| } | |
| } | |
| return None | |
| def detect_gradient_explosion(self) -> Optional[Dict[str, str]]: | |
| """ | |
| Detect gradient explosion (very large gradients). | |
| Returns: | |
| Anomaly dictionary if detected, None otherwise | |
| """ | |
| if len(self.gradient_window) < 3: | |
| return None | |
| gradients = list(self.gradient_window) | |
| latest_gradient = gradients[-1] | |
| # Check for very large gradient | |
| if latest_gradient > 100.0: | |
| return { | |
| 'type': 'gradient_explosion', | |
| 'message': f'Gradient explosion detected: norm={latest_gradient:.2f}', | |
| 'severity': 'critical', | |
| 'details': { | |
| 'gradient_norm': latest_gradient, | |
| 'mean_gradient': np.mean(gradients) | |
| } | |
| } | |
| # Check for rapidly increasing gradients | |
| if len(gradients) >= 3: | |
| gradient_growth = gradients[-1] / (gradients[-3] + 1e-8) | |
| if gradient_growth > 10.0: | |
| return { | |
| 'type': 'gradient_growth', | |
| 'message': f'Rapid gradient growth: {gradient_growth:.2f}x', | |
| 'severity': 'warning', | |
| 'details': { | |
| 'growth_factor': gradient_growth, | |
| 'current_gradient': latest_gradient | |
| } | |
| } | |
| return None | |
| def detect_loss_divergence(self) -> Optional[Dict[str, str]]: | |
| """ | |
| Detect loss divergence (loss increasing or becoming NaN/Inf). | |
| Returns: | |
| Anomaly dictionary if detected, None otherwise | |
| """ | |
| if len(self.loss_window) < self.window_size: | |
| return None | |
| losses = list(self.loss_window) | |
| latest_loss = losses[-1] | |
| # Check for NaN or Inf | |
| if np.isnan(latest_loss) or np.isinf(latest_loss): | |
| return { | |
| 'type': 'loss_invalid', | |
| 'message': f'Invalid loss detected: {latest_loss}', | |
| 'severity': 'critical', | |
| 'details': { | |
| 'loss_value': str(latest_loss) | |
| } | |
| } | |
| # Check for consistently increasing loss | |
| if len(losses) >= 5: | |
| loss_trend = np.polyfit(range(len(losses)), losses, 1)[0] | |
| if loss_trend > 0.1: # Significant positive trend | |
| return { | |
| 'type': 'loss_divergence', | |
| 'message': f'Loss diverging: trend={loss_trend:.4f}', | |
| 'severity': 'warning', | |
| 'details': { | |
| 'trend': loss_trend, | |
| 'current_loss': latest_loss, | |
| 'mean_loss': np.mean(losses) | |
| } | |
| } | |
| return None | |
| def get_alerts(self) -> List[Dict[str, str]]: | |
| """ | |
| Get all alerts. | |
| Returns: | |
| List of alert dictionaries | |
| """ | |
| return self.alerts | |
| def get_recent_alerts(self, n: int = 10) -> List[Dict[str, str]]: | |
| """ | |
| Get most recent alerts. | |
| Args: | |
| n: Number of recent alerts to return | |
| Returns: | |
| List of recent alert dictionaries | |
| """ | |
| return self.alerts[-n:] | |
| def clear_alerts(self) -> None: | |
| """Clear all alerts.""" | |
| self.alerts.clear() | |
| logger.info("Alerts cleared") | |
| def get_summary(self) -> Dict[str, any]: | |
| """ | |
| Get summary of detected anomalies. | |
| Returns: | |
| Summary dictionary | |
| """ | |
| alert_types = {} | |
| for alert in self.alerts: | |
| alert_type = alert['type'] | |
| alert_types[alert_type] = alert_types.get(alert_type, 0) + 1 | |
| return { | |
| 'total_alerts': len(self.alerts), | |
| 'alert_types': alert_types, | |
| 'recent_alerts': self.get_recent_alerts(5) | |
| } | |