Joel Lundgren commited on
Commit
9ef29cf
·
1 Parent(s): a22ca8b
Files changed (1) hide show
  1. app.py +78 -12
app.py CHANGED
@@ -73,13 +73,24 @@ def get_llm(model_name, preferred_file: str | None = None):
73
  tokenizer_repo = tokenizer_repo_map[model_name]
74
 
75
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo)
 
 
 
 
 
 
76
 
77
  # Try a few common ONNX filenames found in community repos to avoid the
78
  # "Too many ONNX model files were found" ambiguity.
 
79
  candidate_files = [
 
 
80
  "model_q4.onnx",
81
  "model_quantized.onnx",
82
- "model_int8.onnx",
 
 
83
  "model.onnx",
84
  ]
85
 
@@ -100,6 +111,7 @@ def get_llm(model_name, preferred_file: str | None = None):
100
  subfolder="onnx",
101
  file_name=fname,
102
  )
 
103
  break
104
  except Exception as e:
105
  last_err = e
@@ -113,6 +125,12 @@ def get_llm(model_name, preferred_file: str | None = None):
113
  model.config.use_cache = False
114
  except Exception:
115
  pass
 
 
 
 
 
 
116
 
117
  llm_cache[cache_key] = (model, tokenizer)
118
  return model, tokenizer
@@ -161,17 +179,47 @@ def generate_text(
161
  )
162
 
163
  inputs = tokenizer([text], return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  with torch.inference_mode():
166
- gen_ids = model.generate(
167
- **inputs,
168
- max_new_tokens=int(max_new_tokens),
169
- do_sample=bool(do_sample),
170
- temperature=float(temperature),
171
- top_p=float(top_p),
172
- top_k=int(top_k),
173
- repetition_penalty=float(repetition_penalty),
174
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  # Decode only the newly generated tokens beyond the input length
177
  trimmed = [
@@ -179,6 +227,8 @@ def generate_text(
179
  for input_ids, output_ids in zip(inputs.input_ids, gen_ids)
180
  ]
181
  response = tokenizer.batch_decode(trimmed, skip_special_tokens=True)[0]
 
 
182
  return response
183
 
184
  def chat_respond(
@@ -233,7 +283,13 @@ def chat_respond(
233
  )
234
 
235
  trimmed = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, gen_ids)]
236
- reply = tokenizer.batch_decode(trimmed, skip_special_tokens=True)[0]
 
 
 
 
 
 
237
 
238
  new_history = (history or []) + [(message, reply)]
239
  return new_history, gr.update(value="")
@@ -251,7 +307,17 @@ with gr.Blocks() as demo:
251
  with gr.Tab("LLM Chat"):
252
  model_selector = gr.Dropdown(choices=["gemma3:1b", "qwen3:0.6b"], label="Select LLM Model")
253
  onnx_file_selector = gr.Dropdown(
254
- choices=["auto", "model_q4.onnx", "model_int8.onnx", "model_quantized.onnx", "model.onnx"],
 
 
 
 
 
 
 
 
 
 
255
  value="auto",
256
  label="ONNX file variant"
257
  )
 
73
  tokenizer_repo = tokenizer_repo_map[model_name]
74
 
75
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo)
76
+ # Ensure pad token exists (common for decoder-only models)
77
+ if tokenizer.pad_token_id is None:
78
+ try:
79
+ tokenizer.pad_token = tokenizer.eos_token
80
+ except Exception:
81
+ pass
82
 
83
  # Try a few common ONNX filenames found in community repos to avoid the
84
  # "Too many ONNX model files were found" ambiguity.
85
+ # Order: prefer int8, then q4f16, q4, general quantized, uint8, fp16, and finally generic.
86
  candidate_files = [
87
+ "model_int8.onnx",
88
+ "model_q4f16.onnx",
89
  "model_q4.onnx",
90
  "model_quantized.onnx",
91
+ "model_uint8.onnx",
92
+ "model_fp16.onnx",
93
+ "model_bnb4.onnx",
94
  "model.onnx",
95
  ]
96
 
 
111
  subfolder="onnx",
112
  file_name=fname,
113
  )
114
+ print(f"[ONNX] Loaded {onnx_repo}/onnx/{fname}")
115
  break
116
  except Exception as e:
117
  last_err = e
 
125
  model.config.use_cache = False
126
  except Exception:
127
  pass
128
+ # Mirror in generation config as well
129
+ if hasattr(model, "generation_config") and hasattr(model.generation_config, "use_cache"):
130
+ try:
131
+ model.generation_config.use_cache = False
132
+ except Exception:
133
+ pass
134
 
135
  llm_cache[cache_key] = (model, tokenizer)
136
  return model, tokenizer
 
179
  )
180
 
181
  inputs = tokenizer([text], return_tensors="pt")
182
+ # Ensure attention_mask is present and pad_token is defined
183
+ if "attention_mask" not in inputs:
184
+ inputs = tokenizer([text], return_tensors="pt", padding=True)
185
+ if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token_id", None) is not None:
186
+ tokenizer.pad_token_id = tokenizer.eos_token_id
187
+
188
+ gen_kwargs = {
189
+ "max_new_tokens": int(max_new_tokens),
190
+ "do_sample": bool(do_sample),
191
+ "temperature": float(temperature),
192
+ "top_p": float(top_p),
193
+ "top_k": int(top_k),
194
+ "repetition_penalty": float(repetition_penalty),
195
+ "use_cache": False,
196
+ }
197
+ if getattr(tokenizer, "eos_token_id", None) is not None:
198
+ gen_kwargs["eos_token_id"] = tokenizer.eos_token_id
199
 
200
  with torch.inference_mode():
201
+ try:
202
+ gen_ids = model.generate(
203
+ **inputs,
204
+ **gen_kwargs,
205
+ )
206
+ except Exception as e:
207
+ msg = str(e)
208
+ # Retry with int8 if KV cache shape mismatch and user didn't pick int8
209
+ if (
210
+ "past_key_values" in msg or "INVALID_ARGUMENT" in msg
211
+ ) and onnx_file_choice != "model_int8.onnx":
212
+ # Reload as int8 and retry once
213
+ model, tokenizer = get_llm(model_name, preferred_file="model_int8.onnx")
214
+ gen_kwargs["use_cache"] = False
215
+ gen_ids = model.generate(
216
+ **inputs,
217
+ **gen_kwargs,
218
+ )
219
+ # Mark that we switched variant
220
+ switched_variant_note = "\n[Note] Switched to model_int8.onnx due to KV-cache shape mismatch."
221
+ else:
222
+ raise
223
 
224
  # Decode only the newly generated tokens beyond the input length
225
  trimmed = [
 
227
  for input_ids, output_ids in zip(inputs.input_ids, gen_ids)
228
  ]
229
  response = tokenizer.batch_decode(trimmed, skip_special_tokens=True)[0]
230
+ if 'switched_variant_note' in locals():
231
+ response = response + switched_variant_note
232
  return response
233
 
234
  def chat_respond(
 
283
  )
284
 
285
  trimmed = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, gen_ids)]
286
+ try:
287
+ reply = tokenizer.batch_decode(trimmed, skip_special_tokens=True)[0]
288
+ except Exception as e:
289
+ # Gracefully surface decoding issues
290
+ reply = f"[Error] Failed to decode model output: {e}"
291
+ if 'switched_variant_note' in locals():
292
+ reply = reply + switched_variant_note
293
 
294
  new_history = (history or []) + [(message, reply)]
295
  return new_history, gr.update(value="")
 
307
  with gr.Tab("LLM Chat"):
308
  model_selector = gr.Dropdown(choices=["gemma3:1b", "qwen3:0.6b"], label="Select LLM Model")
309
  onnx_file_selector = gr.Dropdown(
310
+ choices=[
311
+ "auto",
312
+ "model_int8.onnx",
313
+ "model_q4f16.onnx",
314
+ "model_q4.onnx",
315
+ "model_quantized.onnx",
316
+ "model_uint8.onnx",
317
+ "model_fp16.onnx",
318
+ "model_bnb4.onnx",
319
+ "model.onnx",
320
+ ],
321
  value="auto",
322
  label="ONNX file variant"
323
  )