"""Anomaly detection for training monitoring.""" import numpy as np from typing import List, Dict, Optional, Callable from collections import deque import logging logger = logging.getLogger(__name__) class AnomalyDetector: """ Detects anomalies during training. Monitors for reward collapse, gradient explosion, and other issues. """ def __init__( self, window_size: int = 10, alert_callback: Optional[Callable] = None ): """ Initialize anomaly detector. Args: window_size: Size of sliding window for detection alert_callback: Optional callback function for alerts """ self.window_size = window_size self.alert_callback = alert_callback or self._default_alert # Sliding windows for metrics self.reward_window = deque(maxlen=window_size) self.loss_window = deque(maxlen=window_size) self.gradient_window = deque(maxlen=window_size) # Alert history self.alerts = [] logger.info(f"AnomalyDetector initialized: window_size={window_size}") def _default_alert(self, alert_type: str, message: str, severity: str) -> None: """ Default alert handler. Args: alert_type: Type of alert message: Alert message severity: Severity level """ log_func = { 'critical': logger.critical, 'warning': logger.warning, 'info': logger.info }.get(severity, logger.warning) log_func(f"[{alert_type}] {message}") def update( self, reward: Optional[float] = None, loss: Optional[float] = None, gradient_norm: Optional[float] = None ) -> List[Dict[str, str]]: """ Update detector with new metrics and check for anomalies. Args: reward: Current reward value loss: Current loss value gradient_norm: Current gradient norm Returns: List of detected anomalies """ anomalies = [] # Update windows if reward is not None: self.reward_window.append(reward) if loss is not None: self.loss_window.append(loss) if gradient_norm is not None: self.gradient_window.append(gradient_norm) # Check for anomalies if len(self.reward_window) >= self.window_size: reward_anomaly = self.detect_reward_collapse() if reward_anomaly: anomalies.append(reward_anomaly) if len(self.gradient_window) >= 3: # Need fewer samples for gradient check gradient_anomaly = self.detect_gradient_explosion() if gradient_anomaly: anomalies.append(gradient_anomaly) if len(self.loss_window) >= self.window_size: loss_anomaly = self.detect_loss_divergence() if loss_anomaly: anomalies.append(loss_anomaly) # Store and alert for anomaly in anomalies: self.alerts.append(anomaly) self.alert_callback( anomaly['type'], anomaly['message'], anomaly['severity'] ) return anomalies def detect_reward_collapse(self) -> Optional[Dict[str, str]]: """ Detect reward collapse (rewards stop changing). Returns: Anomaly dictionary if detected, None otherwise """ if len(self.reward_window) < self.window_size: return None rewards = list(self.reward_window) # Check if variance is very low variance = np.var(rewards) if variance < 1e-6: return { 'type': 'reward_collapse', 'message': f'Reward collapse detected: variance={variance:.2e}', 'severity': 'critical', 'details': { 'variance': variance, 'mean_reward': np.mean(rewards) } } # Check if rewards are consistently decreasing if len(rewards) >= 5: recent_trend = np.polyfit(range(len(rewards)), rewards, 1)[0] if recent_trend < -0.01: # Significant negative trend return { 'type': 'reward_decline', 'message': f'Reward declining: trend={recent_trend:.4f}', 'severity': 'warning', 'details': { 'trend': recent_trend, 'mean_reward': np.mean(rewards) } } return None def detect_gradient_explosion(self) -> Optional[Dict[str, str]]: """ Detect gradient explosion (very large gradients). Returns: Anomaly dictionary if detected, None otherwise """ if len(self.gradient_window) < 3: return None gradients = list(self.gradient_window) latest_gradient = gradients[-1] # Check for very large gradient if latest_gradient > 100.0: return { 'type': 'gradient_explosion', 'message': f'Gradient explosion detected: norm={latest_gradient:.2f}', 'severity': 'critical', 'details': { 'gradient_norm': latest_gradient, 'mean_gradient': np.mean(gradients) } } # Check for rapidly increasing gradients if len(gradients) >= 3: gradient_growth = gradients[-1] / (gradients[-3] + 1e-8) if gradient_growth > 10.0: return { 'type': 'gradient_growth', 'message': f'Rapid gradient growth: {gradient_growth:.2f}x', 'severity': 'warning', 'details': { 'growth_factor': gradient_growth, 'current_gradient': latest_gradient } } return None def detect_loss_divergence(self) -> Optional[Dict[str, str]]: """ Detect loss divergence (loss increasing or becoming NaN/Inf). Returns: Anomaly dictionary if detected, None otherwise """ if len(self.loss_window) < self.window_size: return None losses = list(self.loss_window) latest_loss = losses[-1] # Check for NaN or Inf if np.isnan(latest_loss) or np.isinf(latest_loss): return { 'type': 'loss_invalid', 'message': f'Invalid loss detected: {latest_loss}', 'severity': 'critical', 'details': { 'loss_value': str(latest_loss) } } # Check for consistently increasing loss if len(losses) >= 5: loss_trend = np.polyfit(range(len(losses)), losses, 1)[0] if loss_trend > 0.1: # Significant positive trend return { 'type': 'loss_divergence', 'message': f'Loss diverging: trend={loss_trend:.4f}', 'severity': 'warning', 'details': { 'trend': loss_trend, 'current_loss': latest_loss, 'mean_loss': np.mean(losses) } } return None def get_alerts(self) -> List[Dict[str, str]]: """ Get all alerts. Returns: List of alert dictionaries """ return self.alerts def get_recent_alerts(self, n: int = 10) -> List[Dict[str, str]]: """ Get most recent alerts. Args: n: Number of recent alerts to return Returns: List of recent alert dictionaries """ return self.alerts[-n:] def clear_alerts(self) -> None: """Clear all alerts.""" self.alerts.clear() logger.info("Alerts cleared") def get_summary(self) -> Dict[str, any]: """ Get summary of detected anomalies. Returns: Summary dictionary """ alert_types = {} for alert in self.alerts: alert_type = alert['type'] alert_types[alert_type] = alert_types.get(alert_type, 0) + 1 return { 'total_alerts': len(self.alerts), 'alert_types': alert_types, 'recent_alerts': self.get_recent_alerts(5) }