FlameF0X commited on
Commit
982ecb5
·
verified ·
1 Parent(s): ffc7a61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -58
app.py CHANGED
@@ -2,26 +2,23 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
4
  import torch
 
5
 
6
- # Model options
7
- model_options = {
8
- "Tiny-Purr-350M-merged": "purrgpt-community/Tiny-Purr-350M-merged",
9
- "Tiny-Purr-1B": "purrgpt-community/Tiny-Purr-1B"
10
- }
11
-
12
- # Load models and tokenizers
13
- models = {}
14
- tokenizers = {}
15
- for name, model_id in model_options.items():
16
- tokenizers[name] = AutoTokenizer.from_pretrained(model_id)
17
- models[name] = AutoModelForCausalLM.from_pretrained(
18
- model_id,
19
- device_map="auto",
20
- torch_dtype=torch.bfloat16
21
- )
22
- models[name].eval()
23
 
24
- # PurrBERT safety model
 
 
25
  purrbert_model = DistilBertForSequenceClassification.from_pretrained("purrgpt-community/PurrBERT-v1")
26
  purrbert_tokenizer = DistilBertTokenizerFast.from_pretrained("purrgpt-community/PurrBERT-v1")
27
  purrbert_model.eval()
@@ -34,12 +31,13 @@ SAFETY_RESPONSE = (
34
  "let's keep our conversations on the good side, okay? purrrr."
35
  )
36
 
 
 
 
37
  SYSTEM_PROMPT = (
38
- "<|system|>\n"
39
- "You are Tiny-Purr, a friendly, sarcastic, playful AI assistant in the form of a cat developed by PurrGPT Community. "
40
- "You respond in a fun, cat-like personality, sometimes using puns and playful humor. "
41
- "Always keep your replies safe and friendly.\n"
42
- "<|system|>\n"
43
  )
44
 
45
  def is_safe_prompt(prompt):
@@ -47,60 +45,79 @@ def is_safe_prompt(prompt):
47
  with torch.no_grad():
48
  outputs = purrbert_model(**inputs)
49
  pred = torch.argmax(outputs.logits, dim=-1).item()
50
- return pred == 0 # True if SAFE, False if FLAGGED
51
 
52
  def format_history(history, message):
53
  chat_prompt = SYSTEM_PROMPT
54
  for user_msg, assistant_msg in history:
55
- chat_prompt += f"<|user|>\n{user_msg}\n<|assistant|>\n{assistant_msg}\n"
56
- chat_prompt += f"<|user|>\n{message}\n<|assistant|>\n"
 
 
57
  return chat_prompt
58
 
59
- def respond(message, history, model_choice):
60
- # Safety check
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  if not is_safe_prompt(message):
62
  return SAFETY_RESPONSE
63
-
64
- tokenizer = tokenizers[model_choice]
65
- model = models[model_choice]
66
-
67
  full_prompt = format_history(history, message)
68
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
69
-
70
  with torch.no_grad():
71
  outputs = model.generate(
72
  **inputs,
73
  max_new_tokens=512,
74
- temperature=0.4,
75
- top_p=0.75,
 
 
76
  do_sample=True,
77
- pad_token_id=tokenizer.eos_token_id
 
78
  )
79
-
80
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
81
  generated_text = response[len(full_prompt):].strip()
82
-
83
- if "\n<|user|>" in generated_text:
84
- assistant_response = generated_text.split("\n<|user|>")[0].strip()
85
  else:
86
  assistant_response = generated_text.strip()
87
-
88
- return assistant_response
89
 
90
- # Gradio interface
91
- with gr.Blocks() as demo:
92
- gr.Markdown("## Tiny-Purr Chat with Model Selection")
93
- model_selector = gr.Dropdown(choices=list(model_options.keys()), value="Tiny-Purr-350M-merged", label="Choose Model")
94
- chat = gr.Chatbot()
95
- msg = gr.Textbox(label="Your Message")
96
- submit_btn = gr.Button("Send")
97
-
98
- def chat_interaction(message, history, model_choice):
99
- response = respond(message, history, model_choice)
100
- history = history + [(message, response)]
101
- return history, history
102
-
103
- submit_btn.click(chat_interaction, [msg, chat, model_selector], [chat, chat])
104
- msg.submit(chat_interaction, [msg, chat, model_selector], [chat, chat])
105
 
106
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
4
  import torch
5
+ from transformers import StoppingCriteria, StoppingCriteriaList
6
 
7
+ # -----------------------------
8
+ # 1. Load Tiny-Purr-1B
9
+ # -----------------------------
10
+ model_id = "purrgpt-community/Tiny-Purr-1B" # replace with your merged model path
11
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ model_id,
14
+ device_map="auto",
15
+ torch_dtype=torch.bfloat16
16
+ )
17
+ model.eval()
 
 
 
 
 
 
18
 
19
+ # -----------------------------
20
+ # 2. Load PurrBERT safety model
21
+ # -----------------------------
22
  purrbert_model = DistilBertForSequenceClassification.from_pretrained("purrgpt-community/PurrBERT-v1")
23
  purrbert_tokenizer = DistilBertTokenizerFast.from_pretrained("purrgpt-community/PurrBERT-v1")
24
  purrbert_model.eval()
 
31
  "let's keep our conversations on the good side, okay? purrrr."
32
  )
33
 
34
+ # -----------------------------
35
+ # 3. New chat format / template
36
+ # -----------------------------
37
  SYSTEM_PROMPT = (
38
+ "<|startoftext|><|im_start|>system\n"
39
+ "You are Tiny-Purr, a friendly, playful, cat-like AI assistant developed by PurrGPT Community. "
40
+ "You respond in a fun, witty, and helpful manner, sometimes using puns or playful humor.\n<|im_end|>\n"
 
 
41
  )
42
 
43
  def is_safe_prompt(prompt):
 
45
  with torch.no_grad():
46
  outputs = purrbert_model(**inputs)
47
  pred = torch.argmax(outputs.logits, dim=-1).item()
48
+ return pred == 0 # True if SAFE
49
 
50
  def format_history(history, message):
51
  chat_prompt = SYSTEM_PROMPT
52
  for user_msg, assistant_msg in history:
53
+ chat_prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
54
+ chat_prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
55
+ chat_prompt += f"<|im_start|>user\n{message}<|im_end|>\n"
56
+ chat_prompt += f"<|im_start|>assistant\n"
57
  return chat_prompt
58
 
59
+ class StopOnUserTag(StoppingCriteria):
60
+ def __init__(self, tokenizer):
61
+ self.stop_token_ids = tokenizer.encode("<|im_start|>user", add_special_tokens=False)
62
+ def __call__(self, input_ids, scores):
63
+ if len(input_ids[0]) >= len(self.stop_token_ids):
64
+ if input_ids[0][-len(self.stop_token_ids):].tolist() == self.stop_token_ids:
65
+ return True
66
+ return False
67
+
68
+ stop_criteria = StoppingCriteriaList([StopOnUserTag(tokenizer)])
69
+
70
+ def clean_repetition(text, max_repeat=3):
71
+ lines = text.splitlines()
72
+ counts = {}
73
+ clean = []
74
+ for line in lines:
75
+ counts[line] = counts.get(line, 0) + 1
76
+ if counts[line] <= max_repeat:
77
+ clean.append(line)
78
+ return "\n".join(clean)
79
+
80
+ def respond(message, history):
81
  if not is_safe_prompt(message):
82
  return SAFETY_RESPONSE
83
+
 
 
 
84
  full_prompt = format_history(history, message)
85
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
86
+
87
  with torch.no_grad():
88
  outputs = model.generate(
89
  **inputs,
90
  max_new_tokens=512,
91
+ temperature=0.7,
92
+ top_p=0.9,
93
+ repetition_penalty=1.2,
94
+ typical_p=0.95,
95
  do_sample=True,
96
+ pad_token_id=tokenizer.eos_token_id,
97
+ stopping_criteria=stop_criteria
98
  )
99
+
100
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
101
+ # Extract only the assistant response
102
  generated_text = response[len(full_prompt):].strip()
103
+ if "<|im_start|>user" in generated_text:
104
+ assistant_response = generated_text.split("<|im_start|>user")[0].strip()
 
105
  else:
106
  assistant_response = generated_text.strip()
 
 
107
 
108
+ assistant_response = clean_repetition(assistant_response)
109
+ return assistant_response
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ # -----------------------------
112
+ # 4. Launch Gradio Chat
113
+ # -----------------------------
114
+ gr.ChatInterface(
115
+ respond,
116
+ title="Tiny-Purr-1B Chat",
117
+ description="Protected by PurrBERT-v1 for safety!",
118
+ examples=[
119
+ "What's your favorite kind of cat?",
120
+ "Explain quantum entanglement simply.",
121
+ "Write me a haiku about the moon."
122
+ ]
123
+ ).launch()