rajendrr commited on
Commit
a1b4641
·
verified ·
1 Parent(s): ca20429

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +256 -669
app.py CHANGED
@@ -1,693 +1,280 @@
1
- """
2
- HF Space: Normalization + Twitter Sentiment Workbench
3
- Now with built-in datasets:
4
- • Sentiment140 (HF datasets: sentiment140)
5
- • TweetEval (sentiment) (HF datasets: tweet_eval / sentiment)
6
-
7
- Tabs:
8
- • Single Text – step-by-step normalization + sentiment bar
9
- • Batch Tweets (CSV) – upload your own file
10
- • Datasets – pull Sentiment140/TweetEval, sample/filter, analyze, and benchmark
11
-
12
- Models:
13
- • VADER (fast baseline)
14
- • Twitter-RoBERTa (cardiffnlp/twitter-roberta-base-sentiment-latest)
15
-
16
- Run locally:
17
- pip install -r requirements.txt
18
- python app.py
19
- """
20
-
21
- import os
22
- import re
23
- import json
24
- from typing import List, Tuple, Optional, Dict
25
- from collections import Counter, defaultdict
26
-
27
  import gradio as gr
28
  import pandas as pd
29
  import numpy as np
30
- import matplotlib.pyplot as plt
31
-
32
- # ---- NLTK setup ----
33
- import nltk
34
- from nltk.corpus import stopwords, wordnet as wn
35
- from nltk.stem import WordNetLemmatizer
36
- from nltk.sentiment import SentimentIntensityAnalyzer
37
- from nltk.tokenize import TweetTokenizer
38
-
39
- for pkg in [
40
- "punkt", "punkt_tab", "stopwords", "wordnet", "omw-1.4",
41
- "averaged_perceptron_tagger", "averaged_perceptron_tagger_eng",
42
- "vader_lexicon"
43
- ]:
44
  try:
45
- nltk.download(pkg, quiet=True)
46
  except Exception:
47
- pass
48
-
49
- # ---- Transformers (Twitter-RoBERTa) ----
50
- TRANSFORMERS_AVAILABLE = True
51
- try:
52
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
53
- import torch
54
- import torch.nn.functional as F
55
- except Exception:
56
- TRANSFORMERS_AVAILABLE = False
57
-
58
- # ---- Hugging Face datasets (for Sentiment140 / TweetEval) ----
59
- DATASETS_AVAILABLE = True
60
- try:
61
- from datasets import load_dataset
62
- except Exception:
63
- DATASETS_AVAILABLE = False
64
-
65
- # =========================
66
- # Core text normalization
67
- # =========================
68
- _punct_re = re.compile(r"[^\w\s]", flags=re.UNICODE)
69
- _tkn = TweetTokenizer()
70
-
71
- def remove_non_ascii(words: List[str]) -> List[str]:
72
- out = []
73
- for w in words:
74
- ascii_w = "".join(ch for ch in w if ord(ch) < 128)
75
- if ascii_w:
76
- out.append(ascii_w)
77
- return out
78
-
79
- def to_lowercase(words: List[str]) -> List[str]:
80
- return [w.lower() for w in words]
81
-
82
- def remove_punctuation(words: List[str]) -> List[str]:
83
- out = []
84
- for w in words:
85
- stripped = _punct_re.sub("", w)
86
- if stripped:
87
- out.append(stripped)
88
- return out
89
-
90
- def _build_stopword_set() -> set:
91
- base = set(stopwords.words("english"))
92
- base |= {"rt","amp","https","http","t","co","u","s","us"} # twitter-ish noise
93
- stripped_variants = {_punct_re.sub("", w) for w in base}
94
- return base | stripped_variants
95
-
96
- _STOPWORDS = _build_stopword_set()
97
- _lemmatizer = WordNetLemmatizer()
98
-
99
- def _to_wordnet_pos(treebank_tag: str):
100
- if not treebank_tag:
101
- return wn.NOUN
102
- t = treebank_tag[0].upper()
103
- if t == "J": return wn.ADJ
104
- if t == "V": return wn.VERB
105
- if t == "N": return wn.NOUN
106
- if t == "R": return wn.ADV
107
- return wn.NOUN
108
-
109
- def lemmatize_list(words: List[str]) -> List[str]:
110
- try:
111
- tagged = nltk.pos_tag(words)
112
- except LookupError:
113
- tagged = [(w, "N") for w in words]
114
- return [_lemmatizer.lemmatize(w, _to_wordnet_pos(tag)) for w, tag in tagged]
115
-
116
- def tokenize(text: str) -> List[str]:
117
- return _tkn.tokenize(text)
118
-
119
- def normalize(text: str) -> str:
120
- """Full preprocessing pipeline (your original)."""
121
- words = tokenize(text)
122
- words = remove_non_ascii(words)
123
- words = to_lowercase(words)
124
- words = remove_punctuation(words)
125
- words = [w for w in words if w not in _STOPWORDS]
126
- words = lemmatize_list(words)
127
- return " ".join(words)
128
-
129
- # =========================
130
- # Twitter-aware cleaning
131
- # =========================
132
- url_re = re.compile(r"https?://\S+|www\.\S+")
133
- mention_re = re.compile(r"@\w+")
134
- hashtag_re = re.compile(r"#(\w+)")
135
- rt_re = re.compile(r"\brt\b", re.IGNORECASE)
136
- amp_re = re.compile(r"\bamp\b", re.IGNORECASE)
137
-
138
- def twitter_clean(text: str) -> str:
139
- if not text: return ""
140
- s = url_re.sub("", text)
141
- s = mention_re.sub("", s)
142
- s = hashtag_re.sub(lambda m: m.group(1), s) # keep hashtag word
143
- s = rt_re.sub("", s)
144
- s = amp_re.sub("", s)
145
- s = s.replace("U.S.", "US").replace("u.s.", "us")
146
- return re.sub(r"\s+", " ", s).strip()
147
-
148
- # =========================
149
- # Sentiment backends
150
- # =========================
151
- _sia = SentimentIntensityAnalyzer()
152
-
153
- ROBERTA_ID = "cardiffnlp/twitter-roberta-base-sentiment-latest"
154
- _roberta_tok = None
155
- _roberta_model = None
156
-
157
- def _load_roberta():
158
- global _roberta_tok, _roberta_model
159
- if not TRANSFORMERS_AVAILABLE:
160
- return False
161
- if _roberta_model is None:
162
- _roberta_tok = AutoTokenizer.from_pretrained(ROBERTA_ID)
163
- _roberta_model = AutoModelForSequenceClassification.from_pretrained(ROBERTA_ID)
164
- _roberta_model.eval()
165
- return True
166
-
167
- def vader_scores(text: str) -> Dict[str, float]:
168
- s = twitter_clean(text)
169
- sc = _sia.polarity_scores(s)
170
- return sc # keys: neg, neu, pos, compound
171
-
172
- def roberta_scores(text: str) -> Optional[Dict[str, float]]:
173
- if not _load_roberta():
174
- return None
175
- s = twitter_clean(text)
176
- inputs = _roberta_tok(s, return_tensors="pt", truncation=True, max_length=256)
177
- with torch.no_grad():
178
- logits = _roberta_model(**inputs).logits
179
- probs = F.softmax(logits, dim=1).squeeze().cpu().tolist()
180
- # Map to VADER-like schema; define compound = pos - neg
181
- return {"neg": float(probs[0]), "neu": float(probs[1]), "pos": float(probs[2]), "compound": float(probs[2] - probs[0])}
182
-
183
- def score_text(text: str, model_name: str) -> Dict[str, float]:
184
- if model_name == "Twitter-RoBERTa":
185
- sc = roberta_scores(text)
186
- if sc is not None:
187
- return sc
188
- return vader_scores(text)
189
-
190
- def label_from_compound(c: float, pos_thr: float = 0.05, neg_thr: float = -0.05) -> str:
191
- if c >= pos_thr: return "positive"
192
- if c <= neg_thr: return "negative"
193
- return "neutral"
194
-
195
- # =========================
196
- # Visual helpers (matplotlib; default colors only)
197
- # =========================
198
- def plot_sentiment_bar(scores: Dict[str, float]):
199
- fig = plt.figure(figsize=(4.8, 3.0))
200
- keys = ["neg","neu","pos","compound"]
201
- vals_adj = [scores["neg"], scores["neu"], scores["pos"], (scores["compound"] + 1) / 2]
202
- plt.bar(keys, vals_adj)
203
- plt.title("Sentiment Scores")
204
- plt.ylim(0, 1)
205
- return fig
206
 
207
- def plot_hist(vals: List[float], title: str, bins: int = 20):
208
- fig = plt.figure(figsize=(6,3))
209
- plt.hist(vals, bins=bins)
210
- plt.title(title)
211
- plt.xlabel("compound")
212
- plt.ylabel("frequency")
213
- plt.tight_layout()
214
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
- def plot_counts(labels: List[str], title: str):
217
- fig = plt.figure(figsize=(6,3))
218
- series = pd.Series(labels).value_counts().reindex(["negative","neutral","positive"]).fillna(0)
219
- plt.bar(series.index.astype(str), series.values.astype(int))
220
- plt.title(title)
221
- plt.xlabel("label")
222
- plt.ylabel("count")
223
- plt.tight_layout()
224
- return fig
225
 
226
- def plot_top_bar(pairs: List[Tuple[str,int]], title: str, rotate: int = 45):
227
- fig = plt.figure(figsize=(8,3.5))
228
- if pairs:
229
- labels, values = zip(*pairs)
230
- plt.bar(labels, values)
231
- plt.xticks(rotation=rotate, ha="right")
232
- plt.title(title)
233
- plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  return fig
235
 
236
- from wordcloud import WordCloud
237
- from PIL import Image
238
-
239
- def wordcloud_from_tokens(tokens: List[str]):
240
- text = " ".join(tokens)
241
- if not text.strip():
242
- return Image.new("RGB", (800, 400), color=(255,255,255))
243
- wc = WordCloud(width=800, height=400, background_color="white")
244
- return wc.generate(text).to_image()
245
-
246
- # =========================
247
- # Token analytics
248
- # =========================
249
- def tokens_from_texts(texts: List[str]) -> List[str]:
250
- all_toks = []
251
- for t in texts:
252
- s = twitter_clean(t)
253
- toks = tokenize(s)
254
- toks = [w.lower() for w in toks]
255
- toks = [ _punct_re.sub("", w) for w in toks ]
256
- toks = [w for w in toks if w and (w not in _STOPWORDS)]
257
- toks = [ _lemmatizer.lemmatize(w) for w in toks ]
258
- all_toks.extend(toks)
259
- return all_toks
260
-
261
- def bigrams(tokens: List[str]):
262
- return list(zip(tokens, tokens[1:]))
263
-
264
- # =========================
265
- # Aspect-based (simple window)
266
- # =========================
267
- DEFAULT_ASPECTS = ["tariff","jobs","prices","china","farmers","john", "deere"]
268
-
269
- def aspect_sentiment(texts: List[str], aspects: List[str], model_name: str, window: int = 6):
270
- out = {a.lower(): [] for a in aspects}
271
- for t in texts:
272
- clean = twitter_clean(t)
273
- toks = clean.split()
274
- for i, tok in enumerate(toks):
275
- for a in aspects:
276
- key = a.lower().split()[0]
277
- if tok.lower() == key:
278
- lo, hi = max(0, i-window), min(len(toks), i+window+1)
279
- chunk = " ".join(toks[lo:hi])
280
- sc = score_text(chunk, model_name)["compound"]
281
- out[a.lower()].append(sc)
282
- rows = []
283
- for a, vals in out.items():
284
- rows.append({
285
- "aspect": a,
286
- "n": len(vals),
287
- "mean_compound": float(np.mean(vals)) if vals else 0.0
288
- })
289
- df = pd.DataFrame(rows).sort_values(["n","mean_compound"], ascending=[False, False])
290
- return df
291
-
292
- # =========================
293
- # Topic clustering (TF-IDF + k-means)
294
- # =========================
295
- from sklearn.feature_extraction.text import TfidfVectorizer
296
- from sklearn.cluster import KMeans
297
- from sklearn.metrics import classification_report, confusion_matrix
298
-
299
- def cluster_topics(texts: List[str], n_clusters: int, model_name: str):
300
- docs = [twitter_clean(t) for t in texts]
301
- base_docs = [d for d in docs if len(d.split()) >= 3]
302
- if len(base_docs) < max(5, n_clusters):
303
- return pd.DataFrame(columns=["cluster","size","mean_compound","top_terms"]), None
304
- vec = TfidfVectorizer(max_features=4000, ngram_range=(1,2), stop_words="english")
305
- X = vec.fit_transform(base_docs)
306
- km = KMeans(n_clusters=n_clusters, n_init="auto", random_state=0)
307
- labels = km.fit_predict(X)
308
- terms = np.array(vec.get_feature_names_out())
309
- order_centroids = km.cluster_centers_.argsort()[:, ::-1]
310
- top_terms = {i: ", ".join(terms[order_centroids[i, :8]]) for i in range(n_clusters)}
311
- comp = [score_text(d, model_name)["compound"] for d in base_docs]
312
- df = pd.DataFrame({"cluster": labels, "doc": base_docs, "compound": comp})
313
- agg = df.groupby("cluster")["compound"].agg(["size","mean"]).reset_index().rename(columns={"mean":"mean_compound"})
314
- agg["top_terms"] = agg["cluster"].map(top_terms)
315
- agg = agg.sort_values("size", ascending=False)
316
- fig = plt.figure(figsize=(6,3))
317
- plt.bar(agg["cluster"].astype(str), agg["mean_compound"])
318
- plt.title("Cluster mean sentiment (compound)")
319
- plt.xlabel("cluster")
320
- plt.ylabel("mean compound")
321
- plt.tight_layout()
322
- return agg, fig
323
-
324
- # =========================
325
- # SINGLE TEXT: step-by-step
326
- # =========================
327
- def normalize_with_steps(text: str, model_name: str):
328
- if not text or not text.strip():
329
- df = pd.DataFrame([{"Step":"No input","Tokens":"[]","As Text":""}])
330
- return df, "", pd.DataFrame([{"neg":0,"neu":0,"pos":0,"compound":0}]), None
331
- steps = []
332
- tokens = tokenize(text); steps.append(("Tokenize", tokens, " ".join(tokens)))
333
- tokens = remove_non_ascii(tokens); steps.append(("Remove non-ASCII", tokens, " ".join(tokens)))
334
- tokens = to_lowercase(tokens); steps.append(("Lowercase", tokens, " ".join(tokens)))
335
- tokens = remove_punctuation(tokens); steps.append(("Remove punctuation", tokens, " ".join(tokens)))
336
- tokens = [w for w in tokens if w not in _STOPWORDS]; steps.append(("Remove stopwords", tokens, " ".join(tokens)))
337
- tokens = lemmatize_list(tokens); steps.append(("Lemmatize", tokens, " ".join(tokens)))
338
- final_text = " ".join(tokens); steps.append(("Final join", tokens, final_text))
339
- rows = [{"Step":n, "Tokens":json.dumps(t, ensure_ascii=False), "As Text":s} for n,t,s in steps]
340
- steps_df = pd.DataFrame(rows, columns=["Step","Tokens","As Text"])
341
-
342
- scores = score_text(text, model_name)
343
- sent_df = pd.DataFrame([scores])
344
- fig = plot_sentiment_bar(scores)
345
- return steps_df, final_text, sent_df, fig
346
-
347
- # =========================
348
- # ANALYSIS CORE (shared by CSV & datasets)
349
- # =========================
350
- def detect_text_column(df: pd.DataFrame) -> str:
351
- candidates = ["text","tweet","full_text","content","body"]
352
- for c in candidates:
353
- if c in df.columns: return c
354
- for c in df.columns:
355
- if df[c].dtype == object:
356
- return c
357
- return df.columns[0]
358
-
359
- def analyze_df(df_in: pd.DataFrame, model_name: str, pos_thr: float, neg_thr: float,
360
- dedup: bool, min_len: int, top_n: int, n_clusters: int,
361
- aspects_str: str, gold_series: Optional[pd.Series] = None):
362
- df = df_in.copy()
363
- text_col = detect_text_column(df)
364
- df["raw"] = df[text_col].astype(str)
365
-
366
- if dedup:
367
- df = df.drop_duplicates(subset=["raw"])
368
- df = df[df["raw"].str.split().str.len().fillna(0) >= int(min_len)].copy()
369
-
370
- # Score
371
- scs = df["raw"].apply(lambda t: score_text(t, model_name))
372
- sent_df = pd.DataFrame(list(scs))
373
- df = pd.concat([df.reset_index(drop=True), sent_df.reset_index(drop=True)], axis=1)
374
- df["label"] = df["compound"].apply(lambda c: label_from_compound(c, pos_thr, neg_thr))
375
-
376
- # Summary
377
- n = len(df)
378
- share_pos = (df["label"]=="positive").mean() if n else 0
379
- share_neu = (df["label"]=="neutral").mean() if n else 0
380
- share_neg = (df["label"]=="negative").mean() if n else 0
381
- extremes = (df["compound"].abs() >= 0.6).mean() if n else 0
382
- summary = pd.DataFrame([{
383
- "n_tweets": n,
384
- "share_positive": round(share_pos,3),
385
- "share_neutral": round(share_neu,3),
386
- "share_negative": round(share_neg,3),
387
- "share_extremes_|compound|>=0.6": round(extremes,3),
388
- "compound_mean": round(df["compound"].mean() if n else 0, 4),
389
- "compound_std": round(df["compound"].std(ddof=1) if n>1 else 0, 4),
390
- }])
391
 
392
- # Plots
393
- hist_fig = plot_hist(df["compound"].tolist(), "Distribution of compound", bins=20)
394
- count_fig = plot_counts(df["label"].tolist(), "Tweet sentiment counts")
395
-
396
- # Tokens
397
- toks = tokens_from_texts(df["raw"].tolist())
398
- top_words = Counter(toks).most_common(int(top_n))
399
- top_bi = Counter(bigrams(toks)).most_common(int(top_n))
400
- top_bi_pairs = [(" ".join([a,b]), c) for (a,b), c in top_bi]
401
- words_fig = plot_top_bar(top_words, f"Top {top_n} words", rotate=45)
402
- bigrams_fig = plot_top_bar(top_bi_pairs, f"Top {top_n} bigrams", rotate=45)
403
- wc_img = wordcloud_from_tokens(toks)
404
-
405
- # Hashtag sentiment
406
- all_rows = []
407
- for t, comp in zip(df["raw"], df["compound"]):
408
- tags = re.findall(r"#(\w+)", t)
409
- for tag in tags:
410
- all_rows.append((tag.lower(), comp))
411
- tag_map = defaultdict(list)
412
- for tag, sc in all_rows:
413
- tag_map[tag].append(sc)
414
- tag_stats = sorted([(k, len(v), float(np.mean(v))) for k, v in tag_map.items()],
415
- key=lambda x: x[1], reverse=True)[:top_n]
416
- tag_df = pd.DataFrame(tag_stats, columns=["hashtag","count","mean_compound"])
417
- tag_fig = plot_top_bar([(h, c) for h,c,_ in tag_stats], "Top hashtags (by count)", rotate=45)
418
-
419
- # Aspects
420
- aspects = [a.strip() for a in (aspects_str or "").split(",") if a.strip()] or DEFAULT_ASPECTS
421
- asp_df = aspect_sentiment(df["raw"].tolist(), aspects, model_name)
422
-
423
- # Clusters
424
- cluster_tbl, cluster_fig = cluster_topics(df["raw"].tolist(), int(n_clusters), model_name)
425
-
426
- # Evaluation vs gold labels (if provided)
427
- report_df = pd.DataFrame()
428
- cm_fig = None
429
- if gold_series is not None and len(gold_series) == len(df):
430
- y_true = gold_series.tolist()
431
- # Drop rows with unknown gold
432
- mask = pd.Series([y in {"negative","neutral","positive"} for y in y_true])
433
- y_true = pd.Series(y_true)[mask].tolist()
434
- y_pred = df["label"][mask.values].tolist()
435
- if y_true:
436
- report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
437
- report_df = pd.DataFrame(report).transpose().reset_index().rename(columns={"index":"class"})
438
- labels_order = ["negative","neutral","positive"]
439
- cm = confusion_matrix(y_true, y_pred, labels=labels_order)
440
- fig = plt.figure(figsize=(4.5,3.8))
441
- plt.imshow(cm, interpolation="nearest")
442
- plt.title("Confusion matrix")
443
- plt.xticks(range(len(labels_order)), labels_order, rotation=45, ha="right")
444
- plt.yticks(range(len(labels_order)), labels_order)
445
- for i in range(cm.shape[0]):
446
- for j in range(cm.shape[1]):
447
- plt.text(j, i, str(cm[i, j]), ha="center", va="center")
448
- plt.tight_layout()
449
- cm_fig = fig
450
-
451
- # Output file
452
- out_csv = "tweets_with_sentiment.csv"
453
- df.to_csv(out_csv, index=False)
454
-
455
- return (
456
- summary,
457
- hist_fig, count_fig,
458
- words_fig, bigrams_fig, wc_img,
459
- tag_df, tag_fig,
460
- asp_df,
461
- cluster_tbl, cluster_fig,
462
- out_csv,
463
- report_df, cm_fig
464
- )
465
 
466
- # =========================
467
- # CSV entry point (wrap analyze_df)
468
- # =========================
469
- def analyze_csv(file, model_name: str, pos_thr: float, neg_thr: float,
470
- dedup: bool, min_len: int, top_n: int, n_clusters: int,
471
- aspects_str: str):
472
- if file is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  return (
474
- gr.update(value=pd.DataFrame([{"info":"Upload a CSV file with a tweet text column."}])),
475
- None, None, None, None, None, None, None, None, None,
476
- None, # out_file
477
- pd.DataFrame(), None # report, cm_fig
478
  )
479
- df = pd.read_csv(file.name)
480
- return analyze_df(df, model_name, pos_thr, neg_thr, dedup, min_len, top_n, n_clusters, aspects_str)
481
-
482
- # =========================
483
- # DATASETS entry point
484
- # =========================
485
- def load_hf_dataset(dataset_name: str, split: str, sample_n: int, keyword: str, random_sample: bool):
486
- if not DATASETS_AVAILABLE:
487
- raise RuntimeError("The 'datasets' library is not available in this Space.")
488
- if dataset_name == "Sentiment140":
489
- # Split choices on HF are often train only; accept 'train' fallback
490
- ds = load_dataset("sentiment140", split=split or "train")
491
- df = ds.to_pandas()
492
- text_col = "text" if "text" in df.columns else detect_text_column(df)
493
- gold = None
494
- # sentiment140 labels: 0=neg, 4=pos (no neutral)
495
- if "sentiment" in df.columns:
496
- gold_map = {0: "negative", 4: "positive"}
497
- gold = df["sentiment"].map(gold_map).fillna("neutral")
498
- df = df.rename(columns={text_col: "text"})[["text"]].copy()
499
- elif dataset_name == "TweetEval (sentiment)":
500
- ds = load_dataset("tweet_eval", "sentiment", split=split or "test")
501
- df = ds.to_pandas()
502
- # labels: 0=neg, 1=neu, 2=pos
503
- label_map = {0:"negative", 1:"neutral", 2:"positive"}
504
- gold = df["label"].map(label_map)
505
- df = df.rename(columns={"text": "text"})[["text"]].copy()
506
- else:
507
- raise ValueError("Unknown dataset.")
508
- if keyword:
509
- df = df[df["text"].str.contains(keyword, case=False, na=False)]
510
- if gold is not None:
511
- gold = gold.loc[df.index]
512
- if sample_n and sample_n > 0 and sample_n < len(df):
513
- if random_sample:
514
- df = df.sample(n=sample_n, random_state=0)
515
- else:
516
- df = df.head(sample_n)
517
- if gold is not None:
518
- gold = gold.loc[df.index]
519
- gold = gold.reset_index(drop=True) if gold is not None else None
520
- return df.reset_index(drop=True), gold
521
-
522
- def analyze_dataset(dataset_name: str, split: str, sample_n: int, keyword: str, random_sample: bool,
523
- model_name: str, pos_thr: float, neg_thr: float,
524
- dedup: bool, min_len: int, top_n: int, n_clusters: int,
525
- aspects_str: str):
526
- try:
527
- df, gold = load_hf_dataset(dataset_name, split, sample_n, keyword, random_sample)
528
- except Exception as e:
529
- msg = pd.DataFrame([{"error": str(e)}])
530
- return (msg, None, None, None, None, None, None, None, None, None, None,
531
- None, pd.DataFrame(), None)
532
- results = analyze_df(df, model_name, pos_thr, neg_thr, dedup, min_len, top_n, n_clusters, aspects_str, gold_series=gold)
533
- # Prepend a small preview table of the dataset
534
- preview = df.head(10)
535
- return (preview, *results)
536
-
537
- # =========================
538
- # UI
539
- # =========================
540
- EXAMPLES = [
541
- "Cats, DOGS!!! aren't running; they're sleeping.",
542
- "U.S. tariffs on steel & aluminum — what's next?",
543
- "This movie was absolutely amazing—loved every scene!",
544
- "Service was terrible; I’m never coming back."
545
- ]
546
 
547
- with gr.Blocks(title="Normalization + Twitter Sentiment Workbench") as demo:
548
- gr.Markdown("# 🔤 Normalization + 📊 Sentiment (Twitter) Workbench")
549
- gr.Markdown(
550
- "Switch between **VADER** and **Twitter-RoBERTa**; analyze CSVs or pull open datasets "
551
- "(*Sentiment140*, *TweetEval*). Tune thresholds, inspect tokens/hashtags/aspects, and "
552
- "benchmark against gold labels when available."
 
 
 
 
 
 
 
 
 
 
 
553
  )
554
 
555
- # ----- Single text -----
556
- with gr.Tab("Single Text"):
557
- with gr.Row():
558
- model_dd = gr.Dropdown(["VADER","Twitter-RoBERTa"], value="VADER", label="Sentiment model")
559
- inp = gr.Textbox(label="Input text", lines=5, placeholder="Type or pick an example…")
560
- gr.Examples(examples=EXAMPLES, inputs=[inp])
561
- run_btn = gr.Button("Normalize & Analyze", variant="primary")
562
- steps_out = gr.Dataframe(headers=["Step","Tokens","As Text"], label="Step-by-step", interactive=False)
563
- final_out = gr.Textbox(label="Final normalized output", interactive=False)
564
- sent_df = gr.Dataframe(label="Sentiment scores", interactive=False)
565
- sent_plot = gr.Plot(label="Sentiment (bar plot)")
566
- run_btn.click(fn=normalize_with_steps, inputs=[inp, model_dd],
567
- outputs=[steps_out, final_out, sent_df, sent_plot])
568
-
569
- # ----- Batch CSV -----
570
- with gr.Tab("Batch Tweets (CSV)"):
571
- gr.Markdown("Upload a CSV with a tweet text column (auto-detected).")
572
- with gr.Row():
573
- file_up = gr.File(file_types=[".csv"], label="Upload CSV")
574
- model_csv = gr.Dropdown(["VADER","Twitter-RoBERTa"], value="VADER", label="Model")
575
- pos_thr = gr.Slider(0.0, 0.5, value=0.05, step=0.01, label="Positive threshold (compound ≥)")
576
- neg_thr = gr.Slider(-0.5, 0.0, value=-0.05, step=0.01, label="Negative threshold (compound ≤)")
577
- with gr.Row():
578
- dedup = gr.Checkbox(value=True, label="Drop duplicate tweets")
579
- min_len = gr.Slider(0, 10, value=3, step=1, label="Min token length (filter)")
580
- top_n = gr.Slider(5, 30, value=15, step=1, label="Top-N for words/bigrams/hashtags")
581
- n_clusters = gr.Slider(2, 8, value=4, step=1, label="Topic clusters (k-means)")
582
- aspects = gr.Textbox(value="tariff, jobs, prices, china, farmers, john deere",
583
- label="Aspects (comma-separated)")
584
- go = gr.Button("Analyze CSV", variant="primary")
585
-
586
- summary_table = gr.Dataframe(label="Summary", interactive=False)
587
- hist_fig = gr.Plot(label="Distribution of compound")
588
- count_fig = gr.Plot(label="Sentiment counts")
589
- with gr.Row():
590
- words_fig = gr.Plot(label="Top words")
591
- bigrams_fig = gr.Plot(label="Top bigrams")
592
- wc_img = gr.Image(label="Word cloud", type="pil")
593
- with gr.Row():
594
- tag_df = gr.Dataframe(label="Hashtag sentiment (count & mean compound)", interactive=False)
595
- tag_fig = gr.Plot(label="Top hashtags (by count)")
596
- asp_df = gr.Dataframe(label="Aspect sentiment (windowed)", interactive=False)
597
- with gr.Row():
598
- cluster_tbl = gr.Dataframe(label="Topic clusters (size & mean compound + top terms)", interactive=False)
599
- cluster_fig = gr.Plot(label="Cluster mean sentiment")
600
- out_file = gr.File(label="Download augmented CSV")
601
- report_df = gr.Dataframe(label="Benchmark vs gold labels (if present)", interactive=False)
602
- cm_plot = gr.Plot(label="Confusion matrix (if gold labels present)")
603
-
604
- go.click(
605
- fn=analyze_csv,
606
- inputs=[file_up, model_csv, pos_thr, neg_thr, dedup, min_len, top_n, n_clusters, aspects],
607
- outputs=[
608
- summary_table,
609
- hist_fig, count_fig,
610
- words_fig, bigrams_fig, wc_img,
611
- tag_df, tag_fig,
612
- asp_df,
613
- cluster_tbl, cluster_fig,
614
- out_file,
615
- report_df, cm_plot
616
- ],
617
- show_progress=True
618
- )
619
 
620
- # ----- Datasets -----
621
- with gr.Tab("Datasets (Sentiment140 / TweetEval)"):
622
- gr.Markdown(
623
- "Download open tweet datasets—no account required. Optionally filter by keyword and sample size, "
624
- "then analyze and (when available) benchmark against gold labels."
625
- )
626
- with gr.Row():
627
- ds_name = gr.Dropdown(
628
- ["Sentiment140", "TweetEval (sentiment)"],
629
- value="TweetEval (sentiment)",
630
- label="Dataset"
631
- )
632
- ds_split = gr.Textbox(value="test", label="Split (e.g., train / validation / test)",)
633
- sample_n = gr.Slider(0, 20000, value=2000, step=100, label="Sample size (0 = all)")
634
- keyword = gr.Textbox(value="", label="Keyword filter (optional)")
635
- rnd = gr.Checkbox(value=True, label="Random sample")
636
- with gr.Row():
637
- model_ds = gr.Dropdown(["VADER","Twitter-RoBERTa"], value="VADER", label="Model")
638
- pos_thr_ds = gr.Slider(0.0, 0.5, value=0.05, step=0.01, label="Positive threshold (compound ≥)")
639
- neg_thr_ds = gr.Slider(-0.5, 0.0, value=-0.05, step=0.01, label="Negative threshold (compound ≤)")
640
- with gr.Row():
641
- dedup_ds = gr.Checkbox(value=True, label="Drop duplicate tweets")
642
- min_len_ds = gr.Slider(0, 10, value=3, step=1, label="Min token length (filter)")
643
- top_n_ds = gr.Slider(5, 30, value=15, step=1, label="Top-N words/bigrams/hashtags")
644
- n_clusters_ds = gr.Slider(2, 8, value=4, step=1, label="Topic clusters (k-means)")
645
- aspects_ds = gr.Textbox(value="tariff, jobs, prices, china, farmers, john deere",
646
- label="Aspects (comma-separated)")
647
-
648
- fetch = gr.Button("Load & Analyze Dataset", variant="primary")
649
-
650
- preview = gr.Dataframe(label="Dataset preview (first rows)", interactive=False)
651
- summary_table_ds = gr.Dataframe(label="Summary", interactive=False)
652
- hist_fig_ds = gr.Plot(label="Distribution of compound")
653
- count_fig_ds = gr.Plot(label="Sentiment counts")
654
- with gr.Row():
655
- words_fig_ds = gr.Plot(label="Top words")
656
- bigrams_fig_ds = gr.Plot(label="Top bigrams")
657
- wc_img_ds = gr.Image(label="Word cloud", type="pil")
658
- with gr.Row():
659
- tag_df_ds = gr.Dataframe(label="Hashtag sentiment (count & mean compound)", interactive=False)
660
- tag_fig_ds = gr.Plot(label="Top hashtags (by count)")
661
- asp_df_ds = gr.Dataframe(label="Aspect sentiment (windowed)", interactive=False)
662
- with gr.Row():
663
- cluster_tbl_ds = gr.Dataframe(label="Topic clusters (size & mean compound + top terms)", interactive=False)
664
- cluster_fig_ds = gr.Plot(label="Cluster mean sentiment")
665
- out_file_ds = gr.File(label="Download augmented CSV")
666
- report_df_ds = gr.Dataframe(label="Benchmark vs gold labels", interactive=False)
667
- cm_plot_ds = gr.Plot(label="Confusion matrix")
668
-
669
- fetch.click(
670
- fn=analyze_dataset,
671
- inputs=[ds_name, ds_split, sample_n, keyword, rnd,
672
- model_ds, pos_thr_ds, neg_thr_ds, dedup_ds, min_len_ds, top_n_ds, n_clusters_ds, aspects_ds],
673
- outputs=[
674
- preview,
675
- summary_table_ds,
676
- hist_fig_ds, count_fig_ds,
677
- words_fig_ds, bigrams_fig_ds, wc_img_ds,
678
- tag_df_ds, tag_fig_ds,
679
- asp_df_ds,
680
- cluster_tbl_ds, cluster_fig_ds,
681
- out_file_ds,
682
- report_df_ds, cm_plot_ds
683
- ],
684
- show_progress=True
685
- )
686
 
 
 
 
 
 
 
 
687
  gr.Markdown(
688
- "> Notes: RoBERTa downloads the model on first run. For Sentiment140, gold labels are "
689
- "mapped as 0→negative, 4→positive (no neutral). TweetEval has gold labels for all three classes."
 
 
 
 
 
 
 
 
 
 
690
  )
691
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
692
  if __name__ == "__main__":
693
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  import numpy as np
4
+ import re
5
+ from typing import List, Tuple
6
+
7
+ # Lazy imports for heavy deps so the Space boots faster
8
+ from functools import lru_cache
9
+
10
+ def _lazy_imports():
11
+ global datasets, pipeline, WordCloud, plt
12
+ import matplotlib.pyplot as plt # noqa: F401
13
+ from datasets import load_dataset # noqa: F401
14
+ from transformers import pipeline as hf_pipeline # noqa: F401
 
 
 
15
  try:
16
+ from wordcloud import WordCloud # noqa: F401
17
  except Exception:
18
+ WordCloud = None
19
+ return locals()
20
+
21
+ # ----------------------------
22
+ # Helpers
23
+ # ----------------------------
24
+ TARIFF_KEYWORDS_DEFAULT = [
25
+ "tariff", "tariffs", "import tax", "trade war", "section 301", "section301",
26
+ "customs duty", "custom duties", "duties", "anti-dumping", "countervailing",
27
+ "steel tariff", "aluminum tariff", "aluminium tariff", "US tariff", "U.S. tariff",
28
+ "tariff policy", "retaliatory tariff", "tariff hike", "tariff cut"
29
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ KEYWORD_PATTERN_CACHE = {}
32
+
33
+ def compile_keyword_pattern(keywords: List[str]) -> re.Pattern:
34
+ key = "\u0001".join(sorted([k.strip().lower() for k in keywords if k.strip()]))
35
+ if key in KEYWORD_PATTERN_CACHE:
36
+ return KEYWORD_PATTERN_CACHE[key]
37
+ escaped = [re.escape(k) for k in keywords if k.strip()]
38
+ pattern = re.compile(r"(" + r"|".join(escaped) + r")", flags=re.IGNORECASE)
39
+ KEYWORD_PATTERN_CACHE[key] = pattern
40
+ return pattern
41
+
42
+
43
+ def normalize_text(s: str) -> str:
44
+ s = re.sub(r"https?://\S+", " ", s) # drop urls
45
+ s = re.sub(r"@[A-Za-z0-9_]+", " ", s) # drop @mentions
46
+ s = re.sub(r"#[A-Za-z0-9_]+", " ", s) # drop hashtags (we'll match keywords separately)
47
+ s = re.sub(r"\s+", " ", s).strip()
48
+ return s
49
+
50
+
51
+ @lru_cache(maxsize=2)
52
+ def load_sentiment_pipeline(model_name: str = "cardiffnlp/twitter-roberta-base-sentiment-latest"):
53
+ _ = _lazy_imports()
54
+ from transformers import pipeline as hf_pipeline
55
+ pipe = hf_pipeline(
56
+ task="sentiment-analysis",
57
+ model=model_name,
58
+ tokenizer=model_name,
59
+ truncation=True,
60
+ max_length=256,
61
+ return_all_scores=False,
62
+ device=-1,
63
+ )
64
+ return pipe
65
 
 
 
 
 
 
 
 
 
 
66
 
67
+ @lru_cache(maxsize=2)
68
+ def load_hf_dataset(name: str):
69
+ _ = _lazy_imports()
70
+ from datasets import load_dataset
71
+ if name == "sentiment140":
72
+ # 1.6M tweets; we'll stream and sample later
73
+ ds = load_dataset("sentiment140")
74
+ # columns: ['sentiment','ids','date','query','user','text']
75
+ return ds
76
+ elif name == "tweet_eval":
77
+ # We'll use the sentiment subset
78
+ ds = load_dataset("tweet_eval", "sentiment")
79
+ # columns: ['text','label'] where label in {0:negative,1:neutral,2:positive}
80
+ return ds
81
+ else:
82
+ raise ValueError("Unsupported dataset: " + name)
83
+
84
+
85
+ def filter_and_sample(df: pd.DataFrame, keywords: List[str], sample_size: int, random_state: int = 42) -> pd.DataFrame:
86
+ pat = compile_keyword_pattern(keywords)
87
+ mask = df['text'].str.contains(pat, na=False)
88
+ subset = df.loc[mask].copy()
89
+ if subset.empty:
90
+ return subset
91
+ if sample_size > 0 and len(subset) > sample_size:
92
+ subset = subset.sample(n=sample_size, random_state=random_state)
93
+ return subset
94
+
95
+
96
+ def run_inference(texts: List[str], batch_size: int = 64) -> List[dict]:
97
+ pipe = load_sentiment_pipeline()
98
+ results = []
99
+ for i in range(0, len(texts), batch_size):
100
+ batch = texts[i:i+batch_size]
101
+ out = pipe(batch)
102
+ # normalize labels to {positive, neutral, negative}
103
+ for o in out:
104
+ lab = o.get('label', '').lower()
105
+ if 'pos' in lab:
106
+ label = 'positive'
107
+ elif 'neg' in lab:
108
+ label = 'negative'
109
+ else:
110
+ label = 'neutral'
111
+ results.append({'label': label, 'score': float(o.get('score', 0.0))})
112
+ return results
113
+
114
+
115
+ def make_bar_plot(counts: pd.Series):
116
+ import matplotlib.pyplot as plt
117
+ fig = plt.figure(figsize=(5, 3.2), dpi=140)
118
+ ax = fig.gca()
119
+ counts = counts.reindex(['negative', 'neutral', 'positive']).fillna(0)
120
+ ax.bar(counts.index, counts.values)
121
+ total = int(counts.sum())
122
+ ax.set_title(f"Sentiment distribution (n={total})")
123
+ ax.set_xlabel("Sentiment")
124
+ ax.set_ylabel("# Tweets")
125
+ fig.tight_layout()
126
  return fig
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ def make_wordcloud(texts: List[str]):
130
+ # Optional; will return None if wordcloud isn't available
131
+ try:
132
+ from wordcloud import WordCloud
133
+ except Exception:
134
+ return None
135
+ joined = " ".join(texts)
136
+ wc = WordCloud(width=800, height=320, background_color="white").generate(joined)
137
+ import matplotlib.pyplot as plt
138
+ fig = plt.figure(figsize=(8, 3.6), dpi=120)
139
+ plt.imshow(wc)
140
+ plt.axis("off")
141
+ fig.tight_layout()
142
+ return fig
143
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ # ----------------------------
146
+ # Core pipeline
147
+ # ----------------------------
148
+
149
+ def analyze(dataset_choice: str,
150
+ keywords_csv: str,
151
+ max_rows: int,
152
+ include_wordcloud: bool) -> Tuple[str, "matplotlib.figure.Figure", "matplotlib.figure.Figure", pd.DataFrame]:
153
+ """Return (summary_markdown, bar_fig, wordcloud_fig|None, table_df)"""
154
+ ds = load_hf_dataset(dataset_choice)
155
+
156
+ # Convert to pandas
157
+ if dataset_choice == "sentiment140":
158
+ # concatenate a manageable slice from train/test (to keep runtime reasonable)
159
+ train = ds.get('train')
160
+ test = ds.get('test')
161
+ frames = []
162
+ for split in [train, test]:
163
+ if split is None:
164
+ continue
165
+ # Take a small random slice to keep Space responsive
166
+ n = len(split)
167
+ take = min(n, 150_000) # cap
168
+ frames.append(split.shuffle(seed=42).select(range(take)).to_pandas()[['text', 'date']])
169
+ df = pd.concat(frames, ignore_index=True)
170
+ else:
171
+ # tweet_eval sentiment
172
+ frames = []
173
+ for name in ['train', 'validation', 'test']:
174
+ if name in ds:
175
+ frames.append(ds[name].to_pandas()[['text']])
176
+ df = pd.concat(frames, ignore_index=True)
177
+ if 'date' not in df.columns:
178
+ df['date'] = np.nan
179
+
180
+ # Clean
181
+ df['text'] = df['text'].astype(str).apply(normalize_text)
182
+
183
+ # Keywords
184
+ keywords = [k.strip() for k in (keywords_csv or "").split(',') if k.strip()] or TARIFF_KEYWORDS_DEFAULT
185
+
186
+ # Filter + sample
187
+ subset = filter_and_sample(df, keywords, sample_size=max_rows)
188
+ if subset.empty:
189
  return (
190
+ "### No matches found\nTry broadening keywords or increasing the sample size.",
191
+ make_bar_plot(pd.Series(dtype=int)),
192
+ None,
193
+ pd.DataFrame(columns=['text','pred_label','pred_score','date'])
194
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
+ # Inference
197
+ preds = run_inference(subset['text'].tolist())
198
+ pred_df = pd.DataFrame(preds)
199
+ subset = subset.reset_index(drop=True).copy()
200
+ subset['pred_label'] = pred_df['label']
201
+ subset['pred_score'] = pred_df['score']
202
+
203
+ # Metrics
204
+ counts = subset['pred_label'].value_counts()
205
+ total = int(counts.sum())
206
+ pct = (counts / max(total, 1) * 100).round(1)
207
+
208
+ # Summary text
209
+ sentiment_line = (
210
+ f"**Negative:** {int(counts.get('negative', 0))} ({pct.get('negative', 0.0)}%) | "
211
+ f"**Neutral:** {int(counts.get('neutral', 0))} ({pct.get('neutral', 0.0)}%) | "
212
+ f"**Positive:** {int(counts.get('positive', 0))} ({pct.get('positive', 0.0)}%)"
213
  )
214
 
215
+ summary = (
216
+ "## Tariff Tweet Sentiment — Snapshot\n"
217
+ f"Dataset: **{dataset_choice}** | Sampled tweets: **{total}**\n\n"
218
+ f"Keyword filter: `{', '.join(keywords)}`\n\n"
219
+ + sentiment_line +
220
+ "\n\nTip: Neutral can be high when tweets are mostly informative (news/links) or ambiguous."
221
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
+ # Plots
224
+ bar_fig = make_bar_plot(counts)
225
+ wc_fig = make_wordcloud(subset['text'].tolist()) if include_wordcloud else None
226
+
227
+ # Output table (limit rows for UI responsiveness)
228
+ out_df = subset[['text','pred_label','pred_score','date']]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
+ return summary, bar_fig, wc_fig, out_df
231
+
232
+
233
+ # ----------------------------
234
+ # Gradio UI
235
+ # ----------------------------
236
+ with gr.Blocks(title="Tariff Tweet Sentiment (No Twitter API)") as demo:
237
  gr.Markdown(
238
+ """
239
+ # Tariff Tweet Sentiment
240
+ Analyze how people talk about **U.S. tariff policy** using public Twitter corpora (no API key required).
241
+
242
+ **How it works**
243
+ - Choose a public dataset (e.g., `sentiment140` or `tweet_eval/sentiment`).
244
+ - Filter tweets by keywords like *tariff*, *trade war*, *Section 301*, etc.
245
+ - Run a Twitter-optimized sentiment model.
246
+ - View distribution, word cloud, and the matching tweets.
247
+
248
+ *Note:* Public corpora may skew older or topical; results are a **snapshot**, not a live feed.
249
+ """
250
  )
251
 
252
+ with gr.Row():
253
+ dataset_choice = gr.Dropdown(
254
+ choices=["sentiment140", "tweet_eval"],
255
+ value="sentiment140",
256
+ label="Dataset"
257
+ )
258
+ max_rows = gr.Slider(100, 5000, value=1500, step=50, label="Max tweets to analyze (after keyword filter)")
259
+ keywords_csv = gr.Textbox(value=", ".join(TARIFF_KEYWORDS_DEFAULT), label="Keywords (comma‑separated)")
260
+ include_wordcloud = gr.Checkbox(value=True, label="Include word cloud (optional)")
261
+
262
+ run_btn = gr.Button("Run Analysis", variant="primary")
263
+
264
+ summary_md = gr.Markdown()
265
+ bar_plot = gr.Plot(label="Sentiment distribution")
266
+ wc_plot = gr.Plot(label="Word cloud (optional)")
267
+ table = gr.Dataframe(headers=["text","pred_label","pred_score","date"], wrap=True, interactive=False)
268
+ csv = gr.File(label="Download CSV of results", visible=True)
269
+
270
+ def _go(dataset_choice, keywords_csv, max_rows, include_wordcloud):
271
+ summary, bar_fig, wc_fig, df = analyze(dataset_choice, keywords_csv, int(max_rows), bool(include_wordcloud))
272
+ # Save CSV
273
+ out_path = "tariff_tweets_sentiment.csv"
274
+ df.to_csv(out_path, index=False)
275
+ return summary, bar_fig, wc_fig, df, out_path
276
+
277
+ run_btn.click(_go, [dataset_choice, keywords_csv, max_rows, include_wordcloud], [summary_md, bar_plot, wc_plot, table, csv])
278
+
279
  if __name__ == "__main__":
280
  demo.launch()