Spaces:
Runtime error
Runtime error
Update groundingdino/util/inference.py
Browse files- groundingdino/util/inference.py +21 -58
groundingdino/util/inference.py
CHANGED
|
@@ -5,7 +5,6 @@ import numpy as np
|
|
| 5 |
import supervision as sv
|
| 6 |
import torch
|
| 7 |
from PIL import Image
|
| 8 |
-
from torchvision.ops import box_convert
|
| 9 |
import bisect
|
| 10 |
|
| 11 |
import groundingdino.datasets.transforms as T
|
|
@@ -14,6 +13,19 @@ from groundingdino.util.misc import clean_state_dict
|
|
| 14 |
from groundingdino.util.slconfig import SLConfig
|
| 15 |
from groundingdino.util.utils import get_phrases_from_posmap
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
# ----------------------------------------------------------------------------------------------------------------------
|
| 18 |
# OLD API
|
| 19 |
# ----------------------------------------------------------------------------------------------------------------------
|
|
@@ -67,16 +79,16 @@ def predict(
|
|
| 67 |
with torch.no_grad():
|
| 68 |
outputs = model(image[None], captions=[caption])
|
| 69 |
|
| 70 |
-
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0]
|
| 71 |
-
prediction_boxes = outputs["pred_boxes"].cpu()[0]
|
| 72 |
|
| 73 |
mask = prediction_logits.max(dim=1)[0] > box_threshold
|
| 74 |
-
logits = prediction_logits[mask]
|
| 75 |
-
boxes = prediction_boxes[mask]
|
| 76 |
|
| 77 |
tokenizer = model.tokenizer
|
| 78 |
tokenized = tokenizer(caption)
|
| 79 |
-
|
| 80 |
if remove_combined:
|
| 81 |
sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]]
|
| 82 |
|
|
@@ -98,21 +110,9 @@ def predict(
|
|
| 98 |
|
| 99 |
|
| 100 |
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
|
| 101 |
-
"""
|
| 102 |
-
This function annotates an image with bounding boxes and labels.
|
| 103 |
-
|
| 104 |
-
Parameters:
|
| 105 |
-
image_source (np.ndarray): The source image to be annotated.
|
| 106 |
-
boxes (torch.Tensor): A tensor containing bounding box coordinates.
|
| 107 |
-
logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
|
| 108 |
-
phrases (List[str]): A list of labels for each bounding box.
|
| 109 |
-
|
| 110 |
-
Returns:
|
| 111 |
-
np.ndarray: The annotated image.
|
| 112 |
-
"""
|
| 113 |
h, w, _ = image_source.shape
|
| 114 |
boxes = boxes * torch.Tensor([w, h, w, h])
|
| 115 |
-
xyxy =
|
| 116 |
detections = sv.Detections(xyxy=xyxy)
|
| 117 |
|
| 118 |
labels = [
|
|
@@ -156,24 +156,6 @@ class Model:
|
|
| 156 |
box_threshold: float = 0.35,
|
| 157 |
text_threshold: float = 0.25
|
| 158 |
) -> Tuple[sv.Detections, List[str]]:
|
| 159 |
-
"""
|
| 160 |
-
import cv2
|
| 161 |
-
|
| 162 |
-
image = cv2.imread(IMAGE_PATH)
|
| 163 |
-
|
| 164 |
-
model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
|
| 165 |
-
detections, labels = model.predict_with_caption(
|
| 166 |
-
image=image,
|
| 167 |
-
caption=caption,
|
| 168 |
-
box_threshold=BOX_THRESHOLD,
|
| 169 |
-
text_threshold=TEXT_THRESHOLD
|
| 170 |
-
)
|
| 171 |
-
|
| 172 |
-
import supervision as sv
|
| 173 |
-
|
| 174 |
-
box_annotator = sv.BoxAnnotator()
|
| 175 |
-
annotated_image = box_annotator.annotate(scene=image, detections=detections, labels=labels)
|
| 176 |
-
"""
|
| 177 |
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
|
| 178 |
boxes, logits, phrases = predict(
|
| 179 |
model=self.model,
|
|
@@ -197,25 +179,6 @@ class Model:
|
|
| 197 |
box_threshold: float,
|
| 198 |
text_threshold: float
|
| 199 |
) -> sv.Detections:
|
| 200 |
-
"""
|
| 201 |
-
import cv2
|
| 202 |
-
|
| 203 |
-
image = cv2.imread(IMAGE_PATH)
|
| 204 |
-
|
| 205 |
-
model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
|
| 206 |
-
detections = model.predict_with_classes(
|
| 207 |
-
image=image,
|
| 208 |
-
classes=CLASSES,
|
| 209 |
-
box_threshold=BOX_THRESHOLD,
|
| 210 |
-
text_threshold=TEXT_THRESHOLD
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
import supervision as sv
|
| 215 |
-
|
| 216 |
-
box_annotator = sv.BoxAnnotator()
|
| 217 |
-
annotated_image = box_annotator.annotate(scene=image, detections=detections)
|
| 218 |
-
"""
|
| 219 |
caption = ". ".join(classes)
|
| 220 |
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
|
| 221 |
boxes, logits, phrases = predict(
|
|
@@ -256,7 +219,7 @@ class Model:
|
|
| 256 |
logits: torch.Tensor
|
| 257 |
) -> sv.Detections:
|
| 258 |
boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h])
|
| 259 |
-
xyxy =
|
| 260 |
confidence = logits.numpy()
|
| 261 |
return sv.Detections(xyxy=xyxy, confidence=confidence)
|
| 262 |
|
|
@@ -270,4 +233,4 @@ class Model:
|
|
| 270 |
break
|
| 271 |
else:
|
| 272 |
class_ids.append(None)
|
| 273 |
-
return np.array(class_ids)
|
|
|
|
| 5 |
import supervision as sv
|
| 6 |
import torch
|
| 7 |
from PIL import Image
|
|
|
|
| 8 |
import bisect
|
| 9 |
|
| 10 |
import groundingdino.datasets.transforms as T
|
|
|
|
| 13 |
from groundingdino.util.slconfig import SLConfig
|
| 14 |
from groundingdino.util.utils import get_phrases_from_posmap
|
| 15 |
|
| 16 |
+
|
| 17 |
+
def cxcywh_to_xyxy(boxes: torch.Tensor) -> torch.Tensor:
|
| 18 |
+
"""
|
| 19 |
+
Convert bounding boxes from [cx, cy, w, h] format to [x1, y1, x2, y2] format.
|
| 20 |
+
"""
|
| 21 |
+
cx, cy, w, h = boxes.unbind(-1)
|
| 22 |
+
x1 = cx - 0.5 * w
|
| 23 |
+
y1 = cy - 0.5 * h
|
| 24 |
+
x2 = cx + 0.5 * w
|
| 25 |
+
y2 = cy + 0.5 * h
|
| 26 |
+
return torch.stack((x1, y1, x2, y2), dim=-1)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
# ----------------------------------------------------------------------------------------------------------------------
|
| 30 |
# OLD API
|
| 31 |
# ----------------------------------------------------------------------------------------------------------------------
|
|
|
|
| 79 |
with torch.no_grad():
|
| 80 |
outputs = model(image[None], captions=[caption])
|
| 81 |
|
| 82 |
+
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0]
|
| 83 |
+
prediction_boxes = outputs["pred_boxes"].cpu()[0]
|
| 84 |
|
| 85 |
mask = prediction_logits.max(dim=1)[0] > box_threshold
|
| 86 |
+
logits = prediction_logits[mask]
|
| 87 |
+
boxes = prediction_boxes[mask]
|
| 88 |
|
| 89 |
tokenizer = model.tokenizer
|
| 90 |
tokenized = tokenizer(caption)
|
| 91 |
+
|
| 92 |
if remove_combined:
|
| 93 |
sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]]
|
| 94 |
|
|
|
|
| 110 |
|
| 111 |
|
| 112 |
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
h, w, _ = image_source.shape
|
| 114 |
boxes = boxes * torch.Tensor([w, h, w, h])
|
| 115 |
+
xyxy = cxcywh_to_xyxy(boxes).numpy()
|
| 116 |
detections = sv.Detections(xyxy=xyxy)
|
| 117 |
|
| 118 |
labels = [
|
|
|
|
| 156 |
box_threshold: float = 0.35,
|
| 157 |
text_threshold: float = 0.25
|
| 158 |
) -> Tuple[sv.Detections, List[str]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
|
| 160 |
boxes, logits, phrases = predict(
|
| 161 |
model=self.model,
|
|
|
|
| 179 |
box_threshold: float,
|
| 180 |
text_threshold: float
|
| 181 |
) -> sv.Detections:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
caption = ". ".join(classes)
|
| 183 |
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
|
| 184 |
boxes, logits, phrases = predict(
|
|
|
|
| 219 |
logits: torch.Tensor
|
| 220 |
) -> sv.Detections:
|
| 221 |
boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h])
|
| 222 |
+
xyxy = cxcywh_to_xyxy(boxes).numpy()
|
| 223 |
confidence = logits.numpy()
|
| 224 |
return sv.Detections(xyxy=xyxy, confidence=confidence)
|
| 225 |
|
|
|
|
| 233 |
break
|
| 234 |
else:
|
| 235 |
class_ids.append(None)
|
| 236 |
+
return np.array(class_ids)
|