File size: 4,623 Bytes
ecae934
cf55b31
961ea12
cf55b31
928a1a1
cf55b31
ecae934
d087946
928a1a1
cf55b31
ecae934
 
 
 
 
928a1a1
 
ecae934
 
14db299
ecae934
 
cf55b31
 
ecae934
cf55b31
961ea12
cf55b31
 
0933041
 
ecae934
 
cf55b31
ecae934
 
 
928a1a1
cf55b31
928a1a1
cf55b31
 
 
ecae934
 
 
 
 
961ea12
ecae934
 
14db299
ecae934
 
 
 
 
 
 
 
 
 
961ea12
ecae934
 
 
 
 
 
928a1a1
961ea12
14db299
ecae934
961ea12
 
14db299
b56d61f
ecae934
 
 
 
928a1a1
ecae934
 
 
 
 
 
 
 
 
 
 
 
cf55b31
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# download_models.py (v4.0 - Versão Definitiva Completa)
import os
import yaml
import logging
from huggingface_hub import snapshot_download

# Configuração do log para ser claro e informativo
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(name)s] - %(message)s')
logger = logging.getLogger("MODEL_LOGISTICS")

def download_repo_snapshot(repo_id, local_dir, desc, allow_patterns=None):
    """
    Baixa um snapshot de um repositório, verificando se ele já existe para evitar
    downloads repetidos. É a forma mais robusta de baixar modelos.
    """
    os.makedirs(local_dir, exist_ok=True)
    
    # Um bom indicador de que o download foi concluído é a presença de um arquivo de configuração.
    # Isso evita downloads parciais em caso de reinicialização.
    completion_marker = os.path.join(local_dir, '.download_completed')
    if os.path.exists(completion_marker):
        logger.info(f"Modelos para '{desc}' parecem já existir e estão completos em: {local_dir}")
        return

    logger.info(f"Baixando modelos para '{desc}' de '{repo_id}' para '{local_dir}'...")
    try:
        snapshot_download(
            repo_id=repo_id,
            local_dir=local_dir,
            #local_dir_use_symlinks=False,
            #resume_download=True,
            allow_patterns=allow_patterns,
            ignore_patterns=["*.md", "*.txt", "*.gitattributes", "*onnx*", "*fp32*"], # Ignora arquivos desnecessários
        )
        # Cria o marcador de conclusão
        with open(completion_marker, 'w') as f:
            f.write('done')
        logger.info(f"Download para '{desc}' concluído.")
    except Exception as e:
        logger.error(f"Falha CRÍTICA ao baixar o snapshot '{desc}'. Erro: {e}")
        raise

def main():
    """
    Ponto de entrada para baixar todos os modelos de IA necessários, lendo as
    configurações do arquivo config.yaml.
    """
    logger.info("--- Iniciando verificação e download de todos os modelos ---")

    try:
        with open("config.yaml", 'r', encoding='utf-8') as f:
            # Passamos o 'f' (o stream do arquivo) para a função safe_load
            config = yaml.safe_load(f).get('specialists', {})
        if not config:
            logger.warning("Seção 'specialists' não encontrada no config.yaml. Nenhum modelo será baixado.")
            return
    except FileNotFoundError:
        logger.error("Arquivo config.yaml não encontrado! Não é possível determinar quais modelos baixar.")
        raise
    except Exception as e:
        logger.error(f"Erro ao ler ou parsear o config.yaml: {e}")
        raise

    # --- 1. Modelos para LTX-Video ---
    if config.get('ltx', {}).get('gpus_required', 0) > 0:
        ltx_models_dir = "/app/LTX-Video/models_downloaded"
        download_repo_snapshot("Lightricks/LTX-Video", ltx_models_dir, "LTX Models", allow_patterns=["*.safetensors", "*.json"])
    
    # --- 2. Modelos para Wan2.2 e LoRA Lightning (Logística "Mix-and-Match") ---
    if config.get('wan', {}).get('gpus_required', 0) > 0:
        wan_config = config['wan']
        main_model_path = f"/app/models/{wan_config['model_id']}"
        opt_transformer_path = f"/app/models/{wan_config['optimized_transformer_id']}"
        lora_dir = "/app/models/loras"
        
        # Baixa os repositórios completos para garantir a estrutura correta dos arquivos
        download_repo_snapshot(wan_config['model_id'], main_model_path, "Wan2.2 Base Components")
        download_repo_snapshot(wan_config['optimized_transformer_id'], opt_transformer_path, "Wan2.2 Optimized Transformers")
        
        lora_filename_only = os.path.basename(wan_config['lora_filename'])
        download_repo_snapshot(wan_config['lora_repo'], lora_dir, "Wan2.2 LoRA Lightning", allow_patterns=f"*{lora_filename_only}")
    
    # --- 3. Modelos para SeedVR ---
    if config.get('seedvr', {}).get('gpus_required', 0) > 0:
        seedvr_models_dir = "/app/ckpts"
        download_repo_snapshot("batuhanince/seedvr_3b_fp16", seedvr_models_dir, "SeedVR Models (FP16)", allow_patterns=["*.safetensors", "*.pt"])
        download_repo_snapshot("ByteDance-Seed/SeedVR2-3B", seedvr_models_dir, "SeedVR Embeddings", allow_patterns=["*.pt"])

    # --- 4. Modelos para MMAudio ---
    if config.get('mmaudio', {}).get('gpus_required', 0) > 0:
        mmaudio_models_dir = "/app/MMAudio/ckpts"
        download_repo_snapshot("hkchengrex/MMAudio-checkpoints", mmaudio_models_dir, "MMAudio Checkpoints")

    logger.info("--- Verificação de modelos concluída com sucesso ---")

if __name__ == "__main__":
    main()