kitooo commited on
Commit
69abdb1
·
verified ·
1 Parent(s): 65c9ebd

updated app

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -8,23 +8,22 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
 
9
  model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
10
  processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
11
- model = SamModel(config=model_config)
12
  model = SamModel.from_pretrained('kitooo/sidewalk-seg-base')
13
  model.to(device)
14
 
15
  def segment_sidewalk(image, threshold):
16
- # init data
17
  width, height = image.size
18
  prompt = [0, 0, width, height]
19
  inputs = processor(image, input_boxes=[[prompt]], return_tensors='pt')
20
- # make prediction
21
- outputs = model(pixel_values=inputs['pixel_values'].to(device),
22
- input_boxes=inputs['input_boxes'].to(device),
23
- multimask_output=False)
 
24
  prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
25
  prediction = (prob_map > threshold).float()
26
  prob_map, prediction = prob_map.numpy(), prediction.numpy()
27
- # visualize results
28
  save_image(image, 'image.png')
29
  save_image(prob_map, 'prob.png', cmap='jet')
30
  save_image(prediction, 'mask.png', cmap='gray')
@@ -45,8 +44,8 @@ with gr.Blocks() as demo:
45
  threshold_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold')
46
  segment_button = gr.Button('Segment')
47
  with gr.Column():
48
- prediction = gr.Image(type='pil', label='Segmentation Result')
49
- prob_map = gr.Image(type='pil', label='Probability Map')
50
  segment_button.click(
51
  segment_sidewalk,
52
  inputs=[image_input, threshold_slider],
 
8
 
9
  model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
10
  processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
11
+ # model = SamModel(config=model_config)
12
  model = SamModel.from_pretrained('kitooo/sidewalk-seg-base')
13
  model.to(device)
14
 
15
  def segment_sidewalk(image, threshold):
 
16
  width, height = image.size
17
  prompt = [0, 0, width, height]
18
  inputs = processor(image, input_boxes=[[prompt]], return_tensors='pt')
19
+ # outputs = model(pixel_values=inputs['pixel_values'].to(device),
20
+ # input_boxes=inputs['input_boxes'].to(device),
21
+ # multimask_output=False)
22
+ with torch.no_grad():
23
+ outputs = model(**inputs, multimask_output=False)
24
  prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
25
  prediction = (prob_map > threshold).float()
26
  prob_map, prediction = prob_map.numpy(), prediction.numpy()
 
27
  save_image(image, 'image.png')
28
  save_image(prob_map, 'prob.png', cmap='jet')
29
  save_image(prediction, 'mask.png', cmap='gray')
 
44
  threshold_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold')
45
  segment_button = gr.Button('Segment')
46
  with gr.Column():
47
+ prediction = gr.Image(type='pil', label='Predicted Mask')
48
+ prob_map = gr.Image(type='pil', label='Predicted Probability Map')
49
  segment_button.click(
50
  segment_sidewalk,
51
  inputs=[image_input, threshold_slider],