""" SMARTVISION AI - Step 2.5: Model Comparison & Selection This script: - Loads metrics.json and confusion_matrix.npy for all models. - Compares accuracy, precision, recall, F1, top-5 accuracy, speed, and model size. - Generates bar plots for metrics. - Generates confusion matrix heatmaps per model. - Selects the best model using an accuracy–speed tradeoff rule. """ import os import json import numpy as np import matplotlib.pyplot as plt # ------------------------------------------------------------ # 0. CONFIG – resolve paths relative to this file # ------------------------------------------------------------ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = os.path.dirname(SCRIPT_DIR) # one level up from scripts/ METRICS_DIR = os.path.join(ROOT_DIR, "smartvision_metrics") PLOTS_DIR = os.path.join(METRICS_DIR, "comparison_plots") os.makedirs(PLOTS_DIR, exist_ok=True) print(f"[INFO] Using METRICS_DIR = {METRICS_DIR}") print(f"[INFO] Existing subfolders in METRICS_DIR: {os.listdir(METRICS_DIR) if os.path.exists(METRICS_DIR) else 'NOT FOUND'}") # Map "pretty" model names to their metrics subdirectories MODEL_PATHS = { "VGG16" : "vgg16_v2_stage2", "ResNet50" : "resnet50_v2_stage2", "MobileNetV2" : "mobilenetv2_v2", "efficientnetb0" : "efficientnetb0", # Optional: add more models here, e.g.: # "ResNet50 v2 (Stage 1)" : "resnet50_v2_stage1", } # Class names (COCO-style 25 classes) CLASS_NAMES = [ "airplane", "bed", "bench", "bicycle", "bird", "bottle", "bowl", "bus", "cake", "car", "cat", "chair", "couch", "cow", "cup", "dog", "elephant", "horse", "motorcycle", "person", "pizza", "potted plant", "stop sign", "traffic light", "truck", ] # ------------------------------------------------------------ # 1. LOAD METRICS & CONFUSION MATRICES # ------------------------------------------------------------ def load_model_results(): model_metrics = {} model_cms = {} for nice_name, folder_name in MODEL_PATHS.items(): metrics_path = os.path.join(METRICS_DIR, folder_name, "metrics.json") cm_path = os.path.join(METRICS_DIR, folder_name, "confusion_matrix.npy") print(f"[DEBUG] Looking for {nice_name} metrics at: {metrics_path}") print(f"[DEBUG] Looking for {nice_name} CM at : {cm_path}") if not os.path.exists(metrics_path): print(f"[WARN] Skipping {nice_name}: missing {metrics_path}") continue if not os.path.exists(cm_path): print(f"[WARN] Skipping {nice_name}: missing {cm_path}") continue with open(metrics_path, "r") as f: metrics = json.load(f) cm = np.load(cm_path) model_metrics[nice_name] = metrics model_cms[nice_name] = cm print(f"[INFO] Loaded metrics & CM for {nice_name}") return model_metrics, model_cms # ------------------------------------------------------------ # 2. PLOTTING HELPERS # ------------------------------------------------------------ def plot_bar_metric(model_metrics, metric_key, ylabel, filename, higher_is_better=True): names = list(model_metrics.keys()) values = [model_metrics[n][metric_key] for n in names] plt.figure(figsize=(8, 5)) bars = plt.bar(names, values) plt.ylabel(ylabel) plt.xticks(rotation=20, ha="right") for bar, val in zip(bars, values): plt.text( bar.get_x() + bar.get_width() / 2, bar.get_height(), f"{val:.3f}", ha="center", va="bottom", fontsize=8, ) title_prefix = "Higher is better" if higher_is_better else "Lower is better" plt.title(f"{metric_key} comparison ({title_prefix})") plt.tight_layout() out_path = os.path.join(PLOTS_DIR, filename) plt.savefig(out_path, dpi=200) plt.close() print(f"[PLOT] Saved {metric_key} comparison to {out_path}") def plot_confusion_matrix(cm, classes, title, filename, normalize=True): if normalize: cm = cm.astype("float") / (cm.sum(axis=1)[:, np.newaxis] + 1e-12) plt.figure(figsize=(6, 5)) im = plt.imshow(cm, interpolation="nearest") plt.title(title) plt.colorbar(im, fraction=0.046, pad=0.04) tick_marks = np.arange(len(classes)) plt.xticks(tick_marks, classes, rotation=90) plt.yticks(tick_marks, classes) # annotate diagonal only to reduce clutter for i in range(cm.shape[0]): for j in range(cm.shape[1]): if i == j: plt.text( j, i, f"{cm[i, j]:.2f}", ha="center", va="center", color="white" if cm[i, j] > 0.5 else "black", fontsize=6, ) plt.ylabel("True label") plt.xlabel("Predicted label") plt.tight_layout() out_path = os.path.join(PLOTS_DIR, filename) plt.savefig(out_path, dpi=200) plt.close() print(f"[PLOT] Saved confusion matrix to {out_path}") # ------------------------------------------------------------ # 3. MODEL SELECTION (ACCURACY–SPEED TRADEOFF) # ------------------------------------------------------------ def pick_best_model(model_metrics): """ Rule: 1. Prefer highest accuracy. 2. If two models are within 0.5% accuracy, prefer higher images_per_second. """ best_name = None best_acc = -1.0 best_speed = -1.0 for name, m in model_metrics.items(): acc = m["accuracy"] speed = m.get("images_per_second", 0.0) if acc > best_acc + 0.005: # clearly better best_name = name best_acc = acc best_speed = speed elif abs(acc - best_acc) <= 0.005: # within 0.5%, use speed as tie-breaker if speed > best_speed: best_name = name best_acc = acc best_speed = speed return best_name, best_acc, best_speed # ------------------------------------------------------------ # 4. MAIN # ------------------------------------------------------------ def main(): model_metrics, model_cms = load_model_results() if not model_metrics: print("[ERROR] No models found with valid metrics. Check METRICS_DIR and MODEL_PATHS.") return print("\n===== MODEL METRICS SUMMARY =====") print( f"{'Model':30s} {'Acc':>6s} {'Prec':>6s} {'Rec':>6s} {'F1':>6s} {'Top5':>6s} {'img/s':>7s} {'Size(MB)':>8s}" ) for name, m in model_metrics.items(): print( f"{name:30s} " f"{m['accuracy']:.3f} " f"{m['precision_weighted']:.3f} " f"{m['recall_weighted']:.3f} " f"{m['f1_weighted']:.3f} " f"{m['top5_accuracy']:.3f} " f"{m['images_per_second']:.2f} " f"{m['model_size_mb']:.1f}" ) # ---- Comparison plots ---- plot_bar_metric(model_metrics, "accuracy", "Accuracy", "accuracy_comparison.png") plot_bar_metric( model_metrics, "f1_weighted", "Weighted F1-score", "f1_comparison.png" ) plot_bar_metric( model_metrics, "top5_accuracy", "Top-5 Accuracy", "top5_comparison.png" ) plot_bar_metric( model_metrics, "images_per_second", "Images per second", "speed_comparison.png", ) plot_bar_metric( model_metrics, "model_size_mb", "Model size (MB)", "size_comparison.png", higher_is_better=False, ) # ---- Confusion matrices ---- print("\n===== SAVING CONFUSION MATRICES =====") for name, cm in model_cms.items(): safe_name = name.replace(" ", "_").replace("(", "").replace(")", "") filename = f"{safe_name}_cm.png" plot_confusion_matrix( cm, classes=CLASS_NAMES, title=f"Confusion Matrix - {name}", filename=filename, normalize=True, ) # ---- Best model ---- best_name, best_acc, best_speed = pick_best_model(model_metrics) print("\n===== BEST MODEL SELECTION =====") print(f"Selected best model: {best_name}") print(f" Test Accuracy : {best_acc:.4f}") print(f" Images per second : {best_speed:.2f}") print("\nRationale:") print("- Highest accuracy is preferred.") print("- If models are within 0.5% accuracy, the faster model (higher img/s) is chosen.") print("\nSuggested text for report:") print( f"\"Among all evaluated architectures, {best_name} achieved the best accuracy–speed " f"tradeoff on the SmartVision AI test set, with a top-1 accuracy of {best_acc:.3f} " f"and an inference throughput of {best_speed:.2f} images per second on the " f"evaluation hardware.\"" ) if __name__ == "__main__": main()