Joaquin Villar commited on
Commit
b33a33c
·
verified ·
1 Parent(s): 735af62

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+ from peft import PeftModel, PeftConfig
7
+ import os
8
+
9
+ # --- 1. CONFIGURATION & METRICS ---
10
+ # These match the final results from your notebook
11
+ MODEL_METRICS = {
12
+ "Accuracy": "89.20%",
13
+ "F1_Score": "0.8931"
14
+ }
15
+
16
+ # Your Hugging Face Model Repository
17
+ ADAPTER_REPO = "jvillar-sheff/ag-news-distilbert-lora"
18
+ BASE_MODEL_ID = "distilbert-base-uncased"
19
+
20
+ CLASS_NAMES = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
21
+
22
+ # --- 2. PAGE SETUP ---
23
+ st.set_page_config(page_title="News Classifier", page_icon="📰", layout="centered")
24
+
25
+ # --- 3. MODEL LOADING (Cached) ---
26
+ # @st.cache_resource ensures the model loads only once, making the app fast
27
+ @st.cache_resource
28
+ def load_model():
29
+ try:
30
+ # Load Base Model
31
+ base_model = AutoModelForSequenceClassification.from_pretrained(
32
+ BASE_MODEL_ID,
33
+ num_labels=len(CLASS_NAMES),
34
+ id2label={k: v for k, v in enumerate(CLASS_NAMES.values())},
35
+ label2id={v: k for k, v in CLASS_NAMES.items()}
36
+ )
37
+
38
+ # Load Tokenizer (from your repo to ensure consistency)
39
+ tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO)
40
+
41
+ # Load LoRA Adapters
42
+ model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
43
+
44
+ # Force CPU (Standard for free Hugging Face Spaces)
45
+ device = torch.device("cpu")
46
+ model.to(device)
47
+ model.eval()
48
+
49
+ return model, tokenizer, device
50
+ except Exception as e:
51
+ st.error(f"Error loading model: {e}")
52
+ return None, None, None
53
+
54
+ # Initialize model
55
+ model, tokenizer, device = load_model()
56
+
57
+ # --- 4. PREDICTION FUNCTION ---
58
+ def predict(text):
59
+ # Preprocess text
60
+ inputs = tokenizer(
61
+ text,
62
+ return_tensors="pt",
63
+ truncation=True,
64
+ padding="max_length",
65
+ max_length=128
66
+ ).to(device)
67
+
68
+ # Inference
69
+ with torch.no_grad():
70
+ outputs = model(**inputs)
71
+
72
+ # Calculate probabilities
73
+ logits = outputs.logits
74
+ probs = torch.nn.functional.softmax(logits, dim=1).squeeze().cpu().numpy()
75
+
76
+ # Get top prediction
77
+ pred_idx = np.argmax(probs)
78
+ pred_label = CLASS_NAMES[pred_idx]
79
+ pred_conf = probs[pred_idx]
80
+
81
+ # Format all probabilities for the chart
82
+ class_probs = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
83
+
84
+ return pred_label, pred_conf, class_probs
85
+
86
+ # --- 5. USER INTERFACE ---
87
+
88
+ # Header
89
+ st.title("📰 NLP News Classifier")
90
+ st.markdown("""
91
+ This interface uses a **DistilBERT** model fine-tuned with **LoRA (Low-Rank Adaptation)**.
92
+ It classifies news text into four categories: **World, Sports, Business, or Sci/Tech**.
93
+ """)
94
+
95
+ # Green Performance Banner
96
+ st.success(f"✅ **Model Performance (Test Set):** Accuracy: {MODEL_METRICS['Accuracy']} | F1 Score: {MODEL_METRICS['F1_Score']}")
97
+
98
+ # Input Area
99
+ text_input = st.text_area(
100
+ "Enter a News Article or Snippet:",
101
+ height=150,
102
+ placeholder="e.g., The stock market rallied today as tech companies reported record profits..."
103
+ )
104
+
105
+ # Classify Button
106
+ if st.button("Classify Article", type="primary"):
107
+ if not text_input.strip():
108
+ st.warning("Please enter some text first.")
109
+ else:
110
+ with st.spinner("Analyzing..."):
111
+ label, confidence, all_probs = predict(text_input)
112
+
113
+ # --- RESULTS SECTION ---
114
+ st.divider()
115
+
116
+ # Create two columns for layout
117
+ col1, col2 = st.columns([1, 1.5])
118
+
119
+ with col1:
120
+ st.subheader("Prediction")
121
+ # Display big label
122
+ st.markdown(f"<h1>{label}</h1>", unsafe_allow_html=True)
123
+
124
+ # Dynamic color for confidence badge
125
+ if confidence > 0.85:
126
+ badge_color = "#d4edda" # Green
127
+ text_color = "#155724"
128
+ elif confidence > 0.60:
129
+ badge_color = "#fff3cd" # Yellow/Orange
130
+ text_color = "#856404"
131
+ else:
132
+ badge_color = "#f8d7da" # Red
133
+ text_color = "#721c24"
134
+
135
+ st.markdown(
136
+ f"""<div style='background-color:{badge_color}; color:{text_color};
137
+ padding: 10px; border-radius: 5px; display: inline-block; font-weight: bold;'>
138
+ Confidence: {confidence:.2%}
139
+ </div>""",
140
+ unsafe_allow_html=True
141
+ )
142
+
143
+ with col2:
144
+ st.subheader("Probability Breakdown")
145
+ # Prepare data for chart
146
+ df_probs = pd.DataFrame(
147
+ list(all_probs.items()),
148
+ columns=['Category', 'Probability']
149
+ )
150
+ # Show bar chart
151
+ st.bar_chart(df_probs.set_index('Category'))
152
+
153
+ # Footer
154
+ st.markdown("---")
155
+ st.caption("Built by Joaquin Villar Urrutia | Powered by Hugging Face & Streamlit")