Joel Lundgren commited on
Commit
51d35d6
·
1 Parent(s): 9ef29cf
Files changed (1) hide show
  1. app.py +21 -72
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from PIL import Image, ImageDraw
3
- from transformers import pipeline, AutoTokenizer
4
  from optimum.onnxruntime import ORTModelForCausalLM
5
  import torch
6
 
@@ -74,11 +74,8 @@ def get_llm(model_name, preferred_file: str | None = None):
74
 
75
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo)
76
  # Ensure pad token exists (common for decoder-only models)
77
- if tokenizer.pad_token_id is None:
78
- try:
79
- tokenizer.pad_token = tokenizer.eos_token
80
- except Exception:
81
- pass
82
 
83
  # Try a few common ONNX filenames found in community repos to avoid the
84
  # "Too many ONNX model files were found" ambiguity.
@@ -135,6 +132,7 @@ def get_llm(model_name, preferred_file: str | None = None):
135
  llm_cache[cache_key] = (model, tokenizer)
136
  return model, tokenizer
137
 
 
138
  def update_user_prompt(detected_objects, current_prompt):
139
  if "No objects detected" in detected_objects:
140
  return current_prompt
@@ -149,8 +147,7 @@ def update_user_prompt(detected_objects, current_prompt):
149
  def generate_text(
150
  model_name,
151
  onnx_file_choice,
152
- system_prompt,
153
- user_prompt,
154
  do_sample,
155
  temperature,
156
  top_p,
@@ -160,10 +157,6 @@ def generate_text(
160
  ):
161
  model, tokenizer = get_llm(model_name, preferred_file=None if onnx_file_choice == "auto" else onnx_file_choice)
162
 
163
- messages = [
164
- {"role": "system", "content": system_prompt},
165
- {"role": "user", "content": user_prompt},
166
- ]
167
 
168
  chat_template_kwargs = {
169
  "tokenize": False,
@@ -182,8 +175,6 @@ def generate_text(
182
  # Ensure attention_mask is present and pad_token is defined
183
  if "attention_mask" not in inputs:
184
  inputs = tokenizer([text], return_tensors="pt", padding=True)
185
- if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token_id", None) is not None:
186
- tokenizer.pad_token_id = tokenizer.eos_token_id
187
 
188
  gen_kwargs = {
189
  "max_new_tokens": int(max_new_tokens),
@@ -198,37 +189,16 @@ def generate_text(
198
  gen_kwargs["eos_token_id"] = tokenizer.eos_token_id
199
 
200
  with torch.inference_mode():
201
- try:
202
- gen_ids = model.generate(
203
- **inputs,
204
- **gen_kwargs,
205
- )
206
- except Exception as e:
207
- msg = str(e)
208
- # Retry with int8 if KV cache shape mismatch and user didn't pick int8
209
- if (
210
- "past_key_values" in msg or "INVALID_ARGUMENT" in msg
211
- ) and onnx_file_choice != "model_int8.onnx":
212
- # Reload as int8 and retry once
213
- model, tokenizer = get_llm(model_name, preferred_file="model_int8.onnx")
214
- gen_kwargs["use_cache"] = False
215
- gen_ids = model.generate(
216
- **inputs,
217
- **gen_kwargs,
218
- )
219
- # Mark that we switched variant
220
- switched_variant_note = "\n[Note] Switched to model_int8.onnx due to KV-cache shape mismatch."
221
- else:
222
- raise
223
-
224
- # Decode only the newly generated tokens beyond the input length
225
  trimmed = [
226
  output_ids[len(input_ids):]
227
  for input_ids, output_ids in zip(inputs.input_ids, gen_ids)
228
  ]
229
  response = tokenizer.batch_decode(trimmed, skip_special_tokens=True)[0]
230
- if 'switched_variant_note' in locals():
231
- response = response + switched_variant_note
232
  return response
233
 
234
  def chat_respond(
@@ -258,38 +228,17 @@ def chat_respond(
258
  messages.append({"role": "assistant", "content": a})
259
  messages.append({"role": "user", "content": message})
260
 
261
- # Generate using the same path as generate_text, but inline to avoid extra serialization
262
- model, tokenizer = get_llm(model_name, preferred_file=None if onnx_file_choice == "auto" else onnx_file_choice)
263
-
264
- chat_template_kwargs = {
265
- "tokenize": False,
266
- "add_generation_prompt": True,
267
- }
268
- if "qwen" in model_name.lower():
269
- chat_template_kwargs["enable_thinking"] = False
270
-
271
- text = tokenizer.apply_chat_template(messages, **chat_template_kwargs)
272
- inputs = tokenizer([text], return_tensors="pt")
273
-
274
- with torch.inference_mode():
275
- gen_ids = model.generate(
276
- **inputs,
277
- max_new_tokens=int(max_new_tokens),
278
- do_sample=bool(do_sample),
279
- temperature=float(temperature),
280
- top_p=float(top_p),
281
- top_k=int(top_k),
282
- repetition_penalty=float(repetition_penalty),
283
- )
284
-
285
- trimmed = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, gen_ids)]
286
- try:
287
- reply = tokenizer.batch_decode(trimmed, skip_special_tokens=True)[0]
288
- except Exception as e:
289
- # Gracefully surface decoding issues
290
- reply = f"[Error] Failed to decode model output: {e}"
291
- if 'switched_variant_note' in locals():
292
- reply = reply + switched_variant_note
293
 
294
  new_history = (history or []) + [(message, reply)]
295
  return new_history, gr.update(value="")
 
1
  import gradio as gr
2
  from PIL import Image, ImageDraw
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
4
  from optimum.onnxruntime import ORTModelForCausalLM
5
  import torch
6
 
 
74
 
75
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo)
76
  # Ensure pad token exists (common for decoder-only models)
77
+ if tokenizer.pad_token_id is None and getattr(tokenizer, "eos_token_id", None) is not None:
78
+ tokenizer.pad_token_id = tokenizer.eos_token_id
 
 
 
79
 
80
  # Try a few common ONNX filenames found in community repos to avoid the
81
  # "Too many ONNX model files were found" ambiguity.
 
132
  llm_cache[cache_key] = (model, tokenizer)
133
  return model, tokenizer
134
 
135
+
136
  def update_user_prompt(detected_objects, current_prompt):
137
  if "No objects detected" in detected_objects:
138
  return current_prompt
 
147
  def generate_text(
148
  model_name,
149
  onnx_file_choice,
150
+ messages,
 
151
  do_sample,
152
  temperature,
153
  top_p,
 
157
  ):
158
  model, tokenizer = get_llm(model_name, preferred_file=None if onnx_file_choice == "auto" else onnx_file_choice)
159
 
 
 
 
 
160
 
161
  chat_template_kwargs = {
162
  "tokenize": False,
 
175
  # Ensure attention_mask is present and pad_token is defined
176
  if "attention_mask" not in inputs:
177
  inputs = tokenizer([text], return_tensors="pt", padding=True)
 
 
178
 
179
  gen_kwargs = {
180
  "max_new_tokens": int(max_new_tokens),
 
189
  gen_kwargs["eos_token_id"] = tokenizer.eos_token_id
190
 
191
  with torch.inference_mode():
192
+ gen_ids = model.generate(
193
+ **inputs,
194
+ **gen_kwargs,
195
+ )
196
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  trimmed = [
198
  output_ids[len(input_ids):]
199
  for input_ids, output_ids in zip(inputs.input_ids, gen_ids)
200
  ]
201
  response = tokenizer.batch_decode(trimmed, skip_special_tokens=True)[0]
 
 
202
  return response
203
 
204
  def chat_respond(
 
228
  messages.append({"role": "assistant", "content": a})
229
  messages.append({"role": "user", "content": message})
230
 
231
+ reply = generate_text(
232
+ model_name=model_name,
233
+ onnx_file_choice=onnx_file_choice,
234
+ messages=messages,
235
+ do_sample=do_sample,
236
+ temperature=temperature,
237
+ top_p=top_p,
238
+ top_k=top_k,
239
+ repetition_penalty=repetition_penalty,
240
+ max_new_tokens=max_new_tokens,
241
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  new_history = (history or []) + [(message, reply)]
244
  return new_history, gr.update(value="")