mbellan's picture
Initial deployment
c3efd49
"""Anomaly detection for training monitoring."""
import numpy as np
from typing import List, Dict, Optional, Callable
from collections import deque
import logging
logger = logging.getLogger(__name__)
class AnomalyDetector:
"""
Detects anomalies during training.
Monitors for reward collapse, gradient explosion, and other issues.
"""
def __init__(
self,
window_size: int = 10,
alert_callback: Optional[Callable] = None
):
"""
Initialize anomaly detector.
Args:
window_size: Size of sliding window for detection
alert_callback: Optional callback function for alerts
"""
self.window_size = window_size
self.alert_callback = alert_callback or self._default_alert
# Sliding windows for metrics
self.reward_window = deque(maxlen=window_size)
self.loss_window = deque(maxlen=window_size)
self.gradient_window = deque(maxlen=window_size)
# Alert history
self.alerts = []
logger.info(f"AnomalyDetector initialized: window_size={window_size}")
def _default_alert(self, alert_type: str, message: str, severity: str) -> None:
"""
Default alert handler.
Args:
alert_type: Type of alert
message: Alert message
severity: Severity level
"""
log_func = {
'critical': logger.critical,
'warning': logger.warning,
'info': logger.info
}.get(severity, logger.warning)
log_func(f"[{alert_type}] {message}")
def update(
self,
reward: Optional[float] = None,
loss: Optional[float] = None,
gradient_norm: Optional[float] = None
) -> List[Dict[str, str]]:
"""
Update detector with new metrics and check for anomalies.
Args:
reward: Current reward value
loss: Current loss value
gradient_norm: Current gradient norm
Returns:
List of detected anomalies
"""
anomalies = []
# Update windows
if reward is not None:
self.reward_window.append(reward)
if loss is not None:
self.loss_window.append(loss)
if gradient_norm is not None:
self.gradient_window.append(gradient_norm)
# Check for anomalies
if len(self.reward_window) >= self.window_size:
reward_anomaly = self.detect_reward_collapse()
if reward_anomaly:
anomalies.append(reward_anomaly)
if len(self.gradient_window) >= 3: # Need fewer samples for gradient check
gradient_anomaly = self.detect_gradient_explosion()
if gradient_anomaly:
anomalies.append(gradient_anomaly)
if len(self.loss_window) >= self.window_size:
loss_anomaly = self.detect_loss_divergence()
if loss_anomaly:
anomalies.append(loss_anomaly)
# Store and alert
for anomaly in anomalies:
self.alerts.append(anomaly)
self.alert_callback(
anomaly['type'],
anomaly['message'],
anomaly['severity']
)
return anomalies
def detect_reward_collapse(self) -> Optional[Dict[str, str]]:
"""
Detect reward collapse (rewards stop changing).
Returns:
Anomaly dictionary if detected, None otherwise
"""
if len(self.reward_window) < self.window_size:
return None
rewards = list(self.reward_window)
# Check if variance is very low
variance = np.var(rewards)
if variance < 1e-6:
return {
'type': 'reward_collapse',
'message': f'Reward collapse detected: variance={variance:.2e}',
'severity': 'critical',
'details': {
'variance': variance,
'mean_reward': np.mean(rewards)
}
}
# Check if rewards are consistently decreasing
if len(rewards) >= 5:
recent_trend = np.polyfit(range(len(rewards)), rewards, 1)[0]
if recent_trend < -0.01: # Significant negative trend
return {
'type': 'reward_decline',
'message': f'Reward declining: trend={recent_trend:.4f}',
'severity': 'warning',
'details': {
'trend': recent_trend,
'mean_reward': np.mean(rewards)
}
}
return None
def detect_gradient_explosion(self) -> Optional[Dict[str, str]]:
"""
Detect gradient explosion (very large gradients).
Returns:
Anomaly dictionary if detected, None otherwise
"""
if len(self.gradient_window) < 3:
return None
gradients = list(self.gradient_window)
latest_gradient = gradients[-1]
# Check for very large gradient
if latest_gradient > 100.0:
return {
'type': 'gradient_explosion',
'message': f'Gradient explosion detected: norm={latest_gradient:.2f}',
'severity': 'critical',
'details': {
'gradient_norm': latest_gradient,
'mean_gradient': np.mean(gradients)
}
}
# Check for rapidly increasing gradients
if len(gradients) >= 3:
gradient_growth = gradients[-1] / (gradients[-3] + 1e-8)
if gradient_growth > 10.0:
return {
'type': 'gradient_growth',
'message': f'Rapid gradient growth: {gradient_growth:.2f}x',
'severity': 'warning',
'details': {
'growth_factor': gradient_growth,
'current_gradient': latest_gradient
}
}
return None
def detect_loss_divergence(self) -> Optional[Dict[str, str]]:
"""
Detect loss divergence (loss increasing or becoming NaN/Inf).
Returns:
Anomaly dictionary if detected, None otherwise
"""
if len(self.loss_window) < self.window_size:
return None
losses = list(self.loss_window)
latest_loss = losses[-1]
# Check for NaN or Inf
if np.isnan(latest_loss) or np.isinf(latest_loss):
return {
'type': 'loss_invalid',
'message': f'Invalid loss detected: {latest_loss}',
'severity': 'critical',
'details': {
'loss_value': str(latest_loss)
}
}
# Check for consistently increasing loss
if len(losses) >= 5:
loss_trend = np.polyfit(range(len(losses)), losses, 1)[0]
if loss_trend > 0.1: # Significant positive trend
return {
'type': 'loss_divergence',
'message': f'Loss diverging: trend={loss_trend:.4f}',
'severity': 'warning',
'details': {
'trend': loss_trend,
'current_loss': latest_loss,
'mean_loss': np.mean(losses)
}
}
return None
def get_alerts(self) -> List[Dict[str, str]]:
"""
Get all alerts.
Returns:
List of alert dictionaries
"""
return self.alerts
def get_recent_alerts(self, n: int = 10) -> List[Dict[str, str]]:
"""
Get most recent alerts.
Args:
n: Number of recent alerts to return
Returns:
List of recent alert dictionaries
"""
return self.alerts[-n:]
def clear_alerts(self) -> None:
"""Clear all alerts."""
self.alerts.clear()
logger.info("Alerts cleared")
def get_summary(self) -> Dict[str, any]:
"""
Get summary of detected anomalies.
Returns:
Summary dictionary
"""
alert_types = {}
for alert in self.alerts:
alert_type = alert['type']
alert_types[alert_type] = alert_types.get(alert_type, 0) + 1
return {
'total_alerts': len(self.alerts),
'alert_types': alert_types,
'recent_alerts': self.get_recent_alerts(5)
}