Joel Lundgren commited on
Commit
a22ca8b
·
1 Parent(s): 215c956

onnx and ui improvements

Browse files
Files changed (2) hide show
  1. app.py +212 -43
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from PIL import Image, ImageDraw
3
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
 
4
  import torch
5
 
6
  # Load the object detection pipeline
@@ -49,27 +50,71 @@ def detect_objects(image):
49
 
50
  return annotated_image, detected_objects_str
51
 
52
- # Cache for LLM models and tokenizers
53
  llm_cache = {}
54
 
55
- def get_llm(model_name):
56
- if model_name in llm_cache:
57
- return llm_cache[model_name]
 
58
 
59
- model_map = {
 
 
 
 
 
 
60
  "gemma3:1b": "google/gemma-3-1b-it",
61
- "qwen3:0.6b": "Qwen/Qwen3-0.6B-Instruct"
62
  }
63
- hf_model_name = model_map[model_name]
64
 
65
- tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
66
- model = AutoModelForCausalLM.from_pretrained(
67
- hf_model_name,
68
- torch_dtype=torch.bfloat16,
69
- device_map="auto"
70
- )
 
 
 
 
 
 
 
71
 
72
- llm_cache[model_name] = (model, tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  return model, tokenizer
74
 
75
  def update_user_prompt(detected_objects, current_prompt):
@@ -83,40 +128,115 @@ def update_user_prompt(detected_objects, current_prompt):
83
 
84
  return new_prompt
85
 
86
- def generate_text(model_name, system_prompt, user_prompt):
87
- model, tokenizer = get_llm(model_name)
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  messages = [
90
  {"role": "system", "content": system_prompt},
91
  {"role": "user", "content": user_prompt},
92
  ]
93
 
94
- chat_template_args = {
95
  "tokenize": False,
96
- "add_generation_prompt": True
97
  }
98
-
99
- if 'qwen' in model_name.lower():
100
- chat_template_args['enable_thinking'] = False
101
 
102
  text = tokenizer.apply_chat_template(
103
  messages,
104
- **chat_template_args
105
  )
106
 
107
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
108
-
109
- generated_ids = model.generate(
110
- model_inputs.input_ids,
111
- max_new_tokens=512
112
- )
113
- generated_ids = [
114
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
 
 
 
 
 
 
 
 
 
115
  ]
 
 
116
 
117
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  with gr.Blocks() as demo:
122
  gr.Markdown("# Black Box: Object Detection and LLM Chat")
@@ -130,10 +250,25 @@ with gr.Blocks() as demo:
130
 
131
  with gr.Tab("LLM Chat"):
132
  model_selector = gr.Dropdown(choices=["gemma3:1b", "qwen3:0.6b"], label="Select LLM Model")
 
 
 
 
 
133
  system_prompt_input = gr.Textbox(label="System Prompt", value="You are a helpful assistant.")
134
- user_prompt_input = gr.Textbox(label="User Prompt")
135
- llm_output = gr.Textbox(label="LLM Response")
136
- llm_button = gr.Button("Generate")
 
 
 
 
 
 
 
 
 
 
137
 
138
  # Connect object detection components
139
  object_detection_button.click(
@@ -142,18 +277,52 @@ with gr.Blocks() as demo:
142
  outputs=[detected_image_output, detected_objects_output]
143
  )
144
 
145
- # Connect LLM components
146
- llm_button.click(
147
- fn=generate_text,
148
- inputs=[model_selector, system_prompt_input, user_prompt_input],
149
- outputs=llm_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  )
 
 
 
 
151
 
152
- # Connect detected objects to user prompt
153
  detected_objects_output.change(
154
  fn=update_user_prompt,
155
  inputs=[detected_objects_output, user_prompt_input],
156
- outputs=user_prompt_input
157
  )
158
 
159
  demo.launch()
 
1
  import gradio as gr
2
  from PIL import Image, ImageDraw
3
+ from transformers import pipeline, AutoTokenizer
4
+ from optimum.onnxruntime import ORTModelForCausalLM
5
  import torch
6
 
7
  # Load the object detection pipeline
 
50
 
51
  return annotated_image, detected_objects_str
52
 
53
+ # Cache for LLM models and tokenizers (ONNX Runtime)
54
  llm_cache = {}
55
 
56
+ def get_llm(model_name, preferred_file: str | None = None):
57
+ cache_key = (model_name, preferred_file or "auto")
58
+ if cache_key in llm_cache:
59
+ return llm_cache[cache_key]
60
 
61
+ # ONNX model repositories on the Hub
62
+ onnx_repo_map = {
63
+ "gemma3:1b": "onnx-community/gemma-3-1b-it-ONNX-GQA",
64
+ "qwen3:0.6b": "onnx-community/Qwen3-0.6B-ONNX",
65
+ }
66
+ # Original repos to fetch correct tokenizer + chat templates
67
+ tokenizer_repo_map = {
68
  "gemma3:1b": "google/gemma-3-1b-it",
69
+ "qwen3:0.6b": "Qwen/Qwen3-0.6B-Instruct",
70
  }
 
71
 
72
+ onnx_repo = onnx_repo_map[model_name]
73
+ tokenizer_repo = tokenizer_repo_map[model_name]
74
+
75
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo)
76
+
77
+ # Try a few common ONNX filenames found in community repos to avoid the
78
+ # "Too many ONNX model files were found" ambiguity.
79
+ candidate_files = [
80
+ "model_q4.onnx",
81
+ "model_quantized.onnx",
82
+ "model_int8.onnx",
83
+ "model.onnx",
84
+ ]
85
 
86
+ model = None
87
+ last_err = None
88
+ ordered = candidate_files
89
+ if preferred_file and preferred_file in candidate_files:
90
+ # Put preferred file first
91
+ ordered = [preferred_file] + [f for f in candidate_files if f != preferred_file]
92
+ elif preferred_file and preferred_file not in candidate_files:
93
+ # If user typed a specific known filename not in our shortlist, try it first anyway
94
+ ordered = [preferred_file] + candidate_files
95
+
96
+ for fname in ordered:
97
+ try:
98
+ model = ORTModelForCausalLM.from_pretrained(
99
+ onnx_repo,
100
+ subfolder="onnx",
101
+ file_name=fname,
102
+ )
103
+ break
104
+ except Exception as e:
105
+ last_err = e
106
+ continue
107
+ if model is None:
108
+ raise RuntimeError(f"Failed to load ONNX model from {onnx_repo}. Last error: {last_err}")
109
+
110
+ # Disable cache to avoid past_key_values shape issues on some ONNX builds
111
+ if hasattr(model.config, "use_cache"):
112
+ try:
113
+ model.config.use_cache = False
114
+ except Exception:
115
+ pass
116
+
117
+ llm_cache[cache_key] = (model, tokenizer)
118
  return model, tokenizer
119
 
120
  def update_user_prompt(detected_objects, current_prompt):
 
128
 
129
  return new_prompt
130
 
131
+ def generate_text(
132
+ model_name,
133
+ onnx_file_choice,
134
+ system_prompt,
135
+ user_prompt,
136
+ do_sample,
137
+ temperature,
138
+ top_p,
139
+ top_k,
140
+ repetition_penalty,
141
+ max_new_tokens,
142
+ ):
143
+ model, tokenizer = get_llm(model_name, preferred_file=None if onnx_file_choice == "auto" else onnx_file_choice)
144
 
145
  messages = [
146
  {"role": "system", "content": system_prompt},
147
  {"role": "user", "content": user_prompt},
148
  ]
149
 
150
+ chat_template_kwargs = {
151
  "tokenize": False,
152
+ "add_generation_prompt": True,
153
  }
154
+ # Disable "thinking" for Qwen models
155
+ if "qwen" in model_name.lower():
156
+ chat_template_kwargs["enable_thinking"] = False
157
 
158
  text = tokenizer.apply_chat_template(
159
  messages,
160
+ **chat_template_kwargs,
161
  )
162
 
163
+ inputs = tokenizer([text], return_tensors="pt")
164
+
165
+ with torch.inference_mode():
166
+ gen_ids = model.generate(
167
+ **inputs,
168
+ max_new_tokens=int(max_new_tokens),
169
+ do_sample=bool(do_sample),
170
+ temperature=float(temperature),
171
+ top_p=float(top_p),
172
+ top_k=int(top_k),
173
+ repetition_penalty=float(repetition_penalty),
174
+ )
175
+
176
+ # Decode only the newly generated tokens beyond the input length
177
+ trimmed = [
178
+ output_ids[len(input_ids):]
179
+ for input_ids, output_ids in zip(inputs.input_ids, gen_ids)
180
  ]
181
+ response = tokenizer.batch_decode(trimmed, skip_special_tokens=True)[0]
182
+ return response
183
 
184
+ def chat_respond(
185
+ model_name,
186
+ onnx_file_choice,
187
+ system_prompt,
188
+ message,
189
+ history,
190
+ do_sample,
191
+ temperature,
192
+ top_p,
193
+ top_k,
194
+ repetition_penalty,
195
+ max_new_tokens,
196
+ ):
197
+ """Builds a chat messages list from history + current user message, generates a reply, and returns updated history and an empty input box."""
198
+ # Guard: empty message
199
+ if not (message and message.strip()):
200
+ return history, gr.update(value="")
201
+
202
+ # Build messages: system, then alternating user/assistant from history, then current user
203
+ messages = [{"role": "system", "content": system_prompt}]
204
+ for u, a in (history or []):
205
+ if u:
206
+ messages.append({"role": "user", "content": u})
207
+ if a:
208
+ messages.append({"role": "assistant", "content": a})
209
+ messages.append({"role": "user", "content": message})
210
+
211
+ # Generate using the same path as generate_text, but inline to avoid extra serialization
212
+ model, tokenizer = get_llm(model_name, preferred_file=None if onnx_file_choice == "auto" else onnx_file_choice)
213
+
214
+ chat_template_kwargs = {
215
+ "tokenize": False,
216
+ "add_generation_prompt": True,
217
+ }
218
+ if "qwen" in model_name.lower():
219
+ chat_template_kwargs["enable_thinking"] = False
220
 
221
+ text = tokenizer.apply_chat_template(messages, **chat_template_kwargs)
222
+ inputs = tokenizer([text], return_tensors="pt")
223
+
224
+ with torch.inference_mode():
225
+ gen_ids = model.generate(
226
+ **inputs,
227
+ max_new_tokens=int(max_new_tokens),
228
+ do_sample=bool(do_sample),
229
+ temperature=float(temperature),
230
+ top_p=float(top_p),
231
+ top_k=int(top_k),
232
+ repetition_penalty=float(repetition_penalty),
233
+ )
234
+
235
+ trimmed = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, gen_ids)]
236
+ reply = tokenizer.batch_decode(trimmed, skip_special_tokens=True)[0]
237
+
238
+ new_history = (history or []) + [(message, reply)]
239
+ return new_history, gr.update(value="")
240
 
241
  with gr.Blocks() as demo:
242
  gr.Markdown("# Black Box: Object Detection and LLM Chat")
 
250
 
251
  with gr.Tab("LLM Chat"):
252
  model_selector = gr.Dropdown(choices=["gemma3:1b", "qwen3:0.6b"], label="Select LLM Model")
253
+ onnx_file_selector = gr.Dropdown(
254
+ choices=["auto", "model_q4.onnx", "model_int8.onnx", "model_quantized.onnx", "model.onnx"],
255
+ value="auto",
256
+ label="ONNX file variant"
257
+ )
258
  system_prompt_input = gr.Textbox(label="System Prompt", value="You are a helpful assistant.")
259
+ chat_bot = gr.Chatbot(height=360, label="Conversation")
260
+ chat_history = gr.State([])
261
+ user_prompt_input = gr.Textbox(label="Message", placeholder="Type your message and press Send...", lines=3)
262
+ with gr.Accordion("Generation settings", open=False):
263
+ do_sample_cb = gr.Checkbox(value=True, label="do_sample")
264
+ temperature_sl = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.05, label="temperature")
265
+ top_p_sl = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="top_p")
266
+ top_k_sl = gr.Slider(minimum=0, maximum=200, value=50, step=1, label="top_k")
267
+ repetition_penalty_sl = gr.Slider(minimum=0.8, maximum=2.0, value=1.05, step=0.01, label="repetition_penalty")
268
+ max_new_tokens_sl = gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="max_new_tokens")
269
+ with gr.Row():
270
+ send_btn = gr.Button("Send", variant="primary")
271
+ clear_btn = gr.Button("Clear chat")
272
 
273
  # Connect object detection components
274
  object_detection_button.click(
 
277
  outputs=[detected_image_output, detected_objects_output]
278
  )
279
 
280
+ # Connect LLM chat components
281
+ send_btn.click(
282
+ fn=chat_respond,
283
+ inputs=[
284
+ model_selector,
285
+ onnx_file_selector,
286
+ system_prompt_input,
287
+ user_prompt_input,
288
+ chat_history,
289
+ do_sample_cb,
290
+ temperature_sl,
291
+ top_p_sl,
292
+ top_k_sl,
293
+ repetition_penalty_sl,
294
+ max_new_tokens_sl,
295
+ ],
296
+ outputs=[chat_bot, user_prompt_input],
297
+ )
298
+ # Also submit on Enter
299
+ user_prompt_input.submit(
300
+ fn=chat_respond,
301
+ inputs=[
302
+ model_selector,
303
+ onnx_file_selector,
304
+ system_prompt_input,
305
+ user_prompt_input,
306
+ chat_history,
307
+ do_sample_cb,
308
+ temperature_sl,
309
+ top_p_sl,
310
+ top_k_sl,
311
+ repetition_penalty_sl,
312
+ max_new_tokens_sl,
313
+ ],
314
+ outputs=[chat_bot, user_prompt_input],
315
  )
316
+ # Clear chat
317
+ def _clear_chat():
318
+ return [], gr.update(value="")
319
+ clear_btn.click(fn=_clear_chat, inputs=None, outputs=[chat_bot, user_prompt_input])
320
 
321
+ # Connect detected objects to user message input
322
  detected_objects_output.change(
323
  fn=update_user_prompt,
324
  inputs=[detected_objects_output, user_prompt_input],
325
+ outputs=user_prompt_input,
326
  )
327
 
328
  demo.launch()
requirements.txt CHANGED
@@ -2,4 +2,5 @@ gradio
2
  torch
3
  transformers
4
  pillow
5
- accelerate
 
 
2
  torch
3
  transformers
4
  pillow
5
+ accelerate
6
+ optimum[onnxruntime]