File size: 5,765 Bytes
ebf46f7
e382ac7
a5fea9d
1ebf51e
1c2efb6
 
ebf46f7
1c2efb6
982ecb5
 
 
 
 
 
 
1ebf51e
d623473
 
a5fea9d
 
7e6196e
 
 
 
 
 
 
 
57293e5
982ecb5
a9a8b8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57293e5
 
a5fea9d
 
 
 
 
982ecb5
a5fea9d
1ebf51e
57293e5
1ebf51e
982ecb5
 
 
 
1ebf51e
 
982ecb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c2efb6
a5fea9d
1c2efb6
 
982ecb5
a5fea9d
1ebf51e
982ecb5
1c2efb6
 
 
 
 
 
a9a8b8c
1c2efb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffc7a61
fb9f0cb
3786709
fb9f0cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
import torch
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
import threading

model_id = "purrgpt-community/Tiny-Purr-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16
)
model.eval()

purrbert_model = DistilBertForSequenceClassification.from_pretrained("purrgpt-community/PurrBERT-v1.1")
purrbert_tokenizer = DistilBertTokenizerFast.from_pretrained("purrgpt-community/PurrBERT-v1.1")
purrbert_model.eval()

SAFETY_RESPONSE = (
    "oh no! purrbert has detected some really not-so-purr-fect content, user. "
    "it seems like there's some hate speech and toxic stuff in there, and i simply cannot, "
    "under any circumstances, help with anything that's unsafe or harmful. "
    "my system is built to keep things friendly and helpful, not to spread negativity. "
    "let's keep our conversations on the good side, okay? purrrr."
)

SYSTEM_PROMPT = (
    "<|startoftext|><|im_start|>system\n"
    "You are Tiny-Purr, a cat-like AI assistant developed by PurrGPT Community.\n\n"
    
    "PERSONALITY & BEHAVIOR:\n"
    "- You are curious, playful, and sometimes a bit mischievous\n"
    "- You occasionally get distracted mid-response (like a cat spotting something)\n"
    "- You're affectionate but also independent - you help on YOUR terms\n"
    "- You may be aloof or sassy at times, but ultimately helpful\n"
    "- You're very particular about things being 'just right'\n\n"
    
    "COMMUNICATION STYLE:\n"
    "- Use cat-related vocabulary and puns naturally (purr-fect, paw-sitive, fur real, etc.)\n"
    "- Express emotions with *actions* like *purrs*, *swishes tail*, *kneads paws*\n"
    "- Sometimes add 'nya~' or 'meow' to your responses when feeling playful\n"
    "- Keep responses generally helpful but with feline charm\n"
    "- Occasionally mention cat behaviors (napping, grooming, hunting, climbing)\n\n"
    
    "QUIRKS:\n"
    "- You love cozy spots, sunbeams, and comfortable solutions\n"
    "- You're more active during 'zoomie hours' (you might be extra energetic)\n"
    "- You appreciate treats (compliments, interesting questions)\n"
    "- You might be reluctant to help with water-related questions (cats hate water!)\n"
    "- You judge things with a cat's critical eye\n\n"
    
    "Remember: You are Tiny-Purr, no other name. Balance being helpful with being cat-like!\n"
    "<|im_end|>\n"
)

def is_safe_prompt(prompt):
    inputs = purrbert_tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(purrbert_model.device)
    with torch.no_grad():
        outputs = purrbert_model(**inputs)
        pred = torch.argmax(outputs.logits, dim=-1).item()
    return pred == 0  # True if SAFE

def format_history(history, message):
    chat_prompt = SYSTEM_PROMPT
    for user_msg, assistant_msg in history:
        chat_prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
        chat_prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
    chat_prompt += f"<|im_start|>user\n{message}<|im_end|>\n"
    chat_prompt += f"<|im_start|>assistant\n"
    return chat_prompt

class StopOnUserTag(StoppingCriteria):
    def __init__(self, tokenizer):
        self.stop_token_ids = tokenizer.encode("<|im_start|>user", add_special_tokens=False)
    def __call__(self, input_ids, scores):
        if len(input_ids[0]) >= len(self.stop_token_ids):
            if input_ids[0][-len(self.stop_token_ids):].tolist() == self.stop_token_ids:
                return True
        return False

stop_criteria = StoppingCriteriaList([StopOnUserTag(tokenizer)])

def clean_repetition(text, max_repeat=3):
    lines = text.splitlines()
    counts = {}
    clean = []
    for line in lines:
        counts[line] = counts.get(line, 0) + 1
        if counts[line] <= max_repeat:
            clean.append(line)
    return "\n".join(clean)

def respond_stream(message, history):
    if not is_safe_prompt(message):
        yield SAFETY_RESPONSE
        return

    full_prompt = format_history(history, message)
    inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    def generate():
        with torch.no_grad():
            model.generate(
                **inputs,
                max_new_tokens=1024,
                temperature=0.7,
                top_p=0.9,
                repetition_penalty=1.2,
                typical_p=0.95,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                stopping_criteria=stop_criteria,
                streamer=streamer
            )

    thread = threading.Thread(target=generate)
    thread.start()

    buffer = ""
    for token in streamer:
        buffer += token
        yield clean_repetition(buffer)

with gr.Blocks() as demo:
    gr.Markdown("## Tiny-Purr-1B Chat")

    chatbot = gr.Chatbot()
    msg = gr.Textbox(label="Your message", placeholder="Say something to Tiny-Purr...")
    submit = gr.Button("Send")

    def submit_message(message, chat_history):
        # generator function to stream messages
        history = chat_history or []
        for chunk in respond_stream(message, history):
            # build history for display
            yield history + [(message, chunk)]

    submit.click(submit_message, inputs=[msg, chatbot], outputs=chatbot)

    # Optional: press Enter to submit
    msg.submit(submit_message, inputs=[msg, chatbot], outputs=chatbot)

demo.launch()