FlameF0X commited on
Commit
e382ac7
·
verified ·
1 Parent(s): 84077e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -13
app.py CHANGED
@@ -1,21 +1,21 @@
1
  import gradio as gr
2
- from peft import AutoPeftModelForCausalLM
3
- from transformers import AutoTokenizer
4
  import torch
5
 
6
- model_id = "purrgpt-community/Tiny-Purr-350M"
 
7
 
8
- model = AutoPeftModelForCausalLM.from_pretrained(
 
 
9
  model_id,
10
  device_map="auto",
11
  torch_dtype=torch.bfloat16
12
  )
13
  model.eval()
14
 
15
- tokenizer = AutoTokenizer.from_pretrained(model_id)
16
-
17
  def format_history(history, message):
18
- chat_prompt = "<|system|>\nYou are Tiny-Purr,a friendly, sarcastic, playful ai assistant in the form of a cat.\n<|system|>\n"
19
  for user_msg, assistant_msg in history:
20
  chat_prompt += f"<|user|>\n{user_msg}\n<|assistant|>\n{assistant_msg}\n"
21
  chat_prompt += f"<|user|>\n{message}\n<|assistant|>\n"
@@ -24,8 +24,10 @@ def format_history(history, message):
24
  def respond(message, history):
25
  full_prompt = format_history(history, message)
26
 
 
27
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
28
 
 
29
  with torch.no_grad():
30
  outputs = model.generate(
31
  **inputs,
@@ -36,10 +38,9 @@ def respond(message, history):
36
  pad_token_id=tokenizer.eos_token_id
37
  )
38
 
 
39
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
-
41
- start_index = len(full_prompt)
42
- generated_text = response[start_index:].strip()
43
 
44
  if "\n<|user|>" in generated_text:
45
  assistant_response = generated_text.split("\n<|user|>")[0].strip()
@@ -48,9 +49,14 @@ def respond(message, history):
48
 
49
  return assistant_response
50
 
 
51
  gr.ChatInterface(
52
  respond,
53
- title="Tiny-Purr-350M Chatbot",
54
- description="A simple conversational model powered by Tiny-Purr-350M.",
55
- examples=["What is the capital of France?", "Tell me a short story about a cat.", "Explain the concept of quantum entanglement in simple terms."]
 
 
 
 
56
  ).launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
3
  import torch
4
 
5
+ # Use the merged model
6
+ model_id = "purrgpt-community/Tiny-Purr-350M-merged"
7
 
8
+ # Load tokenizer and model
9
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
  model_id,
12
  device_map="auto",
13
  torch_dtype=torch.bfloat16
14
  )
15
  model.eval()
16
 
 
 
17
  def format_history(history, message):
18
+ chat_prompt = "<|system|>\nYou are Tiny-Purr, a friendly, sarcastic, playful AI assistant in the form of a cat.\n<|system|>\n"
19
  for user_msg, assistant_msg in history:
20
  chat_prompt += f"<|user|>\n{user_msg}\n<|assistant|>\n{assistant_msg}\n"
21
  chat_prompt += f"<|user|>\n{message}\n<|assistant|>\n"
 
24
  def respond(message, history):
25
  full_prompt = format_history(history, message)
26
 
27
+ # Tokenize the input
28
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
29
 
30
+ # Generate a response
31
  with torch.no_grad():
32
  outputs = model.generate(
33
  **inputs,
 
38
  pad_token_id=tokenizer.eos_token_id
39
  )
40
 
41
+ # Decode and extract assistant response
42
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
+ generated_text = response[len(full_prompt):].strip()
 
 
44
 
45
  if "\n<|user|>" in generated_text:
46
  assistant_response = generated_text.split("\n<|user|>")[0].strip()
 
49
 
50
  return assistant_response
51
 
52
+ # Launch Gradio chat
53
  gr.ChatInterface(
54
  respond,
55
+ title="Tiny-Purr-350M-merged Chatbot",
56
+ description="A simple conversational model powered by Tiny-Purr-350M-merged.",
57
+ examples=[
58
+ "What is the capital of France?",
59
+ "Tell me a short story about a cat.",
60
+ "Explain the concept of quantum entanglement in simple terms."
61
+ ]
62
  ).launch()