zhangziang commited on
Commit
939a4f3
·
1 Parent(s): a1a407f
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -15,11 +15,12 @@ from huggingface_hub import hf_hub_download
15
  ckpt_path = hf_hub_download(repo_id=ORIANY_V2, filename=REMOTE_CKPT_PATH, repo_type="model", cache_dir='./', resume_download=True)
16
  print(ckpt_path)
17
 
18
-
19
- # mark_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
20
- mark_dtype = torch.float16
21
- # device = 'cuda:0'
22
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
23
 
24
  model = VGGT_OriAny_Ref(out_dim=900, dtype=mark_dtype, nopretrain=True)
25
  model.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
@@ -143,8 +144,8 @@ def run_inference(image_ref, image_tgt, do_rm_bkg):
143
 
144
 
145
  # ====== Gradio Blocks UI ======
146
- with gr.Blocks(title="Orient-Anything Demo") as demo:
147
- gr.Markdown("# Orient-Anything Demo")
148
  gr.Markdown("Upload a **reference image** (required). Optionally upload a **target image** for relative pose.")
149
 
150
  with gr.Row():
 
15
  ckpt_path = hf_hub_download(repo_id=ORIANY_V2, filename=REMOTE_CKPT_PATH, repo_type="model", cache_dir='./', resume_download=True)
16
  print(ckpt_path)
17
 
18
+ if torch.cuda.is_available():
19
+ mark_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
20
+ device = torch.device('cuda')
21
+ else:
22
+ mark_dtype = torch.float16
23
+ device = torch.device('cpu')
24
 
25
  model = VGGT_OriAny_Ref(out_dim=900, dtype=mark_dtype, nopretrain=True)
26
  model.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
 
144
 
145
 
146
  # ====== Gradio Blocks UI ======
147
+ with gr.Blocks(title="Orient-Anything-V2 Demo") as demo:
148
+ gr.Markdown("# Orient-Anything-V2 Demo")
149
  gr.Markdown("Upload a **reference image** (required). Optionally upload a **target image** for relative pose.")
150
 
151
  with gr.Row():