jbilcke-hf commited on
Commit
86e8282
Β·
1 Parent(s): 4d2ca45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -72
app.py CHANGED
@@ -515,21 +515,31 @@ def relate_anything(input_image, k):
515
  mask_source_draw = "draw a mask on input image"
516
  mask_source_segment = "type what to detect below"
517
 
518
- def run_anything_task(input_image, text_prompt, box_threshold, text_threshold, iou_threshold, cleaner_size_limit=1080):
 
 
 
 
 
519
 
520
  text_prompt = text_prompt.strip()
521
- if text_prompt == '':
522
- return [], gr.Gallery.update(label='Detection prompt is not found!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚')
 
523
 
524
  if input_image is None:
525
  return [], gr.Gallery.update(label='Please upload a image!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚')
526
 
527
  file_temp = int(time.time())
 
528
 
529
  output_images = []
530
 
531
  # load image
532
-
 
 
 
533
  if isinstance(input_image, dict):
534
  image_pil, image = load_image(input_image['image'].convert("RGB"))
535
  input_img = input_image['image']
@@ -542,69 +552,159 @@ def run_anything_task(input_image, text_prompt, box_threshold, text_threshold, i
542
  size = image_pil.size
543
 
544
  # run grounding dino model
545
- groundingdino_device = 'cpu'
546
- if device != 'cpu':
547
- try:
548
- from groundingdino import _C
549
- groundingdino_device = 'cuda:0'
550
- except:
551
- warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only in groundingdino!")
552
-
553
- boxes_filt, pred_phrases = get_grounding_output(
554
- groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
555
- )
556
- if boxes_filt.size(0) == 0:
557
- logger.info(f'run_anything_task_[{file_temp}]_[{text_prompt}]_1_[No objects detected, please try others.]_')
558
- return [], gr.Gallery.update(label='No objects detected, please try others.πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚')
559
- boxes_filt_ori = copy.deepcopy(boxes_filt)
560
-
561
- pred_dict = {
562
- "boxes": boxes_filt,
563
- "size": [size[1], size[0]], # H,W
564
- "labels": pred_phrases,
565
- }
566
-
567
- # disabled: we don't want to see the boxes
568
- image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
569
- output_images.append(image_with_box)
570
-
571
- image = np.array(input_img)
572
- sam_predictor.set_image(image)
573
-
574
- H, W = size[1], size[0]
575
- for i in range(boxes_filt.size(0)):
576
- boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
577
- boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
578
- boxes_filt[i][2:] += boxes_filt[i][:2]
579
-
580
- boxes_filt = boxes_filt.to(sam_device)
581
- transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
582
-
583
- masks, _, _, _ = sam_predictor.predict_torch(
584
- point_coords = None,
585
- point_labels = None,
586
- boxes = transformed_boxes,
587
- multimask_output = False,
588
- )
589
- # masks: [9, 1, 512, 512]
590
- assert sam_checkpoint, 'sam_checkpoint is not found!'
591
- # draw output image
592
- plt.figure(figsize=(10, 10))
593
- plt.imshow(image)
594
- for mask in masks:
595
- show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
596
- for box, label in zip(boxes_filt, pred_phrases):
597
- show_box(box.cpu().numpy(), plt.gca(), label)
598
- plt.axis('off')
599
- image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
600
- plt.savefig(image_path, bbox_inches="tight")
601
- segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
602
- os.remove(image_path)
603
- output_images.append(segment_image_result)
604
-
605
-
606
- results = zip(boxes_filt, pred_phrases)
607
- return results, output_images, gr.Gallery.update(label='result images')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
608
 
609
  if __name__ == "__main__":
610
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
@@ -628,9 +728,14 @@ if __name__ == "__main__":
628
  with gr.Row():
629
  with gr.Column():
630
  input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload")
631
-
 
 
 
 
632
  text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
633
  inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
 
634
  run_button = gr.Button(label="Run", visible=True)
635
  with gr.Accordion("Advanced options", open=False) as advanced_options:
636
  box_threshold = gr.Slider(
@@ -642,17 +747,29 @@ if __name__ == "__main__":
642
  iou_threshold = gr.Slider(
643
  label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
644
  )
645
-
 
 
 
 
 
646
 
647
  with gr.Column():
648
  image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True
649
  ).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
650
 
651
  run_button.click(fn=run_anything_task, inputs=[
652
- input_image, text_prompt, box_threshold, text_threshold, iou_threshold], outputs=[gr.outputs.JSON(), image_gallery, image_gallery], show_progress=True, queue=True)
653
 
654
-
655
- DESCRIPTION = f'### This space is used by the experimental VideoQuest game. <br> It is based on <a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true">Grounded-Segment-Anything</a>'
 
 
 
 
 
 
 
656
  gr.Markdown(DESCRIPTION)
657
 
658
  computer_info()
 
515
  mask_source_draw = "draw a mask on input image"
516
  mask_source_segment = "type what to detect below"
517
 
518
+ def run_anything_task(input_image, text_prompt, inpaint_prompt, box_threshold, text_threshold,
519
+ iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, cleaner_size_limit=1080):
520
+ task_type = "segment"
521
+ if (task_type == 'relate anything'):
522
+ output_images = relate_anything(input_image['image'], num_relation)
523
+ return output_images, gr.Gallery.update(label='relate images')
524
 
525
  text_prompt = text_prompt.strip()
526
+ if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
527
+ if text_prompt == '':
528
+ return [], gr.Gallery.update(label='Detection prompt is not found!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚')
529
 
530
  if input_image is None:
531
  return [], gr.Gallery.update(label='Please upload a image!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚')
532
 
533
  file_temp = int(time.time())
534
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
535
 
536
  output_images = []
537
 
538
  # load image
539
+ if mask_source_radio == mask_source_draw:
540
+ input_mask_pil = input_image['mask']
541
+ input_mask = np.array(input_mask_pil.convert("L"))
542
+
543
  if isinstance(input_image, dict):
544
  image_pil, image = load_image(input_image['image'].convert("RGB"))
545
  input_img = input_image['image']
 
552
  size = image_pil.size
553
 
554
  # run grounding dino model
555
+ if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
556
+ pass
557
+ else:
558
+ groundingdino_device = 'cpu'
559
+ if device != 'cpu':
560
+ try:
561
+ from groundingdino import _C
562
+ groundingdino_device = 'cuda:0'
563
+ except:
564
+ warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only in groundingdino!")
565
+
566
+ boxes_filt, pred_phrases = get_grounding_output(
567
+ groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
568
+ )
569
+ if boxes_filt.size(0) == 0:
570
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1_[No objects detected, please try others.]_')
571
+ return [], gr.Gallery.update(label='No objects detected, please try others.πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚')
572
+ boxes_filt_ori = copy.deepcopy(boxes_filt)
573
+
574
+ pred_dict = {
575
+ "boxes": boxes_filt,
576
+ "size": [size[1], size[0]], # H,W
577
+ "labels": pred_phrases,
578
+ }
579
+
580
+ image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
581
+ output_images.append(image_with_box)
582
+
583
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
584
+ if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
585
+ image = np.array(input_img)
586
+ sam_predictor.set_image(image)
587
+
588
+ H, W = size[1], size[0]
589
+ for i in range(boxes_filt.size(0)):
590
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
591
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
592
+ boxes_filt[i][2:] += boxes_filt[i][:2]
593
+
594
+ boxes_filt = boxes_filt.to(sam_device)
595
+ transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
596
+
597
+ masks, _, _, _ = sam_predictor.predict_torch(
598
+ point_coords = None,
599
+ point_labels = None,
600
+ boxes = transformed_boxes,
601
+ multimask_output = False,
602
+ )
603
+ # masks: [9, 1, 512, 512]
604
+ assert sam_checkpoint, 'sam_checkpoint is not found!'
605
+ # draw output image
606
+ plt.figure(figsize=(10, 10))
607
+ plt.imshow(image)
608
+ for mask in masks:
609
+ show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
610
+ for box, label in zip(boxes_filt, pred_phrases):
611
+ show_box(box.cpu().numpy(), plt.gca(), label)
612
+ plt.axis('off')
613
+ image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
614
+ plt.savefig(image_path, bbox_inches="tight")
615
+ segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
616
+ os.remove(image_path)
617
+ output_images.append(segment_image_result)
618
+
619
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
620
+ if task_type == 'detection' or task_type == 'segment':
621
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
622
+ return output_images, gr.Gallery.update(label='result images')
623
+ elif task_type == 'inpainting' or task_type == 'remove':
624
+ if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
625
+ task_type = 'remove'
626
+
627
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
628
+ if mask_source_radio == mask_source_draw:
629
+ mask_pil = input_mask_pil
630
+ mask = input_mask
631
+ else:
632
+ masks_ori = copy.deepcopy(masks)
633
+ if inpaint_mode == 'merge':
634
+ masks = torch.sum(masks, dim=0).unsqueeze(0)
635
+ masks = torch.where(masks > 0, True, False)
636
+ mask = masks[0][0].cpu().numpy()
637
+ mask_pil = Image.fromarray(mask)
638
+ output_images.append(mask_pil.convert("RGB"))
639
+
640
+ if task_type == 'inpainting':
641
+ # inpainting pipeline
642
+ image_source_for_inpaint = image_pil.resize((512, 512))
643
+ image_mask_for_inpaint = mask_pil.resize((512, 512))
644
+ image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
645
+ else:
646
+ # remove from mask
647
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_5_')
648
+ if mask_source_radio == mask_source_segment:
649
+ mask_imgs = []
650
+ masks_shape = masks_ori.shape
651
+ boxes_filt_ori_array = boxes_filt_ori.numpy()
652
+ if inpaint_mode == 'merge':
653
+ extend_shape_0 = masks_shape[0]
654
+ extend_shape_1 = masks_shape[1]
655
+ else:
656
+ extend_shape_0 = 1
657
+ extend_shape_1 = 1
658
+ for i in range(extend_shape_0):
659
+ for j in range(extend_shape_1):
660
+ mask = masks_ori[i][j].cpu().numpy()
661
+ mask_pil = Image.fromarray(mask)
662
+
663
+ if remove_mode == 'segment':
664
+ useRectangle = False
665
+ else:
666
+ useRectangle = True
667
+
668
+ try:
669
+ remove_mask_extend = int(remove_mask_extend)
670
+ except:
671
+ remove_mask_extend = 10
672
+ mask_pil_exp = mask_extend(copy.deepcopy(mask_pil).convert("RGB"),
673
+ xywh_to_xyxy(torch.tensor(boxes_filt_ori_array[i]), size[0], size[1]),
674
+ extend_pixels=remove_mask_extend, useRectangle=useRectangle)
675
+ mask_imgs.append(mask_pil_exp)
676
+ mask_pil = mix_masks(mask_imgs)
677
+ output_images.append(mask_pil.convert("RGB"))
678
+
679
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
680
+ image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
681
+ # output_images.append(image_inpainting)
682
+
683
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_')
684
+ image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
685
+ output_images.append(image_inpainting)
686
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
687
+ return output_images, gr.Gallery.update(label='result images')
688
+ else:
689
+ logger.info(f"task_type:{task_type} error!")
690
+ logger.info(f'run_anything_task_[{file_temp}]_9_9_')
691
+ return output_images, gr.Gallery.update(label='result images')
692
+
693
+ def change_radio_display(task_type, mask_source_radio):
694
+ text_prompt_visible = True
695
+ inpaint_prompt_visible = False
696
+ mask_source_radio_visible = False
697
+ num_relation_visible = False
698
+ if task_type == "inpainting":
699
+ inpaint_prompt_visible = True
700
+ if task_type == "inpainting" or task_type == "remove":
701
+ mask_source_radio_visible = True
702
+ if mask_source_radio == mask_source_draw:
703
+ text_prompt_visible = False
704
+ if task_type == "relate anything":
705
+ text_prompt_visible = False
706
+ num_relation_visible = True
707
+ return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible)
708
 
709
  if __name__ == "__main__":
710
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
 
728
  with gr.Row():
729
  with gr.Column():
730
  input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload")
731
+ task_type = gr.Radio(["detection", "segment", "inpainting", "remove", "relate anything"], value="detection",
732
+ label='Task type', visible=True)
733
+ mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
734
+ value=mask_source_segment, label="Mask from",
735
+ visible=False)
736
  text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
737
  inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
738
+ num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
739
  run_button = gr.Button(label="Run", visible=True)
740
  with gr.Accordion("Advanced options", open=False) as advanced_options:
741
  box_threshold = gr.Slider(
 
747
  iou_threshold = gr.Slider(
748
  label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
749
  )
750
+ inpaint_mode = gr.Radio(["merge", "first"], value="merge", label="inpaint_mode")
751
+ with gr.Row():
752
+ with gr.Column(scale=1):
753
+ remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
754
+ with gr.Column(scale=1):
755
+ remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
756
 
757
  with gr.Column():
758
  image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True
759
  ).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
760
 
761
  run_button.click(fn=run_anything_task, inputs=[
762
+ input_image, text_prompt, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation], outputs=[image_gallery, image_gallery], show_progress=True, queue=True)
763
 
764
+ mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
765
+ task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
766
+
767
+ DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
768
+ DESCRIPTION += f'RAM from [RelateAnything](https://github.com/Luodian/RelateAnything). <br>'
769
+ DESCRIPTION += f'Remove(cleaner) from [lama-cleaner](https://github.com/Sanster/lama-cleaner). <br>'
770
+ DESCRIPTION += f'Thanks for their excellent work.'
771
+ DESCRIPTION += f'<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. \
772
+ <a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
773
  gr.Markdown(DESCRIPTION)
774
 
775
  computer_info()