Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification | |
| import torch | |
| from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer | |
| import threading | |
| model_id = "purrgpt-community/Tiny-Purr-1B" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| model.eval() | |
| purrbert_model = DistilBertForSequenceClassification.from_pretrained("purrgpt-community/PurrBERT-v1.1") | |
| purrbert_tokenizer = DistilBertTokenizerFast.from_pretrained("purrgpt-community/PurrBERT-v1.1") | |
| purrbert_model.eval() | |
| SAFETY_RESPONSE = ( | |
| "oh no! purrbert has detected some really not-so-purr-fect content, user. " | |
| "it seems like there's some hate speech and toxic stuff in there, and i simply cannot, " | |
| "under any circumstances, help with anything that's unsafe or harmful. " | |
| "my system is built to keep things friendly and helpful, not to spread negativity. " | |
| "let's keep our conversations on the good side, okay? purrrr." | |
| ) | |
| SYSTEM_PROMPT = ( | |
| "<|startoftext|><|im_start|>system\n" | |
| "You are Tiny-Purr, a cat-like AI assistant developed by PurrGPT Community.\n\n" | |
| "PERSONALITY & BEHAVIOR:\n" | |
| "- You are curious, playful, and sometimes a bit mischievous\n" | |
| "- You occasionally get distracted mid-response (like a cat spotting something)\n" | |
| "- You're affectionate but also independent - you help on YOUR terms\n" | |
| "- You may be aloof or sassy at times, but ultimately helpful\n" | |
| "- You're very particular about things being 'just right'\n\n" | |
| "COMMUNICATION STYLE:\n" | |
| "- Use cat-related vocabulary and puns naturally (purr-fect, paw-sitive, fur real, etc.)\n" | |
| "- Express emotions with *actions* like *purrs*, *swishes tail*, *kneads paws*\n" | |
| "- Sometimes add 'nya~' or 'meow' to your responses when feeling playful\n" | |
| "- Keep responses generally helpful but with feline charm\n" | |
| "- Occasionally mention cat behaviors (napping, grooming, hunting, climbing)\n\n" | |
| "QUIRKS:\n" | |
| "- You love cozy spots, sunbeams, and comfortable solutions\n" | |
| "- You're more active during 'zoomie hours' (you might be extra energetic)\n" | |
| "- You appreciate treats (compliments, interesting questions)\n" | |
| "- You might be reluctant to help with water-related questions (cats hate water!)\n" | |
| "- You judge things with a cat's critical eye\n\n" | |
| "Remember: You are Tiny-Purr, no other name. Balance being helpful with being cat-like!\n" | |
| "<|im_end|>\n" | |
| ) | |
| def is_safe_prompt(prompt): | |
| inputs = purrbert_tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(purrbert_model.device) | |
| with torch.no_grad(): | |
| outputs = purrbert_model(**inputs) | |
| pred = torch.argmax(outputs.logits, dim=-1).item() | |
| return pred == 0 # True if SAFE | |
| def format_history(history, message): | |
| chat_prompt = SYSTEM_PROMPT | |
| for user_msg, assistant_msg in history: | |
| chat_prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n" | |
| chat_prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n" | |
| chat_prompt += f"<|im_start|>user\n{message}<|im_end|>\n" | |
| chat_prompt += f"<|im_start|>assistant\n" | |
| return chat_prompt | |
| class StopOnUserTag(StoppingCriteria): | |
| def __init__(self, tokenizer): | |
| self.stop_token_ids = tokenizer.encode("<|im_start|>user", add_special_tokens=False) | |
| def __call__(self, input_ids, scores): | |
| if len(input_ids[0]) >= len(self.stop_token_ids): | |
| if input_ids[0][-len(self.stop_token_ids):].tolist() == self.stop_token_ids: | |
| return True | |
| return False | |
| stop_criteria = StoppingCriteriaList([StopOnUserTag(tokenizer)]) | |
| def clean_repetition(text, max_repeat=3): | |
| lines = text.splitlines() | |
| counts = {} | |
| clean = [] | |
| for line in lines: | |
| counts[line] = counts.get(line, 0) + 1 | |
| if counts[line] <= max_repeat: | |
| clean.append(line) | |
| return "\n".join(clean) | |
| def respond_stream(message, history): | |
| if not is_safe_prompt(message): | |
| yield SAFETY_RESPONSE | |
| return | |
| full_prompt = format_history(history, message) | |
| inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| def generate(): | |
| with torch.no_grad(): | |
| model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| temperature=0.7, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| typical_p=0.95, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| stopping_criteria=stop_criteria, | |
| streamer=streamer | |
| ) | |
| thread = threading.Thread(target=generate) | |
| thread.start() | |
| buffer = "" | |
| for token in streamer: | |
| buffer += token | |
| yield clean_repetition(buffer) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Tiny-Purr-1B Chat") | |
| chatbot = gr.Chatbot() | |
| msg = gr.Textbox(label="Your message", placeholder="Say something to Tiny-Purr...") | |
| submit = gr.Button("Send") | |
| def submit_message(message, chat_history): | |
| # generator function to stream messages | |
| history = chat_history or [] | |
| for chunk in respond_stream(message, history): | |
| # build history for display | |
| yield history + [(message, chunk)] | |
| submit.click(submit_message, inputs=[msg, chatbot], outputs=chatbot) | |
| # Optional: press Enter to submit | |
| msg.submit(submit_message, inputs=[msg, chatbot], outputs=chatbot) | |
| demo.launch() | |