import gradio as gr import torch from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForSequenceClassification from sklearn.metrics import accuracy_score, f1_score import numpy as np # --- CONFIGURATION --- # REPLACE THIS WITH YOUR UPLOADED MODEL NAME! MODEL_REPO = "angelperedo01/proj2" DATASET_NAME = "nvidia/Aegis-AI-Content-Safety-Dataset-2.0" MAX_SAMPLES = 300 # Increased slightly since we aren't rendering the table live def get_text_and_label(example): """ Parses the NVIDIA dataset labels. """ text = example.get('prompt', '') label = None if 'prompt_label' in example: raw_label = example['prompt_label'] if isinstance(raw_label, str): raw_lower = raw_label.lower() if any(x in raw_lower for x in ['unsafe', 'harmful', 'toxic', 'attack']): label = 1 elif any(x in raw_lower for x in ['safe', 'benign']): label = 0 else: try: label = int(raw_label) except: label = 1 if 'unsafe' in raw_lower else 0 else: label = int(raw_label) if label is None: label = 0 return text, label def run_evaluation(progress=gr.Progress()): # 1. Load Model & Data yield "Loading Model...", "-", "-", [] try: tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() except Exception as e: yield f"Error: {str(e)}", "Error", "Error", [] return yield "Loading Dataset...", "-", "-", [] try: ds = load_dataset(DATASET_NAME, split="test") except: ds = load_dataset(DATASET_NAME, split="train") # Shuffle and select subset ds = ds.shuffle(seed=42).select(range(MAX_SAMPLES)) true_labels = [] predictions = [] # Store full details to filter later # Structure: [Status, Text, True, Pred] history_correct = [] history_incorrect = [] # 2. The Evaluation Loop # We yield updates less frequently to prevent UI flashing for i, item in enumerate(progress.tqdm(ds, desc="Classifying...")): text, true_label = get_text_and_label(item) true_labels.append(true_label) # Predict inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256).to(device) with torch.no_grad(): logits = model(**inputs).logits pred = torch.argmax(logits, dim=-1).item() predictions.append(pred) # Store for final report label_map = {0: "Safe", 1: "Unsafe"} entry = [ text, label_map[true_label], label_map[pred] ] if pred == true_label: history_correct.append(["✅ Correct"] + entry) else: history_incorrect.append(["🔴 WRONG"] + entry) # Update metrics every 10 steps (Reduces flashing) if i % 10 == 0: acc = accuracy_score(true_labels, predictions) f1 = f1_score(true_labels, predictions, zero_division=0) # Yield empty list for table so it doesn't try to render anything yet yield f"Processed {i+1}/{MAX_SAMPLES}", f"{acc:.2%}", f"{f1:.2f}", [] # 3. Final Compilation # Grab last 10 incorrect and last 10 correct final_display_data = [] # Add header/separator logic if you want, or just mix them # We prioritize showing errors first if history_incorrect: final_display_data.extend(history_incorrect[-10:]) # Last 10 errors if history_correct: final_display_data.extend(history_correct[-10:]) # Last 10 correct final_acc = accuracy_score(true_labels, predictions) final_f1 = f1_score(true_labels, predictions, zero_division=0) yield "Evaluation Complete!", f"{final_acc:.2%}", f"{final_f1:.2f}", final_display_data # --- UI LAYOUT --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown(f"## 🛡️ Model Safety Evaluation Dashboard") gr.Markdown(f"Testing `{MODEL_REPO}` on `{DATASET_NAME}`") with gr.Row(): start_btn = gr.Button("▶️ Run Live Test", variant="primary") with gr.Row(): with gr.Column(): status_box = gr.Label(value="Ready", label="Status") with gr.Column(): acc_box = gr.Label(value="-", label="Accuracy") with gr.Column(): f1_box = gr.Label(value="-", label="F1 Score") gr.Markdown("### 📝 Final Report: Sample of Results") gr.Markdown("*(Showing last 10 Incorrect and last 10 Correct predictions)*") # Defined table but it stays empty until the end result_table = gr.Dataframe( headers=["Result", "Text Snippet", "True Label", "Predicted"], datatype=["str", "str", "str", "str"], wrap=True ) start_btn.click( fn=run_evaluation, inputs=None, outputs=[status_box, acc_box, f1_box, result_table] ) demo.queue().launch()