Joel Lundgren commited on
Commit
f32efcc
·
1 Parent(s): 64a7b3c

test with new layout

Browse files
Files changed (2) hide show
  1. app.py +146 -3
  2. requirements.txt +5 -0
app.py CHANGED
@@ -1,7 +1,150 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
1
  import gradio as gr
2
+ from PIL import Image, ImageDraw
3
+ from ultralytics import YOLO
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
 
7
+ # Load a pre-trained YOLO model
8
+ model = YOLO('yolov8n.pt')
9
+
10
+ def detect_objects(image):
11
+ """
12
+ Performs object detection on an image using the YOLO model.
13
+
14
+ Args:
15
+ image (PIL.Image.Image): The input image.
16
+
17
+ Returns:
18
+ tuple: A tuple containing:
19
+ - PIL.Image.Image: The image with detected objects annotated.
20
+ - str: A string listing the names of detected objects.
21
+ """
22
+ # Perform inference
23
+ results = model(image)
24
+
25
+ # Get the first result
26
+ result = results[0]
27
+
28
+ # Create a copy of the image to draw on
29
+ annotated_image = image.copy()
30
+ draw = ImageDraw.Draw(annotated_image)
31
+
32
+ detected_objects = []
33
+
34
+ # Extract bounding boxes, classes, and confidences
35
+ for box in result.boxes:
36
+ xyxy = box.xyxy[0].tolist()
37
+ label = result.names[int(box.cls)]
38
+ confidence = box.conf[0].item()
39
+
40
+ detected_objects.append(label)
41
+
42
+ # Draw bounding box
43
+ draw.rectangle(xyxy, outline="red", width=2)
44
+ # Draw label
45
+ draw.text((xyxy[0], xyxy[1]), f"{label} ({confidence:.2f})", fill="red")
46
+
47
+ # Create a unique, comma-separated string of detected objects
48
+ detected_objects_str = ", ".join(list(set(detected_objects)))
49
+ if not detected_objects_str:
50
+ detected_objects_str = "No objects detected."
51
+
52
+ return annotated_image, detected_objects_str
53
+
54
+ # Cache for LLM models and tokenizers
55
+ llm_cache = {}
56
+
57
+ def get_llm(model_name):
58
+ if model_name in llm_cache:
59
+ return llm_cache[model_name]
60
+
61
+ model_map = {
62
+ "qwen3:0.6b": "Qwen/Qwen3-0.6B-Instruct",
63
+ "gemma3:1b": "google/gemma-3-1b-it"
64
+ }
65
+ hf_model_name = model_map[model_name]
66
+
67
+ tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
68
+ model = AutoModelForCausalLM.from_pretrained(hf_model_name)
69
+
70
+ llm_cache[model_name] = (model, tokenizer)
71
+ return model, tokenizer
72
+
73
+ def update_user_prompt(detected_objects, current_prompt):
74
+ if "No objects detected" in detected_objects:
75
+ return current_prompt
76
+
77
+ if current_prompt:
78
+ new_prompt = f"{current_prompt}, {detected_objects}"
79
+ else:
80
+ new_prompt = f"Objects detected in the image: {detected_objects}"
81
+
82
+ return new_prompt
83
+
84
+ def generate_text(model_name, system_prompt, user_prompt):
85
+ model, tokenizer = get_llm(model_name)
86
+
87
+ messages = [
88
+ {"role": "system", "content": system_prompt},
89
+ {"role": "user", "content": user_prompt},
90
+ ]
91
+
92
+ text = tokenizer.apply_chat_template(
93
+ messages,
94
+ tokenize=False,
95
+ add_generation_prompt=True
96
+ )
97
+
98
+ model_inputs = tokenizer([text], return_tensors="pt")
99
+
100
+ generated_ids = model.generate(
101
+ model_inputs.input_ids,
102
+ max_new_tokens=512
103
+ )
104
+ generated_ids = [
105
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
106
+ ]
107
+
108
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
109
+
110
+ return response
111
+
112
+ with gr.Blocks() as demo:
113
+ gr.Markdown("# Black Box: Object Detection and LLM Chat")
114
+
115
+ with gr.Tab("Object Detection"):
116
+ with gr.Row():
117
+ image_input = gr.Image(type="pil", label="Upload Image or Use Webcam", sources=["upload", "webcam"])
118
+ detected_image_output = gr.Image(label="Detected Objects")
119
+ object_detection_button = gr.Button("Detect Objects")
120
+ detected_objects_output = gr.Textbox(label="Detected Objects")
121
+
122
+ with gr.Tab("LLM Chat"):
123
+ model_selector = gr.Dropdown(choices=["qwen2:0.5b", "gemma2:2b"], label="Select LLM Model")
124
+ system_prompt_input = gr.Textbox(label="System Prompt", value="You are a helpful assistant.")
125
+ user_prompt_input = gr.Textbox(label="User Prompt")
126
+ llm_output = gr.Textbox(label="LLM Response")
127
+ llm_button = gr.Button("Generate")
128
+
129
+ # Connect object detection components
130
+ object_detection_button.click(
131
+ fn=detect_objects,
132
+ inputs=image_input,
133
+ outputs=[detected_image_output, detected_objects_output]
134
+ )
135
+
136
+ # Connect LLM components
137
+ llm_button.click(
138
+ fn=generate_text,
139
+ inputs=[model_selector, system_prompt_input, user_prompt_input],
140
+ outputs=llm_output
141
+ )
142
+
143
+ # Connect detected objects to user prompt
144
+ detected_objects_output.change(
145
+ fn=update_user_prompt,
146
+ inputs=[detected_objects_output, user_prompt_input],
147
+ outputs=user_prompt_input
148
+ )
149
 
 
150
  demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ ultralytics
3
+ torch
4
+ transformers
5
+ pillow