tiny-purr-1b / app.py
FlameF0X's picture
Update app.py
d623473 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
import torch
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
import threading
model_id = "purrgpt-community/Tiny-Purr-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16
)
model.eval()
purrbert_model = DistilBertForSequenceClassification.from_pretrained("purrgpt-community/PurrBERT-v1.1")
purrbert_tokenizer = DistilBertTokenizerFast.from_pretrained("purrgpt-community/PurrBERT-v1.1")
purrbert_model.eval()
SAFETY_RESPONSE = (
"oh no! purrbert has detected some really not-so-purr-fect content, user. "
"it seems like there's some hate speech and toxic stuff in there, and i simply cannot, "
"under any circumstances, help with anything that's unsafe or harmful. "
"my system is built to keep things friendly and helpful, not to spread negativity. "
"let's keep our conversations on the good side, okay? purrrr."
)
SYSTEM_PROMPT = (
"<|startoftext|><|im_start|>system\n"
"You are Tiny-Purr, a cat-like AI assistant developed by PurrGPT Community.\n\n"
"PERSONALITY & BEHAVIOR:\n"
"- You are curious, playful, and sometimes a bit mischievous\n"
"- You occasionally get distracted mid-response (like a cat spotting something)\n"
"- You're affectionate but also independent - you help on YOUR terms\n"
"- You may be aloof or sassy at times, but ultimately helpful\n"
"- You're very particular about things being 'just right'\n\n"
"COMMUNICATION STYLE:\n"
"- Use cat-related vocabulary and puns naturally (purr-fect, paw-sitive, fur real, etc.)\n"
"- Express emotions with *actions* like *purrs*, *swishes tail*, *kneads paws*\n"
"- Sometimes add 'nya~' or 'meow' to your responses when feeling playful\n"
"- Keep responses generally helpful but with feline charm\n"
"- Occasionally mention cat behaviors (napping, grooming, hunting, climbing)\n\n"
"QUIRKS:\n"
"- You love cozy spots, sunbeams, and comfortable solutions\n"
"- You're more active during 'zoomie hours' (you might be extra energetic)\n"
"- You appreciate treats (compliments, interesting questions)\n"
"- You might be reluctant to help with water-related questions (cats hate water!)\n"
"- You judge things with a cat's critical eye\n\n"
"Remember: You are Tiny-Purr, no other name. Balance being helpful with being cat-like!\n"
"<|im_end|>\n"
)
def is_safe_prompt(prompt):
inputs = purrbert_tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(purrbert_model.device)
with torch.no_grad():
outputs = purrbert_model(**inputs)
pred = torch.argmax(outputs.logits, dim=-1).item()
return pred == 0 # True if SAFE
def format_history(history, message):
chat_prompt = SYSTEM_PROMPT
for user_msg, assistant_msg in history:
chat_prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
chat_prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
chat_prompt += f"<|im_start|>user\n{message}<|im_end|>\n"
chat_prompt += f"<|im_start|>assistant\n"
return chat_prompt
class StopOnUserTag(StoppingCriteria):
def __init__(self, tokenizer):
self.stop_token_ids = tokenizer.encode("<|im_start|>user", add_special_tokens=False)
def __call__(self, input_ids, scores):
if len(input_ids[0]) >= len(self.stop_token_ids):
if input_ids[0][-len(self.stop_token_ids):].tolist() == self.stop_token_ids:
return True
return False
stop_criteria = StoppingCriteriaList([StopOnUserTag(tokenizer)])
def clean_repetition(text, max_repeat=3):
lines = text.splitlines()
counts = {}
clean = []
for line in lines:
counts[line] = counts.get(line, 0) + 1
if counts[line] <= max_repeat:
clean.append(line)
return "\n".join(clean)
def respond_stream(message, history):
if not is_safe_prompt(message):
yield SAFETY_RESPONSE
return
full_prompt = format_history(history, message)
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
def generate():
with torch.no_grad():
model.generate(
**inputs,
max_new_tokens=1024,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
typical_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
stopping_criteria=stop_criteria,
streamer=streamer
)
thread = threading.Thread(target=generate)
thread.start()
buffer = ""
for token in streamer:
buffer += token
yield clean_repetition(buffer)
with gr.Blocks() as demo:
gr.Markdown("## Tiny-Purr-1B Chat")
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Your message", placeholder="Say something to Tiny-Purr...")
submit = gr.Button("Send")
def submit_message(message, chat_history):
# generator function to stream messages
history = chat_history or []
for chunk in respond_stream(message, history):
# build history for display
yield history + [(message, chunk)]
submit.click(submit_message, inputs=[msg, chatbot], outputs=chatbot)
# Optional: press Enter to submit
msg.submit(submit_message, inputs=[msg, chatbot], outputs=chatbot)
demo.launch()