|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import RobertaTokenizer, RobertaModel |
|
|
from datasets import load_dataset |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print("Using:", device) |
|
|
|
|
|
|
|
|
|
|
|
from datasets import load_dataset |
|
|
dataset = load_dataset("stanfordnlp/imdb") |
|
|
|
|
|
|
|
|
|
|
|
tokenizer = RobertaTokenizer.from_pretrained("roberta-base") |
|
|
dataset = load_dataset("imdb") |
|
|
|
|
|
def tokenize(batch): |
|
|
return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=128) |
|
|
|
|
|
encoded_dataset = dataset.map(tokenize, batched=True) |
|
|
encoded_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"]) |
|
|
|
|
|
|
|
|
|
|
|
class RobertaBiLSTMAttention(nn.Module): |
|
|
def __init__(self, hidden_dim=128, num_labels=2): |
|
|
super().__init__() |
|
|
self.roberta = RobertaModel.from_pretrained("roberta-base") |
|
|
self.lstm = nn.LSTM(768, hidden_dim, batch_first=True, bidirectional=True) |
|
|
self.attn = nn.Linear(hidden_dim * 2, 1) |
|
|
self.dropout = nn.Dropout(0.3) |
|
|
self.fc = nn.Linear(hidden_dim * 2, num_labels) |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
with torch.no_grad(): |
|
|
roberta_out = self.roberta(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state |
|
|
lstm_out, _ = self.lstm(roberta_out) |
|
|
weights = torch.softmax(self.attn(lstm_out), dim=1) |
|
|
context = torch.sum(weights * lstm_out, dim=1) |
|
|
output = self.fc(self.dropout(context)) |
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
train_loader = DataLoader(encoded_dataset["train"].select(range(20000)), batch_size=16, shuffle=True) |
|
|
test_loader = DataLoader(encoded_dataset["test"].select(range(2000)), batch_size=16) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = RobertaBiLSTMAttention().to(device) |
|
|
|
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
import torch |
|
|
|
|
|
|
|
|
model_path = hf_hub_download( |
|
|
repo_id="hrnrxb/roberta-bilstm-attention-sentiment", |
|
|
filename="pytorch_model.bin" |
|
|
) |
|
|
|
|
|
|
|
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
samples = [ |
|
|
|
|
|
"Wow, what a masterpiece. I especially loved the part where nothing happened for two hours.", |
|
|
|
|
|
|
|
|
"The movie was boring at times, but the ending completely blew my mind.", |
|
|
|
|
|
|
|
|
"It was... fine, I guess. Not bad. Not good. Just there.", |
|
|
|
|
|
|
|
|
"Beautiful cinematography canβt save a script written by a potato.", |
|
|
|
|
|
|
|
|
"Yo that movie was sick af! π₯π₯", |
|
|
|
|
|
|
|
|
"I didnβt expect much, and yet it still managed to disappoint me.", |
|
|
|
|
|
|
|
|
"10/10 would recommend... if you enjoy falling asleep halfway through.", |
|
|
|
|
|
|
|
|
"Absolutely incredible! I only checked my phone 12 times.", |
|
|
|
|
|
|
|
|
"Reminded me of my childhood, cheesy but heartwarming.", |
|
|
|
|
|
|
|
|
"Bro that film went HARD. Straight banger!", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
for s in samples: |
|
|
tokens = tokenizer(s, return_tensors="pt", truncation=True, padding="max_length", max_length=128).to(device) |
|
|
with torch.no_grad(): |
|
|
logits = model(tokens["input_ids"], tokens["attention_mask"]) |
|
|
pred = torch.argmax(logits, dim=1).item() |
|
|
print(f"{s} β€ {'π’ Positive' if pred == 1 else 'π΄ Negative'} ") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
header = gr.HTML(""" |
|
|
<div style="text-align:center; margin-bottom:10px;"> |
|
|
<a href="https://github.com/hrnrxb/Advanced-Sentiment-Classifier" target="_blank" style="font-weight:bold; font-size:18px; text-decoration:none; color:#4A90E2;"> |
|
|
π View on GitHub |
|
|
</a> | |
|
|
<a href="https://hrnrxb.github.io" target="_blank" style="font-weight:bold; font-size:18px; text-decoration:none; color:#4A90E2;"> |
|
|
π My Website |
|
|
</a> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
def predict(text): |
|
|
model.eval() |
|
|
tokens = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=128).to(device) |
|
|
with torch.no_grad(): |
|
|
logits = model(tokens["input_ids"], tokens["attention_mask"]) |
|
|
prob = torch.softmax(logits, dim=1) |
|
|
pred = torch.argmax(prob, dim=1).item() |
|
|
conf = prob[0][pred].item() |
|
|
label = "π’ Positive" if pred == 1 else "π΄ Negative" |
|
|
return f"{label} ({conf*100:.1f}%)" |
|
|
|
|
|
gr.Interface(fn=predict, inputs=gr.Textbox(label="Enter a review"), outputs="text", description="βοΈ [GitHub Repo](https://github.com/hrnrxb/Advanced-Sentiment-Classifier) | π [My Website (https://hrnrxb.github.io)](https://hrnrxb.github.io)").launch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|