sam3d-body-rerun / src /sam3d_body /build_models.py
pablovela5620's picture
init commit
6da47c0
# Copyright (c) Meta Platforms, Inc. and affiliates.
import os
from os import PathLike
import torch
from .models.meta_arch import SAM3DBody
from .utils.checkpoint import load_state_dict
from .utils.config import CN, get_config
def load_sam_3d_body(
checkpoint_path: str | PathLike[str] = "",
device: str | torch.device = "cuda",
mhr_path: str | PathLike[str] = "",
) -> tuple[SAM3DBody, CN]:
print("Loading SAM 3D Body model...")
checkpoint_path = os.fspath(checkpoint_path)
mhr_path = os.fspath(mhr_path)
# Check the current directory, and if not present check the parent dir.
model_cfg = os.path.join(os.path.dirname(checkpoint_path), "model_config.yaml")
if not os.path.exists(model_cfg):
# Looks at parent dir
model_cfg = os.path.join(os.path.dirname(os.path.dirname(checkpoint_path)), "model_config.yaml")
model_cfg = get_config(model_cfg)
# Disable face for inference
model_cfg.defrost()
model_cfg.MODEL.MHR_HEAD.MHR_MODEL_PATH = mhr_path
model_cfg.freeze()
# Initialze the model
model = SAM3DBody(model_cfg)
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
state_dict = checkpoint.get("state_dict", checkpoint)
load_state_dict(model, state_dict, strict=False)
model = model.to(device)
model.eval()
return model, model_cfg
def _hf_download(repo_id):
from huggingface_hub import snapshot_download
local_dir = snapshot_download(repo_id=repo_id)
return os.path.join(local_dir, "model.ckpt"), os.path.join(local_dir, "assets", "mhr_model.pt")
def load_sam_3d_body_hf(repo_id, **kwargs):
ckpt_path, mhr_path = _hf_download(repo_id)
return load_sam_3d_body(checkpoint_path=ckpt_path, mhr_path=mhr_path)