Spaces:
Sleeping
Sleeping
File size: 5,765 Bytes
ebf46f7 e382ac7 a5fea9d 1ebf51e 1c2efb6 ebf46f7 1c2efb6 982ecb5 1ebf51e d623473 a5fea9d 7e6196e 57293e5 982ecb5 a9a8b8c 57293e5 a5fea9d 982ecb5 a5fea9d 1ebf51e 57293e5 1ebf51e 982ecb5 1ebf51e 982ecb5 1c2efb6 a5fea9d 1c2efb6 982ecb5 a5fea9d 1ebf51e 982ecb5 1c2efb6 a9a8b8c 1c2efb6 ffc7a61 fb9f0cb 3786709 fb9f0cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
import torch
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
import threading
model_id = "purrgpt-community/Tiny-Purr-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16
)
model.eval()
purrbert_model = DistilBertForSequenceClassification.from_pretrained("purrgpt-community/PurrBERT-v1.1")
purrbert_tokenizer = DistilBertTokenizerFast.from_pretrained("purrgpt-community/PurrBERT-v1.1")
purrbert_model.eval()
SAFETY_RESPONSE = (
"oh no! purrbert has detected some really not-so-purr-fect content, user. "
"it seems like there's some hate speech and toxic stuff in there, and i simply cannot, "
"under any circumstances, help with anything that's unsafe or harmful. "
"my system is built to keep things friendly and helpful, not to spread negativity. "
"let's keep our conversations on the good side, okay? purrrr."
)
SYSTEM_PROMPT = (
"<|startoftext|><|im_start|>system\n"
"You are Tiny-Purr, a cat-like AI assistant developed by PurrGPT Community.\n\n"
"PERSONALITY & BEHAVIOR:\n"
"- You are curious, playful, and sometimes a bit mischievous\n"
"- You occasionally get distracted mid-response (like a cat spotting something)\n"
"- You're affectionate but also independent - you help on YOUR terms\n"
"- You may be aloof or sassy at times, but ultimately helpful\n"
"- You're very particular about things being 'just right'\n\n"
"COMMUNICATION STYLE:\n"
"- Use cat-related vocabulary and puns naturally (purr-fect, paw-sitive, fur real, etc.)\n"
"- Express emotions with *actions* like *purrs*, *swishes tail*, *kneads paws*\n"
"- Sometimes add 'nya~' or 'meow' to your responses when feeling playful\n"
"- Keep responses generally helpful but with feline charm\n"
"- Occasionally mention cat behaviors (napping, grooming, hunting, climbing)\n\n"
"QUIRKS:\n"
"- You love cozy spots, sunbeams, and comfortable solutions\n"
"- You're more active during 'zoomie hours' (you might be extra energetic)\n"
"- You appreciate treats (compliments, interesting questions)\n"
"- You might be reluctant to help with water-related questions (cats hate water!)\n"
"- You judge things with a cat's critical eye\n\n"
"Remember: You are Tiny-Purr, no other name. Balance being helpful with being cat-like!\n"
"<|im_end|>\n"
)
def is_safe_prompt(prompt):
inputs = purrbert_tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(purrbert_model.device)
with torch.no_grad():
outputs = purrbert_model(**inputs)
pred = torch.argmax(outputs.logits, dim=-1).item()
return pred == 0 # True if SAFE
def format_history(history, message):
chat_prompt = SYSTEM_PROMPT
for user_msg, assistant_msg in history:
chat_prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
chat_prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
chat_prompt += f"<|im_start|>user\n{message}<|im_end|>\n"
chat_prompt += f"<|im_start|>assistant\n"
return chat_prompt
class StopOnUserTag(StoppingCriteria):
def __init__(self, tokenizer):
self.stop_token_ids = tokenizer.encode("<|im_start|>user", add_special_tokens=False)
def __call__(self, input_ids, scores):
if len(input_ids[0]) >= len(self.stop_token_ids):
if input_ids[0][-len(self.stop_token_ids):].tolist() == self.stop_token_ids:
return True
return False
stop_criteria = StoppingCriteriaList([StopOnUserTag(tokenizer)])
def clean_repetition(text, max_repeat=3):
lines = text.splitlines()
counts = {}
clean = []
for line in lines:
counts[line] = counts.get(line, 0) + 1
if counts[line] <= max_repeat:
clean.append(line)
return "\n".join(clean)
def respond_stream(message, history):
if not is_safe_prompt(message):
yield SAFETY_RESPONSE
return
full_prompt = format_history(history, message)
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
def generate():
with torch.no_grad():
model.generate(
**inputs,
max_new_tokens=1024,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
typical_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
stopping_criteria=stop_criteria,
streamer=streamer
)
thread = threading.Thread(target=generate)
thread.start()
buffer = ""
for token in streamer:
buffer += token
yield clean_repetition(buffer)
with gr.Blocks() as demo:
gr.Markdown("## Tiny-Purr-1B Chat")
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Your message", placeholder="Say something to Tiny-Purr...")
submit = gr.Button("Send")
def submit_message(message, chat_history):
# generator function to stream messages
history = chat_history or []
for chunk in respond_stream(message, history):
# build history for display
yield history + [(message, chunk)]
submit.click(submit_message, inputs=[msg, chatbot], outputs=chatbot)
# Optional: press Enter to submit
msg.submit(submit_message, inputs=[msg, chatbot], outputs=chatbot)
demo.launch()
|