jbilcke-hf commited on
Commit
2f2223a
Β·
1 Parent(s): f1d2589

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -39
app.py CHANGED
@@ -554,7 +554,7 @@ def run_anything_task(input_image, text_prompt, box_threshold, text_threshold, i
554
  groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
555
  )
556
  if boxes_filt.size(0) == 0:
557
- logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1_[No objects detected, please try others.]_')
558
  return [], gr.Gallery.update(label='No objects detected, please try others.πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚')
559
  boxes_filt_ori = copy.deepcopy(boxes_filt)
560
 
@@ -565,44 +565,42 @@ def run_anything_task(input_image, text_prompt, box_threshold, text_threshold, i
565
  }
566
 
567
  # disabled: we don't want to see the boxes
568
- # image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
569
- # output_images.append(image_with_box)
570
-
571
-
572
- if task_type == 'segment':
573
- image = np.array(input_img)
574
- sam_predictor.set_image(image)
575
-
576
- H, W = size[1], size[0]
577
- for i in range(boxes_filt.size(0)):
578
- boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
579
- boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
580
- boxes_filt[i][2:] += boxes_filt[i][:2]
581
-
582
- boxes_filt = boxes_filt.to(sam_device)
583
- transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
584
-
585
- masks, _, _, _ = sam_predictor.predict_torch(
586
- point_coords = None,
587
- point_labels = None,
588
- boxes = transformed_boxes,
589
- multimask_output = False,
590
- )
591
- # masks: [9, 1, 512, 512]
592
- assert sam_checkpoint, 'sam_checkpoint is not found!'
593
- # draw output image
594
- plt.figure(figsize=(10, 10))
595
- plt.imshow(image)
596
- for mask in masks:
597
- show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
598
- for box, label in zip(boxes_filt, pred_phrases):
599
- show_box(box.cpu().numpy(), plt.gca(), label)
600
- plt.axis('off')
601
- image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
602
- plt.savefig(image_path, bbox_inches="tight")
603
- segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
604
- os.remove(image_path)
605
- output_images.append(segment_image_result)
606
 
607
 
608
  results = zip(boxes_filt, pred_phrases)
 
554
  groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
555
  )
556
  if boxes_filt.size(0) == 0:
557
+ logger.info(f'run_anything_task_[{file_temp}]_[{text_prompt}]_1_[No objects detected, please try others.]_')
558
  return [], gr.Gallery.update(label='No objects detected, please try others.πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚')
559
  boxes_filt_ori = copy.deepcopy(boxes_filt)
560
 
 
565
  }
566
 
567
  # disabled: we don't want to see the boxes
568
+ image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
569
+ output_images.append(image_with_box)
570
+
571
+ image = np.array(input_img)
572
+ sam_predictor.set_image(image)
573
+
574
+ H, W = size[1], size[0]
575
+ for i in range(boxes_filt.size(0)):
576
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
577
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
578
+ boxes_filt[i][2:] += boxes_filt[i][:2]
579
+
580
+ boxes_filt = boxes_filt.to(sam_device)
581
+ transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
582
+
583
+ masks, _, _, _ = sam_predictor.predict_torch(
584
+ point_coords = None,
585
+ point_labels = None,
586
+ boxes = transformed_boxes,
587
+ multimask_output = False,
588
+ )
589
+ # masks: [9, 1, 512, 512]
590
+ assert sam_checkpoint, 'sam_checkpoint is not found!'
591
+ # draw output image
592
+ plt.figure(figsize=(10, 10))
593
+ plt.imshow(image)
594
+ for mask in masks:
595
+ show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
596
+ for box, label in zip(boxes_filt, pred_phrases):
597
+ show_box(box.cpu().numpy(), plt.gca(), label)
598
+ plt.axis('off')
599
+ image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
600
+ plt.savefig(image_path, bbox_inches="tight")
601
+ segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
602
+ os.remove(image_path)
603
+ output_images.append(segment_image_result)
 
 
604
 
605
 
606
  results = zip(boxes_filt, pred_phrases)