Spaces:
Sleeping
Sleeping
Update my_model/object_detection.py
Browse files
my_model/object_detection.py
CHANGED
|
@@ -26,10 +26,11 @@ class ObjectDetector:
|
|
| 26 |
"""
|
| 27 |
Initializes the ObjectDetector class with default values.
|
| 28 |
"""
|
| 29 |
-
|
| 30 |
self.model = None
|
| 31 |
self.processor = None
|
| 32 |
self.model_name = None
|
|
|
|
| 33 |
|
| 34 |
def load_model(self, model_name='detic', pretrained=True, model_version='yolov5s'):
|
| 35 |
"""
|
|
@@ -52,6 +53,7 @@ class ObjectDetector:
|
|
| 52 |
else:
|
| 53 |
raise ValueError(f"Unsupported model name: {model_name}")
|
| 54 |
|
|
|
|
| 55 |
def _load_detic_model(self, pretrained):
|
| 56 |
"""
|
| 57 |
Load the Detic model.
|
|
@@ -62,13 +64,13 @@ class ObjectDetector:
|
|
| 62 |
|
| 63 |
try:
|
| 64 |
model_path = get_model_path('deformable-detr-detic')
|
| 65 |
-
st.write(model_path)
|
| 66 |
self.processor = AutoImageProcessor.from_pretrained(model_path)
|
| 67 |
self.model = AutoModelForObjectDetection.from_pretrained(model_path)
|
| 68 |
except Exception as e:
|
| 69 |
print(f"Error loading Detic model: {e}")
|
| 70 |
raise
|
| 71 |
|
|
|
|
| 72 |
def _load_yolov5_model(self, pretrained, model_version):
|
| 73 |
"""
|
| 74 |
Load the YOLOv5 model.
|
|
@@ -80,7 +82,6 @@ class ObjectDetector:
|
|
| 80 |
|
| 81 |
try:
|
| 82 |
model_path = get_model_path ('yolov5')
|
| 83 |
-
st.write(model_path)
|
| 84 |
if model_path and os.path.exists(model_path):
|
| 85 |
self.model = torch.hub.load(model_path, model_version, pretrained=pretrained, source='local')
|
| 86 |
else:
|
|
@@ -89,6 +90,7 @@ class ObjectDetector:
|
|
| 89 |
print(f"Error loading YOLOv5 model: {e}")
|
| 90 |
raise
|
| 91 |
|
|
|
|
| 92 |
def process_image(self, image_input):
|
| 93 |
"""
|
| 94 |
Process the image from the given path or file-like object.
|
|
@@ -194,6 +196,7 @@ class ObjectDetector:
|
|
| 194 |
detected_objects_list.append((label_name, box_rounded, certainty))
|
| 195 |
return detected_objects_str, detected_objects_list
|
| 196 |
|
|
|
|
| 197 |
def draw_boxes(self, image, detected_objects, show_confidence=True):
|
| 198 |
"""
|
| 199 |
Draw bounding boxes around detected objects in the image.
|
|
@@ -218,7 +221,6 @@ class ObjectDetector:
|
|
| 218 |
for label_name, box, score in detected_objects:
|
| 219 |
if label_name not in label_color_map:
|
| 220 |
label_color_map[label_name] = colors[len(label_color_map) % len(colors)]
|
| 221 |
-
|
| 222 |
color = label_color_map[label_name]
|
| 223 |
draw.rectangle(box, outline=color, width=3)
|
| 224 |
label_text = f"{label_name}"
|
|
|
|
| 26 |
"""
|
| 27 |
Initializes the ObjectDetector class with default values.
|
| 28 |
"""
|
| 29 |
+
|
| 30 |
self.model = None
|
| 31 |
self.processor = None
|
| 32 |
self.model_name = None
|
| 33 |
+
|
| 34 |
|
| 35 |
def load_model(self, model_name='detic', pretrained=True, model_version='yolov5s'):
|
| 36 |
"""
|
|
|
|
| 53 |
else:
|
| 54 |
raise ValueError(f"Unsupported model name: {model_name}")
|
| 55 |
|
| 56 |
+
|
| 57 |
def _load_detic_model(self, pretrained):
|
| 58 |
"""
|
| 59 |
Load the Detic model.
|
|
|
|
| 64 |
|
| 65 |
try:
|
| 66 |
model_path = get_model_path('deformable-detr-detic')
|
|
|
|
| 67 |
self.processor = AutoImageProcessor.from_pretrained(model_path)
|
| 68 |
self.model = AutoModelForObjectDetection.from_pretrained(model_path)
|
| 69 |
except Exception as e:
|
| 70 |
print(f"Error loading Detic model: {e}")
|
| 71 |
raise
|
| 72 |
|
| 73 |
+
|
| 74 |
def _load_yolov5_model(self, pretrained, model_version):
|
| 75 |
"""
|
| 76 |
Load the YOLOv5 model.
|
|
|
|
| 82 |
|
| 83 |
try:
|
| 84 |
model_path = get_model_path ('yolov5')
|
|
|
|
| 85 |
if model_path and os.path.exists(model_path):
|
| 86 |
self.model = torch.hub.load(model_path, model_version, pretrained=pretrained, source='local')
|
| 87 |
else:
|
|
|
|
| 90 |
print(f"Error loading YOLOv5 model: {e}")
|
| 91 |
raise
|
| 92 |
|
| 93 |
+
|
| 94 |
def process_image(self, image_input):
|
| 95 |
"""
|
| 96 |
Process the image from the given path or file-like object.
|
|
|
|
| 196 |
detected_objects_list.append((label_name, box_rounded, certainty))
|
| 197 |
return detected_objects_str, detected_objects_list
|
| 198 |
|
| 199 |
+
|
| 200 |
def draw_boxes(self, image, detected_objects, show_confidence=True):
|
| 201 |
"""
|
| 202 |
Draw bounding boxes around detected objects in the image.
|
|
|
|
| 221 |
for label_name, box, score in detected_objects:
|
| 222 |
if label_name not in label_color_map:
|
| 223 |
label_color_map[label_name] = colors[len(label_color_map) % len(colors)]
|
|
|
|
| 224 |
color = label_color_map[label_name]
|
| 225 |
draw.rectangle(box, outline=color, width=3)
|
| 226 |
label_text = f"{label_name}"
|