kitooo commited on
Commit
2fefa77
·
verified ·
1 Parent(s): 69abdb1

updated app

Browse files
Files changed (1) hide show
  1. app.py +8 -10
app.py CHANGED
@@ -12,17 +12,14 @@ processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
12
  model = SamModel.from_pretrained('kitooo/sidewalk-seg-base')
13
  model.to(device)
14
 
15
- def segment_sidewalk(image, threshold):
16
  width, height = image.size
17
  prompt = [0, 0, width, height]
18
  inputs = processor(image, input_boxes=[[prompt]], return_tensors='pt')
19
- # outputs = model(pixel_values=inputs['pixel_values'].to(device),
20
- # input_boxes=inputs['input_boxes'].to(device),
21
- # multimask_output=False)
22
  with torch.no_grad():
23
  outputs = model(**inputs, multimask_output=False)
24
  prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
25
- prediction = (prob_map > threshold).float()
26
  prob_map, prediction = prob_map.numpy(), prediction.numpy()
27
  save_image(image, 'image.png')
28
  save_image(prob_map, 'prob.png', cmap='jet')
@@ -41,14 +38,15 @@ with gr.Blocks() as demo:
41
  with gr.Row():
42
  with gr.Column():
43
  image_input = gr.Image(type='pil', label='TIFF Image')
44
- threshold_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold')
45
- segment_button = gr.Button('Segment')
46
  with gr.Column():
47
- prediction = gr.Image(type='pil', label='Predicted Mask')
48
  prob_map = gr.Image(type='pil', label='Predicted Probability Map')
49
  segment_button.click(
50
  segment_sidewalk,
51
- inputs=[image_input, threshold_slider],
52
- outputs=[image_input, prediction, prob_map]
 
53
  )
54
  demo.launch(debug=True, show_error=True)
 
12
  model = SamModel.from_pretrained('kitooo/sidewalk-seg-base')
13
  model.to(device)
14
 
15
+ def segment_sidewalk(image):
16
  width, height = image.size
17
  prompt = [0, 0, width, height]
18
  inputs = processor(image, input_boxes=[[prompt]], return_tensors='pt')
 
 
 
19
  with torch.no_grad():
20
  outputs = model(**inputs, multimask_output=False)
21
  prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
22
+ prediction = (prob_map > 0.5).float()
23
  prob_map, prediction = prob_map.numpy(), prediction.numpy()
24
  save_image(image, 'image.png')
25
  save_image(prob_map, 'prob.png', cmap='jet')
 
38
  with gr.Row():
39
  with gr.Column():
40
  image_input = gr.Image(type='pil', label='TIFF Image')
41
+ # threshold_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold')
42
+ segment_button = gr.Button('Get Sidewalk Mask')
43
  with gr.Column():
44
+ mask = gr.Image(type='pil', label='Predicted Mask')
45
  prob_map = gr.Image(type='pil', label='Predicted Probability Map')
46
  segment_button.click(
47
  segment_sidewalk,
48
+ # inputs=[image_input, threshold_slider],
49
+ inputs=[image_input],
50
+ outputs=[image_input, mask, prob_map]
51
  )
52
  demo.launch(debug=True, show_error=True)