jvillar02 commited on
Commit
c157987
Β·
verified Β·
1 Parent(s): dc2ff4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -18
app.py CHANGED
@@ -1,20 +1,41 @@
1
  import gradio as gr
2
  import torch
3
  import numpy as np
 
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  from peft import PeftModel
 
 
6
 
7
  # --- 1. CONFIGURATION ---
8
- MODEL_METRICS = {
9
- "Accuracy": "89.20%",
10
- "F1_Score": "0.8931"
11
- }
12
-
13
  ADAPTER_REPO = "jvillar-sheff/ag-news-distilbert-lora"
14
  BASE_MODEL_ID = "distilbert-base-uncased"
15
  CLASS_NAMES = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
16
 
17
- # --- 2. MODEL LOADING ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def load_model():
19
  print("Loading Base Model...")
20
  base_model = AutoModelForSequenceClassification.from_pretrained(
@@ -30,6 +51,7 @@ def load_model():
30
  print("Loading Adapters...")
31
  model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
32
 
 
33
  device = torch.device("cpu")
34
  model.to(device)
35
  model.eval()
@@ -37,7 +59,7 @@ def load_model():
37
 
38
  model, tokenizer, device = load_model()
39
 
40
- # --- 3. PREDICTION LOGIC ---
41
  def predict(text):
42
  if not text.strip():
43
  return None, None, None
@@ -60,35 +82,36 @@ def predict(text):
60
  # 2. Create Probability Dict for the Chart
61
  class_probs = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
62
 
63
- # 3. Create HTML for the "Confidence Badge" (Mimicking Streamlit)
64
  if conf > 0.85:
65
- bg_color, txt_color = "#d4edda", "#155724" # Green
66
  elif conf > 0.60:
67
- bg_color, txt_color = "#fff3cd", "#856404" # Yellow
68
  else:
69
- bg_color, txt_color = "#f8d7da", "#721c24" # Red
70
 
71
  badge_html = f"""
72
  <div style='background-color: {bg_color}; color: {txt_color};
73
- padding: 8px 12px; border-radius: 5px; display: inline-block; font-weight: bold; font-size: 16px;'>
74
- Confidence: {conf:.2%}
75
  </div>
76
  """
77
 
78
  # Return: Label Text, Badge HTML, Chart Data
79
  return f"# {pred_label}", badge_html, class_probs
80
 
81
- # --- 4. UI LAYOUT (gr.Blocks) ---
82
- with gr.Blocks() as demo:
 
 
83
 
84
- # Title
85
  gr.Markdown("# πŸ“° NLP News Classifier")
86
  gr.Markdown("Classify news articles into World, Sports, Business, or Sci/Tech using DistilBERT + LoRA.")
87
 
88
  # -- The "Green Banner" (HTML) --
89
  gr.HTML(f"""
90
  <div style="background-color: #d1e7dd; color: #0f5132; padding: 15px; border-radius: 5px; border: 1px solid #badbcc; margin-bottom: 20px;">
91
- βœ… <b>Model Performance:</b> Accuracy: {MODEL_METRICS['Accuracy']} | F1 Score: {MODEL_METRICS['F1_Score']}
92
  </div>
93
  """)
94
 
@@ -121,7 +144,7 @@ with gr.Blocks() as demo:
121
  out_badge = gr.HTML()
122
 
123
  gr.Markdown("### Probability Breakdown")
124
- # Output 3: Bar Chart (Label component handles this beautifully)
125
  out_chart = gr.Label(num_top_classes=4, label="Confidence Scores")
126
 
127
  # Wire up the button
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
+ import json
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  from peft import PeftModel
7
+ from huggingface_hub import hf_hub_download
8
+ import os
9
 
10
  # --- 1. CONFIGURATION ---
 
 
 
 
 
11
  ADAPTER_REPO = "jvillar-sheff/ag-news-distilbert-lora"
12
  BASE_MODEL_ID = "distilbert-base-uncased"
13
  CLASS_NAMES = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
14
 
15
+ # --- 2. DYNAMIC METRICS LOADING ---
16
+ def fetch_metrics():
17
+ """Downloads evaluation_report.json from the Model Hub."""
18
+ try:
19
+ file_path = hf_hub_download(repo_id=ADAPTER_REPO, filename="evaluation_report.json")
20
+ with open(file_path, "r") as f:
21
+ data = json.load(f)
22
+
23
+ # Extract numbers
24
+ acc = data['overall_metrics']['Accuracy']
25
+ f1 = data['overall_metrics']['F1 Macro']
26
+
27
+ return {
28
+ "Accuracy": f"{acc:.2%}",
29
+ "F1_Score": f"{f1:.4f}"
30
+ }
31
+ except Exception as e:
32
+ print(f"Error loading metrics: {e}")
33
+ return {"Accuracy": "N/A", "F1_Score": "N/A"}
34
+
35
+ # Load metrics on app startup
36
+ MODEL_METRICS = fetch_metrics()
37
+
38
+ # --- 3. MODEL LOADING ---
39
  def load_model():
40
  print("Loading Base Model...")
41
  base_model = AutoModelForSequenceClassification.from_pretrained(
 
51
  print("Loading Adapters...")
52
  model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
53
 
54
+ # Force CPU for Free Tier Spaces
55
  device = torch.device("cpu")
56
  model.to(device)
57
  model.eval()
 
59
 
60
  model, tokenizer, device = load_model()
61
 
62
+ # --- 4. PREDICTION LOGIC ---
63
  def predict(text):
64
  if not text.strip():
65
  return None, None, None
 
82
  # 2. Create Probability Dict for the Chart
83
  class_probs = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
84
 
85
+ # 3. Create HTML for the "Confidence Badge"
86
  if conf > 0.85:
87
+ bg_color, txt_color, icon = "#d4edda", "#155724", "↑" # Green
88
  elif conf > 0.60:
89
+ bg_color, txt_color, icon = "#fff3cd", "#856404", "~" # Yellow
90
  else:
91
+ bg_color, txt_color, icon = "#f8d7da", "#721c24", "↓" # Red
92
 
93
  badge_html = f"""
94
  <div style='background-color: {bg_color}; color: {txt_color};
95
+ padding: 8px 16px; border-radius: 5px; display: inline-block; font-weight: bold; font-size: 16px;'>
96
+ {icon} Confidence: {conf:.2%}
97
  </div>
98
  """
99
 
100
  # Return: Label Text, Badge HTML, Chart Data
101
  return f"# {pred_label}", badge_html, class_probs
102
 
103
+ # --- 5. UI LAYOUT (gr.Blocks) ---
104
+ # Using Soft theme (requires newer Gradio version in requirements.txt)
105
+ # If it fails, remove theme=gr.themes.Soft()
106
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
107
 
 
108
  gr.Markdown("# πŸ“° NLP News Classifier")
109
  gr.Markdown("Classify news articles into World, Sports, Business, or Sci/Tech using DistilBERT + LoRA.")
110
 
111
  # -- The "Green Banner" (HTML) --
112
  gr.HTML(f"""
113
  <div style="background-color: #d1e7dd; color: #0f5132; padding: 15px; border-radius: 5px; border: 1px solid #badbcc; margin-bottom: 20px;">
114
+ βœ… <b>Model Performance (Test Set):</b> Accuracy: {MODEL_METRICS['Accuracy']} | F1 Score: {MODEL_METRICS['F1_Score']}
115
  </div>
116
  """)
117
 
 
144
  out_badge = gr.HTML()
145
 
146
  gr.Markdown("### Probability Breakdown")
147
+ # Output 3: Bar Chart
148
  out_chart = gr.Label(num_top_classes=4, label="Confidence Scores")
149
 
150
  # Wire up the button