""" Utility functions for training and evaluation """ import numpy as np from sklearn.metrics import ( accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report ) import matplotlib.pyplot as plt import seaborn as sns from typing import Dict, Tuple, List, Optional import os def compute_metrics(eval_pred, id2label: Optional[Dict[int, str]] = None) -> Dict[str, float]: """ Compute comprehensive metrics for evaluation. Args: eval_pred: Tuple of (predictions, labels) id2label: Optional mapping from label IDs to label names for per-class metrics Returns: Dictionary of metrics including overall and per-class metrics """ predictions, labels = eval_pred predictions = np.argmax(predictions, axis=1) # Overall metrics accuracy = accuracy_score(labels, predictions) # Weighted metrics (accounts for class imbalance) precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support( labels, predictions, average='weighted', zero_division=0 ) # Macro-averaged metrics (treats all classes equally) precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support( labels, predictions, average='macro', zero_division=0 ) # Micro-averaged metrics (aggregates contributions of all classes) precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support( labels, predictions, average='micro', zero_division=0 ) metrics = { 'accuracy': accuracy, 'precision_weighted': precision_weighted, 'recall_weighted': recall_weighted, 'f1_weighted': f1_weighted, 'precision_macro': precision_macro, 'recall_macro': recall_macro, 'f1_macro': f1_macro, 'precision_micro': precision_micro, 'recall_micro': recall_micro, 'f1_micro': f1_micro, } # Per-class metrics if label mapping is provided if id2label is not None: num_classes = len(id2label) precision_per_class, recall_per_class, f1_per_class, support = precision_recall_fscore_support( labels, predictions, labels=list(range(num_classes)), average=None, zero_division=0 ) for i in range(num_classes): label_name = id2label[i] metrics[f'precision_{label_name}'] = float(precision_per_class[i]) metrics[f'recall_{label_name}'] = float(recall_per_class[i]) metrics[f'f1_{label_name}'] = float(f1_per_class[i]) metrics[f'support_{label_name}'] = int(support[i]) return metrics def compute_metrics_factory(id2label: Optional[Dict[int, str]] = None): """ Factory function to create compute_metrics with label mapping. Args: id2label: Mapping from label IDs to label names Returns: Function compatible with HuggingFace Trainer """ def compute_metrics_fn(eval_pred): return compute_metrics(eval_pred, id2label) return compute_metrics_fn def plot_confusion_matrix( y_true: np.ndarray, y_pred: np.ndarray, labels: List[str], save_path: str = "confusion_matrix.png", normalize: bool = False, figsize: Tuple[int, int] = (10, 8) ) -> None: """ Plot and save confusion matrix with optional normalization. Args: y_true: True labels y_pred: Predicted labels labels: List of label names save_path: Path to save the plot normalize: If True, normalize confusion matrix to percentages figsize: Figure size (width, height) """ cm = confusion_matrix(y_true, y_pred, labels=list(range(len(labels)))) if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] fmt = '.2f' title = 'Normalized Confusion Matrix' else: fmt = 'd' title = 'Confusion Matrix' plt.figure(figsize=figsize) sns.heatmap( cm, annot=True, fmt=fmt, cmap='Blues', xticklabels=labels, yticklabels=labels, cbar_kws={'label': 'Percentage' if normalize else 'Count'} ) plt.title(title, fontsize=14, fontweight='bold') plt.ylabel('True Label', fontsize=12) plt.xlabel('Predicted Label', fontsize=12) plt.tight_layout() # Create directory if it doesn't exist os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True) plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close() print(f"Confusion matrix saved to {save_path}") def print_classification_report( y_true: np.ndarray, y_pred: np.ndarray, labels: List[str], output_dict: bool = False ) -> Optional[Dict]: """ Print detailed classification report. Args: y_true: True labels y_pred: Predicted labels labels: List of label names output_dict: If True, return report as dictionary instead of printing Returns: Classification report as dictionary if output_dict=True, else None """ report = classification_report( y_true, y_pred, target_names=labels, digits=4, output_dict=output_dict, zero_division=0 ) if output_dict: return report print("\nClassification Report:") print("=" * 60) print(report) return None def plot_training_curves( train_losses: List[float], eval_losses: List[float], eval_metrics: Dict[str, List[float]], save_path: str = "./results/training_curves.png" ) -> None: """ Plot training and evaluation curves. Args: train_losses: List of training losses per step/epoch eval_losses: List of evaluation losses per step/epoch eval_metrics: Dictionary of metric names to lists of values save_path: Path to save the plot """ fig, axes = plt.subplots(2, 2, figsize=(15, 10)) # Loss curves axes[0, 0].plot(train_losses, label='Train Loss', color='blue') axes[0, 0].plot(eval_losses, label='Eval Loss', color='red') axes[0, 0].set_xlabel('Step/Epoch') axes[0, 0].set_ylabel('Loss') axes[0, 0].set_title('Training and Validation Loss') axes[0, 0].legend() axes[0, 0].grid(True, alpha=0.3) # Accuracy if 'accuracy' in eval_metrics: axes[0, 1].plot(eval_metrics['accuracy'], label='Accuracy', color='green') axes[0, 1].set_xlabel('Step/Epoch') axes[0, 1].set_ylabel('Accuracy') axes[0, 1].set_title('Validation Accuracy') axes[0, 1].legend() axes[0, 1].grid(True, alpha=0.3) # F1 Score if 'f1_weighted' in eval_metrics: axes[1, 0].plot(eval_metrics['f1_weighted'], label='F1 (weighted)', color='purple') axes[1, 0].set_xlabel('Step/Epoch') axes[1, 0].set_ylabel('F1 Score') axes[1, 0].set_title('Validation F1 Score') axes[1, 0].legend() axes[1, 0].grid(True, alpha=0.3) # Precision and Recall if 'precision_weighted' in eval_metrics and 'recall_weighted' in eval_metrics: axes[1, 1].plot(eval_metrics['precision_weighted'], label='Precision', color='orange') axes[1, 1].plot(eval_metrics['recall_weighted'], label='Recall', color='cyan') axes[1, 1].set_xlabel('Step/Epoch') axes[1, 1].set_ylabel('Score') axes[1, 1].set_title('Validation Precision and Recall') axes[1, 1].legend() axes[1, 1].grid(True, alpha=0.3) plt.tight_layout() os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True) plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close() print(f"Training curves saved to {save_path}")