import gradio as gr import torch import numpy as np import json from transformers import AutoTokenizer, AutoModelForSequenceClassification from peft import PeftModel from huggingface_hub import hf_hub_download import os # --- 1. CONFIGURATION --- ADAPTER_REPO = "jvillar-sheff/ag-news-distilbert-lora" BASE_MODEL_ID = "distilbert-base-uncased" CLASS_NAMES = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"} # --- 2. DYNAMIC METRICS LOADING --- def fetch_metrics(): """Downloads evaluation_report.json from the Model Hub.""" try: file_path = hf_hub_download(repo_id=ADAPTER_REPO, filename="evaluation_report.json") with open(file_path, "r") as f: data = json.load(f) # Extract numbers acc = data['overall_metrics']['Accuracy'] f1 = data['overall_metrics']['F1 Macro'] return { "Accuracy": f"{acc:.2%}", "F1_Score": f"{f1:.4f}" } except Exception as e: print(f"Error loading metrics: {e}") return {"Accuracy": "N/A", "F1_Score": "N/A"} # Load metrics on app startup MODEL_METRICS = fetch_metrics() # --- 3. MODEL LOADING --- def load_model(): print("Loading Base Model...") base_model = AutoModelForSequenceClassification.from_pretrained( BASE_MODEL_ID, num_labels=len(CLASS_NAMES), id2label={k: v for k, v in enumerate(CLASS_NAMES.values())}, label2id={v: k for k, v in CLASS_NAMES.items()} ) print("Loading Tokenizer...") tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO) print("Loading Adapters...") model = PeftModel.from_pretrained(base_model, ADAPTER_REPO) # Force CPU for Free Tier Spaces device = torch.device("cpu") model.to(device) model.eval() return model, tokenizer, device model, tokenizer, device = load_model() # --- 4. PREDICTION LOGIC --- def predict(text): if not text.strip(): return None, None, None inputs = tokenizer( text, return_tensors="pt", truncation=True, padding="max_length", max_length=128 ).to(device) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = torch.nn.functional.softmax(logits, dim=1).squeeze().cpu().numpy() # 1. Get Top Label pred_idx = np.argmax(probs) pred_label = CLASS_NAMES[pred_idx] conf = float(probs[pred_idx]) # 2. Create Probability Dict for the Chart class_probs = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))} # 3. Create HTML for the "Confidence Badge" if conf > 0.85: bg_color, txt_color, icon = "#d4edda", "#155724", "↑" # Green elif conf > 0.60: bg_color, txt_color, icon = "#fff3cd", "#856404", "~" # Yellow else: bg_color, txt_color, icon = "#f8d7da", "#721c24", "↓" # Red badge_html = f"""