jbilcke-hf commited on
Commit
f1d2589
Β·
1 Parent(s): b7ab178

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -9
app.py CHANGED
@@ -466,6 +466,54 @@ def concatenate_images_vertical(image1, image2):
466
 
467
  return new_image
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
  def run_anything_task(input_image, text_prompt, box_threshold, text_threshold, iou_threshold, cleaner_size_limit=1080):
471
 
@@ -492,7 +540,7 @@ def run_anything_task(input_image, text_prompt, box_threshold, text_threshold, i
492
  output_images.append(input_image)
493
 
494
  size = image_pil.size
495
-
496
  # run grounding dino model
497
  groundingdino_device = 'cpu'
498
  if device != 'cpu':
@@ -506,7 +554,7 @@ def run_anything_task(input_image, text_prompt, box_threshold, text_threshold, i
506
  groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
507
  )
508
  if boxes_filt.size(0) == 0:
509
- logger.info(f'run_anything_task_[{file_temp}]_[{text_prompt}]_1_[No objects detected, please try others.]_')
510
  return [], gr.Gallery.update(label='No objects detected, please try others.πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚')
511
  boxes_filt_ori = copy.deepcopy(boxes_filt)
512
 
@@ -516,15 +564,52 @@ def run_anything_task(input_image, text_prompt, box_threshold, text_threshold, i
516
  "labels": pred_phrases,
517
  }
518
 
519
- image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
520
- output_images.append(image_with_box)
 
521
 
522
 
523
- return pred_dict
524
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
 
526
  if __name__ == "__main__":
527
- parser = argparse.ArgumentParser("VideoQuest segmentation module", add_help=True)
528
  parser.add_argument("--debug", action="store_true", help="using debug mode")
529
  parser.add_argument("--share", action="store_true", help="share the app")
530
  args = parser.parse_args()
@@ -547,7 +632,7 @@ if __name__ == "__main__":
547
  input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload")
548
 
549
  text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
550
-
551
  run_button = gr.Button(label="Run", visible=True)
552
  with gr.Accordion("Advanced options", open=False) as advanced_options:
553
  box_threshold = gr.Slider(
@@ -559,9 +644,15 @@ if __name__ == "__main__":
559
  iou_threshold = gr.Slider(
560
  label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
561
  )
 
 
 
 
 
562
 
563
  run_button.click(fn=run_anything_task, inputs=[
564
- input_image, text_prompt, box_threshold, text_threshold, iou_threshold], outputs=[gr.outputs.JSON()], show_progress=True, queue=True)
 
565
 
566
  DESCRIPTION = f'### This space is used by the experimental VideoQuest game. <br> It is based on <a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true">Grounded-Segment-Anything</a>'
567
  gr.Markdown(DESCRIPTION)
 
466
 
467
  return new_image
468
 
469
+ def relate_anything(input_image, k):
470
+ logger.info(f'relate_anything_1_{input_image.size}_')
471
+ w, h = input_image.size
472
+ max_edge = 1500
473
+ if w > max_edge or h > max_edge:
474
+ ratio = max(w, h) / max_edge
475
+ new_size = (int(w / ratio), int(h / ratio))
476
+ input_image.thumbnail(new_size)
477
+
478
+ logger.info(f'relate_anything_2_')
479
+ # load image
480
+ pil_image = input_image.convert('RGBA')
481
+ image = np.array(input_image)
482
+ sam_masks = sam_mask_generator.generate(image)
483
+ filtered_masks = sort_and_deduplicate(sam_masks)
484
+
485
+ logger.info(f'relate_anything_3_')
486
+ feat_list = []
487
+ for fm in filtered_masks:
488
+ feat = torch.Tensor(fm['feat']).unsqueeze(0).unsqueeze(0).to(device)
489
+ feat_list.append(feat)
490
+ feat = torch.cat(feat_list, dim=1).to(device)
491
+ matrix_output, rel_triplets = ram_model.predict(feat)
492
+
493
+ logger.info(f'relate_anything_4_')
494
+ pil_image_list = []
495
+ for i, rel in enumerate(rel_triplets[:k]):
496
+ s,o,r = int(rel[0]),int(rel[1]),int(rel[2])
497
+ relation = relation_classes[r]
498
+
499
+ mask_image = Image.new('RGBA', pil_image.size, color=(0, 0, 0, 0))
500
+ mask_draw = ImageDraw.Draw(mask_image)
501
+
502
+ draw_selected_mask(filtered_masks[s]['segmentation'], mask_draw)
503
+ draw_object_mask(filtered_masks[o]['segmentation'], mask_draw)
504
+
505
+ current_pil_image = pil_image.copy()
506
+ current_pil_image.alpha_composite(mask_image)
507
+
508
+ title_image = create_title_image('Red', relation, 'Blue', current_pil_image.size[0])
509
+ concate_pil_image = concatenate_images_vertical(current_pil_image, title_image)
510
+ pil_image_list.append(concate_pil_image)
511
+
512
+ logger.info(f'relate_anything_5_{len(pil_image_list)}')
513
+ return pil_image_list
514
+
515
+ mask_source_draw = "draw a mask on input image"
516
+ mask_source_segment = "type what to detect below"
517
 
518
  def run_anything_task(input_image, text_prompt, box_threshold, text_threshold, iou_threshold, cleaner_size_limit=1080):
519
 
 
540
  output_images.append(input_image)
541
 
542
  size = image_pil.size
543
+
544
  # run grounding dino model
545
  groundingdino_device = 'cpu'
546
  if device != 'cpu':
 
554
  groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
555
  )
556
  if boxes_filt.size(0) == 0:
557
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1_[No objects detected, please try others.]_')
558
  return [], gr.Gallery.update(label='No objects detected, please try others.πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚')
559
  boxes_filt_ori = copy.deepcopy(boxes_filt)
560
 
 
564
  "labels": pred_phrases,
565
  }
566
 
567
+ # disabled: we don't want to see the boxes
568
+ # image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
569
+ # output_images.append(image_with_box)
570
 
571
 
572
+ if task_type == 'segment':
573
+ image = np.array(input_img)
574
+ sam_predictor.set_image(image)
575
+
576
+ H, W = size[1], size[0]
577
+ for i in range(boxes_filt.size(0)):
578
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
579
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
580
+ boxes_filt[i][2:] += boxes_filt[i][:2]
581
+
582
+ boxes_filt = boxes_filt.to(sam_device)
583
+ transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
584
+
585
+ masks, _, _, _ = sam_predictor.predict_torch(
586
+ point_coords = None,
587
+ point_labels = None,
588
+ boxes = transformed_boxes,
589
+ multimask_output = False,
590
+ )
591
+ # masks: [9, 1, 512, 512]
592
+ assert sam_checkpoint, 'sam_checkpoint is not found!'
593
+ # draw output image
594
+ plt.figure(figsize=(10, 10))
595
+ plt.imshow(image)
596
+ for mask in masks:
597
+ show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
598
+ for box, label in zip(boxes_filt, pred_phrases):
599
+ show_box(box.cpu().numpy(), plt.gca(), label)
600
+ plt.axis('off')
601
+ image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
602
+ plt.savefig(image_path, bbox_inches="tight")
603
+ segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
604
+ os.remove(image_path)
605
+ output_images.append(segment_image_result)
606
+
607
+
608
+ results = zip(boxes_filt, pred_phrases)
609
+ return results, output_images, gr.Gallery.update(label='result images')
610
 
611
  if __name__ == "__main__":
612
+ parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
613
  parser.add_argument("--debug", action="store_true", help="using debug mode")
614
  parser.add_argument("--share", action="store_true", help="share the app")
615
  args = parser.parse_args()
 
632
  input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload")
633
 
634
  text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
635
+ inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
636
  run_button = gr.Button(label="Run", visible=True)
637
  with gr.Accordion("Advanced options", open=False) as advanced_options:
638
  box_threshold = gr.Slider(
 
644
  iou_threshold = gr.Slider(
645
  label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
646
  )
647
+
648
+
649
+ with gr.Column():
650
+ image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True
651
+ ).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
652
 
653
  run_button.click(fn=run_anything_task, inputs=[
654
+ input_image, text_prompt, task_type, box_threshold, text_threshold, iou_threshold], outputs=[gr.outputs.JSON(), image_gallery, image_gallery], show_progress=True, queue=True)
655
+
656
 
657
  DESCRIPTION = f'### This space is used by the experimental VideoQuest game. <br> It is based on <a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true">Grounded-Segment-Anything</a>'
658
  gr.Markdown(DESCRIPTION)