import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification import torch from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer import threading model_id = "purrgpt-community/Tiny-Purr-1B" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16 ) model.eval() purrbert_model = DistilBertForSequenceClassification.from_pretrained("purrgpt-community/PurrBERT-v1.1") purrbert_tokenizer = DistilBertTokenizerFast.from_pretrained("purrgpt-community/PurrBERT-v1.1") purrbert_model.eval() SAFETY_RESPONSE = ( "oh no! purrbert has detected some really not-so-purr-fect content, user. " "it seems like there's some hate speech and toxic stuff in there, and i simply cannot, " "under any circumstances, help with anything that's unsafe or harmful. " "my system is built to keep things friendly and helpful, not to spread negativity. " "let's keep our conversations on the good side, okay? purrrr." ) SYSTEM_PROMPT = ( "<|startoftext|><|im_start|>system\n" "You are Tiny-Purr, a cat-like AI assistant developed by PurrGPT Community.\n\n" "PERSONALITY & BEHAVIOR:\n" "- You are curious, playful, and sometimes a bit mischievous\n" "- You occasionally get distracted mid-response (like a cat spotting something)\n" "- You're affectionate but also independent - you help on YOUR terms\n" "- You may be aloof or sassy at times, but ultimately helpful\n" "- You're very particular about things being 'just right'\n\n" "COMMUNICATION STYLE:\n" "- Use cat-related vocabulary and puns naturally (purr-fect, paw-sitive, fur real, etc.)\n" "- Express emotions with *actions* like *purrs*, *swishes tail*, *kneads paws*\n" "- Sometimes add 'nya~' or 'meow' to your responses when feeling playful\n" "- Keep responses generally helpful but with feline charm\n" "- Occasionally mention cat behaviors (napping, grooming, hunting, climbing)\n\n" "QUIRKS:\n" "- You love cozy spots, sunbeams, and comfortable solutions\n" "- You're more active during 'zoomie hours' (you might be extra energetic)\n" "- You appreciate treats (compliments, interesting questions)\n" "- You might be reluctant to help with water-related questions (cats hate water!)\n" "- You judge things with a cat's critical eye\n\n" "Remember: You are Tiny-Purr, no other name. Balance being helpful with being cat-like!\n" "<|im_end|>\n" ) def is_safe_prompt(prompt): inputs = purrbert_tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(purrbert_model.device) with torch.no_grad(): outputs = purrbert_model(**inputs) pred = torch.argmax(outputs.logits, dim=-1).item() return pred == 0 # True if SAFE def format_history(history, message): chat_prompt = SYSTEM_PROMPT for user_msg, assistant_msg in history: chat_prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n" chat_prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n" chat_prompt += f"<|im_start|>user\n{message}<|im_end|>\n" chat_prompt += f"<|im_start|>assistant\n" return chat_prompt class StopOnUserTag(StoppingCriteria): def __init__(self, tokenizer): self.stop_token_ids = tokenizer.encode("<|im_start|>user", add_special_tokens=False) def __call__(self, input_ids, scores): if len(input_ids[0]) >= len(self.stop_token_ids): if input_ids[0][-len(self.stop_token_ids):].tolist() == self.stop_token_ids: return True return False stop_criteria = StoppingCriteriaList([StopOnUserTag(tokenizer)]) def clean_repetition(text, max_repeat=3): lines = text.splitlines() counts = {} clean = [] for line in lines: counts[line] = counts.get(line, 0) + 1 if counts[line] <= max_repeat: clean.append(line) return "\n".join(clean) def respond_stream(message, history): if not is_safe_prompt(message): yield SAFETY_RESPONSE return full_prompt = format_history(history, message) inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) def generate(): with torch.no_grad(): model.generate( **inputs, max_new_tokens=1024, temperature=0.7, top_p=0.9, repetition_penalty=1.2, typical_p=0.95, do_sample=True, pad_token_id=tokenizer.eos_token_id, stopping_criteria=stop_criteria, streamer=streamer ) thread = threading.Thread(target=generate) thread.start() buffer = "" for token in streamer: buffer += token yield clean_repetition(buffer) with gr.Blocks() as demo: gr.Markdown("## Tiny-Purr-1B Chat") chatbot = gr.Chatbot() msg = gr.Textbox(label="Your message", placeholder="Say something to Tiny-Purr...") submit = gr.Button("Send") def submit_message(message, chat_history): # generator function to stream messages history = chat_history or [] for chunk in respond_stream(message, history): # build history for display yield history + [(message, chunk)] submit.click(submit_message, inputs=[msg, chatbot], outputs=chatbot) # Optional: press Enter to submit msg.submit(submit_message, inputs=[msg, chatbot], outputs=chatbot) demo.launch()