|
|
from transformers import AutoTokenizer, AutoModel |
|
|
import kagglehub |
|
|
import numpy as np |
|
|
import os |
|
|
import pandas as pd |
|
|
import streamlit as st |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
MODEL_HANDLE = "prathabmurugan/dlgenai-emotion-classification/pyTorch/1a" |
|
|
EMOTION_LABELS = ['anger', 'fear', 'joy', 'sadness', 'surprise'] |
|
|
THRESHOLDS = np.array([0.85, 0.43, 0.21, 0.7, 0.36]) |
|
|
MAX_LEN = 100 |
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
class RobertaClassifier(nn.Module): |
|
|
def __init__(self, model_name: str, num_labels: int, dropout: float = 0.3): |
|
|
super().__init__() |
|
|
self.roberta = AutoModel.from_pretrained(model_name) |
|
|
hidden_size = self.roberta.config.hidden_size |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.classifier = nn.Linear(hidden_size, num_labels) |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
outputs = self.roberta( |
|
|
input_ids=input_ids, attention_mask=attention_mask |
|
|
) |
|
|
pooled = outputs.pooler_output |
|
|
pooled = self.dropout(pooled) |
|
|
logits = self.classifier(pooled) |
|
|
return logits |
|
|
|
|
|
|
|
|
def standardize_space(text): |
|
|
"""Normalize whitespace in text.""" |
|
|
return " ".join(str(text).split()) |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_resources(): |
|
|
status_container = st.empty() |
|
|
|
|
|
|
|
|
status_container.info( |
|
|
f"Downloading model weights from KaggleHub [{MODEL_HANDLE}]") |
|
|
try: |
|
|
model_dir = kagglehub.model_download(MODEL_HANDLE) |
|
|
model_path = os.path.join(model_dir, "roberta_best_model.pth") |
|
|
except Exception as e: |
|
|
status_container.error(f"Failed to download model [{e}]") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
status_container.info("Initializing RoBERTa architecture") |
|
|
tokenizer = AutoTokenizer.from_pretrained("roberta-base") |
|
|
model = RobertaClassifier("roberta-base", num_labels=5) |
|
|
|
|
|
|
|
|
try: |
|
|
model.load_state_dict(torch.load(model_path, map_location=DEVICE)) |
|
|
model.to(DEVICE) |
|
|
model.eval() |
|
|
except Exception as e: |
|
|
status_container.error(f"Error loading state dict [{e}]") |
|
|
st.stop() |
|
|
|
|
|
status_container.empty() |
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
def predict(texts, model, tokenizer): |
|
|
|
|
|
processed_texts = [standardize_space(t) for t in texts] |
|
|
|
|
|
|
|
|
encodings = tokenizer( |
|
|
processed_texts, |
|
|
truncation=True, |
|
|
max_length=MAX_LEN, |
|
|
padding='max_length', |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
input_ids = encodings['input_ids'].to(DEVICE) |
|
|
attention_mask = encodings['attention_mask'].to(DEVICE) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(input_ids, attention_mask) |
|
|
probs = torch.sigmoid(logits).cpu().numpy() |
|
|
|
|
|
|
|
|
preds = (probs > THRESHOLDS).astype(int) |
|
|
|
|
|
return preds, probs |
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Emotion Classifier", layout="centered") |
|
|
|
|
|
st.title("Emotion Classification") |
|
|
st.markdown( |
|
|
"This app pulls a custom fine-tuned **RoBERTa** model from Kaggle to classify text into 5 emotions.") |
|
|
|
|
|
|
|
|
model, tokenizer = load_resources() |
|
|
|
|
|
|
|
|
tab1, tab2 = st.tabs(["Single Text Inference", "Batch CSV Inference"]) |
|
|
|
|
|
with tab1: |
|
|
st.header("Test a single sentence") |
|
|
user_input = st.text_area( |
|
|
"Enter text here:", "Hello World!") |
|
|
|
|
|
if st.button("Analyze Text", type="primary"): |
|
|
if user_input.strip(): |
|
|
with st.spinner("Analyzing..."): |
|
|
preds, probs = predict([user_input], model, tokenizer) |
|
|
|
|
|
st.subheader("Results:") |
|
|
|
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
st.write("**Detected Emotions:**") |
|
|
detected = [] |
|
|
for idx, is_present in enumerate(preds[0]): |
|
|
if is_present: |
|
|
detected.append(EMOTION_LABELS[idx].capitalize()) |
|
|
|
|
|
if detected: |
|
|
for d in detected: |
|
|
st.markdown(f"### β
{d}") |
|
|
else: |
|
|
st.markdown( |
|
|
"*No specific emotion detected above thresholds.*") |
|
|
|
|
|
with col2: |
|
|
st.write("**Confidence Scores:**") |
|
|
scores_df = pd.DataFrame({ |
|
|
"Emotion": EMOTION_LABELS, |
|
|
"Score": probs[0], |
|
|
"Threshold": THRESHOLDS, |
|
|
"Detected": preds[0].astype(bool) |
|
|
}) |
|
|
|
|
|
st.dataframe( |
|
|
scores_df.style.format( |
|
|
{"Score": "{:.2%}", "Threshold": "{:.2f}"}) |
|
|
.background_gradient(subset=["Score"], cmap="Greens"), |
|
|
hide_index=True, |
|
|
use_container_width=True |
|
|
) |
|
|
else: |
|
|
st.warning("Please enter some text.") |
|
|
|
|
|
with tab2: |
|
|
st.header("Batch Process (CSV)") |
|
|
st.markdown("Upload a CSV file with a `text` and `id` column.") |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload CSV", type=["csv"]) |
|
|
|
|
|
if uploaded_file is not None: |
|
|
try: |
|
|
input_df = pd.read_csv(uploaded_file) |
|
|
if 'text' not in input_df.columns: |
|
|
st.error("CSV must have a 'text' column.") |
|
|
else: |
|
|
st.info( |
|
|
f"Loaded [{len(input_df)}] rows. Click below to start.") |
|
|
|
|
|
if st.button("Generate Predictions"): |
|
|
progress_bar = st.progress(0) |
|
|
status_text = st.empty() |
|
|
|
|
|
|
|
|
batch_size = 16 |
|
|
all_preds = [] |
|
|
texts = input_df['text'].tolist() |
|
|
|
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch_texts = texts[i:i + batch_size] |
|
|
batch_preds, _ = predict(batch_texts, model, tokenizer) |
|
|
all_preds.append(batch_preds) |
|
|
|
|
|
|
|
|
progress = min((i + batch_size) / len(texts), 1.0) |
|
|
progress_bar.progress(progress) |
|
|
status_text.text( |
|
|
f"Processed {i + len(batch_texts)}/{len(texts)} rows") |
|
|
|
|
|
|
|
|
predictions_np = np.vstack(all_preds) |
|
|
submission_df = pd.DataFrame( |
|
|
predictions_np, columns=EMOTION_LABELS, dtype=int) |
|
|
|
|
|
|
|
|
if 'id' in input_df.columns: |
|
|
final_df = pd.concat( |
|
|
[input_df[['id']], submission_df], axis=1) |
|
|
else: |
|
|
final_df = submission_df |
|
|
|
|
|
st.success("Processing complete!") |
|
|
st.dataframe(final_df.head(), use_container_width=True) |
|
|
|
|
|
|
|
|
csv = final_df.to_csv(index=False).encode('utf-8') |
|
|
st.download_button( |
|
|
label="Download Predictions CSV", |
|
|
data=csv, |
|
|
file_name="submission.csv", |
|
|
mime="text/csv" |
|
|
) |
|
|
except Exception as e: |
|
|
st.error(f"Error reading CSV: {e}") |
|
|
|