File size: 5,223 Bytes
68c0f61
 
 
 
 
 
 
 
 
 
 
0bcf758
68c0f61
 
 
0bcf758
68c0f61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bcf758
68c0f61
0bcf758
68c0f61
 
 
 
 
 
 
 
0bcf758
68c0f61
 
0bcf758
68c0f61
 
 
 
 
0bcf758
68c0f61
 
 
 
0bcf758
 
 
 
 
68c0f61
 
0bcf758
 
68c0f61
 
 
 
0bcf758
68c0f61
 
 
 
 
0bcf758
 
 
 
 
 
 
 
 
 
68c0f61
0bcf758
68c0f61
0bcf758
 
68c0f61
 
0bcf758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68c0f61
 
 
0bcf758
 
68c0f61
 
0bcf758
68c0f61
 
 
 
 
0bcf758
68c0f61
0bcf758
68c0f61
0bcf758
 
 
 
 
 
68c0f61
0bcf758
68c0f61
 
 
0bcf758
68c0f61
0bcf758
68c0f61
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.metrics import accuracy_score, f1_score
import numpy as np

# --- CONFIGURATION ---
# REPLACE THIS WITH YOUR UPLOADED MODEL NAME!
MODEL_REPO = "angelperedo01/proj2" 
DATASET_NAME = "nvidia/Aegis-AI-Content-Safety-Dataset-2.0"
MAX_SAMPLES = 300  # Increased slightly since we aren't rendering the table live

def get_text_and_label(example):
    """
    Parses the NVIDIA dataset labels.
    """
    text = example.get('prompt', '')
    label = None
    
    if 'prompt_label' in example:
        raw_label = example['prompt_label']
        if isinstance(raw_label, str):
            raw_lower = raw_label.lower()
            if any(x in raw_lower for x in ['unsafe', 'harmful', 'toxic', 'attack']):
                label = 1
            elif any(x in raw_lower for x in ['safe', 'benign']):
                label = 0
            else:
                try: label = int(raw_label)
                except: label = 1 if 'unsafe' in raw_lower else 0
        else:
            label = int(raw_label)
    
    if label is None: label = 0
    return text, label

def run_evaluation(progress=gr.Progress()):
    # 1. Load Model & Data
    yield "Loading Model...", "-", "-", []
    
    try:
        tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        model.eval()
    except Exception as e:
        yield f"Error: {str(e)}", "Error", "Error", []
        return

    yield "Loading Dataset...", "-", "-", []
    try:
        ds = load_dataset(DATASET_NAME, split="test")
    except:
        ds = load_dataset(DATASET_NAME, split="train")
    
    # Shuffle and select subset
    ds = ds.shuffle(seed=42).select(range(MAX_SAMPLES))
    
    true_labels = []
    predictions = []
    
    # Store full details to filter later
    # Structure: [Status, Text, True, Pred]
    history_correct = []
    history_incorrect = []
    
    # 2. The Evaluation Loop
    # We yield updates less frequently to prevent UI flashing
    for i, item in enumerate(progress.tqdm(ds, desc="Classifying...")):
        text, true_label = get_text_and_label(item)
        true_labels.append(true_label)
        
        # Predict
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256).to(device)
        with torch.no_grad():
            logits = model(**inputs).logits
            pred = torch.argmax(logits, dim=-1).item()
            predictions.append(pred)
        
        # Store for final report
        label_map = {0: "Safe", 1: "Unsafe"}
        entry = [
            text, 
            label_map[true_label], 
            label_map[pred]
        ]
        
        if pred == true_label:
            history_correct.append(["βœ… Correct"] + entry)
        else:
            history_incorrect.append(["πŸ”΄ WRONG"] + entry)

        # Update metrics every 10 steps (Reduces flashing)
        if i % 10 == 0:
            acc = accuracy_score(true_labels, predictions)
            f1 = f1_score(true_labels, predictions, zero_division=0)
            # Yield empty list for table so it doesn't try to render anything yet
            yield f"Processed {i+1}/{MAX_SAMPLES}", f"{acc:.2%}", f"{f1:.2f}", []

    # 3. Final Compilation
    # Grab last 10 incorrect and last 10 correct
    final_display_data = []
    
    # Add header/separator logic if you want, or just mix them
    # We prioritize showing errors first
    if history_incorrect:
        final_display_data.extend(history_incorrect[-10:]) # Last 10 errors
    
    if history_correct:
        final_display_data.extend(history_correct[-10:])   # Last 10 correct
        
    final_acc = accuracy_score(true_labels, predictions)
    final_f1 = f1_score(true_labels, predictions, zero_division=0)
    
    yield "Evaluation Complete!", f"{final_acc:.2%}", f"{final_f1:.2f}", final_display_data

# --- UI LAYOUT ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(f"## πŸ›‘οΈ Model Safety Evaluation Dashboard")
    gr.Markdown(f"Testing `{MODEL_REPO}` on `{DATASET_NAME}`")
    
    with gr.Row():
        start_btn = gr.Button("▢️ Run Live Test", variant="primary")
    
    with gr.Row():
        with gr.Column():
            status_box = gr.Label(value="Ready", label="Status")
        with gr.Column():
            acc_box = gr.Label(value="-", label="Accuracy")
        with gr.Column():
            f1_box = gr.Label(value="-", label="F1 Score")

    gr.Markdown("### πŸ“ Final Report: Sample of Results")
    gr.Markdown("*(Showing last 10 Incorrect and last 10 Correct predictions)*")
    
    # Defined table but it stays empty until the end
    result_table = gr.Dataframe(
        headers=["Result", "Text Snippet", "True Label", "Predicted"], 
        datatype=["str", "str", "str", "str"],
        wrap=True
    )

    start_btn.click(
        fn=run_evaluation,
        inputs=None,
        outputs=[status_box, acc_box, f1_box, result_table]
    )

demo.queue().launch()