Spaces:
Runtime error
Runtime error
Joel Lundgren
commited on
Commit
·
9ef29cf
1
Parent(s):
a22ca8b
test
Browse files
app.py
CHANGED
|
@@ -73,13 +73,24 @@ def get_llm(model_name, preferred_file: str | None = None):
|
|
| 73 |
tokenizer_repo = tokenizer_repo_map[model_name]
|
| 74 |
|
| 75 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
# Try a few common ONNX filenames found in community repos to avoid the
|
| 78 |
# "Too many ONNX model files were found" ambiguity.
|
|
|
|
| 79 |
candidate_files = [
|
|
|
|
|
|
|
| 80 |
"model_q4.onnx",
|
| 81 |
"model_quantized.onnx",
|
| 82 |
-
"
|
|
|
|
|
|
|
| 83 |
"model.onnx",
|
| 84 |
]
|
| 85 |
|
|
@@ -100,6 +111,7 @@ def get_llm(model_name, preferred_file: str | None = None):
|
|
| 100 |
subfolder="onnx",
|
| 101 |
file_name=fname,
|
| 102 |
)
|
|
|
|
| 103 |
break
|
| 104 |
except Exception as e:
|
| 105 |
last_err = e
|
|
@@ -113,6 +125,12 @@ def get_llm(model_name, preferred_file: str | None = None):
|
|
| 113 |
model.config.use_cache = False
|
| 114 |
except Exception:
|
| 115 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
llm_cache[cache_key] = (model, tokenizer)
|
| 118 |
return model, tokenizer
|
|
@@ -161,17 +179,47 @@ def generate_text(
|
|
| 161 |
)
|
| 162 |
|
| 163 |
inputs = tokenizer([text], return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
with torch.inference_mode():
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
# Decode only the newly generated tokens beyond the input length
|
| 177 |
trimmed = [
|
|
@@ -179,6 +227,8 @@ def generate_text(
|
|
| 179 |
for input_ids, output_ids in zip(inputs.input_ids, gen_ids)
|
| 180 |
]
|
| 181 |
response = tokenizer.batch_decode(trimmed, skip_special_tokens=True)[0]
|
|
|
|
|
|
|
| 182 |
return response
|
| 183 |
|
| 184 |
def chat_respond(
|
|
@@ -233,7 +283,13 @@ def chat_respond(
|
|
| 233 |
)
|
| 234 |
|
| 235 |
trimmed = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, gen_ids)]
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
new_history = (history or []) + [(message, reply)]
|
| 239 |
return new_history, gr.update(value="")
|
|
@@ -251,7 +307,17 @@ with gr.Blocks() as demo:
|
|
| 251 |
with gr.Tab("LLM Chat"):
|
| 252 |
model_selector = gr.Dropdown(choices=["gemma3:1b", "qwen3:0.6b"], label="Select LLM Model")
|
| 253 |
onnx_file_selector = gr.Dropdown(
|
| 254 |
-
choices=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
value="auto",
|
| 256 |
label="ONNX file variant"
|
| 257 |
)
|
|
|
|
| 73 |
tokenizer_repo = tokenizer_repo_map[model_name]
|
| 74 |
|
| 75 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo)
|
| 76 |
+
# Ensure pad token exists (common for decoder-only models)
|
| 77 |
+
if tokenizer.pad_token_id is None:
|
| 78 |
+
try:
|
| 79 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 80 |
+
except Exception:
|
| 81 |
+
pass
|
| 82 |
|
| 83 |
# Try a few common ONNX filenames found in community repos to avoid the
|
| 84 |
# "Too many ONNX model files were found" ambiguity.
|
| 85 |
+
# Order: prefer int8, then q4f16, q4, general quantized, uint8, fp16, and finally generic.
|
| 86 |
candidate_files = [
|
| 87 |
+
"model_int8.onnx",
|
| 88 |
+
"model_q4f16.onnx",
|
| 89 |
"model_q4.onnx",
|
| 90 |
"model_quantized.onnx",
|
| 91 |
+
"model_uint8.onnx",
|
| 92 |
+
"model_fp16.onnx",
|
| 93 |
+
"model_bnb4.onnx",
|
| 94 |
"model.onnx",
|
| 95 |
]
|
| 96 |
|
|
|
|
| 111 |
subfolder="onnx",
|
| 112 |
file_name=fname,
|
| 113 |
)
|
| 114 |
+
print(f"[ONNX] Loaded {onnx_repo}/onnx/{fname}")
|
| 115 |
break
|
| 116 |
except Exception as e:
|
| 117 |
last_err = e
|
|
|
|
| 125 |
model.config.use_cache = False
|
| 126 |
except Exception:
|
| 127 |
pass
|
| 128 |
+
# Mirror in generation config as well
|
| 129 |
+
if hasattr(model, "generation_config") and hasattr(model.generation_config, "use_cache"):
|
| 130 |
+
try:
|
| 131 |
+
model.generation_config.use_cache = False
|
| 132 |
+
except Exception:
|
| 133 |
+
pass
|
| 134 |
|
| 135 |
llm_cache[cache_key] = (model, tokenizer)
|
| 136 |
return model, tokenizer
|
|
|
|
| 179 |
)
|
| 180 |
|
| 181 |
inputs = tokenizer([text], return_tensors="pt")
|
| 182 |
+
# Ensure attention_mask is present and pad_token is defined
|
| 183 |
+
if "attention_mask" not in inputs:
|
| 184 |
+
inputs = tokenizer([text], return_tensors="pt", padding=True)
|
| 185 |
+
if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token_id", None) is not None:
|
| 186 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 187 |
+
|
| 188 |
+
gen_kwargs = {
|
| 189 |
+
"max_new_tokens": int(max_new_tokens),
|
| 190 |
+
"do_sample": bool(do_sample),
|
| 191 |
+
"temperature": float(temperature),
|
| 192 |
+
"top_p": float(top_p),
|
| 193 |
+
"top_k": int(top_k),
|
| 194 |
+
"repetition_penalty": float(repetition_penalty),
|
| 195 |
+
"use_cache": False,
|
| 196 |
+
}
|
| 197 |
+
if getattr(tokenizer, "eos_token_id", None) is not None:
|
| 198 |
+
gen_kwargs["eos_token_id"] = tokenizer.eos_token_id
|
| 199 |
|
| 200 |
with torch.inference_mode():
|
| 201 |
+
try:
|
| 202 |
+
gen_ids = model.generate(
|
| 203 |
+
**inputs,
|
| 204 |
+
**gen_kwargs,
|
| 205 |
+
)
|
| 206 |
+
except Exception as e:
|
| 207 |
+
msg = str(e)
|
| 208 |
+
# Retry with int8 if KV cache shape mismatch and user didn't pick int8
|
| 209 |
+
if (
|
| 210 |
+
"past_key_values" in msg or "INVALID_ARGUMENT" in msg
|
| 211 |
+
) and onnx_file_choice != "model_int8.onnx":
|
| 212 |
+
# Reload as int8 and retry once
|
| 213 |
+
model, tokenizer = get_llm(model_name, preferred_file="model_int8.onnx")
|
| 214 |
+
gen_kwargs["use_cache"] = False
|
| 215 |
+
gen_ids = model.generate(
|
| 216 |
+
**inputs,
|
| 217 |
+
**gen_kwargs,
|
| 218 |
+
)
|
| 219 |
+
# Mark that we switched variant
|
| 220 |
+
switched_variant_note = "\n[Note] Switched to model_int8.onnx due to KV-cache shape mismatch."
|
| 221 |
+
else:
|
| 222 |
+
raise
|
| 223 |
|
| 224 |
# Decode only the newly generated tokens beyond the input length
|
| 225 |
trimmed = [
|
|
|
|
| 227 |
for input_ids, output_ids in zip(inputs.input_ids, gen_ids)
|
| 228 |
]
|
| 229 |
response = tokenizer.batch_decode(trimmed, skip_special_tokens=True)[0]
|
| 230 |
+
if 'switched_variant_note' in locals():
|
| 231 |
+
response = response + switched_variant_note
|
| 232 |
return response
|
| 233 |
|
| 234 |
def chat_respond(
|
|
|
|
| 283 |
)
|
| 284 |
|
| 285 |
trimmed = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, gen_ids)]
|
| 286 |
+
try:
|
| 287 |
+
reply = tokenizer.batch_decode(trimmed, skip_special_tokens=True)[0]
|
| 288 |
+
except Exception as e:
|
| 289 |
+
# Gracefully surface decoding issues
|
| 290 |
+
reply = f"[Error] Failed to decode model output: {e}"
|
| 291 |
+
if 'switched_variant_note' in locals():
|
| 292 |
+
reply = reply + switched_variant_note
|
| 293 |
|
| 294 |
new_history = (history or []) + [(message, reply)]
|
| 295 |
return new_history, gr.update(value="")
|
|
|
|
| 307 |
with gr.Tab("LLM Chat"):
|
| 308 |
model_selector = gr.Dropdown(choices=["gemma3:1b", "qwen3:0.6b"], label="Select LLM Model")
|
| 309 |
onnx_file_selector = gr.Dropdown(
|
| 310 |
+
choices=[
|
| 311 |
+
"auto",
|
| 312 |
+
"model_int8.onnx",
|
| 313 |
+
"model_q4f16.onnx",
|
| 314 |
+
"model_q4.onnx",
|
| 315 |
+
"model_quantized.onnx",
|
| 316 |
+
"model_uint8.onnx",
|
| 317 |
+
"model_fp16.onnx",
|
| 318 |
+
"model_bnb4.onnx",
|
| 319 |
+
"model.onnx",
|
| 320 |
+
],
|
| 321 |
value="auto",
|
| 322 |
label="ONNX file variant"
|
| 323 |
)
|