Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import os | |
| from os import PathLike | |
| import torch | |
| from .models.meta_arch import SAM3DBody | |
| from .utils.checkpoint import load_state_dict | |
| from .utils.config import CN, get_config | |
| def load_sam_3d_body( | |
| checkpoint_path: str | PathLike[str] = "", | |
| device: str | torch.device = "cuda", | |
| mhr_path: str | PathLike[str] = "", | |
| ) -> tuple[SAM3DBody, CN]: | |
| print("Loading SAM 3D Body model...") | |
| checkpoint_path = os.fspath(checkpoint_path) | |
| mhr_path = os.fspath(mhr_path) | |
| # Check the current directory, and if not present check the parent dir. | |
| model_cfg = os.path.join(os.path.dirname(checkpoint_path), "model_config.yaml") | |
| if not os.path.exists(model_cfg): | |
| # Looks at parent dir | |
| model_cfg = os.path.join(os.path.dirname(os.path.dirname(checkpoint_path)), "model_config.yaml") | |
| model_cfg = get_config(model_cfg) | |
| # Disable face for inference | |
| model_cfg.defrost() | |
| model_cfg.MODEL.MHR_HEAD.MHR_MODEL_PATH = mhr_path | |
| model_cfg.freeze() | |
| # Initialze the model | |
| model = SAM3DBody(model_cfg) | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) | |
| state_dict = checkpoint.get("state_dict", checkpoint) | |
| load_state_dict(model, state_dict, strict=False) | |
| model = model.to(device) | |
| model.eval() | |
| return model, model_cfg | |
| def _hf_download(repo_id): | |
| from huggingface_hub import snapshot_download | |
| local_dir = snapshot_download(repo_id=repo_id) | |
| return os.path.join(local_dir, "model.ckpt"), os.path.join(local_dir, "assets", "mhr_model.pt") | |
| def load_sam_3d_body_hf(repo_id, **kwargs): | |
| ckpt_path, mhr_path = _hf_download(repo_id) | |
| return load_sam_3d_body(checkpoint_path=ckpt_path, mhr_path=mhr_path) | |