Spaces:
Paused
Paused
| import os | |
| import logging | |
| import torch | |
| import asyncio | |
| import aiohttp | |
| import requests | |
| from huggingface_hub import hf_hub_download | |
| # Configure logging | |
| logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data') | |
| MODELS_DIR = os.path.join(DATA_ROOT, "models") | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Hugging Face repository information | |
| HF_REPO_ID = "jbilcke-hf/model-cocktail" | |
| # Model files to download | |
| MODEL_FILES = [ | |
| "dwpose/dw-ll_ucoco_384.pth", | |
| "face-detector/s3fd-619a316812.pth", | |
| "liveportrait/spade_generator.pth", | |
| "liveportrait/warping_module.pth", | |
| "liveportrait/motion_extractor.pth", | |
| "liveportrait/stitching_retargeting_module.pth", | |
| "liveportrait/appearance_feature_extractor.pth", | |
| "liveportrait/landmark.onnx", | |
| # For animal mode πΆπ± | |
| # however they say animal mode doesn't support stitching yet? | |
| # https://github.com/KwaiVGI/LivePortrait/blob/main/assets/docs/changelog/2024-08-02.md#updates-on-animals-mode | |
| #"liveportrait-animals/warping_module.pth", | |
| #"liveportrait-animals/spade_generator.pth", | |
| #"liveportrait-animals/motion_extractor.pth", | |
| #"liveportrait-animals/appearance_feature_extractor.pth", | |
| #"liveportrait-animals/stitching_retargeting_module.pth", | |
| #"liveportrait-animals/xpose.pth", | |
| # this is a hack, instead we should probably try to | |
| # fix liveportrait/utils/dependencies/insightface/utils/storage.py | |
| "insightface/models/buffalo_l.zip", | |
| "insightface/buffalo_l/det_10g.onnx", | |
| "insightface/buffalo_l/2d106det.onnx", | |
| "sd-vae-ft-mse/diffusion_pytorch_model.bin", | |
| "sd-vae-ft-mse/diffusion_pytorch_model.safetensors", | |
| "sd-vae-ft-mse/config.json", | |
| # we don't use those yet | |
| #"flux-dev/flux-dev-fp8.safetensors", | |
| #"flux-dev/flux_dev_quantization_map.json", | |
| #"pulid-flux/pulid_flux_v0.9.0.safetensors", | |
| #"pulid-flux/pulid_v1.bin" | |
| ] | |
| def create_directory(directory): | |
| """Create a directory if it doesn't exist and log its status.""" | |
| if not os.path.exists(directory): | |
| os.makedirs(directory) | |
| logger.info(f" Directory created: {directory}") | |
| else: | |
| logger.info(f" Directory already exists: {directory}") | |
| def print_directory_structure(startpath): | |
| """Print the directory structure starting from the given path.""" | |
| for root, dirs, files in os.walk(startpath): | |
| level = root.replace(startpath, '').count(os.sep) | |
| indent = ' ' * 4 * level | |
| logger.info(f"{indent}{os.path.basename(root)}/") | |
| subindent = ' ' * 4 * (level + 1) | |
| for f in files: | |
| logger.info(f"{subindent}{f}") | |
| async def download_hf_file(filename: str) -> None: | |
| """Download a file from Hugging Face to the models directory.""" | |
| dest = os.path.join(MODELS_DIR, filename) | |
| os.makedirs(os.path.dirname(dest), exist_ok=True) | |
| if os.path.exists(dest): | |
| # this is really for debugging purposes only | |
| logger.debug(f" β {filename}") | |
| return | |
| logger.info(f" β³ Downloading {HF_REPO_ID}/{filename}") | |
| try: | |
| await asyncio.get_event_loop().run_in_executor( | |
| None, | |
| lambda: hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=filename, | |
| local_dir=MODELS_DIR | |
| ) | |
| ) | |
| logger.info(f" β Downloaded {filename}") | |
| except Exception as e: | |
| logger.error(f"π¨ Error downloading file from Hugging Face: {e}") | |
| if os.path.exists(dest): | |
| os.remove(dest) | |
| raise | |
| async def download_all_models(): | |
| """Download all required models from the Hugging Face repository.""" | |
| logger.info(" π Looking for models...") | |
| tasks = [download_hf_file(filename) for filename in MODEL_FILES] | |
| await asyncio.gather(*tasks) | |
| logger.info(" β All models are available") | |
| # are you looking to debug the app and verify that models are downloaded properly? | |
| # then un-comment the two following lines: | |
| #logger.info("π‘ Printing directory structure of models:") | |
| #print_directory_structure(MODELS_DIR) | |
| class ModelLoader: | |
| """A class responsible for loading and initializing all required models.""" | |
| def __init__(self): | |
| self.device = DEVICE | |
| self.models_dir = MODELS_DIR | |
| async def load_live_portrait(self): | |
| """Load LivePortrait models.""" | |
| from liveportrait.config.inference_config import InferenceConfig | |
| from liveportrait.config.crop_config import CropConfig | |
| from liveportrait.live_portrait_pipeline import LivePortraitPipeline | |
| logger.info(" β³ Loading LivePortrait models...") | |
| live_portrait_pipeline = await asyncio.to_thread( | |
| LivePortraitPipeline, | |
| inference_cfg=InferenceConfig( | |
| # default values | |
| flag_stitching=True, # we recommend setting it to True! | |
| flag_relative=True, # whether to use relative motion | |
| flag_pasteback=True, # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space | |
| flag_do_crop= True, # whether to crop the source portrait to the face-cropping space | |
| flag_do_rot=True, # whether to conduct the rotation when flag_do_crop is True | |
| ), | |
| crop_cfg=CropConfig() | |
| ) | |
| logger.info(" β LivePortrait models loaded successfully.") | |
| return live_portrait_pipeline | |
| async def initialize_models(): | |
| """Initialize and load all required models.""" | |
| logger.info("π Starting model initialization...") | |
| # Ensure all required models are downloaded | |
| await download_all_models() | |
| # Initialize the ModelLoader | |
| loader = ModelLoader() | |
| # Load LivePortrait models | |
| live_portrait = await loader.load_live_portrait() | |
| logger.info("β Model initialization completed.") | |
| return live_portrait | |
| # Initial setup | |
| logger.info("π Setting up storage directories...") | |
| create_directory(MODELS_DIR) | |
| logger.info("β Storage directories setup completed.") | |