jbilcke-hf commited on
Commit
012a344
Β·
1 Parent(s): 302553c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -17
app.py CHANGED
@@ -210,11 +210,7 @@ def get_grounding_output(model, image, caption, box_threshold, text_threshold, w
210
 
211
  return boxes_filt, pred_phrases
212
 
213
- def show_mask(mask, ax, random_color=False):
214
- if random_color:
215
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
216
- else:
217
- color = np.array([30/255, 144/255, 255/255, 0.6])
218
  h, w = mask.shape[-2:]
219
  mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
220
  ax.imshow(mask_image)
@@ -515,14 +511,15 @@ def run_anything_task(input_image, text_prompt, box_threshold, text_threshold,
515
  return [], gr.Gallery.update(label='No objects detected, please try others.πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚')
516
  boxes_filt_ori = copy.deepcopy(boxes_filt)
517
 
518
- pred_dict = {
519
- "boxes": boxes_filt,
520
- "size": [size[1], size[0]], # H,W
521
- "labels": pred_phrases,
522
- }
523
-
524
- image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
525
- output_images.append(image_with_box)
 
526
 
527
  # now we generate the segmentation
528
  image = np.array(input_img)
@@ -547,7 +544,7 @@ def run_anything_task(input_image, text_prompt, box_threshold, text_threshold,
547
  assert sam_checkpoint, 'sam_checkpoint is not found!'
548
  # draw output image
549
  plt.figure(figsize=(10, 10))
550
- # we don't draw the background image
551
  # plt.imshow(image)
552
  boxes_with_labels = zip(boxes_filt, pred_phrases)
553
  debug = {
@@ -555,9 +552,13 @@ def run_anything_task(input_image, text_prompt, box_threshold, text_threshold,
555
  "thing2": boxes_with_labels,
556
  "thing3": pred_phrases
557
  }
558
-
 
 
559
  for mask in masks:
560
- show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
 
 
561
  for box, label in boxes_with_labels:
562
  show_box(box.cpu().numpy(), plt.gca(), label)
563
  plt.axis('off')
@@ -571,7 +572,7 @@ def run_anything_task(input_image, text_prompt, box_threshold, text_threshold,
571
  for i, box in enumerate(boxes_filt):
572
  label, score = pred_phrases[i][:-5], float(pred_phrases[i][-4:-1]) # assuming 'roof(0.70)' format
573
  print("label: " + label)
574
- print("score: " + score)
575
  print(box.tolist())
576
 
577
  return debug, output_images, gr.Gallery.update(label='result images')
 
210
 
211
  return boxes_filt, pred_phrases
212
 
213
+ def show_mask(mask, ax, color):
 
 
 
 
214
  h, w = mask.shape[-2:]
215
  mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
216
  ax.imshow(mask_image)
 
511
  return [], gr.Gallery.update(label='No objects detected, please try others.πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚')
512
  boxes_filt_ori = copy.deepcopy(boxes_filt)
513
 
514
+
515
+ # print bounding boxes only
516
+ #pred_dict = {
517
+ # "boxes": boxes_filt,
518
+ # "size": [size[1], size[0]], # H,W
519
+ # "labels": pred_phrases,
520
+ #}
521
+ # image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
522
+ # output_images.append(image_with_box)
523
 
524
  # now we generate the segmentation
525
  image = np.array(input_img)
 
544
  assert sam_checkpoint, 'sam_checkpoint is not found!'
545
  # draw output image
546
  plt.figure(figsize=(10, 10))
547
+ # we don't draw the background image, we only want the mask
548
  # plt.imshow(image)
549
  boxes_with_labels = zip(boxes_filt, pred_phrases)
550
  debug = {
 
552
  "thing2": boxes_with_labels,
553
  "thing3": pred_phrases
554
  }
555
+
556
+ results = []
557
+
558
  for mask in masks:
559
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
560
+ # color = np.array([30/255, 144/255, 255/255, 0.6])
561
+ show_mask(mask.cpu().numpy(), plt.gca(), color)
562
  for box, label in boxes_with_labels:
563
  show_box(box.cpu().numpy(), plt.gca(), label)
564
  plt.axis('off')
 
572
  for i, box in enumerate(boxes_filt):
573
  label, score = pred_phrases[i][:-5], float(pred_phrases[i][-4:-1]) # assuming 'roof(0.70)' format
574
  print("label: " + label)
575
+ print("score: " + str(score))
576
  print(box.tolist())
577
 
578
  return debug, output_images, gr.Gallery.update(label='result images')