Spaces:
Runtime error
Runtime error
File size: 11,580 Bytes
64a7b3c f32efcc 51d35d6 a22ca8b f32efcc 64a7b3c d7b4521 f32efcc d7b4521 f32efcc d7b4521 f32efcc d7b4521 f32efcc d7b4521 f32efcc a22ca8b f32efcc a22ca8b f32efcc a22ca8b d7b4521 a22ca8b dc90ed9 f199719 a22ca8b 9ef29cf 51d35d6 a22ca8b 9ef29cf a22ca8b 9ef29cf a22ca8b 9ef29cf a22ca8b f32efcc a22ca8b 9ef29cf a22ca8b 9ef29cf a22ca8b f32efcc 51d35d6 f32efcc a22ca8b 51d35d6 a22ca8b f32efcc dc90ed9 a22ca8b f199719 a22ca8b f199719 a22ca8b f199719 f32efcc a22ca8b f32efcc a22ca8b 9ef29cf a22ca8b 51d35d6 a22ca8b f199719 a22ca8b f32efcc a22ca8b 51d35d6 a22ca8b f32efcc 215c956 a22ca8b 9ef29cf a22ca8b f32efcc a22ca8b f32efcc a22ca8b f32efcc a22ca8b f32efcc a22ca8b f32efcc a22ca8b f32efcc 64a7b3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 |
import gradio as gr
from PIL import Image, ImageDraw
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from optimum.onnxruntime import ORTModelForCausalLM
import torch
# Load the object detection pipeline
object_detector = pipeline("object-detection", model="valentinafeve/yolos-fashionpedia")
def detect_objects(image):
"""
Performs object detection on an image using a transformers pipeline.
Args:
image (PIL.Image.Image): The input image.
Returns:
tuple: A tuple containing:
- PIL.Image.Image: The image with detected objects annotated.
- str: A string listing the names of detected objects.
"""
# Perform inference
results = object_detector(image)
# Create a copy of the image to draw on
annotated_image = image.copy()
draw = ImageDraw.Draw(annotated_image)
detected_objects = []
# Extract bounding boxes, classes, and confidences
for result in results:
box = result['box']
label = result['label']
score = result['score']
xyxy = [box['xmin'], box['ymin'], box['xmax'], box['ymax']]
detected_objects.append(label)
# Draw bounding box
draw.rectangle(xyxy, outline="red", width=2)
# Draw label
draw.text((xyxy[0], xyxy[1]), f"{label} ({score:.2f})", fill="red")
# Create a unique, comma-separated string of detected objects
detected_objects_str = ", ".join(list(set(detected_objects)))
if not detected_objects_str:
detected_objects_str = "No objects detected."
return annotated_image, detected_objects_str
# Cache for LLM models and tokenizers (ONNX Runtime)
llm_cache = {}
def get_llm(model_name, preferred_file: str | None = None):
cache_key = (model_name, preferred_file or "auto")
if cache_key in llm_cache:
return llm_cache[cache_key]
# ONNX model repositories on the Hub
onnx_repo_map = {
"gemma3:1b": "onnx-community/gemma-3-1b-it-ONNX-GQA",
"qwen3:0.6b": "onnx-community/Qwen3-0.6B-ONNX",
}
# Original repos to fetch correct tokenizer + chat templates
tokenizer_repo_map = {
"gemma3:1b": "google/gemma-3-1b-it",
"qwen3:0.6b": "Qwen/Qwen3-0.6B-Instruct",
}
onnx_repo = onnx_repo_map[model_name]
tokenizer_repo = tokenizer_repo_map[model_name]
tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo)
# Ensure pad token exists (common for decoder-only models)
if tokenizer.pad_token_id is None and getattr(tokenizer, "eos_token_id", None) is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# Try a few common ONNX filenames found in community repos to avoid the
# "Too many ONNX model files were found" ambiguity.
# Order: prefer int8, then q4f16, q4, general quantized, uint8, fp16, and finally generic.
candidate_files = [
"model_int8.onnx",
"model_q4f16.onnx",
"model_q4.onnx",
"model_quantized.onnx",
"model_uint8.onnx",
"model_fp16.onnx",
"model_bnb4.onnx",
"model.onnx",
]
model = None
last_err = None
ordered = candidate_files
if preferred_file and preferred_file in candidate_files:
# Put preferred file first
ordered = [preferred_file] + [f for f in candidate_files if f != preferred_file]
elif preferred_file and preferred_file not in candidate_files:
# If user typed a specific known filename not in our shortlist, try it first anyway
ordered = [preferred_file] + candidate_files
for fname in ordered:
try:
model = ORTModelForCausalLM.from_pretrained(
onnx_repo,
subfolder="onnx",
file_name=fname,
)
print(f"[ONNX] Loaded {onnx_repo}/onnx/{fname}")
break
except Exception as e:
last_err = e
continue
if model is None:
raise RuntimeError(f"Failed to load ONNX model from {onnx_repo}. Last error: {last_err}")
# Disable cache to avoid past_key_values shape issues on some ONNX builds
if hasattr(model.config, "use_cache"):
try:
model.config.use_cache = False
except Exception:
pass
# Mirror in generation config as well
if hasattr(model, "generation_config") and hasattr(model.generation_config, "use_cache"):
try:
model.generation_config.use_cache = False
except Exception:
pass
llm_cache[cache_key] = (model, tokenizer)
return model, tokenizer
def update_user_prompt(detected_objects, current_prompt):
if "No objects detected" in detected_objects:
return current_prompt
if current_prompt:
new_prompt = f"{current_prompt}, {detected_objects}"
else:
new_prompt = f"Objects detected in the image: {detected_objects}"
return new_prompt
def generate_text(
model_name,
onnx_file_choice,
messages,
do_sample,
temperature,
top_p,
top_k,
repetition_penalty,
max_new_tokens,
):
model, tokenizer = get_llm(model_name, preferred_file=None if onnx_file_choice == "auto" else onnx_file_choice)
chat_template_kwargs = {
"tokenize": False,
"add_generation_prompt": True,
}
# Disable "thinking" for Qwen models
if "qwen" in model_name.lower():
chat_template_kwargs["enable_thinking"] = False
text = tokenizer.apply_chat_template(
messages,
**chat_template_kwargs,
)
inputs = tokenizer([text], return_tensors="pt")
# Ensure attention_mask is present and pad_token is defined
if "attention_mask" not in inputs:
inputs = tokenizer([text], return_tensors="pt", padding=True)
gen_kwargs = {
"max_new_tokens": int(max_new_tokens),
"do_sample": bool(do_sample),
"temperature": float(temperature),
"top_p": float(top_p),
"top_k": int(top_k),
"repetition_penalty": float(repetition_penalty),
"use_cache": False,
}
if getattr(tokenizer, "eos_token_id", None) is not None:
gen_kwargs["eos_token_id"] = tokenizer.eos_token_id
with torch.inference_mode():
gen_ids = model.generate(
**inputs,
**gen_kwargs,
)
trimmed = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, gen_ids)
]
response = tokenizer.batch_decode(trimmed, skip_special_tokens=True)[0]
return response
def chat_respond(
model_name,
onnx_file_choice,
system_prompt,
message,
history,
do_sample,
temperature,
top_p,
top_k,
repetition_penalty,
max_new_tokens,
):
"""Builds a chat messages list from history + current user message, generates a reply, and returns updated history and an empty input box."""
# Guard: empty message
if not (message and message.strip()):
return history, gr.update(value="")
# Build messages: system, then alternating user/assistant from history, then current user
messages = [{"role": "system", "content": system_prompt}]
for u, a in (history or []):
if u:
messages.append({"role": "user", "content": u})
if a:
messages.append({"role": "assistant", "content": a})
messages.append({"role": "user", "content": message})
reply = generate_text(
model_name=model_name,
onnx_file_choice=onnx_file_choice,
messages=messages,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
max_new_tokens=max_new_tokens,
)
new_history = (history or []) + [(message, reply)]
return new_history, gr.update(value="")
with gr.Blocks() as demo:
gr.Markdown("# Black Box: Object Detection and LLM Chat")
with gr.Tab("Object Detection"):
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Image or Use Webcam", sources=["upload", "webcam"])
detected_image_output = gr.Image(label="Detected Objects")
object_detection_button = gr.Button("Detect Objects")
detected_objects_output = gr.Textbox(label="Detected Objects")
with gr.Tab("LLM Chat"):
model_selector = gr.Dropdown(choices=["gemma3:1b", "qwen3:0.6b"], label="Select LLM Model")
onnx_file_selector = gr.Dropdown(
choices=[
"auto",
"model_int8.onnx",
"model_q4f16.onnx",
"model_q4.onnx",
"model_quantized.onnx",
"model_uint8.onnx",
"model_fp16.onnx",
"model_bnb4.onnx",
"model.onnx",
],
value="auto",
label="ONNX file variant"
)
system_prompt_input = gr.Textbox(label="System Prompt", value="You are a helpful assistant.")
chat_bot = gr.Chatbot(height=360, label="Conversation")
chat_history = gr.State([])
user_prompt_input = gr.Textbox(label="Message", placeholder="Type your message and press Send...", lines=3)
with gr.Accordion("Generation settings", open=False):
do_sample_cb = gr.Checkbox(value=True, label="do_sample")
temperature_sl = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.05, label="temperature")
top_p_sl = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="top_p")
top_k_sl = gr.Slider(minimum=0, maximum=200, value=50, step=1, label="top_k")
repetition_penalty_sl = gr.Slider(minimum=0.8, maximum=2.0, value=1.05, step=0.01, label="repetition_penalty")
max_new_tokens_sl = gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="max_new_tokens")
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear chat")
# Connect object detection components
object_detection_button.click(
fn=detect_objects,
inputs=image_input,
outputs=[detected_image_output, detected_objects_output]
)
# Connect LLM chat components
send_btn.click(
fn=chat_respond,
inputs=[
model_selector,
onnx_file_selector,
system_prompt_input,
user_prompt_input,
chat_history,
do_sample_cb,
temperature_sl,
top_p_sl,
top_k_sl,
repetition_penalty_sl,
max_new_tokens_sl,
],
outputs=[chat_bot, user_prompt_input],
)
# Also submit on Enter
user_prompt_input.submit(
fn=chat_respond,
inputs=[
model_selector,
onnx_file_selector,
system_prompt_input,
user_prompt_input,
chat_history,
do_sample_cb,
temperature_sl,
top_p_sl,
top_k_sl,
repetition_penalty_sl,
max_new_tokens_sl,
],
outputs=[chat_bot, user_prompt_input],
)
# Clear chat
def _clear_chat():
return [], gr.update(value="")
clear_btn.click(fn=_clear_chat, inputs=None, outputs=[chat_bot, user_prompt_input])
# Connect detected objects to user message input
detected_objects_output.change(
fn=update_user_prompt,
inputs=[detected_objects_output, user_prompt_input],
outputs=user_prompt_input,
)
demo.launch()
|