File size: 5,345 Bytes
b33a33c
 
 
 
bd3aa04
b33a33c
bd3aa04
 
b33a33c
 
bd3aa04
b33a33c
 
 
 
 
 
 
bd3aa04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b33a33c
 
 
 
 
 
 
 
 
 
 
bd3aa04
b33a33c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd3aa04
b33a33c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd3aa04
b33a33c
 
 
 
 
 
 
bd3aa04
 
b33a33c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd3aa04
b33a33c
bd3aa04
b33a33c
bd3aa04
b33a33c
bd3aa04
b33a33c
 
bd3aa04
 
 
 
 
 
 
 
 
b33a33c
 
 
 
 
bd3aa04
b33a33c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import streamlit as st
import torch
import numpy as np
import pandas as pd
import json
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel
from huggingface_hub import hf_hub_download
import os

# --- 1. CONFIGURATION ---
ADAPTER_REPO = "jvillar-sheff/ag-news-distilbert-lora"
BASE_MODEL_ID = "distilbert-base-uncased"
CLASS_NAMES = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}

# --- 2. PAGE SETUP ---
st.set_page_config(page_title="News Classifier", page_icon="πŸ“°", layout="centered")

# --- 3. DYNAMIC METRICS LOADING ---
@st.cache_data(ttl=3600)  # Cache for 1 hour so we don't download every second
def fetch_metrics():
    """Downloads evaluation_report.json from the Model Hub."""
    try:
        file_path = hf_hub_download(repo_id=ADAPTER_REPO, filename="evaluation_report.json")
        with open(file_path, "r") as f:
            data = json.load(f)
        
        # Extract numbers (assuming your JSON structure from the notebook)
        # Adjust keys if your specific JSON structure differs slightly
        acc = data['overall_metrics']['Accuracy']
        f1 = data['overall_metrics']['F1 Macro']
        
        return {
            "Accuracy": f"{acc:.2%}",
            "F1_Score": f"{f1:.4f}"
        }
    except Exception as e:
        # Fallback if file missing
        return {"Accuracy": "N/A", "F1_Score": "N/A"}

# Load metrics immediately
MODEL_METRICS = fetch_metrics()

# --- 4. MODEL LOADING (Cached) ---
@st.cache_resource
def load_model():
    try:
        # Load Base Model
        base_model = AutoModelForSequenceClassification.from_pretrained(
            BASE_MODEL_ID,
            num_labels=len(CLASS_NAMES),
            id2label={k: v for k, v in enumerate(CLASS_NAMES.values())},
            label2id={v: k for k, v in CLASS_NAMES.items()}
        )
        
        # Load Tokenizer
        tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO)
        
        # Load LoRA Adapters
        model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
        
        # Force CPU (Standard for free Hugging Face Spaces)
        device = torch.device("cpu")
        model.to(device)
        model.eval()
        
        return model, tokenizer, device
    except Exception as e:
        st.error(f"Error loading model: {e}")
        return None, None, None

model, tokenizer, device = load_model()

# --- 5. PREDICTION FUNCTION ---
def predict(text):
    inputs = tokenizer(
        text, 
        return_tensors="pt", 
        truncation=True, 
        padding="max_length", 
        max_length=128
    ).to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    logits = outputs.logits
    probs = torch.nn.functional.softmax(logits, dim=1).squeeze().cpu().numpy()
    
    pred_idx = np.argmax(probs)
    pred_label = CLASS_NAMES[pred_idx]
    pred_conf = probs[pred_idx]
    
    class_probs = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
    
    return pred_label, pred_conf, class_probs

# --- 6. USER INTERFACE ---

st.title("πŸ“° NLP News Classifier")
st.markdown("""
This interface uses a **DistilBERT** model fine-tuned with **LoRA (Low-Rank Adaptation)**.
It classifies news text into four categories: **World, Sports, Business, or Sci/Tech**.
""")

# Dynamic Green Banner
st.success(f"βœ… **Live Model Performance:** Accuracy: {MODEL_METRICS['Accuracy']} | F1 Score: {MODEL_METRICS['F1_Score']}")

text_input = st.text_area(
    "Enter a News Article or Snippet:", 
    height=150, 
    placeholder="e.g., The stock market rallied today as tech companies reported record profits..."
)

if st.button("Classify Article", type="primary"):
    if not text_input.strip():
        st.warning("Please enter some text first.")
    else:
        with st.spinner("Analyzing..."):
            label, confidence, all_probs = predict(text_input)
            
        st.divider()
        col1, col2 = st.columns([1, 1.5])
        
        with col1:
            st.subheader("Prediction")
            st.markdown(f"<h1>{label}</h1>", unsafe_allow_html=True)
            
            # Stylized Badge
            if confidence > 0.85:
                bg_color, text_color, icon = "#1b4332", "#4ade80", "↑"
            elif confidence > 0.60:
                bg_color, text_color, icon = "#453a1f", "#facc15", "~"
            else:
                bg_color, text_color, icon = "#451a1a", "#f87171", "↓"

            st.markdown(
                f"""
                <div style="
                    background-color: {bg_color}; color: {text_color};
                    padding: 5px 15px; border-radius: 20px; display: inline-flex;
                    align-items: center; font-weight: 600; font-size: 14px;
                    border: 1px solid {text_color}40; margin-bottom: 10px;">
                    {icon} &nbsp; {confidence:.2%} Confidence
                </div>
                """,
                unsafe_allow_html=True
            )
            
        with col2:
            st.subheader("Probability Breakdown")
            df_probs = pd.DataFrame(list(all_probs.items()), columns=['Category', 'Probability'])
            st.bar_chart(df_probs.set_index('Category'))

st.markdown("---")
st.caption("Built by Joaquin Villar Urrutia | Powered by Hugging Face & Streamlit")