jbilcke-hf commited on
Commit
8dce505
·
verified ·
1 Parent(s): 58240d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -30
app.py CHANGED
@@ -24,6 +24,7 @@ import argparse
24
  import copy
25
  import re
26
  import json
 
27
 
28
  import numpy as np
29
  import torch
@@ -66,6 +67,24 @@ from lama_cleaner.helper import (
66
 
67
  SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
70
  ckpt_repo_id = "ShilongLiu/GroundingDINO"
71
  ckpt_filenmae = "groundingdino_swint_ogc.pth"
@@ -589,12 +608,18 @@ def run_anything_task(secret_token, input_image, text_prompt, box_threshold, tex
589
  # show_box(box.cpu().numpy(), plt.gca(), label)
590
  plt.axis('off')
591
  image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.png")
 
 
592
  plt.savefig(image_path, bbox_inches="tight", pad_inches=0)
593
  segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
594
  os.remove(image_path)
595
- output_images.append(segment_image_result)
596
 
597
- return json.dumps(results), output_images, gr.Gallery.update(label='result images')
 
 
 
 
598
 
599
  if __name__ == "__main__":
600
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
@@ -623,28 +648,22 @@ if __name__ == "__main__":
623
  </div>
624
  </div>""")
625
  with gr.Row():
626
- with gr.Column():
627
- secret_token = gr.Textbox()
628
- text_prompt = gr.Textbox()
629
-
630
- input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload")
631
-
632
- text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
633
- run_button = gr.Button(label="Run", visible=True)
634
- with gr.Accordion("Advanced options", open=False) as advanced_options:
635
- box_threshold = gr.Slider(
636
- label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
637
- )
638
- text_threshold = gr.Slider(
639
- label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
640
- )
641
- iou_threshold = gr.Slider(
642
- label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
643
- )
644
 
645
- with gr.Column():
646
- image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True
647
- ).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
 
 
 
 
 
 
 
 
 
 
 
 
648
 
649
  run_button.click(
650
  fn=run_anything_task,
@@ -656,13 +675,7 @@ if __name__ == "__main__":
656
  text_threshold,
657
  iou_threshold
658
  ],
659
- outputs=[
660
- gr.Textbox(),
661
- image_gallery,
662
- image_gallery
663
- ],
664
- show_progress=False,
665
- queue=True
666
  )
667
 
668
  block.queue(max_size=20).launch(server_name='0.0.0.0', debug=args.debug, share=args.share)
 
24
  import copy
25
  import re
26
  import json
27
+ import base64
28
 
29
  import numpy as np
30
  import torch
 
67
 
68
  SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
69
 
70
+ # Regex pattern to match data URI scheme
71
+ data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,')
72
+
73
+ def readb64(b64):
74
+ # Remove any data URI scheme prefix with regex
75
+ b64 = data_uri_pattern.sub("", b64)
76
+ # Decode and open the image with PIL
77
+ img = Image.open(BytesIO(base64.b64decode(b64)))
78
+ return img
79
+
80
+ # convert from PIL to base64
81
+ def writeb64(image):
82
+ buffered = BytesIO()
83
+ image.save(buffered, format="PNG")
84
+ b64image = base64.b64encode(buffered.getvalue())
85
+ b64image_str = b64image.decode("utf-8")
86
+ return b64image_str
87
+
88
  config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
89
  ckpt_repo_id = "ShilongLiu/GroundingDINO"
90
  ckpt_filenmae = "groundingdino_swint_ogc.pth"
 
608
  # show_box(box.cpu().numpy(), plt.gca(), label)
609
  plt.axis('off')
610
  image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.png")
611
+
612
+ # do we really need to write to the disk to get an image? seems inneficient
613
  plt.savefig(image_path, bbox_inches="tight", pad_inches=0)
614
  segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
615
  os.remove(image_path)
616
+ # output_images.append(segment_image_result)
617
 
618
+ response_object = {
619
+ "data": results,
620
+ "bitmap": writeb64(segment_image_result) # save as PNG base64
621
+ }
622
+ return json.dumps(response_object)
623
 
624
  if __name__ == "__main__":
625
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
 
648
  </div>
649
  </div>""")
650
  with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651
 
652
+ secret_token = gr.Textbox()
653
+ text_prompt = gr.Textbox()
654
+ input_image = gr.Textbox()
655
+ text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
656
+ run_button = gr.Button(label="Run", visible=True)
657
+ with gr.Accordion("Advanced options", open=False) as advanced_options:
658
+ box_threshold = gr.Slider(
659
+ label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
660
+ )
661
+ text_threshold = gr.Slider(
662
+ label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
663
+ )
664
+ iou_threshold = gr.Slider(
665
+ label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
666
+ )
667
 
668
  run_button.click(
669
  fn=run_anything_task,
 
675
  text_threshold,
676
  iou_threshold
677
  ],
678
+ outputs=gr.Textbox()
 
 
 
 
 
 
679
  )
680
 
681
  block.queue(max_size=20).launch(server_name='0.0.0.0', debug=args.debug, share=args.share)