Spaces:
Runtime error
Runtime error
| import importlib | |
| from functools import partialmethod | |
| from pathlib import Path | |
| from torchvision.datasets.utils import download_url | |
| import gdown | |
| from utils.aux import ensure_annotation_class | |
| def get_class(cls_path: str, *args, **kwargs): | |
| module_name = ".".join(cls_path.split(".")[:-1]) | |
| module = importlib.import_module(module_name) | |
| class_ = getattr(module, cls_path.split(".")[-1]) | |
| class_.__init__ = partialmethod(class_.__init__, *args, **kwargs) | |
| return class_ | |
| def download_ckpt(local_path: Path, global_path: str) -> str: | |
| if local_path.exists(): | |
| return local_path.as_posix() | |
| else: | |
| if not local_path.parent.exists(): | |
| local_path.parent.mkdir(parents=True) | |
| if "drive.google.com" in global_path and "file" in global_path: | |
| url = global_path | |
| dest = local_path.as_posix() | |
| gdown.download(url=url, output=dest, fuzzy=True) | |
| elif "drive.google.com" in global_path and "folder" in global_path: | |
| url = global_path | |
| dest = local_path.parent.as_posix() | |
| gdown.download_folder(url=url, output=dest) | |
| elif local_path.suffix == ".safetensors" or "." not in local_path.as_posix(): | |
| ckpt_url = f"https://huggingface.co/{global_path}" | |
| try: | |
| download_url(ckpt_url, local_path.parent.as_posix(), | |
| local_path.name) | |
| except Exception as e: | |
| print( | |
| f"Error: Failed to download model from {ckpt_url} to {local_path}") | |
| raise e | |
| else: | |
| raise NotImplementedError( | |
| f"Download model file {global_path} not supported") | |
| assert local_path.exists(), f"Missing checkpoint {local_path}" | |
| return local_path.as_posix() | |