File size: 5,122 Bytes
6d3cb04 9292dcb 9589a47 c58e4fe 6d3cb04 8230b7b 6d3cb04 22408f1 6d3cb04 8230b7b 6d3cb04 409ea56 6d3cb04 8230b7b 6d3cb04 32c4fe4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import gradio as gr
import timm
import torch
from transformers import RobertaForSequenceClassification, RobertaTokenizer
"Vision"
vit_model = timm.create_model("hf_hub:Marqo/nsfw-image-detection-384", pretrained=True)
vit_model = vit_model.eval()
data_config = timm.data.resolve_model_data_config(vit_model)
transforms = timm.data.create_transform(**data_config, is_training=False)
"NLP"
tokenizer = RobertaTokenizer.from_pretrained("s-nlp/roberta_toxicity_classifier")
model = RobertaForSequenceClassification.from_pretrained(
"s-nlp/roberta_toxicity_classifier"
)
def moderate_image(img):
# Load your model
with torch.no_grad():
output = vit_model(transforms(img).unsqueeze(0)).softmax(dim=-1).cpu()
class_names = vit_model.pretrained_cfg["label_names"]
probabilities = output[0].tolist()
if probabilities[0] >= 0.3:
return class_names[0]
else:
return class_names[1]
def classify_toxic(text):
with torch.no_grad():
batch = tokenizer.encode(text, return_tensors="pt")
output = model(batch).logits
probabilities = torch.nn.functional.softmax(output, dim=-1)
preds = probabilities.tolist()
return "Toxic" if preds[0][0] <= 0.55 else "Safe"
# -----------------------
# Apple-Minimal Styling
# -----------------------
custom_css = """
/* Center container and control width */
.gradio-container {
max-width: 900px !important;
margin: 0 auto !important;
padding: 20px 10px !important;
}
/* Header styling */
.clean-title {
font-size: 1.9rem;
font-weight: 600;
text-align: center;
margin-bottom: 1.2rem;
letter-spacing: -0.4px;
}
/* Apple-like card sections */
.apple-card {
padding: 18px;
border-radius: 12px;
border: 1px solid rgba(var(--block-border-color-rgb), 0.14);
background: var(--block-background-fill);
box-shadow: 0 1px 3px rgba(0,0,0,0.04);
margin-bottom: 18px;
}
/* Button styling: clean, flat, subtle */
.gr-button {
border-radius: 8px !important;
background: var(--button-secondary-background-fill) !important;
border: 1px solid rgba(var(--block-border-color-rgb), 0.22) !important;
transition: 0.2s ease !important;
}
.gr-button:hover {
background: var(--button-secondary-background-fill-hover) !important;
border-color: rgba(var(--block-border-color-rgb), 0.34) !important;
}
.gr-button:active {
background: var(--button-secondary-background-fill-pressed) !important;
}
/* Reduce blank space between elements */
.gr-block {
margin: 6px 0 !important;
}
/* Label style */
label {
font-weight: 500 !important;
}
/* Make body fill full height so footer can stick */
body, .gradio-container {
min-height: 100vh !important;
display: flex;
flex-direction: column;
}
/* Main content should expand, footer sits at bottom */
.main-content {
flex: 1 0 auto;
}
.footer-custom {
flex-shrink: 0;
text-align: center;
font-size: 0.80rem;
opacity: 0.6;
padding: 14px 0;
border-top: 1px solid rgba(var(--block-border-color-rgb), 0.12);
margin-top: 25px;
}
footer {display: none !important}
"""
# -----------------------
# UI Layout
# -----------------------
with gr.Blocks(
theme=gr.themes.Soft(primary_hue="violet", secondary_hue="slate"), css=custom_css
) as demo:
with gr.Column(elem_classes="main-content"):
gr.Markdown("<div class='clean-title'>Content Safety Demo</div>")
with gr.Tabs():
# ---- NSFW Image Classification ---- #
with gr.Tab("NSFW Image Detection"):
with gr.Row():
with gr.Column(scale=3):
with gr.Group(elem_classes="apple-card"):
img_in = gr.Image(type="pil", label="Upload Image")
classify_img_btn = gr.Button("Classify")
img_clear_btn = gr.ClearButton(components=img_in)
with gr.Column(scale=2):
with gr.Group(elem_classes="apple-card"):
img_out = gr.Label(label="Prediction")
classify_img_btn.click(
fn=moderate_image, inputs=img_in, outputs=img_out
)
# ---- Toxic Text Classification ---- #
with gr.Tab("Toxic Text Detection"):
with gr.Row():
with gr.Column(scale=3):
with gr.Group(elem_classes="apple-card"):
txt_in = gr.Textbox(lines=4, label="Enter Text")
classify_txt_btn = gr.Button("Analyze")
text_clear_btn = gr.ClearButton(components=txt_in)
with gr.Column(scale=2):
with gr.Group(elem_classes="apple-card"):
txt_out = gr.Label(label="Prediction")
classify_txt_btn.click(classify_toxic, inputs=txt_in, outputs=txt_out)
gr.Markdown(
"<div class='footer-custom'>Demo by 7th • Powered by Transformers</div>"
)
if __name__ == "__main__":
demo.launch()
|