Spaces:
Paused
Paused
File size: 4,623 Bytes
ecae934 cf55b31 961ea12 cf55b31 928a1a1 cf55b31 ecae934 d087946 928a1a1 cf55b31 ecae934 928a1a1 ecae934 14db299 ecae934 cf55b31 ecae934 cf55b31 961ea12 cf55b31 0933041 ecae934 cf55b31 ecae934 928a1a1 cf55b31 928a1a1 cf55b31 ecae934 961ea12 ecae934 14db299 ecae934 961ea12 ecae934 928a1a1 961ea12 14db299 ecae934 961ea12 14db299 b56d61f ecae934 928a1a1 ecae934 cf55b31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
# download_models.py (v4.0 - Versão Definitiva Completa)
import os
import yaml
import logging
from huggingface_hub import snapshot_download
# Configuração do log para ser claro e informativo
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(name)s] - %(message)s')
logger = logging.getLogger("MODEL_LOGISTICS")
def download_repo_snapshot(repo_id, local_dir, desc, allow_patterns=None):
"""
Baixa um snapshot de um repositório, verificando se ele já existe para evitar
downloads repetidos. É a forma mais robusta de baixar modelos.
"""
os.makedirs(local_dir, exist_ok=True)
# Um bom indicador de que o download foi concluído é a presença de um arquivo de configuração.
# Isso evita downloads parciais em caso de reinicialização.
completion_marker = os.path.join(local_dir, '.download_completed')
if os.path.exists(completion_marker):
logger.info(f"Modelos para '{desc}' parecem já existir e estão completos em: {local_dir}")
return
logger.info(f"Baixando modelos para '{desc}' de '{repo_id}' para '{local_dir}'...")
try:
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
#local_dir_use_symlinks=False,
#resume_download=True,
allow_patterns=allow_patterns,
ignore_patterns=["*.md", "*.txt", "*.gitattributes", "*onnx*", "*fp32*"], # Ignora arquivos desnecessários
)
# Cria o marcador de conclusão
with open(completion_marker, 'w') as f:
f.write('done')
logger.info(f"Download para '{desc}' concluído.")
except Exception as e:
logger.error(f"Falha CRÍTICA ao baixar o snapshot '{desc}'. Erro: {e}")
raise
def main():
"""
Ponto de entrada para baixar todos os modelos de IA necessários, lendo as
configurações do arquivo config.yaml.
"""
logger.info("--- Iniciando verificação e download de todos os modelos ---")
try:
with open("config.yaml", 'r', encoding='utf-8') as f:
# Passamos o 'f' (o stream do arquivo) para a função safe_load
config = yaml.safe_load(f).get('specialists', {})
if not config:
logger.warning("Seção 'specialists' não encontrada no config.yaml. Nenhum modelo será baixado.")
return
except FileNotFoundError:
logger.error("Arquivo config.yaml não encontrado! Não é possível determinar quais modelos baixar.")
raise
except Exception as e:
logger.error(f"Erro ao ler ou parsear o config.yaml: {e}")
raise
# --- 1. Modelos para LTX-Video ---
if config.get('ltx', {}).get('gpus_required', 0) > 0:
ltx_models_dir = "/app/LTX-Video/models_downloaded"
download_repo_snapshot("Lightricks/LTX-Video", ltx_models_dir, "LTX Models", allow_patterns=["*.safetensors", "*.json"])
# --- 2. Modelos para Wan2.2 e LoRA Lightning (Logística "Mix-and-Match") ---
if config.get('wan', {}).get('gpus_required', 0) > 0:
wan_config = config['wan']
main_model_path = f"/app/models/{wan_config['model_id']}"
opt_transformer_path = f"/app/models/{wan_config['optimized_transformer_id']}"
lora_dir = "/app/models/loras"
# Baixa os repositórios completos para garantir a estrutura correta dos arquivos
download_repo_snapshot(wan_config['model_id'], main_model_path, "Wan2.2 Base Components")
download_repo_snapshot(wan_config['optimized_transformer_id'], opt_transformer_path, "Wan2.2 Optimized Transformers")
lora_filename_only = os.path.basename(wan_config['lora_filename'])
download_repo_snapshot(wan_config['lora_repo'], lora_dir, "Wan2.2 LoRA Lightning", allow_patterns=f"*{lora_filename_only}")
# --- 3. Modelos para SeedVR ---
if config.get('seedvr', {}).get('gpus_required', 0) > 0:
seedvr_models_dir = "/app/ckpts"
download_repo_snapshot("batuhanince/seedvr_3b_fp16", seedvr_models_dir, "SeedVR Models (FP16)", allow_patterns=["*.safetensors", "*.pt"])
download_repo_snapshot("ByteDance-Seed/SeedVR2-3B", seedvr_models_dir, "SeedVR Embeddings", allow_patterns=["*.pt"])
# --- 4. Modelos para MMAudio ---
if config.get('mmaudio', {}).get('gpus_required', 0) > 0:
mmaudio_models_dir = "/app/MMAudio/ckpts"
download_repo_snapshot("hkchengrex/MMAudio-checkpoints", mmaudio_models_dir, "MMAudio Checkpoints")
logger.info("--- Verificação de modelos concluída com sucesso ---")
if __name__ == "__main__":
main() |