Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from PIL import Image, ImageDraw | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| from optimum.onnxruntime import ORTModelForCausalLM | |
| import torch | |
| # Load the object detection pipeline | |
| object_detector = pipeline("object-detection", model="valentinafeve/yolos-fashionpedia") | |
| def detect_objects(image): | |
| """ | |
| Performs object detection on an image using a transformers pipeline. | |
| Args: | |
| image (PIL.Image.Image): The input image. | |
| Returns: | |
| tuple: A tuple containing: | |
| - PIL.Image.Image: The image with detected objects annotated. | |
| - str: A string listing the names of detected objects. | |
| """ | |
| # Perform inference | |
| results = object_detector(image) | |
| # Create a copy of the image to draw on | |
| annotated_image = image.copy() | |
| draw = ImageDraw.Draw(annotated_image) | |
| detected_objects = [] | |
| # Extract bounding boxes, classes, and confidences | |
| for result in results: | |
| box = result['box'] | |
| label = result['label'] | |
| score = result['score'] | |
| xyxy = [box['xmin'], box['ymin'], box['xmax'], box['ymax']] | |
| detected_objects.append(label) | |
| # Draw bounding box | |
| draw.rectangle(xyxy, outline="red", width=2) | |
| # Draw label | |
| draw.text((xyxy[0], xyxy[1]), f"{label} ({score:.2f})", fill="red") | |
| # Create a unique, comma-separated string of detected objects | |
| detected_objects_str = ", ".join(list(set(detected_objects))) | |
| if not detected_objects_str: | |
| detected_objects_str = "No objects detected." | |
| return annotated_image, detected_objects_str | |
| # Cache for LLM models and tokenizers (ONNX Runtime) | |
| llm_cache = {} | |
| def get_llm(model_name, preferred_file: str | None = None): | |
| cache_key = (model_name, preferred_file or "auto") | |
| if cache_key in llm_cache: | |
| return llm_cache[cache_key] | |
| # ONNX model repositories on the Hub | |
| onnx_repo_map = { | |
| "gemma3:1b": "onnx-community/gemma-3-1b-it-ONNX-GQA", | |
| "qwen3:0.6b": "onnx-community/Qwen3-0.6B-ONNX", | |
| } | |
| # Original repos to fetch correct tokenizer + chat templates | |
| tokenizer_repo_map = { | |
| "gemma3:1b": "google/gemma-3-1b-it", | |
| "qwen3:0.6b": "Qwen/Qwen3-0.6B-Instruct", | |
| } | |
| onnx_repo = onnx_repo_map[model_name] | |
| tokenizer_repo = tokenizer_repo_map[model_name] | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo) | |
| # Ensure pad token exists (common for decoder-only models) | |
| if tokenizer.pad_token_id is None and getattr(tokenizer, "eos_token_id", None) is not None: | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| # Try a few common ONNX filenames found in community repos to avoid the | |
| # "Too many ONNX model files were found" ambiguity. | |
| # Order: prefer int8, then q4f16, q4, general quantized, uint8, fp16, and finally generic. | |
| candidate_files = [ | |
| "model_int8.onnx", | |
| "model_q4f16.onnx", | |
| "model_q4.onnx", | |
| "model_quantized.onnx", | |
| "model_uint8.onnx", | |
| "model_fp16.onnx", | |
| "model_bnb4.onnx", | |
| "model.onnx", | |
| ] | |
| model = None | |
| last_err = None | |
| ordered = candidate_files | |
| if preferred_file and preferred_file in candidate_files: | |
| # Put preferred file first | |
| ordered = [preferred_file] + [f for f in candidate_files if f != preferred_file] | |
| elif preferred_file and preferred_file not in candidate_files: | |
| # If user typed a specific known filename not in our shortlist, try it first anyway | |
| ordered = [preferred_file] + candidate_files | |
| for fname in ordered: | |
| try: | |
| model = ORTModelForCausalLM.from_pretrained( | |
| onnx_repo, | |
| subfolder="onnx", | |
| file_name=fname, | |
| ) | |
| print(f"[ONNX] Loaded {onnx_repo}/onnx/{fname}") | |
| break | |
| except Exception as e: | |
| last_err = e | |
| continue | |
| if model is None: | |
| raise RuntimeError(f"Failed to load ONNX model from {onnx_repo}. Last error: {last_err}") | |
| # Disable cache to avoid past_key_values shape issues on some ONNX builds | |
| if hasattr(model.config, "use_cache"): | |
| try: | |
| model.config.use_cache = False | |
| except Exception: | |
| pass | |
| # Mirror in generation config as well | |
| if hasattr(model, "generation_config") and hasattr(model.generation_config, "use_cache"): | |
| try: | |
| model.generation_config.use_cache = False | |
| except Exception: | |
| pass | |
| llm_cache[cache_key] = (model, tokenizer) | |
| return model, tokenizer | |
| def update_user_prompt(detected_objects, current_prompt): | |
| if "No objects detected" in detected_objects: | |
| return current_prompt | |
| if current_prompt: | |
| new_prompt = f"{current_prompt}, {detected_objects}" | |
| else: | |
| new_prompt = f"Objects detected in the image: {detected_objects}" | |
| return new_prompt | |
| def generate_text( | |
| model_name, | |
| onnx_file_choice, | |
| messages, | |
| do_sample, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty, | |
| max_new_tokens, | |
| ): | |
| model, tokenizer = get_llm(model_name, preferred_file=None if onnx_file_choice == "auto" else onnx_file_choice) | |
| chat_template_kwargs = { | |
| "tokenize": False, | |
| "add_generation_prompt": True, | |
| } | |
| # Disable "thinking" for Qwen models | |
| if "qwen" in model_name.lower(): | |
| chat_template_kwargs["enable_thinking"] = False | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| **chat_template_kwargs, | |
| ) | |
| inputs = tokenizer([text], return_tensors="pt") | |
| # Ensure attention_mask is present and pad_token is defined | |
| if "attention_mask" not in inputs: | |
| inputs = tokenizer([text], return_tensors="pt", padding=True) | |
| gen_kwargs = { | |
| "max_new_tokens": int(max_new_tokens), | |
| "do_sample": bool(do_sample), | |
| "temperature": float(temperature), | |
| "top_p": float(top_p), | |
| "top_k": int(top_k), | |
| "repetition_penalty": float(repetition_penalty), | |
| "use_cache": False, | |
| } | |
| if getattr(tokenizer, "eos_token_id", None) is not None: | |
| gen_kwargs["eos_token_id"] = tokenizer.eos_token_id | |
| with torch.inference_mode(): | |
| gen_ids = model.generate( | |
| **inputs, | |
| **gen_kwargs, | |
| ) | |
| trimmed = [ | |
| output_ids[len(input_ids):] | |
| for input_ids, output_ids in zip(inputs.input_ids, gen_ids) | |
| ] | |
| response = tokenizer.batch_decode(trimmed, skip_special_tokens=True)[0] | |
| return response | |
| def chat_respond( | |
| model_name, | |
| onnx_file_choice, | |
| system_prompt, | |
| message, | |
| history, | |
| do_sample, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty, | |
| max_new_tokens, | |
| ): | |
| """Builds a chat messages list from history + current user message, generates a reply, and returns updated history and an empty input box.""" | |
| # Guard: empty message | |
| if not (message and message.strip()): | |
| return history, gr.update(value="") | |
| # Build messages: system, then alternating user/assistant from history, then current user | |
| messages = [{"role": "system", "content": system_prompt}] | |
| for u, a in (history or []): | |
| if u: | |
| messages.append({"role": "user", "content": u}) | |
| if a: | |
| messages.append({"role": "assistant", "content": a}) | |
| messages.append({"role": "user", "content": message}) | |
| reply = generate_text( | |
| model_name=model_name, | |
| onnx_file_choice=onnx_file_choice, | |
| messages=messages, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| max_new_tokens=max_new_tokens, | |
| ) | |
| new_history = (history or []) + [(message, reply)] | |
| return new_history, gr.update(value="") | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Black Box: Object Detection and LLM Chat") | |
| with gr.Tab("Object Detection"): | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Upload Image or Use Webcam", sources=["upload", "webcam"]) | |
| detected_image_output = gr.Image(label="Detected Objects") | |
| object_detection_button = gr.Button("Detect Objects") | |
| detected_objects_output = gr.Textbox(label="Detected Objects") | |
| with gr.Tab("LLM Chat"): | |
| model_selector = gr.Dropdown(choices=["gemma3:1b", "qwen3:0.6b"], label="Select LLM Model") | |
| onnx_file_selector = gr.Dropdown( | |
| choices=[ | |
| "auto", | |
| "model_int8.onnx", | |
| "model_q4f16.onnx", | |
| "model_q4.onnx", | |
| "model_quantized.onnx", | |
| "model_uint8.onnx", | |
| "model_fp16.onnx", | |
| "model_bnb4.onnx", | |
| "model.onnx", | |
| ], | |
| value="auto", | |
| label="ONNX file variant" | |
| ) | |
| system_prompt_input = gr.Textbox(label="System Prompt", value="You are a helpful assistant.") | |
| chat_bot = gr.Chatbot(height=360, label="Conversation") | |
| chat_history = gr.State([]) | |
| user_prompt_input = gr.Textbox(label="Message", placeholder="Type your message and press Send...", lines=3) | |
| with gr.Accordion("Generation settings", open=False): | |
| do_sample_cb = gr.Checkbox(value=True, label="do_sample") | |
| temperature_sl = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.05, label="temperature") | |
| top_p_sl = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="top_p") | |
| top_k_sl = gr.Slider(minimum=0, maximum=200, value=50, step=1, label="top_k") | |
| repetition_penalty_sl = gr.Slider(minimum=0.8, maximum=2.0, value=1.05, step=0.01, label="repetition_penalty") | |
| max_new_tokens_sl = gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="max_new_tokens") | |
| with gr.Row(): | |
| send_btn = gr.Button("Send", variant="primary") | |
| clear_btn = gr.Button("Clear chat") | |
| # Connect object detection components | |
| object_detection_button.click( | |
| fn=detect_objects, | |
| inputs=image_input, | |
| outputs=[detected_image_output, detected_objects_output] | |
| ) | |
| # Connect LLM chat components | |
| send_btn.click( | |
| fn=chat_respond, | |
| inputs=[ | |
| model_selector, | |
| onnx_file_selector, | |
| system_prompt_input, | |
| user_prompt_input, | |
| chat_history, | |
| do_sample_cb, | |
| temperature_sl, | |
| top_p_sl, | |
| top_k_sl, | |
| repetition_penalty_sl, | |
| max_new_tokens_sl, | |
| ], | |
| outputs=[chat_bot, user_prompt_input], | |
| ) | |
| # Also submit on Enter | |
| user_prompt_input.submit( | |
| fn=chat_respond, | |
| inputs=[ | |
| model_selector, | |
| onnx_file_selector, | |
| system_prompt_input, | |
| user_prompt_input, | |
| chat_history, | |
| do_sample_cb, | |
| temperature_sl, | |
| top_p_sl, | |
| top_k_sl, | |
| repetition_penalty_sl, | |
| max_new_tokens_sl, | |
| ], | |
| outputs=[chat_bot, user_prompt_input], | |
| ) | |
| # Clear chat | |
| def _clear_chat(): | |
| return [], gr.update(value="") | |
| clear_btn.click(fn=_clear_chat, inputs=None, outputs=[chat_bot, user_prompt_input]) | |
| # Connect detected objects to user message input | |
| detected_objects_output.change( | |
| fn=update_user_prompt, | |
| inputs=[detected_objects_output, user_prompt_input], | |
| outputs=user_prompt_input, | |
| ) | |
| demo.launch() | |