jbilcke-hf commited on
Commit
3260133
Β·
1 Parent(s): a84c111

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -66
app.py CHANGED
@@ -499,72 +499,68 @@ def run_anything_task(input_image, text_prompt, box_threshold, text_threshold,
499
  size = image_pil.size
500
 
501
  # run grounding dino model
502
- if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
503
- pass
504
- else:
505
- groundingdino_device = 'cpu'
506
- if device != 'cpu':
507
- try:
508
- from groundingdino import _C
509
- groundingdino_device = 'cuda:0'
510
- except:
511
- warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only in groundingdino!")
512
-
513
- boxes_filt, pred_phrases = get_grounding_output(
514
- groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
515
- )
516
- if boxes_filt.size(0) == 0:
517
- logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1_[No objects detected, please try others.]_')
518
- return [], gr.Gallery.update(label='No objects detected, please try others.πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚')
519
- boxes_filt_ori = copy.deepcopy(boxes_filt)
520
-
521
- pred_dict = {
522
- "boxes": boxes_filt,
523
- "size": [size[1], size[0]], # H,W
524
- "labels": pred_phrases,
525
- }
526
-
527
- image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
528
- output_images.append(image_with_box)
529
-
530
- logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
531
- if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
532
- image = np.array(input_img)
533
- sam_predictor.set_image(image)
534
-
535
- H, W = size[1], size[0]
536
- for i in range(boxes_filt.size(0)):
537
- boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
538
- boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
539
- boxes_filt[i][2:] += boxes_filt[i][:2]
540
-
541
- boxes_filt = boxes_filt.to(sam_device)
542
- transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
543
-
544
- masks, _, _, _ = sam_predictor.predict_torch(
545
- point_coords = None,
546
- point_labels = None,
547
- boxes = transformed_boxes,
548
- multimask_output = False,
549
- )
550
- # masks: [9, 1, 512, 512]
551
- assert sam_checkpoint, 'sam_checkpoint is not found!'
552
- # draw output image
553
- plt.figure(figsize=(10, 10))
554
- # we don't draw the background image
555
- # plt.imshow(image)
556
- for mask in masks:
557
- show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
558
- # for box, label in zip(boxes_filt, pred_phrases):
559
- # show_box(box.cpu().numpy(), plt.gca(), label)
560
- plt.axis('off')
561
- image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.png")
562
- plt.savefig(image_path, bbox_inches="tight")
563
- segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
564
- os.remove(image_path)
565
- output_images.append(segment_image_result)
566
-
567
- logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
568
  return pred_dict, output_images, gr.Gallery.update(label='result images')
569
 
570
  if __name__ == "__main__":
 
499
  size = image_pil.size
500
 
501
  # run grounding dino model
502
+ groundingdino_device = 'cpu'
503
+ if device != 'cpu':
504
+ try:
505
+ from groundingdino import _C
506
+ groundingdino_device = 'cuda:0'
507
+ except:
508
+ warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only in groundingdino!")
509
+
510
+ boxes_filt, pred_phrases = get_grounding_output(
511
+ groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
512
+ )
513
+ if boxes_filt.size(0) == 0:
514
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1_[No objects detected, please try others.]_')
515
+ return [], gr.Gallery.update(label='No objects detected, please try others.πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚')
516
+ boxes_filt_ori = copy.deepcopy(boxes_filt)
517
+
518
+ pred_dict = {
519
+ "boxes": boxes_filt,
520
+ "size": [size[1], size[0]], # H,W
521
+ "labels": pred_phrases,
522
+ }
523
+
524
+ image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
525
+ output_images.append(image_with_box)
526
+
527
+ # now we generate the segmentation
528
+ image = np.array(input_img)
529
+ sam_predictor.set_image(image)
530
+
531
+ H, W = size[1], size[0]
532
+ for i in range(boxes_filt.size(0)):
533
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
534
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
535
+ boxes_filt[i][2:] += boxes_filt[i][:2]
536
+
537
+ boxes_filt = boxes_filt.to(sam_device)
538
+ transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
539
+
540
+ masks, _, _, _ = sam_predictor.predict_torch(
541
+ point_coords = None,
542
+ point_labels = None,
543
+ boxes = transformed_boxes,
544
+ multimask_output = False,
545
+ )
546
+ # masks: [9, 1, 512, 512]
547
+ assert sam_checkpoint, 'sam_checkpoint is not found!'
548
+ # draw output image
549
+ plt.figure(figsize=(10, 10))
550
+ # we don't draw the background image
551
+ # plt.imshow(image)
552
+ boxes_with_labels = zip(boxes_filt, pred_phrases)
553
+ for mask in masks:
554
+ show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
555
+ for box, label in boxes_with_labels:
556
+ show_box(box.cpu().numpy(), plt.gca(), label)
557
+ plt.axis('off')
558
+ image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.png")
559
+ plt.savefig(image_path, bbox_inches="tight")
560
+ segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
561
+ os.remove(image_path)
562
+ output_images.append(segment_image_result)
563
+
 
 
 
 
564
  return pred_dict, output_images, gr.Gallery.update(label='result images')
565
 
566
  if __name__ == "__main__":