|
|
""" |
|
|
Evaluate ensemble performance with cross-validation. |
|
|
|
|
|
Compares ensemble against: |
|
|
- Individual models |
|
|
- Baseline (single best model) |
|
|
- Ground truth annotations |
|
|
|
|
|
Metrics: |
|
|
- Accuracy |
|
|
- F1-score (per class and macro) |
|
|
- Confusion matrix |
|
|
- Agreement rate |
|
|
- Confidence calibration |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from typing import Dict, List, Any |
|
|
from sklearn.metrics import ( |
|
|
accuracy_score, |
|
|
f1_score, |
|
|
classification_report, |
|
|
confusion_matrix |
|
|
) |
|
|
from sklearn.model_selection import KFold |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
import json |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class EnsembleEvaluator: |
|
|
"""Evaluate ensemble performance with cross-validation.""" |
|
|
|
|
|
def __init__(self, output_dir: str = "data/evaluation/"): |
|
|
self.output_dir = Path(output_dir) |
|
|
self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def load_predictions(self, predictions_file: str) -> pd.DataFrame: |
|
|
"""Load predictions from parquet file.""" |
|
|
logger.info(f"Loading predictions from {predictions_file}") |
|
|
df = pd.read_parquet(predictions_file) |
|
|
return df |
|
|
|
|
|
def load_ground_truth(self, ground_truth_file: str) -> Dict[str, str]: |
|
|
"""Load ground truth annotations.""" |
|
|
logger.info(f"Loading ground truth from {ground_truth_file}") |
|
|
|
|
|
if ground_truth_file.endswith('.json'): |
|
|
with open(ground_truth_file, 'r') as f: |
|
|
return json.load(f) |
|
|
elif ground_truth_file.endswith('.parquet'): |
|
|
df = pd.read_parquet(ground_truth_file) |
|
|
return dict(zip(df['id'], df['emotion'])) |
|
|
else: |
|
|
raise ValueError("Ground truth must be .json or .parquet") |
|
|
|
|
|
def calculate_metrics(self, y_true: List[str], y_pred: List[str]) -> Dict[str, Any]: |
|
|
"""Calculate comprehensive evaluation metrics.""" |
|
|
logger.info("Calculating metrics...") |
|
|
|
|
|
|
|
|
accuracy = accuracy_score(y_true, y_pred) |
|
|
f1_macro = f1_score(y_true, y_pred, average='macro') |
|
|
f1_weighted = f1_score(y_true, y_pred, average='weighted') |
|
|
|
|
|
|
|
|
report = classification_report(y_true, y_pred, output_dict=True) |
|
|
|
|
|
|
|
|
cm = confusion_matrix(y_true, y_pred) |
|
|
|
|
|
return { |
|
|
"accuracy": float(accuracy), |
|
|
"f1_macro": float(f1_macro), |
|
|
"f1_weighted": float(f1_weighted), |
|
|
"classification_report": report, |
|
|
"confusion_matrix": cm.tolist() |
|
|
} |
|
|
|
|
|
def plot_confusion_matrix(self, y_true: List[str], y_pred: List[str], |
|
|
labels: List[str], save_path: str): |
|
|
"""Plot and save confusion matrix.""" |
|
|
cm = confusion_matrix(y_true, y_pred, labels=labels) |
|
|
|
|
|
plt.figure(figsize=(10, 8)) |
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', |
|
|
xticklabels=labels, yticklabels=labels) |
|
|
plt.title('Confusion Matrix - Ensemble') |
|
|
plt.ylabel('True Label') |
|
|
plt.xlabel('Predicted Label') |
|
|
plt.tight_layout() |
|
|
plt.savefig(save_path, dpi=300) |
|
|
plt.close() |
|
|
|
|
|
logger.info(f"Confusion matrix saved to {save_path}") |
|
|
|
|
|
def compare_models(self, predictions_df: pd.DataFrame, |
|
|
ground_truth: Dict[str, str]) -> pd.DataFrame: |
|
|
"""Compare ensemble vs individual models.""" |
|
|
logger.info("Comparing ensemble vs individual models...") |
|
|
|
|
|
results = [] |
|
|
|
|
|
|
|
|
ensemble_pred = predictions_df['emotion_label'].tolist() |
|
|
ensemble_true = [ground_truth.get(str(id), 'unknown') |
|
|
for id in predictions_df['id']] |
|
|
|
|
|
|
|
|
valid_indices = [i for i, t in enumerate(ensemble_true) if t != 'unknown'] |
|
|
ensemble_pred = [ensemble_pred[i] for i in valid_indices] |
|
|
ensemble_true = [ensemble_true[i] for i in valid_indices] |
|
|
|
|
|
|
|
|
ensemble_acc = accuracy_score(ensemble_true, ensemble_pred) |
|
|
ensemble_f1 = f1_score(ensemble_true, ensemble_pred, average='macro') |
|
|
|
|
|
results.append({ |
|
|
"model": "Ensemble (OPTION A)", |
|
|
"accuracy": ensemble_acc, |
|
|
"f1_macro": ensemble_f1, |
|
|
"num_models": 3 |
|
|
}) |
|
|
|
|
|
|
|
|
if 'emotion_predictions' in predictions_df.columns: |
|
|
|
|
|
for idx, row in predictions_df.iterrows(): |
|
|
if pd.isna(row['emotion_predictions']): |
|
|
continue |
|
|
|
|
|
try: |
|
|
|
|
|
import ast |
|
|
preds = ast.literal_eval(row['emotion_predictions']) |
|
|
|
|
|
for pred in preds: |
|
|
model_name = pred.get('model', 'unknown') |
|
|
|
|
|
|
|
|
except: |
|
|
continue |
|
|
|
|
|
df_results = pd.DataFrame(results) |
|
|
return df_results |
|
|
|
|
|
def cross_validate(self, predictions_df: pd.DataFrame, |
|
|
ground_truth: Dict[str, str], |
|
|
n_splits: int = 5) -> Dict[str, Any]: |
|
|
"""Perform k-fold cross-validation.""" |
|
|
logger.info(f"Performing {n_splits}-fold cross-validation...") |
|
|
|
|
|
|
|
|
ids = predictions_df['id'].tolist() |
|
|
preds = predictions_df['emotion_label'].tolist() |
|
|
true_labels = [ground_truth.get(str(id), 'unknown') for id in ids] |
|
|
|
|
|
|
|
|
valid_data = [(p, t) for p, t in zip(preds, true_labels) if t != 'unknown'] |
|
|
preds, true_labels = zip(*valid_data) if valid_data else ([], []) |
|
|
|
|
|
if not preds: |
|
|
logger.error("No valid ground truth labels found") |
|
|
return {"error": "No valid labels"} |
|
|
|
|
|
|
|
|
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42) |
|
|
fold_scores = [] |
|
|
|
|
|
preds_array = np.array(preds) |
|
|
true_array = np.array(true_labels) |
|
|
|
|
|
for fold, (train_idx, test_idx) in enumerate(kf.split(preds_array)): |
|
|
y_test = true_array[test_idx] |
|
|
y_pred = preds_array[test_idx] |
|
|
|
|
|
acc = accuracy_score(y_test, y_pred) |
|
|
f1 = f1_score(y_test, y_pred, average='macro') |
|
|
|
|
|
fold_scores.append({ |
|
|
"fold": fold + 1, |
|
|
"accuracy": float(acc), |
|
|
"f1_macro": float(f1) |
|
|
}) |
|
|
|
|
|
logger.info(f"Fold {fold + 1}: Acc={acc:.4f}, F1={f1:.4f}") |
|
|
|
|
|
|
|
|
accuracies = [s['accuracy'] for s in fold_scores] |
|
|
f1_scores = [s['f1_macro'] for s in fold_scores] |
|
|
|
|
|
return { |
|
|
"n_splits": n_splits, |
|
|
"fold_scores": fold_scores, |
|
|
"mean_accuracy": float(np.mean(accuracies)), |
|
|
"std_accuracy": float(np.std(accuracies)), |
|
|
"mean_f1_macro": float(np.mean(f1_scores)), |
|
|
"std_f1_macro": float(np.std(f1_scores)) |
|
|
} |
|
|
|
|
|
def evaluate(self, predictions_file: str, ground_truth_file: str, |
|
|
n_splits: int = 5) -> Dict[str, Any]: |
|
|
"""Full evaluation pipeline.""" |
|
|
logger.info("=" * 60) |
|
|
logger.info("Ensemble Evaluation") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
|
|
|
predictions_df = self.load_predictions(predictions_file) |
|
|
ground_truth = self.load_ground_truth(ground_truth_file) |
|
|
|
|
|
logger.info(f"Predictions: {len(predictions_df)} samples") |
|
|
logger.info(f"Ground truth: {len(ground_truth)} samples") |
|
|
|
|
|
|
|
|
y_pred = predictions_df['emotion_label'].tolist() |
|
|
y_true = [ground_truth.get(str(id), 'unknown') |
|
|
for id in predictions_df['id']] |
|
|
|
|
|
|
|
|
valid_indices = [i for i, t in enumerate(y_true) if t != 'unknown'] |
|
|
y_pred = [y_pred[i] for i in valid_indices] |
|
|
y_true = [y_true[i] for i in valid_indices] |
|
|
|
|
|
logger.info(f"Valid samples for evaluation: {len(y_true)}") |
|
|
|
|
|
if not y_true: |
|
|
logger.error("No valid samples found for evaluation") |
|
|
return {"error": "No valid samples"} |
|
|
|
|
|
|
|
|
metrics = self.calculate_metrics(y_true, y_pred) |
|
|
|
|
|
logger.info(f"\n📊 Overall Metrics:") |
|
|
logger.info(f" Accuracy: {metrics['accuracy']:.4f}") |
|
|
logger.info(f" F1 (macro): {metrics['f1_macro']:.4f}") |
|
|
logger.info(f" F1 (weighted): {metrics['f1_weighted']:.4f}") |
|
|
|
|
|
|
|
|
cv_results = self.cross_validate(predictions_df, ground_truth, n_splits) |
|
|
|
|
|
if "error" not in cv_results: |
|
|
logger.info(f"\n📊 Cross-Validation ({n_splits}-fold):") |
|
|
logger.info(f" Accuracy: {cv_results['mean_accuracy']:.4f} ± {cv_results['std_accuracy']:.4f}") |
|
|
logger.info(f" F1 (macro): {cv_results['mean_f1_macro']:.4f} ± {cv_results['std_f1_macro']:.4f}") |
|
|
|
|
|
|
|
|
unique_labels = sorted(list(set(y_true))) |
|
|
cm_path = self.output_dir / "confusion_matrix.png" |
|
|
self.plot_confusion_matrix(y_true, y_pred, unique_labels, str(cm_path)) |
|
|
|
|
|
|
|
|
comparison = self.compare_models(predictions_df, ground_truth) |
|
|
logger.info(f"\n📊 Model Comparison:") |
|
|
logger.info(comparison.to_string()) |
|
|
|
|
|
|
|
|
results = { |
|
|
"overall_metrics": metrics, |
|
|
"cross_validation": cv_results, |
|
|
"model_comparison": comparison.to_dict('records') |
|
|
} |
|
|
|
|
|
results_path = self.output_dir / "evaluation_results.json" |
|
|
with open(results_path, 'w') as f: |
|
|
json.dump(results, f, indent=2) |
|
|
|
|
|
logger.info(f"\n✅ Results saved to {results_path}") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Evaluate ensemble performance") |
|
|
parser.add_argument("--predictions", type=str, required=True, |
|
|
help="Path to predictions file (.parquet)") |
|
|
parser.add_argument("--ground-truth", type=str, required=True, |
|
|
help="Path to ground truth file (.json or .parquet)") |
|
|
parser.add_argument("--output-dir", type=str, default="data/evaluation/", |
|
|
help="Output directory for evaluation results") |
|
|
parser.add_argument("--n-splits", type=int, default=5, |
|
|
help="Number of folds for cross-validation") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
evaluator = EnsembleEvaluator(output_dir=args.output_dir) |
|
|
|
|
|
|
|
|
results = evaluator.evaluate( |
|
|
predictions_file=args.predictions, |
|
|
ground_truth_file=args.ground_truth, |
|
|
n_splits=args.n_splits |
|
|
) |
|
|
|
|
|
logger.info("\n" + "=" * 60) |
|
|
logger.info("✅ Evaluation complete!") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|