Spaces:
Runtime error
Runtime error
Martin Tomov
commited on
spaces.GPU
Browse files
app.py
CHANGED
|
@@ -92,19 +92,21 @@ def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> L
|
|
| 92 |
masks[idx] = cv2.fillPoly(np.zeros(shape, dtype=np.uint8), [polygon], 1)
|
| 93 |
return list(masks)
|
| 94 |
|
|
|
|
| 95 |
def detect(image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None) -> List[DetectionResult]:
|
| 96 |
detector_id = detector_id if detector_id else "IDEA-Research/grounding-dino-base"
|
| 97 |
-
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection")
|
| 98 |
labels = [label if label.endswith(".") else label + "." for label in labels]
|
| 99 |
results = object_detector(image, candidate_labels=labels, threshold=threshold)
|
| 100 |
return [DetectionResult.from_dict(result) for result in results]
|
| 101 |
|
|
|
|
| 102 |
def segment(image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False, segmenter_id: Optional[str] = None) -> List[DetectionResult]:
|
| 103 |
segmenter_id = segmenter_id if segmenter_id else "martintmv/InsectSAM"
|
| 104 |
-
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id)
|
| 105 |
processor = AutoProcessor.from_pretrained(segmenter_id)
|
| 106 |
boxes = get_boxes(detection_results)
|
| 107 |
-
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt")
|
| 108 |
outputs = segmentator(**inputs)
|
| 109 |
masks = processor.post_process_masks(masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes)[0]
|
| 110 |
masks = refine_masks(masks, polygon_refinement)
|
|
|
|
| 92 |
masks[idx] = cv2.fillPoly(np.zeros(shape, dtype=np.uint8), [polygon], 1)
|
| 93 |
return list(masks)
|
| 94 |
|
| 95 |
+
@spaces.GPU
|
| 96 |
def detect(image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None) -> List[DetectionResult]:
|
| 97 |
detector_id = detector_id if detector_id else "IDEA-Research/grounding-dino-base"
|
| 98 |
+
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=0)
|
| 99 |
labels = [label if label.endswith(".") else label + "." for label in labels]
|
| 100 |
results = object_detector(image, candidate_labels=labels, threshold=threshold)
|
| 101 |
return [DetectionResult.from_dict(result) for result in results]
|
| 102 |
|
| 103 |
+
@spaces.GPU
|
| 104 |
def segment(image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False, segmenter_id: Optional[str] = None) -> List[DetectionResult]:
|
| 105 |
segmenter_id = segmenter_id if segmenter_id else "martintmv/InsectSAM"
|
| 106 |
+
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to("cuda")
|
| 107 |
processor = AutoProcessor.from_pretrained(segmenter_id)
|
| 108 |
boxes = get_boxes(detection_results)
|
| 109 |
+
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to("cuda")
|
| 110 |
outputs = segmentator(**inputs)
|
| 111 |
masks = processor.post_process_masks(masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes)[0]
|
| 112 |
masks = refine_masks(masks, polygon_refinement)
|