Spaces:
Paused
Paused
Commit
Β·
3260133
1
Parent(s):
a84c111
Update app.py
Browse files
app.py
CHANGED
|
@@ -499,72 +499,68 @@ def run_anything_task(input_image, text_prompt, box_threshold, text_threshold,
|
|
| 499 |
size = image_pil.size
|
| 500 |
|
| 501 |
# run grounding dino model
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
)
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
os.remove(image_path)
|
| 565 |
-
output_images.append(segment_image_result)
|
| 566 |
-
|
| 567 |
-
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
| 568 |
return pred_dict, output_images, gr.Gallery.update(label='result images')
|
| 569 |
|
| 570 |
if __name__ == "__main__":
|
|
|
|
| 499 |
size = image_pil.size
|
| 500 |
|
| 501 |
# run grounding dino model
|
| 502 |
+
groundingdino_device = 'cpu'
|
| 503 |
+
if device != 'cpu':
|
| 504 |
+
try:
|
| 505 |
+
from groundingdino import _C
|
| 506 |
+
groundingdino_device = 'cuda:0'
|
| 507 |
+
except:
|
| 508 |
+
warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only in groundingdino!")
|
| 509 |
+
|
| 510 |
+
boxes_filt, pred_phrases = get_grounding_output(
|
| 511 |
+
groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
|
| 512 |
+
)
|
| 513 |
+
if boxes_filt.size(0) == 0:
|
| 514 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1_[No objects detected, please try others.]_')
|
| 515 |
+
return [], gr.Gallery.update(label='No objects detected, please try others.ππππ')
|
| 516 |
+
boxes_filt_ori = copy.deepcopy(boxes_filt)
|
| 517 |
+
|
| 518 |
+
pred_dict = {
|
| 519 |
+
"boxes": boxes_filt,
|
| 520 |
+
"size": [size[1], size[0]], # H,W
|
| 521 |
+
"labels": pred_phrases,
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
|
| 525 |
+
output_images.append(image_with_box)
|
| 526 |
+
|
| 527 |
+
# now we generate the segmentation
|
| 528 |
+
image = np.array(input_img)
|
| 529 |
+
sam_predictor.set_image(image)
|
| 530 |
+
|
| 531 |
+
H, W = size[1], size[0]
|
| 532 |
+
for i in range(boxes_filt.size(0)):
|
| 533 |
+
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
|
| 534 |
+
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
| 535 |
+
boxes_filt[i][2:] += boxes_filt[i][:2]
|
| 536 |
+
|
| 537 |
+
boxes_filt = boxes_filt.to(sam_device)
|
| 538 |
+
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
|
| 539 |
+
|
| 540 |
+
masks, _, _, _ = sam_predictor.predict_torch(
|
| 541 |
+
point_coords = None,
|
| 542 |
+
point_labels = None,
|
| 543 |
+
boxes = transformed_boxes,
|
| 544 |
+
multimask_output = False,
|
| 545 |
+
)
|
| 546 |
+
# masks: [9, 1, 512, 512]
|
| 547 |
+
assert sam_checkpoint, 'sam_checkpoint is not found!'
|
| 548 |
+
# draw output image
|
| 549 |
+
plt.figure(figsize=(10, 10))
|
| 550 |
+
# we don't draw the background image
|
| 551 |
+
# plt.imshow(image)
|
| 552 |
+
boxes_with_labels = zip(boxes_filt, pred_phrases)
|
| 553 |
+
for mask in masks:
|
| 554 |
+
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
|
| 555 |
+
for box, label in boxes_with_labels:
|
| 556 |
+
show_box(box.cpu().numpy(), plt.gca(), label)
|
| 557 |
+
plt.axis('off')
|
| 558 |
+
image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.png")
|
| 559 |
+
plt.savefig(image_path, bbox_inches="tight")
|
| 560 |
+
segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
| 561 |
+
os.remove(image_path)
|
| 562 |
+
output_images.append(segment_image_result)
|
| 563 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
return pred_dict, output_images, gr.Gallery.update(label='result images')
|
| 565 |
|
| 566 |
if __name__ == "__main__":
|