|
|
""" |
|
|
Utility functions for training and evaluation |
|
|
""" |
|
|
import numpy as np |
|
|
from sklearn.metrics import ( |
|
|
accuracy_score, |
|
|
precision_recall_fscore_support, |
|
|
confusion_matrix, |
|
|
classification_report |
|
|
) |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from typing import Dict, Tuple, List, Optional |
|
|
import os |
|
|
|
|
|
|
|
|
def compute_metrics(eval_pred, id2label: Optional[Dict[int, str]] = None) -> Dict[str, float]: |
|
|
""" |
|
|
Compute comprehensive metrics for evaluation. |
|
|
|
|
|
Args: |
|
|
eval_pred: Tuple of (predictions, labels) |
|
|
id2label: Optional mapping from label IDs to label names for per-class metrics |
|
|
|
|
|
Returns: |
|
|
Dictionary of metrics including overall and per-class metrics |
|
|
""" |
|
|
predictions, labels = eval_pred |
|
|
predictions = np.argmax(predictions, axis=1) |
|
|
|
|
|
|
|
|
accuracy = accuracy_score(labels, predictions) |
|
|
|
|
|
|
|
|
precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support( |
|
|
labels, |
|
|
predictions, |
|
|
average='weighted', |
|
|
zero_division=0 |
|
|
) |
|
|
|
|
|
|
|
|
precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support( |
|
|
labels, |
|
|
predictions, |
|
|
average='macro', |
|
|
zero_division=0 |
|
|
) |
|
|
|
|
|
|
|
|
precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support( |
|
|
labels, |
|
|
predictions, |
|
|
average='micro', |
|
|
zero_division=0 |
|
|
) |
|
|
|
|
|
metrics = { |
|
|
'accuracy': accuracy, |
|
|
'precision_weighted': precision_weighted, |
|
|
'recall_weighted': recall_weighted, |
|
|
'f1_weighted': f1_weighted, |
|
|
'precision_macro': precision_macro, |
|
|
'recall_macro': recall_macro, |
|
|
'f1_macro': f1_macro, |
|
|
'precision_micro': precision_micro, |
|
|
'recall_micro': recall_micro, |
|
|
'f1_micro': f1_micro, |
|
|
} |
|
|
|
|
|
|
|
|
if id2label is not None: |
|
|
num_classes = len(id2label) |
|
|
precision_per_class, recall_per_class, f1_per_class, support = precision_recall_fscore_support( |
|
|
labels, |
|
|
predictions, |
|
|
labels=list(range(num_classes)), |
|
|
average=None, |
|
|
zero_division=0 |
|
|
) |
|
|
|
|
|
for i in range(num_classes): |
|
|
label_name = id2label[i] |
|
|
metrics[f'precision_{label_name}'] = float(precision_per_class[i]) |
|
|
metrics[f'recall_{label_name}'] = float(recall_per_class[i]) |
|
|
metrics[f'f1_{label_name}'] = float(f1_per_class[i]) |
|
|
metrics[f'support_{label_name}'] = int(support[i]) |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
def compute_metrics_factory(id2label: Optional[Dict[int, str]] = None): |
|
|
""" |
|
|
Factory function to create compute_metrics with label mapping. |
|
|
|
|
|
Args: |
|
|
id2label: Mapping from label IDs to label names |
|
|
|
|
|
Returns: |
|
|
Function compatible with HuggingFace Trainer |
|
|
""" |
|
|
def compute_metrics_fn(eval_pred): |
|
|
return compute_metrics(eval_pred, id2label) |
|
|
|
|
|
return compute_metrics_fn |
|
|
|
|
|
|
|
|
def plot_confusion_matrix( |
|
|
y_true: np.ndarray, |
|
|
y_pred: np.ndarray, |
|
|
labels: List[str], |
|
|
save_path: str = "confusion_matrix.png", |
|
|
normalize: bool = False, |
|
|
figsize: Tuple[int, int] = (10, 8) |
|
|
) -> None: |
|
|
""" |
|
|
Plot and save confusion matrix with optional normalization. |
|
|
|
|
|
Args: |
|
|
y_true: True labels |
|
|
y_pred: Predicted labels |
|
|
labels: List of label names |
|
|
save_path: Path to save the plot |
|
|
normalize: If True, normalize confusion matrix to percentages |
|
|
figsize: Figure size (width, height) |
|
|
""" |
|
|
cm = confusion_matrix(y_true, y_pred, labels=list(range(len(labels)))) |
|
|
|
|
|
if normalize: |
|
|
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] |
|
|
fmt = '.2f' |
|
|
title = 'Normalized Confusion Matrix' |
|
|
else: |
|
|
fmt = 'd' |
|
|
title = 'Confusion Matrix' |
|
|
|
|
|
plt.figure(figsize=figsize) |
|
|
sns.heatmap( |
|
|
cm, |
|
|
annot=True, |
|
|
fmt=fmt, |
|
|
cmap='Blues', |
|
|
xticklabels=labels, |
|
|
yticklabels=labels, |
|
|
cbar_kws={'label': 'Percentage' if normalize else 'Count'} |
|
|
) |
|
|
plt.title(title, fontsize=14, fontweight='bold') |
|
|
plt.ylabel('True Label', fontsize=12) |
|
|
plt.xlabel('Predicted Label', fontsize=12) |
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True) |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
print(f"Confusion matrix saved to {save_path}") |
|
|
|
|
|
|
|
|
def print_classification_report( |
|
|
y_true: np.ndarray, |
|
|
y_pred: np.ndarray, |
|
|
labels: List[str], |
|
|
output_dict: bool = False |
|
|
) -> Optional[Dict]: |
|
|
""" |
|
|
Print detailed classification report. |
|
|
|
|
|
Args: |
|
|
y_true: True labels |
|
|
y_pred: Predicted labels |
|
|
labels: List of label names |
|
|
output_dict: If True, return report as dictionary instead of printing |
|
|
|
|
|
Returns: |
|
|
Classification report as dictionary if output_dict=True, else None |
|
|
""" |
|
|
report = classification_report( |
|
|
y_true, |
|
|
y_pred, |
|
|
target_names=labels, |
|
|
digits=4, |
|
|
output_dict=output_dict, |
|
|
zero_division=0 |
|
|
) |
|
|
|
|
|
if output_dict: |
|
|
return report |
|
|
|
|
|
print("\nClassification Report:") |
|
|
print("=" * 60) |
|
|
print(report) |
|
|
return None |
|
|
|
|
|
|
|
|
def plot_training_curves( |
|
|
train_losses: List[float], |
|
|
eval_losses: List[float], |
|
|
eval_metrics: Dict[str, List[float]], |
|
|
save_path: str = "./results/training_curves.png" |
|
|
) -> None: |
|
|
""" |
|
|
Plot training and evaluation curves. |
|
|
|
|
|
Args: |
|
|
train_losses: List of training losses per step/epoch |
|
|
eval_losses: List of evaluation losses per step/epoch |
|
|
eval_metrics: Dictionary of metric names to lists of values |
|
|
save_path: Path to save the plot |
|
|
""" |
|
|
fig, axes = plt.subplots(2, 2, figsize=(15, 10)) |
|
|
|
|
|
|
|
|
axes[0, 0].plot(train_losses, label='Train Loss', color='blue') |
|
|
axes[0, 0].plot(eval_losses, label='Eval Loss', color='red') |
|
|
axes[0, 0].set_xlabel('Step/Epoch') |
|
|
axes[0, 0].set_ylabel('Loss') |
|
|
axes[0, 0].set_title('Training and Validation Loss') |
|
|
axes[0, 0].legend() |
|
|
axes[0, 0].grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
if 'accuracy' in eval_metrics: |
|
|
axes[0, 1].plot(eval_metrics['accuracy'], label='Accuracy', color='green') |
|
|
axes[0, 1].set_xlabel('Step/Epoch') |
|
|
axes[0, 1].set_ylabel('Accuracy') |
|
|
axes[0, 1].set_title('Validation Accuracy') |
|
|
axes[0, 1].legend() |
|
|
axes[0, 1].grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
if 'f1_weighted' in eval_metrics: |
|
|
axes[1, 0].plot(eval_metrics['f1_weighted'], label='F1 (weighted)', color='purple') |
|
|
axes[1, 0].set_xlabel('Step/Epoch') |
|
|
axes[1, 0].set_ylabel('F1 Score') |
|
|
axes[1, 0].set_title('Validation F1 Score') |
|
|
axes[1, 0].legend() |
|
|
axes[1, 0].grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
if 'precision_weighted' in eval_metrics and 'recall_weighted' in eval_metrics: |
|
|
axes[1, 1].plot(eval_metrics['precision_weighted'], label='Precision', color='orange') |
|
|
axes[1, 1].plot(eval_metrics['recall_weighted'], label='Recall', color='cyan') |
|
|
axes[1, 1].set_xlabel('Step/Epoch') |
|
|
axes[1, 1].set_ylabel('Score') |
|
|
axes[1, 1].set_title('Validation Precision and Recall') |
|
|
axes[1, 1].legend() |
|
|
axes[1, 1].grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True) |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
print(f"Training curves saved to {save_path}") |
|
|
|