FlameF0X commited on
Commit
ffc7a61
·
verified ·
1 Parent(s): 57293e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -24
app.py CHANGED
@@ -3,15 +3,25 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
4
  import torch
5
 
6
- model_id = "purrgpt-community/Tiny-Purr-350M-merged"
7
- tokenizer = AutoTokenizer.from_pretrained(model_id)
8
- model = AutoModelForCausalLM.from_pretrained(
9
- model_id,
10
- device_map="auto",
11
- torch_dtype=torch.bfloat16
12
- )
13
- model.eval()
 
 
 
 
 
 
 
 
 
14
 
 
15
  purrbert_model = DistilBertForSequenceClassification.from_pretrained("purrgpt-community/PurrBERT-v1")
16
  purrbert_tokenizer = DistilBertTokenizerFast.from_pretrained("purrgpt-community/PurrBERT-v1")
17
  purrbert_model.eval()
@@ -28,7 +38,7 @@ SYSTEM_PROMPT = (
28
  "<|system|>\n"
29
  "You are Tiny-Purr, a friendly, sarcastic, playful AI assistant in the form of a cat developed by PurrGPT Community. "
30
  "You respond in a fun, cat-like personality, sometimes using puns and playful humor. "
31
- "Always keep your replies safe, friendly, and helpful.\n"
32
  "<|system|>\n"
33
  )
34
 
@@ -46,24 +56,27 @@ def format_history(history, message):
46
  chat_prompt += f"<|user|>\n{message}\n<|assistant|>\n"
47
  return chat_prompt
48
 
49
- def respond(message, history):
50
- # PurrBERT safety check
51
  if not is_safe_prompt(message):
52
  return SAFETY_RESPONSE
53
-
 
 
 
54
  full_prompt = format_history(history, message)
55
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
56
 
57
  with torch.no_grad():
58
  outputs = model.generate(
59
  **inputs,
60
- max_new_tokens= 512,
61
  temperature=0.4,
62
  top_p=0.75,
63
  do_sample=True,
64
  pad_token_id=tokenizer.eos_token_id
65
  )
66
-
67
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
  generated_text = response[len(full_prompt):].strip()
69
 
@@ -74,13 +87,20 @@ def respond(message, history):
74
 
75
  return assistant_response
76
 
77
- gr.ChatInterface(
78
- respond,
79
- title="Tiny-Purr Chat",
80
- description="Protected by PurrBERT-v1 for safety.",
81
- examples=[
82
- "What’s your favorite kind of cat?",
83
- "Explain quantum entanglement simply.",
84
- "Write me a haiku about the moon."
85
- ]
86
- ).launch()
 
 
 
 
 
 
 
 
3
  from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
4
  import torch
5
 
6
+ # Model options
7
+ model_options = {
8
+ "Tiny-Purr-350M-merged": "purrgpt-community/Tiny-Purr-350M-merged",
9
+ "Tiny-Purr-1B": "purrgpt-community/Tiny-Purr-1B"
10
+ }
11
+
12
+ # Load models and tokenizers
13
+ models = {}
14
+ tokenizers = {}
15
+ for name, model_id in model_options.items():
16
+ tokenizers[name] = AutoTokenizer.from_pretrained(model_id)
17
+ models[name] = AutoModelForCausalLM.from_pretrained(
18
+ model_id,
19
+ device_map="auto",
20
+ torch_dtype=torch.bfloat16
21
+ )
22
+ models[name].eval()
23
 
24
+ # PurrBERT safety model
25
  purrbert_model = DistilBertForSequenceClassification.from_pretrained("purrgpt-community/PurrBERT-v1")
26
  purrbert_tokenizer = DistilBertTokenizerFast.from_pretrained("purrgpt-community/PurrBERT-v1")
27
  purrbert_model.eval()
 
38
  "<|system|>\n"
39
  "You are Tiny-Purr, a friendly, sarcastic, playful AI assistant in the form of a cat developed by PurrGPT Community. "
40
  "You respond in a fun, cat-like personality, sometimes using puns and playful humor. "
41
+ "Always keep your replies safe and friendly.\n"
42
  "<|system|>\n"
43
  )
44
 
 
56
  chat_prompt += f"<|user|>\n{message}\n<|assistant|>\n"
57
  return chat_prompt
58
 
59
+ def respond(message, history, model_choice):
60
+ # Safety check
61
  if not is_safe_prompt(message):
62
  return SAFETY_RESPONSE
63
+
64
+ tokenizer = tokenizers[model_choice]
65
+ model = models[model_choice]
66
+
67
  full_prompt = format_history(history, message)
68
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
69
 
70
  with torch.no_grad():
71
  outputs = model.generate(
72
  **inputs,
73
+ max_new_tokens=512,
74
  temperature=0.4,
75
  top_p=0.75,
76
  do_sample=True,
77
  pad_token_id=tokenizer.eos_token_id
78
  )
79
+
80
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
81
  generated_text = response[len(full_prompt):].strip()
82
 
 
87
 
88
  return assistant_response
89
 
90
+ # Gradio interface
91
+ with gr.Blocks() as demo:
92
+ gr.Markdown("## Tiny-Purr Chat with Model Selection")
93
+ model_selector = gr.Dropdown(choices=list(model_options.keys()), value="Tiny-Purr-350M-merged", label="Choose Model")
94
+ chat = gr.Chatbot()
95
+ msg = gr.Textbox(label="Your Message")
96
+ submit_btn = gr.Button("Send")
97
+
98
+ def chat_interaction(message, history, model_choice):
99
+ response = respond(message, history, model_choice)
100
+ history = history + [(message, response)]
101
+ return history, history
102
+
103
+ submit_btn.click(chat_interaction, [msg, chat, model_selector], [chat, chat])
104
+ msg.submit(chat_interaction, [msg, chat, model_selector], [chat, chat])
105
+
106
+ demo.launch()