Snaseem2026's picture
Upload folder using huggingface_hub
4089b4a verified
"""
Utility functions for training and evaluation
"""
import numpy as np
from sklearn.metrics import (
accuracy_score,
precision_recall_fscore_support,
confusion_matrix,
classification_report
)
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, Tuple, List, Optional
import os
def compute_metrics(eval_pred, id2label: Optional[Dict[int, str]] = None) -> Dict[str, float]:
"""
Compute comprehensive metrics for evaluation.
Args:
eval_pred: Tuple of (predictions, labels)
id2label: Optional mapping from label IDs to label names for per-class metrics
Returns:
Dictionary of metrics including overall and per-class metrics
"""
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
# Overall metrics
accuracy = accuracy_score(labels, predictions)
# Weighted metrics (accounts for class imbalance)
precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
labels,
predictions,
average='weighted',
zero_division=0
)
# Macro-averaged metrics (treats all classes equally)
precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
labels,
predictions,
average='macro',
zero_division=0
)
# Micro-averaged metrics (aggregates contributions of all classes)
precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
labels,
predictions,
average='micro',
zero_division=0
)
metrics = {
'accuracy': accuracy,
'precision_weighted': precision_weighted,
'recall_weighted': recall_weighted,
'f1_weighted': f1_weighted,
'precision_macro': precision_macro,
'recall_macro': recall_macro,
'f1_macro': f1_macro,
'precision_micro': precision_micro,
'recall_micro': recall_micro,
'f1_micro': f1_micro,
}
# Per-class metrics if label mapping is provided
if id2label is not None:
num_classes = len(id2label)
precision_per_class, recall_per_class, f1_per_class, support = precision_recall_fscore_support(
labels,
predictions,
labels=list(range(num_classes)),
average=None,
zero_division=0
)
for i in range(num_classes):
label_name = id2label[i]
metrics[f'precision_{label_name}'] = float(precision_per_class[i])
metrics[f'recall_{label_name}'] = float(recall_per_class[i])
metrics[f'f1_{label_name}'] = float(f1_per_class[i])
metrics[f'support_{label_name}'] = int(support[i])
return metrics
def compute_metrics_factory(id2label: Optional[Dict[int, str]] = None):
"""
Factory function to create compute_metrics with label mapping.
Args:
id2label: Mapping from label IDs to label names
Returns:
Function compatible with HuggingFace Trainer
"""
def compute_metrics_fn(eval_pred):
return compute_metrics(eval_pred, id2label)
return compute_metrics_fn
def plot_confusion_matrix(
y_true: np.ndarray,
y_pred: np.ndarray,
labels: List[str],
save_path: str = "confusion_matrix.png",
normalize: bool = False,
figsize: Tuple[int, int] = (10, 8)
) -> None:
"""
Plot and save confusion matrix with optional normalization.
Args:
y_true: True labels
y_pred: Predicted labels
labels: List of label names
save_path: Path to save the plot
normalize: If True, normalize confusion matrix to percentages
figsize: Figure size (width, height)
"""
cm = confusion_matrix(y_true, y_pred, labels=list(range(len(labels))))
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fmt = '.2f'
title = 'Normalized Confusion Matrix'
else:
fmt = 'd'
title = 'Confusion Matrix'
plt.figure(figsize=figsize)
sns.heatmap(
cm,
annot=True,
fmt=fmt,
cmap='Blues',
xticklabels=labels,
yticklabels=labels,
cbar_kws={'label': 'Percentage' if normalize else 'Count'}
)
plt.title(title, fontsize=14, fontweight='bold')
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.tight_layout()
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"Confusion matrix saved to {save_path}")
def print_classification_report(
y_true: np.ndarray,
y_pred: np.ndarray,
labels: List[str],
output_dict: bool = False
) -> Optional[Dict]:
"""
Print detailed classification report.
Args:
y_true: True labels
y_pred: Predicted labels
labels: List of label names
output_dict: If True, return report as dictionary instead of printing
Returns:
Classification report as dictionary if output_dict=True, else None
"""
report = classification_report(
y_true,
y_pred,
target_names=labels,
digits=4,
output_dict=output_dict,
zero_division=0
)
if output_dict:
return report
print("\nClassification Report:")
print("=" * 60)
print(report)
return None
def plot_training_curves(
train_losses: List[float],
eval_losses: List[float],
eval_metrics: Dict[str, List[float]],
save_path: str = "./results/training_curves.png"
) -> None:
"""
Plot training and evaluation curves.
Args:
train_losses: List of training losses per step/epoch
eval_losses: List of evaluation losses per step/epoch
eval_metrics: Dictionary of metric names to lists of values
save_path: Path to save the plot
"""
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
# Loss curves
axes[0, 0].plot(train_losses, label='Train Loss', color='blue')
axes[0, 0].plot(eval_losses, label='Eval Loss', color='red')
axes[0, 0].set_xlabel('Step/Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
# Accuracy
if 'accuracy' in eval_metrics:
axes[0, 1].plot(eval_metrics['accuracy'], label='Accuracy', color='green')
axes[0, 1].set_xlabel('Step/Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].set_title('Validation Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
# F1 Score
if 'f1_weighted' in eval_metrics:
axes[1, 0].plot(eval_metrics['f1_weighted'], label='F1 (weighted)', color='purple')
axes[1, 0].set_xlabel('Step/Epoch')
axes[1, 0].set_ylabel('F1 Score')
axes[1, 0].set_title('Validation F1 Score')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
# Precision and Recall
if 'precision_weighted' in eval_metrics and 'recall_weighted' in eval_metrics:
axes[1, 1].plot(eval_metrics['precision_weighted'], label='Precision', color='orange')
axes[1, 1].plot(eval_metrics['recall_weighted'], label='Recall', color='cyan')
axes[1, 1].set_xlabel('Step/Epoch')
axes[1, 1].set_ylabel('Score')
axes[1, 1].set_title('Validation Precision and Recall')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)
plt.tight_layout()
os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"Training curves saved to {save_path}")