Spaces:
Build error
Build error
Commit
·
0daeefc
1
Parent(s):
d5b9c19
mps
Browse files
app.py
CHANGED
|
@@ -38,7 +38,14 @@ with open(os.path.join(CONFIG), "r") as f:
|
|
| 38 |
|
| 39 |
cfg = dict2namespace(config)
|
| 40 |
|
| 41 |
-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
model = instructir.create_model(input_channels =cfg.model.in_ch, width=cfg.model.width, enc_blks = cfg.model.enc_blks,
|
| 43 |
middle_blk_num = cfg.model.middle_blk_num, dec_blks = cfg.model.dec_blks, txtdim=cfg.model.textdim)
|
| 44 |
model = model.to(device)
|
|
@@ -148,4 +155,4 @@ demo = gr.Interface(
|
|
| 148 |
)
|
| 149 |
|
| 150 |
if __name__ == "__main__":
|
| 151 |
-
demo.launch()
|
|
|
|
| 38 |
|
| 39 |
cfg = dict2namespace(config)
|
| 40 |
|
| 41 |
+
#device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 42 |
+
if torch.backends.mps.is_available():
|
| 43 |
+
device = "mps"
|
| 44 |
+
torch_dtype = torch.float32
|
| 45 |
+
elif torch.cuda.is_available():
|
| 46 |
+
device = "cuda"
|
| 47 |
+
else:
|
| 48 |
+
device = "cpu"
|
| 49 |
model = instructir.create_model(input_channels =cfg.model.in_ch, width=cfg.model.width, enc_blks = cfg.model.enc_blks,
|
| 50 |
middle_blk_num = cfg.model.middle_blk_num, dec_blks = cfg.model.dec_blks, txtdim=cfg.model.textdim)
|
| 51 |
model = model.to(device)
|
|
|
|
| 155 |
)
|
| 156 |
|
| 157 |
if __name__ == "__main__":
|
| 158 |
+
demo.launch()
|