black-box / app.py
Joel Lundgren
fix
51d35d6
import gradio as gr
from PIL import Image, ImageDraw
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from optimum.onnxruntime import ORTModelForCausalLM
import torch
# Load the object detection pipeline
object_detector = pipeline("object-detection", model="valentinafeve/yolos-fashionpedia")
def detect_objects(image):
"""
Performs object detection on an image using a transformers pipeline.
Args:
image (PIL.Image.Image): The input image.
Returns:
tuple: A tuple containing:
- PIL.Image.Image: The image with detected objects annotated.
- str: A string listing the names of detected objects.
"""
# Perform inference
results = object_detector(image)
# Create a copy of the image to draw on
annotated_image = image.copy()
draw = ImageDraw.Draw(annotated_image)
detected_objects = []
# Extract bounding boxes, classes, and confidences
for result in results:
box = result['box']
label = result['label']
score = result['score']
xyxy = [box['xmin'], box['ymin'], box['xmax'], box['ymax']]
detected_objects.append(label)
# Draw bounding box
draw.rectangle(xyxy, outline="red", width=2)
# Draw label
draw.text((xyxy[0], xyxy[1]), f"{label} ({score:.2f})", fill="red")
# Create a unique, comma-separated string of detected objects
detected_objects_str = ", ".join(list(set(detected_objects)))
if not detected_objects_str:
detected_objects_str = "No objects detected."
return annotated_image, detected_objects_str
# Cache for LLM models and tokenizers (ONNX Runtime)
llm_cache = {}
def get_llm(model_name, preferred_file: str | None = None):
cache_key = (model_name, preferred_file or "auto")
if cache_key in llm_cache:
return llm_cache[cache_key]
# ONNX model repositories on the Hub
onnx_repo_map = {
"gemma3:1b": "onnx-community/gemma-3-1b-it-ONNX-GQA",
"qwen3:0.6b": "onnx-community/Qwen3-0.6B-ONNX",
}
# Original repos to fetch correct tokenizer + chat templates
tokenizer_repo_map = {
"gemma3:1b": "google/gemma-3-1b-it",
"qwen3:0.6b": "Qwen/Qwen3-0.6B-Instruct",
}
onnx_repo = onnx_repo_map[model_name]
tokenizer_repo = tokenizer_repo_map[model_name]
tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo)
# Ensure pad token exists (common for decoder-only models)
if tokenizer.pad_token_id is None and getattr(tokenizer, "eos_token_id", None) is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# Try a few common ONNX filenames found in community repos to avoid the
# "Too many ONNX model files were found" ambiguity.
# Order: prefer int8, then q4f16, q4, general quantized, uint8, fp16, and finally generic.
candidate_files = [
"model_int8.onnx",
"model_q4f16.onnx",
"model_q4.onnx",
"model_quantized.onnx",
"model_uint8.onnx",
"model_fp16.onnx",
"model_bnb4.onnx",
"model.onnx",
]
model = None
last_err = None
ordered = candidate_files
if preferred_file and preferred_file in candidate_files:
# Put preferred file first
ordered = [preferred_file] + [f for f in candidate_files if f != preferred_file]
elif preferred_file and preferred_file not in candidate_files:
# If user typed a specific known filename not in our shortlist, try it first anyway
ordered = [preferred_file] + candidate_files
for fname in ordered:
try:
model = ORTModelForCausalLM.from_pretrained(
onnx_repo,
subfolder="onnx",
file_name=fname,
)
print(f"[ONNX] Loaded {onnx_repo}/onnx/{fname}")
break
except Exception as e:
last_err = e
continue
if model is None:
raise RuntimeError(f"Failed to load ONNX model from {onnx_repo}. Last error: {last_err}")
# Disable cache to avoid past_key_values shape issues on some ONNX builds
if hasattr(model.config, "use_cache"):
try:
model.config.use_cache = False
except Exception:
pass
# Mirror in generation config as well
if hasattr(model, "generation_config") and hasattr(model.generation_config, "use_cache"):
try:
model.generation_config.use_cache = False
except Exception:
pass
llm_cache[cache_key] = (model, tokenizer)
return model, tokenizer
def update_user_prompt(detected_objects, current_prompt):
if "No objects detected" in detected_objects:
return current_prompt
if current_prompt:
new_prompt = f"{current_prompt}, {detected_objects}"
else:
new_prompt = f"Objects detected in the image: {detected_objects}"
return new_prompt
def generate_text(
model_name,
onnx_file_choice,
messages,
do_sample,
temperature,
top_p,
top_k,
repetition_penalty,
max_new_tokens,
):
model, tokenizer = get_llm(model_name, preferred_file=None if onnx_file_choice == "auto" else onnx_file_choice)
chat_template_kwargs = {
"tokenize": False,
"add_generation_prompt": True,
}
# Disable "thinking" for Qwen models
if "qwen" in model_name.lower():
chat_template_kwargs["enable_thinking"] = False
text = tokenizer.apply_chat_template(
messages,
**chat_template_kwargs,
)
inputs = tokenizer([text], return_tensors="pt")
# Ensure attention_mask is present and pad_token is defined
if "attention_mask" not in inputs:
inputs = tokenizer([text], return_tensors="pt", padding=True)
gen_kwargs = {
"max_new_tokens": int(max_new_tokens),
"do_sample": bool(do_sample),
"temperature": float(temperature),
"top_p": float(top_p),
"top_k": int(top_k),
"repetition_penalty": float(repetition_penalty),
"use_cache": False,
}
if getattr(tokenizer, "eos_token_id", None) is not None:
gen_kwargs["eos_token_id"] = tokenizer.eos_token_id
with torch.inference_mode():
gen_ids = model.generate(
**inputs,
**gen_kwargs,
)
trimmed = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, gen_ids)
]
response = tokenizer.batch_decode(trimmed, skip_special_tokens=True)[0]
return response
def chat_respond(
model_name,
onnx_file_choice,
system_prompt,
message,
history,
do_sample,
temperature,
top_p,
top_k,
repetition_penalty,
max_new_tokens,
):
"""Builds a chat messages list from history + current user message, generates a reply, and returns updated history and an empty input box."""
# Guard: empty message
if not (message and message.strip()):
return history, gr.update(value="")
# Build messages: system, then alternating user/assistant from history, then current user
messages = [{"role": "system", "content": system_prompt}]
for u, a in (history or []):
if u:
messages.append({"role": "user", "content": u})
if a:
messages.append({"role": "assistant", "content": a})
messages.append({"role": "user", "content": message})
reply = generate_text(
model_name=model_name,
onnx_file_choice=onnx_file_choice,
messages=messages,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
max_new_tokens=max_new_tokens,
)
new_history = (history or []) + [(message, reply)]
return new_history, gr.update(value="")
with gr.Blocks() as demo:
gr.Markdown("# Black Box: Object Detection and LLM Chat")
with gr.Tab("Object Detection"):
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Image or Use Webcam", sources=["upload", "webcam"])
detected_image_output = gr.Image(label="Detected Objects")
object_detection_button = gr.Button("Detect Objects")
detected_objects_output = gr.Textbox(label="Detected Objects")
with gr.Tab("LLM Chat"):
model_selector = gr.Dropdown(choices=["gemma3:1b", "qwen3:0.6b"], label="Select LLM Model")
onnx_file_selector = gr.Dropdown(
choices=[
"auto",
"model_int8.onnx",
"model_q4f16.onnx",
"model_q4.onnx",
"model_quantized.onnx",
"model_uint8.onnx",
"model_fp16.onnx",
"model_bnb4.onnx",
"model.onnx",
],
value="auto",
label="ONNX file variant"
)
system_prompt_input = gr.Textbox(label="System Prompt", value="You are a helpful assistant.")
chat_bot = gr.Chatbot(height=360, label="Conversation")
chat_history = gr.State([])
user_prompt_input = gr.Textbox(label="Message", placeholder="Type your message and press Send...", lines=3)
with gr.Accordion("Generation settings", open=False):
do_sample_cb = gr.Checkbox(value=True, label="do_sample")
temperature_sl = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.05, label="temperature")
top_p_sl = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="top_p")
top_k_sl = gr.Slider(minimum=0, maximum=200, value=50, step=1, label="top_k")
repetition_penalty_sl = gr.Slider(minimum=0.8, maximum=2.0, value=1.05, step=0.01, label="repetition_penalty")
max_new_tokens_sl = gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="max_new_tokens")
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear chat")
# Connect object detection components
object_detection_button.click(
fn=detect_objects,
inputs=image_input,
outputs=[detected_image_output, detected_objects_output]
)
# Connect LLM chat components
send_btn.click(
fn=chat_respond,
inputs=[
model_selector,
onnx_file_selector,
system_prompt_input,
user_prompt_input,
chat_history,
do_sample_cb,
temperature_sl,
top_p_sl,
top_k_sl,
repetition_penalty_sl,
max_new_tokens_sl,
],
outputs=[chat_bot, user_prompt_input],
)
# Also submit on Enter
user_prompt_input.submit(
fn=chat_respond,
inputs=[
model_selector,
onnx_file_selector,
system_prompt_input,
user_prompt_input,
chat_history,
do_sample_cb,
temperature_sl,
top_p_sl,
top_k_sl,
repetition_penalty_sl,
max_new_tokens_sl,
],
outputs=[chat_bot, user_prompt_input],
)
# Clear chat
def _clear_chat():
return [], gr.update(value="")
clear_btn.click(fn=_clear_chat, inputs=None, outputs=[chat_bot, user_prompt_input])
# Connect detected objects to user message input
detected_objects_output.change(
fn=update_user_prompt,
inputs=[detected_objects_output, user_prompt_input],
outputs=user_prompt_input,
)
demo.launch()