wmaousley commited on
Commit
c7d81ab
·
verified ·
1 Parent(s): 46d5e73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -12
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
4
 
5
  MODEL = "wmaousley/MiniCrit-1.5B"
6
 
@@ -8,24 +9,74 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL)
8
  model = AutoModelForCausalLM.from_pretrained(
9
  MODEL,
10
  torch_dtype=torch.float16,
11
- device_map="cpu" # Spaces CPU runtime
12
  )
13
 
14
- def chat_fn(user_input):
15
- inputs = tokenizer(user_input, return_tensors="pt")
16
- outputs = model.generate(
 
 
 
17
  **inputs,
18
  max_new_tokens=200,
19
  temperature=0.7,
20
  do_sample=True,
 
21
  )
22
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
23
 
24
- demo = gr.Interface(
25
- fn=chat_fn,
26
- inputs=gr.Textbox(label="Input"),
27
- outputs=gr.Textbox(label="Response"),
28
- title="MiniCrit-1.5B Chat"
29
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  demo.launch(debug=True)
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
  import torch
4
+ import threading
5
 
6
  MODEL = "wmaousley/MiniCrit-1.5B"
7
 
 
9
  model = AutoModelForCausalLM.from_pretrained(
10
  MODEL,
11
  torch_dtype=torch.float16,
12
+ device_map="cpu"
13
  )
14
 
15
+ def generate_stream(prompt):
16
+ """Streaming generator."""
17
+ inputs = tokenizer(prompt, return_tensors="pt")
18
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
19
+
20
+ generation_kwargs = dict(
21
  **inputs,
22
  max_new_tokens=200,
23
  temperature=0.7,
24
  do_sample=True,
25
+ streamer=streamer
26
  )
 
27
 
28
+ thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
29
+ thread.start()
30
+
31
+ for new_token in streamer:
32
+ yield new_token
33
+
34
+
35
+ def chat_fn(message, history):
36
+ """Formats chat history + generates streaming reply."""
37
+ # Build conversation text
38
+ conversation = ""
39
+ for user, bot in history:
40
+ conversation += f"User: {user}\nMiniCrit: {bot}\n"
41
+ conversation += f"User: {message}\nMiniCrit:"
42
+
43
+ # Stream tokens
44
+ reply = ""
45
+ for token in generate_stream(conversation):
46
+ reply += token
47
+ yield reply
48
+
49
+
50
+ # -------- UI --------
51
+
52
+ with gr.Blocks(theme=gr.themes.Base()) as demo:
53
+
54
+ gr.Markdown(
55
+ """
56
+ <h1 style='text-align:center; color:#00eaff;'>
57
+ MiniCrit-1.5B Chat UI 🚀
58
+ </h1>
59
+ <p style='text-align:center; color:gray;'>Enhanced Streaming Interface</p>
60
+ """
61
+ )
62
+
63
+ chatbox = gr.Chatbot(
64
+ label="MiniCrit-1.5B",
65
+ height=500
66
+ )
67
+
68
+ with gr.Row():
69
+ msg = gr.Textbox(
70
+ placeholder="Ask something...",
71
+ label="Message",
72
+ scale=10
73
+ )
74
+ send = gr.Button("Send", variant="primary")
75
+ clear = gr.Button("Clear")
76
+
77
+ send.click(chat_fn, [msg, chatbox], chatbox)
78
+ send.click(lambda: "", None, msg)
79
+ clear.click(lambda: [], None, chatbox)
80
+
81
 
82
  demo.launch(debug=True)