File size: 2,950 Bytes
d0cd3b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8b9d30
d0cd3b0
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os

def download_model(model_paths, model_name, model_type, ckpt_url, conf_url):
    model_dir = os.path.join(model_paths, model_type)
    os.makedirs(model_dir, exist_ok=True)

    # Инициализация переменных (на случай, если ни одно условие не сработает)
    config_path = None
    checkpoint_path = None

    if model_type == "mel_band_roformer":
        config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
        checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")

    elif model_type == "vr":
        config_path = os.path.join(model_dir, f"{model_name}.json")
        checkpoint_path = os.path.join(model_dir, f"{model_name}.pth")
  
    elif model_type == "bs_roformer":
        config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
        checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
    
    elif model_type == "mdx23c":
        config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
        checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
    
    elif model_type == "scnet":
        config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
        checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")

    elif model_type == "bandit":
        config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
        checkpoint_path = os.path.join(model_dir, f"{model_name}.chpt")

    elif model_type == "bandit_v2":
        config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
        checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
    
    elif model_type == "htdemucs":
        config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
        checkpoint_path = os.path.join(model_dir, f"{model_name}.th")

    elif model_type == "medley_vox":
        medley_vox_model_dir = os.path.join(model_dir, model_name)
        os.makedirs(medley_vox_model_dir, exist_ok=True)
        config_path = os.path.join(medley_vox_model_dir, f"vocals.json")
        checkpoint_path = os.path.join(medley_vox_model_dir, f"vocals.pth")
    
    else:
        raise ValueError(f"Unsupported model_type: {model_type}")

    # Проверяем, что пути заданы (на всякий случай)
    if config_path is None or checkpoint_path is None:
        raise RuntimeError("Failed to set model paths!")

    # Если файлы уже есть — пропускаем загрузку
    if os.path.exists(checkpoint_path) and os.path.exists(config_path):
        print("Model already downloaded")
    else:
        for local_path, url_model in [(checkpoint_path, ckpt_url), (config_path, conf_url)]:
            download_cmd = f"wget -nv -O {local_path} {url_model}"
            os.system(download_cmd)

    if model_type == "medley_vox":
        return model_dir
    else:
        return config_path, checkpoint_path