jbilcke-hf commited on
Commit
48104ec
·
1 Parent(s): 86e8282

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -108
app.py CHANGED
@@ -515,12 +515,9 @@ def relate_anything(input_image, k):
515
  mask_source_draw = "draw a mask on input image"
516
  mask_source_segment = "type what to detect below"
517
 
518
- def run_anything_task(input_image, text_prompt, inpaint_prompt, box_threshold, text_threshold,
519
- iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, cleaner_size_limit=1080):
520
  task_type = "segment"
521
- if (task_type == 'relate anything'):
522
- output_images = relate_anything(input_image['image'], num_relation)
523
- return output_images, gr.Gallery.update(label='relate images')
524
 
525
  text_prompt = text_prompt.strip()
526
  if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
@@ -536,10 +533,7 @@ def run_anything_task(input_image, text_prompt, inpaint_prompt, box_threshold, t
536
  output_images = []
537
 
538
  # load image
539
- if mask_source_radio == mask_source_draw:
540
- input_mask_pil = input_image['mask']
541
- input_mask = np.array(input_mask_pil.convert("L"))
542
-
543
  if isinstance(input_image, dict):
544
  image_pil, image = load_image(input_image['image'].convert("RGB"))
545
  input_img = input_image['image']
@@ -619,93 +613,12 @@ def run_anything_task(input_image, text_prompt, inpaint_prompt, box_threshold, t
619
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
620
  if task_type == 'detection' or task_type == 'segment':
621
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
622
- return output_images, gr.Gallery.update(label='result images')
623
- elif task_type == 'inpainting' or task_type == 'remove':
624
- if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
625
- task_type = 'remove'
626
-
627
- logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
628
- if mask_source_radio == mask_source_draw:
629
- mask_pil = input_mask_pil
630
- mask = input_mask
631
- else:
632
- masks_ori = copy.deepcopy(masks)
633
- if inpaint_mode == 'merge':
634
- masks = torch.sum(masks, dim=0).unsqueeze(0)
635
- masks = torch.where(masks > 0, True, False)
636
- mask = masks[0][0].cpu().numpy()
637
- mask_pil = Image.fromarray(mask)
638
- output_images.append(mask_pil.convert("RGB"))
639
-
640
- if task_type == 'inpainting':
641
- # inpainting pipeline
642
- image_source_for_inpaint = image_pil.resize((512, 512))
643
- image_mask_for_inpaint = mask_pil.resize((512, 512))
644
- image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
645
- else:
646
- # remove from mask
647
- logger.info(f'run_anything_task_[{file_temp}]_{task_type}_5_')
648
- if mask_source_radio == mask_source_segment:
649
- mask_imgs = []
650
- masks_shape = masks_ori.shape
651
- boxes_filt_ori_array = boxes_filt_ori.numpy()
652
- if inpaint_mode == 'merge':
653
- extend_shape_0 = masks_shape[0]
654
- extend_shape_1 = masks_shape[1]
655
- else:
656
- extend_shape_0 = 1
657
- extend_shape_1 = 1
658
- for i in range(extend_shape_0):
659
- for j in range(extend_shape_1):
660
- mask = masks_ori[i][j].cpu().numpy()
661
- mask_pil = Image.fromarray(mask)
662
-
663
- if remove_mode == 'segment':
664
- useRectangle = False
665
- else:
666
- useRectangle = True
667
-
668
- try:
669
- remove_mask_extend = int(remove_mask_extend)
670
- except:
671
- remove_mask_extend = 10
672
- mask_pil_exp = mask_extend(copy.deepcopy(mask_pil).convert("RGB"),
673
- xywh_to_xyxy(torch.tensor(boxes_filt_ori_array[i]), size[0], size[1]),
674
- extend_pixels=remove_mask_extend, useRectangle=useRectangle)
675
- mask_imgs.append(mask_pil_exp)
676
- mask_pil = mix_masks(mask_imgs)
677
- output_images.append(mask_pil.convert("RGB"))
678
-
679
- logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
680
- image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
681
- # output_images.append(image_inpainting)
682
-
683
- logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_')
684
- image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
685
- output_images.append(image_inpainting)
686
- logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
687
- return output_images, gr.Gallery.update(label='result images')
688
  else:
689
  logger.info(f"task_type:{task_type} error!")
690
  logger.info(f'run_anything_task_[{file_temp}]_9_9_')
691
  return output_images, gr.Gallery.update(label='result images')
692
 
693
- def change_radio_display(task_type, mask_source_radio):
694
- text_prompt_visible = True
695
- inpaint_prompt_visible = False
696
- mask_source_radio_visible = False
697
- num_relation_visible = False
698
- if task_type == "inpainting":
699
- inpaint_prompt_visible = True
700
- if task_type == "inpainting" or task_type == "remove":
701
- mask_source_radio_visible = True
702
- if mask_source_radio == mask_source_draw:
703
- text_prompt_visible = False
704
- if task_type == "relate anything":
705
- text_prompt_visible = False
706
- num_relation_visible = True
707
- return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible)
708
-
709
  if __name__ == "__main__":
710
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
711
  parser.add_argument("--debug", action="store_true", help="using debug mode")
@@ -728,14 +641,8 @@ if __name__ == "__main__":
728
  with gr.Row():
729
  with gr.Column():
730
  input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload")
731
- task_type = gr.Radio(["detection", "segment", "inpainting", "remove", "relate anything"], value="detection",
732
- label='Task type', visible=True)
733
- mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
734
- value=mask_source_segment, label="Mask from",
735
- visible=False)
736
  text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
737
- inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
738
- num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
739
  run_button = gr.Button(label="Run", visible=True)
740
  with gr.Accordion("Advanced options", open=False) as advanced_options:
741
  box_threshold = gr.Slider(
@@ -747,22 +654,13 @@ if __name__ == "__main__":
747
  iou_threshold = gr.Slider(
748
  label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
749
  )
750
- inpaint_mode = gr.Radio(["merge", "first"], value="merge", label="inpaint_mode")
751
- with gr.Row():
752
- with gr.Column(scale=1):
753
- remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
754
- with gr.Column(scale=1):
755
- remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
756
 
757
  with gr.Column():
758
  image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True
759
  ).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
760
 
761
  run_button.click(fn=run_anything_task, inputs=[
762
- input_image, text_prompt, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation], outputs=[image_gallery, image_gallery], show_progress=True, queue=True)
763
-
764
- mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
765
- task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
766
 
767
  DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
768
  DESCRIPTION += f'RAM from [RelateAnything](https://github.com/Luodian/RelateAnything). <br>'
 
515
  mask_source_draw = "draw a mask on input image"
516
  mask_source_segment = "type what to detect below"
517
 
518
+ def run_anything_task(input_image, text_prompt, box_threshold, text_threshold,
519
+ iou_threshold, cleaner_size_limit=1080):
520
  task_type = "segment"
 
 
 
521
 
522
  text_prompt = text_prompt.strip()
523
  if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
 
533
  output_images = []
534
 
535
  # load image
536
+
 
 
 
537
  if isinstance(input_image, dict):
538
  image_pil, image = load_image(input_image['image'].convert("RGB"))
539
  input_img = input_image['image']
 
613
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
614
  if task_type == 'detection' or task_type == 'segment':
615
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
616
+ return output_images, gr.Gallery.update(label='result images')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617
  else:
618
  logger.info(f"task_type:{task_type} error!")
619
  logger.info(f'run_anything_task_[{file_temp}]_9_9_')
620
  return output_images, gr.Gallery.update(label='result images')
621
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
622
  if __name__ == "__main__":
623
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
624
  parser.add_argument("--debug", action="store_true", help="using debug mode")
 
641
  with gr.Row():
642
  with gr.Column():
643
  input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload")
644
+
 
 
 
 
645
  text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
 
 
646
  run_button = gr.Button(label="Run", visible=True)
647
  with gr.Accordion("Advanced options", open=False) as advanced_options:
648
  box_threshold = gr.Slider(
 
654
  iou_threshold = gr.Slider(
655
  label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
656
  )
 
 
 
 
 
 
657
 
658
  with gr.Column():
659
  image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True
660
  ).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
661
 
662
  run_button.click(fn=run_anything_task, inputs=[
663
+ input_image, text_prompt, box_threshold, text_threshold, iou_threshold], outputs=[image_gallery, image_gallery], show_progress=True, queue=True)
 
 
 
664
 
665
  DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
666
  DESCRIPTION += f'RAM from [RelateAnything](https://github.com/Luodian/RelateAnything). <br>'