File size: 1,037 Bytes
f8271bc
 
6d2b52c
 
52a31b8
f8271bc
 
52a31b8
6d2b52c
52a31b8
 
 
 
 
 
 
 
 
d06f60f
52a31b8
6d2b52c
f8271bc
52a31b8
 
 
 
 
f8271bc
52a31b8
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import gradio as gr

# Load your saved model and tokenizer
model_dir = "saved_model"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)

# Define all 6 labels (Jigsaw-style multi-label toxic comment classification)
labels = [
    "toxic",
    "severe_toxic",
    "obscene",
    "threat",
    "insult",
    "identity_hate"
]

# Inference function
def classify(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.sigmoid(outputs.logits)[0]  # Sigmoid for multi-label
        result = {label: float(probs[i]) for i, label in enumerate(labels)}
    return result

# Gradio interface
gr.Interface(
    fn=classify,
    inputs=gr.Textbox(placeholder="Enter your comment..."),
    outputs=gr.Label(num_top_classes=6),
    title="Toxic Comment Classifier"
).launch()