SmartVision_AI / scripts /compare_models.py
yogesh-venkat's picture
Auto-deploy from GitHub Actions
08d66df verified
raw
history blame
8.92 kB
"""
SMARTVISION AI - Step 2.5: Model Comparison & Selection
This script:
- Loads metrics.json and confusion_matrix.npy for all models.
- Compares accuracy, precision, recall, F1, top-5 accuracy, speed, and model size.
- Generates bar plots for metrics.
- Generates confusion matrix heatmaps per model.
- Selects the best model using an accuracy–speed tradeoff rule.
"""
import os
import json
import numpy as np
import matplotlib.pyplot as plt
# ------------------------------------------------------------
# 0. CONFIG – resolve paths relative to this file
# ------------------------------------------------------------
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(SCRIPT_DIR) # one level up from scripts/
METRICS_DIR = os.path.join(ROOT_DIR, "smartvision_metrics")
PLOTS_DIR = os.path.join(METRICS_DIR, "comparison_plots")
os.makedirs(PLOTS_DIR, exist_ok=True)
print(f"[INFO] Using METRICS_DIR = {METRICS_DIR}")
print(f"[INFO] Existing subfolders in METRICS_DIR: {os.listdir(METRICS_DIR) if os.path.exists(METRICS_DIR) else 'NOT FOUND'}")
# Map "pretty" model names to their metrics subdirectories
MODEL_PATHS = {
"VGG16" : "vgg16_v2_stage2",
"ResNet50" : "resnet50_v2_stage2",
"MobileNetV2" : "mobilenetv2_v2",
"efficientnetb0" : "efficientnetb0",
# Optional: add more models here, e.g.:
# "ResNet50 v2 (Stage 1)" : "resnet50_v2_stage1",
}
# Class names (COCO-style 25 classes)
CLASS_NAMES = [
"airplane", "bed", "bench", "bicycle", "bird",
"bottle", "bowl", "bus", "cake", "car",
"cat", "chair", "couch", "cow", "cup",
"dog", "elephant", "horse", "motorcycle", "person",
"pizza", "potted plant", "stop sign", "traffic light", "truck",
]
# ------------------------------------------------------------
# 1. LOAD METRICS & CONFUSION MATRICES
# ------------------------------------------------------------
def load_model_results():
model_metrics = {}
model_cms = {}
for nice_name, folder_name in MODEL_PATHS.items():
metrics_path = os.path.join(METRICS_DIR, folder_name, "metrics.json")
cm_path = os.path.join(METRICS_DIR, folder_name, "confusion_matrix.npy")
print(f"[DEBUG] Looking for {nice_name} metrics at: {metrics_path}")
print(f"[DEBUG] Looking for {nice_name} CM at : {cm_path}")
if not os.path.exists(metrics_path):
print(f"[WARN] Skipping {nice_name}: missing {metrics_path}")
continue
if not os.path.exists(cm_path):
print(f"[WARN] Skipping {nice_name}: missing {cm_path}")
continue
with open(metrics_path, "r") as f:
metrics = json.load(f)
cm = np.load(cm_path)
model_metrics[nice_name] = metrics
model_cms[nice_name] = cm
print(f"[INFO] Loaded metrics & CM for {nice_name}")
return model_metrics, model_cms
# ------------------------------------------------------------
# 2. PLOTTING HELPERS
# ------------------------------------------------------------
def plot_bar_metric(model_metrics, metric_key, ylabel, filename, higher_is_better=True):
names = list(model_metrics.keys())
values = [model_metrics[n][metric_key] for n in names]
plt.figure(figsize=(8, 5))
bars = plt.bar(names, values)
plt.ylabel(ylabel)
plt.xticks(rotation=20, ha="right")
for bar, val in zip(bars, values):
plt.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height(),
f"{val:.3f}",
ha="center",
va="bottom",
fontsize=8,
)
title_prefix = "Higher is better" if higher_is_better else "Lower is better"
plt.title(f"{metric_key} comparison ({title_prefix})")
plt.tight_layout()
out_path = os.path.join(PLOTS_DIR, filename)
plt.savefig(out_path, dpi=200)
plt.close()
print(f"[PLOT] Saved {metric_key} comparison to {out_path}")
def plot_confusion_matrix(cm, classes, title, filename, normalize=True):
if normalize:
cm = cm.astype("float") / (cm.sum(axis=1)[:, np.newaxis] + 1e-12)
plt.figure(figsize=(6, 5))
im = plt.imshow(cm, interpolation="nearest")
plt.title(title)
plt.colorbar(im, fraction=0.046, pad=0.04)
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=90)
plt.yticks(tick_marks, classes)
# annotate diagonal only to reduce clutter
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
if i == j:
plt.text(
j,
i,
f"{cm[i, j]:.2f}",
ha="center",
va="center",
color="white" if cm[i, j] > 0.5 else "black",
fontsize=6,
)
plt.ylabel("True label")
plt.xlabel("Predicted label")
plt.tight_layout()
out_path = os.path.join(PLOTS_DIR, filename)
plt.savefig(out_path, dpi=200)
plt.close()
print(f"[PLOT] Saved confusion matrix to {out_path}")
# ------------------------------------------------------------
# 3. MODEL SELECTION (ACCURACY–SPEED TRADEOFF)
# ------------------------------------------------------------
def pick_best_model(model_metrics):
"""
Rule:
1. Prefer highest accuracy.
2. If two models are within 0.5% accuracy, prefer higher images_per_second.
"""
best_name = None
best_acc = -1.0
best_speed = -1.0
for name, m in model_metrics.items():
acc = m["accuracy"]
speed = m.get("images_per_second", 0.0)
if acc > best_acc + 0.005: # clearly better
best_name = name
best_acc = acc
best_speed = speed
elif abs(acc - best_acc) <= 0.005: # within 0.5%, use speed as tie-breaker
if speed > best_speed:
best_name = name
best_acc = acc
best_speed = speed
return best_name, best_acc, best_speed
# ------------------------------------------------------------
# 4. MAIN
# ------------------------------------------------------------
def main():
model_metrics, model_cms = load_model_results()
if not model_metrics:
print("[ERROR] No models found with valid metrics. Check METRICS_DIR and MODEL_PATHS.")
return
print("\n===== MODEL METRICS SUMMARY =====")
print(
f"{'Model':30s} {'Acc':>6s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s} {'Top5':>6s} {'img/s':>7s} {'Size(MB)':>8s}"
)
for name, m in model_metrics.items():
print(
f"{name:30s} "
f"{m['accuracy']:.3f} "
f"{m['precision_weighted']:.3f} "
f"{m['recall_weighted']:.3f} "
f"{m['f1_weighted']:.3f} "
f"{m['top5_accuracy']:.3f} "
f"{m['images_per_second']:.2f} "
f"{m['model_size_mb']:.1f}"
)
# ---- Comparison plots ----
plot_bar_metric(model_metrics, "accuracy", "Accuracy", "accuracy_comparison.png")
plot_bar_metric(
model_metrics, "f1_weighted", "Weighted F1-score", "f1_comparison.png"
)
plot_bar_metric(
model_metrics, "top5_accuracy", "Top-5 Accuracy", "top5_comparison.png"
)
plot_bar_metric(
model_metrics,
"images_per_second",
"Images per second",
"speed_comparison.png",
)
plot_bar_metric(
model_metrics,
"model_size_mb",
"Model size (MB)",
"size_comparison.png",
higher_is_better=False,
)
# ---- Confusion matrices ----
print("\n===== SAVING CONFUSION MATRICES =====")
for name, cm in model_cms.items():
safe_name = name.replace(" ", "_").replace("(", "").replace(")", "")
filename = f"{safe_name}_cm.png"
plot_confusion_matrix(
cm,
classes=CLASS_NAMES,
title=f"Confusion Matrix - {name}",
filename=filename,
normalize=True,
)
# ---- Best model ----
best_name, best_acc, best_speed = pick_best_model(model_metrics)
print("\n===== BEST MODEL SELECTION =====")
print(f"Selected best model: {best_name}")
print(f" Test Accuracy : {best_acc:.4f}")
print(f" Images per second : {best_speed:.2f}")
print("\nRationale:")
print("- Highest accuracy is preferred.")
print("- If models are within 0.5% accuracy, the faster model (higher img/s) is chosen.")
print("\nSuggested text for report:")
print(
f"\"Among all evaluated architectures, {best_name} achieved the best accuracy–speed "
f"tradeoff on the SmartVision AI test set, with a top-1 accuracy of {best_acc:.3f} "
f"and an inference throughput of {best_speed:.2f} images per second on the "
f"evaluation hardware.\""
)
if __name__ == "__main__":
main()