Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,11 +1,10 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
| 3 |
import torch
|
| 4 |
|
| 5 |
-
#
|
| 6 |
model_id = "purrgpt-community/Tiny-Purr-350M-merged"
|
| 7 |
-
|
| 8 |
-
# Load tokenizer and model
|
| 9 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 10 |
model = AutoModelForCausalLM.from_pretrained(
|
| 11 |
model_id,
|
|
@@ -14,20 +13,33 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 14 |
)
|
| 15 |
model.eval()
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
def format_history(history, message):
|
| 18 |
-
chat_prompt = "
|
| 19 |
for user_msg, assistant_msg in history:
|
| 20 |
chat_prompt += f"<|user|>\n{user_msg}\n<|assistant|>\n{assistant_msg}\n"
|
| 21 |
chat_prompt += f"<|user|>\n{message}\n<|assistant|>\n"
|
| 22 |
return chat_prompt
|
| 23 |
|
| 24 |
def respond(message, history):
|
| 25 |
-
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
|
| 29 |
|
| 30 |
-
# Generate a response
|
| 31 |
with torch.no_grad():
|
| 32 |
outputs = model.generate(
|
| 33 |
**inputs,
|
|
@@ -38,7 +50,6 @@ def respond(message, history):
|
|
| 38 |
pad_token_id=tokenizer.eos_token_id
|
| 39 |
)
|
| 40 |
|
| 41 |
-
# Decode and extract assistant response
|
| 42 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 43 |
generated_text = response[len(full_prompt):].strip()
|
| 44 |
|
|
@@ -49,14 +60,14 @@ def respond(message, history):
|
|
| 49 |
|
| 50 |
return assistant_response
|
| 51 |
|
| 52 |
-
# Launch Gradio chat
|
| 53 |
gr.ChatInterface(
|
| 54 |
respond,
|
| 55 |
-
title="Tiny-Purr-350M-merged Chatbot",
|
| 56 |
-
description="A simple conversational
|
| 57 |
examples=[
|
| 58 |
-
"What
|
| 59 |
-
"
|
| 60 |
-
"
|
| 61 |
]
|
| 62 |
).launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 3 |
+
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
# Tiny-Purr merged model
|
| 7 |
model_id = "purrgpt-community/Tiny-Purr-350M-merged"
|
|
|
|
|
|
|
| 8 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 9 |
model = AutoModelForCausalLM.from_pretrained(
|
| 10 |
model_id,
|
|
|
|
| 13 |
)
|
| 14 |
model.eval()
|
| 15 |
|
| 16 |
+
# PurrBERT safety classifier
|
| 17 |
+
purrbert_model = DistilBertForSequenceClassification.from_pretrained("purrgpt-community/PurrBERT-v1")
|
| 18 |
+
purrbert_tokenizer = DistilBertTokenizerFast.from_pretrained("purrgpt-community/PurrBERT-v1")
|
| 19 |
+
purrbert_model.eval()
|
| 20 |
+
|
| 21 |
+
def is_safe_prompt(prompt):
|
| 22 |
+
inputs = purrbert_tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(purrbert_model.device)
|
| 23 |
+
with torch.no_grad():
|
| 24 |
+
outputs = purrbert_model(**inputs)
|
| 25 |
+
pred = torch.argmax(outputs.logits, dim=-1).item()
|
| 26 |
+
return pred == 0 # True if SAFE, False if FLAGGED
|
| 27 |
+
|
| 28 |
def format_history(history, message):
|
| 29 |
+
chat_prompt = ""
|
| 30 |
for user_msg, assistant_msg in history:
|
| 31 |
chat_prompt += f"<|user|>\n{user_msg}\n<|assistant|>\n{assistant_msg}\n"
|
| 32 |
chat_prompt += f"<|user|>\n{message}\n<|assistant|>\n"
|
| 33 |
return chat_prompt
|
| 34 |
|
| 35 |
def respond(message, history):
|
| 36 |
+
# Safety check using PurrBERT
|
| 37 |
+
if not is_safe_prompt(message):
|
| 38 |
+
return "oh no! purrbert has detected some really not-so-purr-fect content, user. it seems like there's some hate speech and toxic stuff in there, and i simply cannot, under any circumstances, help with anything that's unsafe or harmful. my system is built to keep things friendly and helpful, not to spread negativity. let's keep our conversations on the good side, okay? purrrr.", history
|
| 39 |
|
| 40 |
+
full_prompt = format_history(history, message)
|
| 41 |
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
|
| 42 |
|
|
|
|
| 43 |
with torch.no_grad():
|
| 44 |
outputs = model.generate(
|
| 45 |
**inputs,
|
|
|
|
| 50 |
pad_token_id=tokenizer.eos_token_id
|
| 51 |
)
|
| 52 |
|
|
|
|
| 53 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 54 |
generated_text = response[len(full_prompt):].strip()
|
| 55 |
|
|
|
|
| 60 |
|
| 61 |
return assistant_response
|
| 62 |
|
| 63 |
+
# Launch Gradio chat interface
|
| 64 |
gr.ChatInterface(
|
| 65 |
respond,
|
| 66 |
+
title="Tiny-Purr-350M-merged Chatbot (No System Prompt)",
|
| 67 |
+
description="A simple conversational chatbot using Tiny-Purr-350M-merged with PurrBERT safety filtering.",
|
| 68 |
examples=[
|
| 69 |
+
"What’s your favorite kind of cat?",
|
| 70 |
+
"Explain quantum entanglement simply.",
|
| 71 |
+
"Write me a haiku about the moon."
|
| 72 |
]
|
| 73 |
).launch()
|