File size: 11,580 Bytes
64a7b3c
f32efcc
51d35d6
a22ca8b
f32efcc
64a7b3c
d7b4521
 
f32efcc
 
 
d7b4521
f32efcc
 
 
 
 
 
 
 
 
 
d7b4521
f32efcc
 
 
 
 
 
 
 
d7b4521
 
 
 
 
 
f32efcc
 
 
 
 
 
d7b4521
f32efcc
 
 
 
 
 
 
 
a22ca8b
f32efcc
 
a22ca8b
 
 
 
f32efcc
a22ca8b
 
 
 
 
 
 
d7b4521
a22ca8b
dc90ed9
f199719
a22ca8b
 
 
 
9ef29cf
51d35d6
 
a22ca8b
 
 
9ef29cf
a22ca8b
9ef29cf
 
a22ca8b
 
9ef29cf
 
 
a22ca8b
 
f32efcc
a22ca8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ef29cf
a22ca8b
 
 
 
 
 
 
 
 
 
 
 
 
9ef29cf
 
 
 
 
 
a22ca8b
 
f32efcc
 
51d35d6
f32efcc
 
 
 
 
 
 
 
 
 
 
a22ca8b
 
 
51d35d6
a22ca8b
 
 
 
 
 
 
 
f32efcc
dc90ed9
a22ca8b
f199719
a22ca8b
f199719
a22ca8b
 
 
f199719
 
f32efcc
a22ca8b
f32efcc
 
a22ca8b
9ef29cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a22ca8b
 
51d35d6
 
 
 
 
a22ca8b
 
 
f199719
a22ca8b
 
f32efcc
a22ca8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51d35d6
 
 
 
 
 
 
 
 
 
 
a22ca8b
 
 
f32efcc
 
 
 
 
 
 
 
 
 
 
 
215c956
a22ca8b
9ef29cf
 
 
 
 
 
 
 
 
 
 
a22ca8b
 
 
f32efcc
a22ca8b
 
 
 
 
 
 
 
 
 
 
 
 
f32efcc
 
 
 
 
 
 
 
a22ca8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f32efcc
a22ca8b
 
 
 
f32efcc
a22ca8b
f32efcc
 
 
a22ca8b
f32efcc
64a7b3c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
import gradio as gr
from PIL import Image, ImageDraw
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from optimum.onnxruntime import ORTModelForCausalLM
import torch

# Load the object detection pipeline
object_detector = pipeline("object-detection", model="valentinafeve/yolos-fashionpedia")

def detect_objects(image):
    """
    Performs object detection on an image using a transformers pipeline.

    Args:
        image (PIL.Image.Image): The input image.

    Returns:
        tuple: A tuple containing:
            - PIL.Image.Image: The image with detected objects annotated.
            - str: A string listing the names of detected objects.
    """
    # Perform inference
    results = object_detector(image)

    # Create a copy of the image to draw on
    annotated_image = image.copy()
    draw = ImageDraw.Draw(annotated_image)

    detected_objects = []

    # Extract bounding boxes, classes, and confidences
    for result in results:
        box = result['box']
        label = result['label']
        score = result['score']
        
        xyxy = [box['xmin'], box['ymin'], box['xmax'], box['ymax']]

        detected_objects.append(label)

        # Draw bounding box
        draw.rectangle(xyxy, outline="red", width=2)
        # Draw label
        draw.text((xyxy[0], xyxy[1]), f"{label} ({score:.2f})", fill="red")

    # Create a unique, comma-separated string of detected objects
    detected_objects_str = ", ".join(list(set(detected_objects)))
    if not detected_objects_str:
        detected_objects_str = "No objects detected."

    return annotated_image, detected_objects_str

# Cache for LLM models and tokenizers (ONNX Runtime)
llm_cache = {}

def get_llm(model_name, preferred_file: str | None = None):
    cache_key = (model_name, preferred_file or "auto")
    if cache_key in llm_cache:
        return llm_cache[cache_key]

    # ONNX model repositories on the Hub
    onnx_repo_map = {
        "gemma3:1b": "onnx-community/gemma-3-1b-it-ONNX-GQA",
        "qwen3:0.6b": "onnx-community/Qwen3-0.6B-ONNX",
    }
    # Original repos to fetch correct tokenizer + chat templates
    tokenizer_repo_map = {
        "gemma3:1b": "google/gemma-3-1b-it",
        "qwen3:0.6b": "Qwen/Qwen3-0.6B-Instruct",
    }

    onnx_repo = onnx_repo_map[model_name]
    tokenizer_repo = tokenizer_repo_map[model_name]

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo)
    # Ensure pad token exists (common for decoder-only models)
    if tokenizer.pad_token_id is None and getattr(tokenizer, "eos_token_id", None) is not None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    # Try a few common ONNX filenames found in community repos to avoid the
    # "Too many ONNX model files were found" ambiguity.
    # Order: prefer int8, then q4f16, q4, general quantized, uint8, fp16, and finally generic.
    candidate_files = [
        "model_int8.onnx",
        "model_q4f16.onnx",
        "model_q4.onnx",
        "model_quantized.onnx",
        "model_uint8.onnx",
        "model_fp16.onnx",
        "model_bnb4.onnx",
        "model.onnx",
    ]

    model = None
    last_err = None
    ordered = candidate_files
    if preferred_file and preferred_file in candidate_files:
        # Put preferred file first
        ordered = [preferred_file] + [f for f in candidate_files if f != preferred_file]
    elif preferred_file and preferred_file not in candidate_files:
        # If user typed a specific known filename not in our shortlist, try it first anyway
        ordered = [preferred_file] + candidate_files

    for fname in ordered:
        try:
            model = ORTModelForCausalLM.from_pretrained(
                onnx_repo,
                subfolder="onnx",
                file_name=fname,
            )
            print(f"[ONNX] Loaded {onnx_repo}/onnx/{fname}")
            break
        except Exception as e:
            last_err = e
            continue
    if model is None:
        raise RuntimeError(f"Failed to load ONNX model from {onnx_repo}. Last error: {last_err}")

    # Disable cache to avoid past_key_values shape issues on some ONNX builds
    if hasattr(model.config, "use_cache"):
        try:
            model.config.use_cache = False
        except Exception:
            pass
    # Mirror in generation config as well
    if hasattr(model, "generation_config") and hasattr(model.generation_config, "use_cache"):
        try:
            model.generation_config.use_cache = False
        except Exception:
            pass

    llm_cache[cache_key] = (model, tokenizer)
    return model, tokenizer


def update_user_prompt(detected_objects, current_prompt):
    if "No objects detected" in detected_objects:
        return current_prompt
    
    if current_prompt:
        new_prompt = f"{current_prompt}, {detected_objects}"
    else:
        new_prompt = f"Objects detected in the image: {detected_objects}"
        
    return new_prompt

def generate_text(
    model_name,
    onnx_file_choice,
    messages,
    do_sample,
    temperature,
    top_p,
    top_k,
    repetition_penalty,
    max_new_tokens,
):
    model, tokenizer = get_llm(model_name, preferred_file=None if onnx_file_choice == "auto" else onnx_file_choice)


    chat_template_kwargs = {
        "tokenize": False,
        "add_generation_prompt": True,
    }
    # Disable "thinking" for Qwen models
    if "qwen" in model_name.lower():
        chat_template_kwargs["enable_thinking"] = False

    text = tokenizer.apply_chat_template(
        messages,
        **chat_template_kwargs,
    )

    inputs = tokenizer([text], return_tensors="pt")
    # Ensure attention_mask is present and pad_token is defined
    if "attention_mask" not in inputs:
        inputs = tokenizer([text], return_tensors="pt", padding=True)

    gen_kwargs = {
        "max_new_tokens": int(max_new_tokens),
        "do_sample": bool(do_sample),
        "temperature": float(temperature),
        "top_p": float(top_p),
        "top_k": int(top_k),
        "repetition_penalty": float(repetition_penalty),
        "use_cache": False,
    }
    if getattr(tokenizer, "eos_token_id", None) is not None:
        gen_kwargs["eos_token_id"] = tokenizer.eos_token_id

    with torch.inference_mode():
        gen_ids = model.generate(
            **inputs,
            **gen_kwargs,
        )

    trimmed = [
        output_ids[len(input_ids):]
        for input_ids, output_ids in zip(inputs.input_ids, gen_ids)
    ]
    response = tokenizer.batch_decode(trimmed, skip_special_tokens=True)[0]
    return response

def chat_respond(
    model_name,
    onnx_file_choice,
    system_prompt,
    message,
    history,
    do_sample,
    temperature,
    top_p,
    top_k,
    repetition_penalty,
    max_new_tokens,
):
    """Builds a chat messages list from history + current user message, generates a reply, and returns updated history and an empty input box."""
    # Guard: empty message
    if not (message and message.strip()):
        return history, gr.update(value="")

    # Build messages: system, then alternating user/assistant from history, then current user
    messages = [{"role": "system", "content": system_prompt}]
    for u, a in (history or []):
        if u:
            messages.append({"role": "user", "content": u})
        if a:
            messages.append({"role": "assistant", "content": a})
    messages.append({"role": "user", "content": message})

    reply = generate_text(
        model_name=model_name,
        onnx_file_choice=onnx_file_choice,
        messages=messages,
        do_sample=do_sample,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        max_new_tokens=max_new_tokens,
    )

    new_history = (history or []) + [(message, reply)]
    return new_history, gr.update(value="")

with gr.Blocks() as demo:
    gr.Markdown("# Black Box: Object Detection and LLM Chat")

    with gr.Tab("Object Detection"):
        with gr.Row():
            image_input = gr.Image(type="pil", label="Upload Image or Use Webcam", sources=["upload", "webcam"])
            detected_image_output = gr.Image(label="Detected Objects")
        object_detection_button = gr.Button("Detect Objects")
        detected_objects_output = gr.Textbox(label="Detected Objects")

    with gr.Tab("LLM Chat"):
        model_selector = gr.Dropdown(choices=["gemma3:1b", "qwen3:0.6b"], label="Select LLM Model")
        onnx_file_selector = gr.Dropdown(
            choices=[
                "auto",
                "model_int8.onnx",
                "model_q4f16.onnx",
                "model_q4.onnx",
                "model_quantized.onnx",
                "model_uint8.onnx",
                "model_fp16.onnx",
                "model_bnb4.onnx",
                "model.onnx",
            ],
            value="auto",
            label="ONNX file variant"
        )
        system_prompt_input = gr.Textbox(label="System Prompt", value="You are a helpful assistant.")
        chat_bot = gr.Chatbot(height=360, label="Conversation")
        chat_history = gr.State([])
        user_prompt_input = gr.Textbox(label="Message", placeholder="Type your message and press Send...", lines=3)
        with gr.Accordion("Generation settings", open=False):
            do_sample_cb = gr.Checkbox(value=True, label="do_sample")
            temperature_sl = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.05, label="temperature")
            top_p_sl = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="top_p")
            top_k_sl = gr.Slider(minimum=0, maximum=200, value=50, step=1, label="top_k")
            repetition_penalty_sl = gr.Slider(minimum=0.8, maximum=2.0, value=1.05, step=0.01, label="repetition_penalty")
            max_new_tokens_sl = gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="max_new_tokens")
        with gr.Row():
            send_btn = gr.Button("Send", variant="primary")
            clear_btn = gr.Button("Clear chat")

    # Connect object detection components
    object_detection_button.click(
        fn=detect_objects,
        inputs=image_input,
        outputs=[detected_image_output, detected_objects_output]
    )

    # Connect LLM chat components
    send_btn.click(
        fn=chat_respond,
        inputs=[
            model_selector,
            onnx_file_selector,
            system_prompt_input,
            user_prompt_input,
            chat_history,
            do_sample_cb,
            temperature_sl,
            top_p_sl,
            top_k_sl,
            repetition_penalty_sl,
            max_new_tokens_sl,
        ],
        outputs=[chat_bot, user_prompt_input],
    )
    # Also submit on Enter
    user_prompt_input.submit(
        fn=chat_respond,
        inputs=[
            model_selector,
            onnx_file_selector,
            system_prompt_input,
            user_prompt_input,
            chat_history,
            do_sample_cb,
            temperature_sl,
            top_p_sl,
            top_k_sl,
            repetition_penalty_sl,
            max_new_tokens_sl,
        ],
        outputs=[chat_bot, user_prompt_input],
    )
    # Clear chat
    def _clear_chat():
        return [], gr.update(value="")
    clear_btn.click(fn=_clear_chat, inputs=None, outputs=[chat_bot, user_prompt_input])

    # Connect detected objects to user message input
    detected_objects_output.change(
        fn=update_user_prompt,
        inputs=[detected_objects_output, user_prompt_input],
        outputs=user_prompt_input,
    )

demo.launch()