FlameF0X commited on
Commit
a5fea9d
·
verified ·
1 Parent(s): e382ac7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -14
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
3
  import torch
4
 
5
- # Use the merged model
6
  model_id = "purrgpt-community/Tiny-Purr-350M-merged"
7
-
8
- # Load tokenizer and model
9
  tokenizer = AutoTokenizer.from_pretrained(model_id)
10
  model = AutoModelForCausalLM.from_pretrained(
11
  model_id,
@@ -14,20 +13,33 @@ model = AutoModelForCausalLM.from_pretrained(
14
  )
15
  model.eval()
16
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def format_history(history, message):
18
- chat_prompt = "<|system|>\nYou are Tiny-Purr, a friendly, sarcastic, playful AI assistant in the form of a cat.\n<|system|>\n"
19
  for user_msg, assistant_msg in history:
20
  chat_prompt += f"<|user|>\n{user_msg}\n<|assistant|>\n{assistant_msg}\n"
21
  chat_prompt += f"<|user|>\n{message}\n<|assistant|>\n"
22
  return chat_prompt
23
 
24
  def respond(message, history):
25
- full_prompt = format_history(history, message)
 
 
26
 
27
- # Tokenize the input
28
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
29
 
30
- # Generate a response
31
  with torch.no_grad():
32
  outputs = model.generate(
33
  **inputs,
@@ -38,7 +50,6 @@ def respond(message, history):
38
  pad_token_id=tokenizer.eos_token_id
39
  )
40
 
41
- # Decode and extract assistant response
42
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
  generated_text = response[len(full_prompt):].strip()
44
 
@@ -49,14 +60,14 @@ def respond(message, history):
49
 
50
  return assistant_response
51
 
52
- # Launch Gradio chat
53
  gr.ChatInterface(
54
  respond,
55
- title="Tiny-Purr-350M-merged Chatbot",
56
- description="A simple conversational model powered by Tiny-Purr-350M-merged.",
57
  examples=[
58
- "What is the capital of France?",
59
- "Tell me a short story about a cat.",
60
- "Explain the concept of quantum entanglement in simple terms."
61
  ]
62
  ).launch()
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
4
  import torch
5
 
6
+ # Tiny-Purr merged model
7
  model_id = "purrgpt-community/Tiny-Purr-350M-merged"
 
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_id)
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_id,
 
13
  )
14
  model.eval()
15
 
16
+ # PurrBERT safety classifier
17
+ purrbert_model = DistilBertForSequenceClassification.from_pretrained("purrgpt-community/PurrBERT-v1")
18
+ purrbert_tokenizer = DistilBertTokenizerFast.from_pretrained("purrgpt-community/PurrBERT-v1")
19
+ purrbert_model.eval()
20
+
21
+ def is_safe_prompt(prompt):
22
+ inputs = purrbert_tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(purrbert_model.device)
23
+ with torch.no_grad():
24
+ outputs = purrbert_model(**inputs)
25
+ pred = torch.argmax(outputs.logits, dim=-1).item()
26
+ return pred == 0 # True if SAFE, False if FLAGGED
27
+
28
  def format_history(history, message):
29
+ chat_prompt = ""
30
  for user_msg, assistant_msg in history:
31
  chat_prompt += f"<|user|>\n{user_msg}\n<|assistant|>\n{assistant_msg}\n"
32
  chat_prompt += f"<|user|>\n{message}\n<|assistant|>\n"
33
  return chat_prompt
34
 
35
  def respond(message, history):
36
+ # Safety check using PurrBERT
37
+ if not is_safe_prompt(message):
38
+ return "oh no! purrbert has detected some really not-so-purr-fect content, user. it seems like there's some hate speech and toxic stuff in there, and i simply cannot, under any circumstances, help with anything that's unsafe or harmful. my system is built to keep things friendly and helpful, not to spread negativity. let's keep our conversations on the good side, okay? purrrr.", history
39
 
40
+ full_prompt = format_history(history, message)
41
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
42
 
 
43
  with torch.no_grad():
44
  outputs = model.generate(
45
  **inputs,
 
50
  pad_token_id=tokenizer.eos_token_id
51
  )
52
 
 
53
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
54
  generated_text = response[len(full_prompt):].strip()
55
 
 
60
 
61
  return assistant_response
62
 
63
+ # Launch Gradio chat interface
64
  gr.ChatInterface(
65
  respond,
66
+ title="Tiny-Purr-350M-merged Chatbot (No System Prompt)",
67
+ description="A simple conversational chatbot using Tiny-Purr-350M-merged with PurrBERT safety filtering.",
68
  examples=[
69
+ "What’s your favorite kind of cat?",
70
+ "Explain quantum entanglement simply.",
71
+ "Write me a haiku about the moon."
72
  ]
73
  ).launch()