import gradio as gr from PIL import Image, ImageDraw from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM from optimum.onnxruntime import ORTModelForCausalLM import torch # Load the object detection pipeline object_detector = pipeline("object-detection", model="valentinafeve/yolos-fashionpedia") def detect_objects(image): """ Performs object detection on an image using a transformers pipeline. Args: image (PIL.Image.Image): The input image. Returns: tuple: A tuple containing: - PIL.Image.Image: The image with detected objects annotated. - str: A string listing the names of detected objects. """ # Perform inference results = object_detector(image) # Create a copy of the image to draw on annotated_image = image.copy() draw = ImageDraw.Draw(annotated_image) detected_objects = [] # Extract bounding boxes, classes, and confidences for result in results: box = result['box'] label = result['label'] score = result['score'] xyxy = [box['xmin'], box['ymin'], box['xmax'], box['ymax']] detected_objects.append(label) # Draw bounding box draw.rectangle(xyxy, outline="red", width=2) # Draw label draw.text((xyxy[0], xyxy[1]), f"{label} ({score:.2f})", fill="red") # Create a unique, comma-separated string of detected objects detected_objects_str = ", ".join(list(set(detected_objects))) if not detected_objects_str: detected_objects_str = "No objects detected." return annotated_image, detected_objects_str # Cache for LLM models and tokenizers (ONNX Runtime) llm_cache = {} def get_llm(model_name, preferred_file: str | None = None): cache_key = (model_name, preferred_file or "auto") if cache_key in llm_cache: return llm_cache[cache_key] # ONNX model repositories on the Hub onnx_repo_map = { "gemma3:1b": "onnx-community/gemma-3-1b-it-ONNX-GQA", "qwen3:0.6b": "onnx-community/Qwen3-0.6B-ONNX", } # Original repos to fetch correct tokenizer + chat templates tokenizer_repo_map = { "gemma3:1b": "google/gemma-3-1b-it", "qwen3:0.6b": "Qwen/Qwen3-0.6B-Instruct", } onnx_repo = onnx_repo_map[model_name] tokenizer_repo = tokenizer_repo_map[model_name] tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo) # Ensure pad token exists (common for decoder-only models) if tokenizer.pad_token_id is None and getattr(tokenizer, "eos_token_id", None) is not None: tokenizer.pad_token_id = tokenizer.eos_token_id # Try a few common ONNX filenames found in community repos to avoid the # "Too many ONNX model files were found" ambiguity. # Order: prefer int8, then q4f16, q4, general quantized, uint8, fp16, and finally generic. candidate_files = [ "model_int8.onnx", "model_q4f16.onnx", "model_q4.onnx", "model_quantized.onnx", "model_uint8.onnx", "model_fp16.onnx", "model_bnb4.onnx", "model.onnx", ] model = None last_err = None ordered = candidate_files if preferred_file and preferred_file in candidate_files: # Put preferred file first ordered = [preferred_file] + [f for f in candidate_files if f != preferred_file] elif preferred_file and preferred_file not in candidate_files: # If user typed a specific known filename not in our shortlist, try it first anyway ordered = [preferred_file] + candidate_files for fname in ordered: try: model = ORTModelForCausalLM.from_pretrained( onnx_repo, subfolder="onnx", file_name=fname, ) print(f"[ONNX] Loaded {onnx_repo}/onnx/{fname}") break except Exception as e: last_err = e continue if model is None: raise RuntimeError(f"Failed to load ONNX model from {onnx_repo}. Last error: {last_err}") # Disable cache to avoid past_key_values shape issues on some ONNX builds if hasattr(model.config, "use_cache"): try: model.config.use_cache = False except Exception: pass # Mirror in generation config as well if hasattr(model, "generation_config") and hasattr(model.generation_config, "use_cache"): try: model.generation_config.use_cache = False except Exception: pass llm_cache[cache_key] = (model, tokenizer) return model, tokenizer def update_user_prompt(detected_objects, current_prompt): if "No objects detected" in detected_objects: return current_prompt if current_prompt: new_prompt = f"{current_prompt}, {detected_objects}" else: new_prompt = f"Objects detected in the image: {detected_objects}" return new_prompt def generate_text( model_name, onnx_file_choice, messages, do_sample, temperature, top_p, top_k, repetition_penalty, max_new_tokens, ): model, tokenizer = get_llm(model_name, preferred_file=None if onnx_file_choice == "auto" else onnx_file_choice) chat_template_kwargs = { "tokenize": False, "add_generation_prompt": True, } # Disable "thinking" for Qwen models if "qwen" in model_name.lower(): chat_template_kwargs["enable_thinking"] = False text = tokenizer.apply_chat_template( messages, **chat_template_kwargs, ) inputs = tokenizer([text], return_tensors="pt") # Ensure attention_mask is present and pad_token is defined if "attention_mask" not in inputs: inputs = tokenizer([text], return_tensors="pt", padding=True) gen_kwargs = { "max_new_tokens": int(max_new_tokens), "do_sample": bool(do_sample), "temperature": float(temperature), "top_p": float(top_p), "top_k": int(top_k), "repetition_penalty": float(repetition_penalty), "use_cache": False, } if getattr(tokenizer, "eos_token_id", None) is not None: gen_kwargs["eos_token_id"] = tokenizer.eos_token_id with torch.inference_mode(): gen_ids = model.generate( **inputs, **gen_kwargs, ) trimmed = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, gen_ids) ] response = tokenizer.batch_decode(trimmed, skip_special_tokens=True)[0] return response def chat_respond( model_name, onnx_file_choice, system_prompt, message, history, do_sample, temperature, top_p, top_k, repetition_penalty, max_new_tokens, ): """Builds a chat messages list from history + current user message, generates a reply, and returns updated history and an empty input box.""" # Guard: empty message if not (message and message.strip()): return history, gr.update(value="") # Build messages: system, then alternating user/assistant from history, then current user messages = [{"role": "system", "content": system_prompt}] for u, a in (history or []): if u: messages.append({"role": "user", "content": u}) if a: messages.append({"role": "assistant", "content": a}) messages.append({"role": "user", "content": message}) reply = generate_text( model_name=model_name, onnx_file_choice=onnx_file_choice, messages=messages, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=max_new_tokens, ) new_history = (history or []) + [(message, reply)] return new_history, gr.update(value="") with gr.Blocks() as demo: gr.Markdown("# Black Box: Object Detection and LLM Chat") with gr.Tab("Object Detection"): with gr.Row(): image_input = gr.Image(type="pil", label="Upload Image or Use Webcam", sources=["upload", "webcam"]) detected_image_output = gr.Image(label="Detected Objects") object_detection_button = gr.Button("Detect Objects") detected_objects_output = gr.Textbox(label="Detected Objects") with gr.Tab("LLM Chat"): model_selector = gr.Dropdown(choices=["gemma3:1b", "qwen3:0.6b"], label="Select LLM Model") onnx_file_selector = gr.Dropdown( choices=[ "auto", "model_int8.onnx", "model_q4f16.onnx", "model_q4.onnx", "model_quantized.onnx", "model_uint8.onnx", "model_fp16.onnx", "model_bnb4.onnx", "model.onnx", ], value="auto", label="ONNX file variant" ) system_prompt_input = gr.Textbox(label="System Prompt", value="You are a helpful assistant.") chat_bot = gr.Chatbot(height=360, label="Conversation") chat_history = gr.State([]) user_prompt_input = gr.Textbox(label="Message", placeholder="Type your message and press Send...", lines=3) with gr.Accordion("Generation settings", open=False): do_sample_cb = gr.Checkbox(value=True, label="do_sample") temperature_sl = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.05, label="temperature") top_p_sl = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="top_p") top_k_sl = gr.Slider(minimum=0, maximum=200, value=50, step=1, label="top_k") repetition_penalty_sl = gr.Slider(minimum=0.8, maximum=2.0, value=1.05, step=0.01, label="repetition_penalty") max_new_tokens_sl = gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="max_new_tokens") with gr.Row(): send_btn = gr.Button("Send", variant="primary") clear_btn = gr.Button("Clear chat") # Connect object detection components object_detection_button.click( fn=detect_objects, inputs=image_input, outputs=[detected_image_output, detected_objects_output] ) # Connect LLM chat components send_btn.click( fn=chat_respond, inputs=[ model_selector, onnx_file_selector, system_prompt_input, user_prompt_input, chat_history, do_sample_cb, temperature_sl, top_p_sl, top_k_sl, repetition_penalty_sl, max_new_tokens_sl, ], outputs=[chat_bot, user_prompt_input], ) # Also submit on Enter user_prompt_input.submit( fn=chat_respond, inputs=[ model_selector, onnx_file_selector, system_prompt_input, user_prompt_input, chat_history, do_sample_cb, temperature_sl, top_p_sl, top_k_sl, repetition_penalty_sl, max_new_tokens_sl, ], outputs=[chat_bot, user_prompt_input], ) # Clear chat def _clear_chat(): return [], gr.update(value="") clear_btn.click(fn=_clear_chat, inputs=None, outputs=[chat_bot, user_prompt_input]) # Connect detected objects to user message input detected_objects_output.change( fn=update_user_prompt, inputs=[detected_objects_output, user_prompt_input], outputs=user_prompt_input, ) demo.launch()