kitooo commited on
Commit
bd2491f
·
verified ·
1 Parent(s): 57bb905

updated app

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -2,11 +2,15 @@ import torch
2
  import gradio as gr
3
  import matplotlib.pyplot as plt
4
  from PIL import Image
5
- from transformers import SamModel, SamProcessor
6
 
7
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
8
  processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
9
- model = SamModel.from_pretrained('hmdliu/sidewalks-seg-base')
 
10
  model.to(device)
11
 
12
  def segment_sidewalk(image, threshold):
 
2
  import gradio as gr
3
  import matplotlib.pyplot as plt
4
  from PIL import Image
5
+ from transformers import SamModel, SamConfig, SamProcessor
6
 
7
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
+
9
+ # model = SamModel.from_pretrained('hmdliu/sidewalks-seg-base')
10
+ model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
11
  processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
12
+ model = SamModel(config=model_config)
13
+ model.load_state_dict(torch.load("/content/drive/MyDrive/Project/sidewalk_model_epoch10.pth"))
14
  model.to(device)
15
 
16
  def segment_sidewalk(image, threshold):