File size: 7,922 Bytes
f9abc90
26db3f0
 
 
f9abc90
 
 
 
 
 
 
 
 
 
 
 
 
 
833555d
f9abc90
 
 
 
 
 
833555d
f9abc90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
833555d
 
26db3f0
 
f9abc90
 
833555d
 
 
f9abc90
 
 
 
833555d
 
6064267
833555d
 
 
f9abc90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
833555d
f9abc90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import argparse
import os
import subprocess
from pathlib import Path
from typing import Optional
import uuid
import sys
sys.path.append(str(Path(__file__).parent.parent))

import fal
import modal

from wandml.utils.storage import download_http, upload_gcs
from wandml.services.fal.utils import get_commit_hash, get_requirements, install_wandml
from wandml import WandAuth

from qwenimage.training import run_training
REQUIREMENTS_PATH = os.path.abspath("requirements.txt")
WAND_REQUIREMENTS_PATH = os.path.abspath("scripts/wand_requirements.txt")

local_modules = ["qwenimage","wandml","scripts"]

## Fal zone
@fal.function(
    machine_type="GPU-H100",
    requirements=get_requirements(REQUIREMENTS_PATH, WAND_REQUIREMENTS_PATH),
    local_python_modules = local_modules,
    max_concurrency=16,
    request_timeout=6*60*60,
)
def run_training_on_fal(**kwargs):
    install_wandml(commit_hash=kwargs["commit_hash"])
    cfg_dest = Path("/tmp") / kwargs["yaml_file_url"].split("/")[-1]
    cfg_downloaded = download_http(kwargs["yaml_file_url"], cfg_dest)
    if cfg_downloaded is None:
        raise RuntimeError("Failed to download training config file")
    config_path = cfg_dest
    update_paths = []
    if "update_yaml_file_urls" in kwargs and kwargs["update_yaml_file_urls"] is not None:
        for idx, url in enumerate(kwargs["update_yaml_file_urls"]):
            upd_dest = Path("/tmp") / f"update_{idx}_{url.split('/')[-1]}"
            upd_downloaded = download_http(url, upd_dest)
            if upd_downloaded is None:
                raise RuntimeError(f"Failed to download update config file {url}")
            update_paths.append(upd_dest)
    return run_training(config_path, update_config_paths=update_paths if update_paths else None)
## End Fal zone


## Modal zone


modalapp = modal.App("next-stroke")
modalapp.image = (
    modal.Image.debian_slim(python_version="3.10")
    .apt_install("git", "ffmpeg", "libsm6", "libxext6")
    .pip_install_from_requirements(REQUIREMENTS_PATH)
    .pip_install_from_requirements(WAND_REQUIREMENTS_PATH)
    .add_local_python_source(*local_modules)
)


@modalapp.function(
    gpu="B200",
    max_containers=1,
    timeout=4 * 60 * 60,
    volumes={
        "/data/wand_cache": modal.Volume.from_name("FLUX_MODELS"),
        "/data/checkpoints": modal.Volume.from_name("training_checkpoints", create_if_missing=True),
        "/root/.cache/torch/hub/checkpoints": modal.Volume.from_name("torch_hub_checkpoints", create_if_missing=True),

        "/root/.cache/huggingface/hub":  modal.Volume.from_name("hf_cache", create_if_missing=True),
        "/root/.cache/huggingface/datasets":  modal.Volume.from_name("hf_cache_datasets", create_if_missing=True),

        "/data/regression_data": modal.Volume.from_name("regression_data"),
        "/data/edit_data": modal.Volume.from_name("edit_data"),
    },
    secrets=[
        modal.Secret.from_name("wand-modal-gcloud-keyfile"),
        modal.Secret.from_name("elea-huggingface-secret"),
    ],
)
def run_training_on_modal(yaml_file_url: str, update_yaml_file_urls: Optional[list[str]] = None):
    config_path = Path("/tmp")/yaml_file_url.split("/")[-1]
    download_http(yaml_file_url, config_path)
    update_paths = []
    if update_yaml_file_urls is not None:
        for idx, url in enumerate(update_yaml_file_urls):
            update_path = Path("/tmp")/f"update_{idx}_{url.split('/')[-1]}"
            download_http(url, update_path)
            update_paths.append(update_path)
    return run_training(config_path, update_config_paths=update_paths if update_paths else None)

@modalapp.local_entrypoint()
def run_modal_local(yaml: str, update: Optional[str] = None):
    WandAuth(ignore=True)
    if not yaml.startswith("http"):
        yamlp = Path(yaml)
        name = yamlp.stem + str(uuid.uuid4())[:8] + yamlp.suffix
        yaml_file_url: str = upload_gcs(yaml, "wand-finetune", name, public=True)  # pyright: ignore
    else:
        yaml_file_url = yaml
    update_urls: Optional[list[str]] = None
    if update is not None and len(update) > 0:
        update_list = update.split("|")
        update_urls = []
        for upd in update_list:
            if not upd.startswith("http"):
                up = Path(upd)
                uname = up.stem + str(uuid.uuid4())[:8] + up.suffix
                update_url = upload_gcs(str(up), "wand-finetune", uname, public=True)  # pyright: ignore
                if update_url is None:
                    raise RuntimeError(f"Failed to upload {upd} to GCS")
                update_urls.append(update_url)
            else:
                update_urls.append(upd)
    return run_training_on_modal.remote(yaml_file_url, update_yaml_file_urls=update_urls) 

## End modal zone


def parse_args():
    parser = argparse.ArgumentParser(description="Run training.")
    parser.add_argument("config",type=str,help="Path or Url to YAML configuration file")
    parser.add_argument("--update", type=str, action="append", help="Optional secondary YAML with overrides (path or URL). Can be specified multiple times.")
    parser.add_argument("--where", choices=["local", "fal", "modal"])
    parser.add_argument("-d", "--detached", action="store_true", default=False, help="Run Modal in detached mode (-d). Only valid when --where modal")
    args = parser.parse_args()
    if args.detached and args.where != "modal":
        parser.error("--detached is only valid when --where modal")
    if args.where == "local" and args.config.startswith("http"):
        local_path = Path("/tmp") / args.config.split("/")[-1]
        download_http(args.config, local_path)
        args.config = local_path
    elif args.where != "local" and not args.config.startswith("http"):
        yamlp = Path(args.config)
        name = yamlp.stem + str(uuid.uuid4())[:8] + yamlp.suffix
        yaml_file_url = upload_gcs(args.config, "wand-finetune", name, public=True)
        if yaml_file_url is None:
            raise RuntimeError(f"Failed to upload {args.config} to GCS")
        args.config = yaml_file_url
    # Handle update paths/urls depending on where
    if args.update is not None and len(args.update) > 0:
        processed_updates = []
        for upd in args.update:
            if args.where == "local" and upd.startswith("http"):
                up_local_path = Path("/tmp") / upd.split("/")[-1]
                download_http(upd, up_local_path)
                processed_updates.append(up_local_path)
            elif args.where != "local" and not upd.startswith("http"):
                up = Path(upd)
                uname = up.stem + str(uuid.uuid4())[:8] + up.suffix
                up_url = upload_gcs(str(up), "wand-finetune", uname, public=True)
                if up_url is None:
                    raise RuntimeError(f"Failed to upload {upd} to GCS")
                processed_updates.append(up_url)
            else:
                processed_updates.append(upd)
        args.update = processed_updates
    return args

if __name__ == "__main__":
    WandAuth()

    args = parse_args()

    if args.where == "fal":
        out = run_training_on_fal(
            yaml_file_url=args.config,
            commit_hash=get_commit_hash(),
            update_yaml_file_urls=args.update,
        )
    elif args.where == "modal":
        cmd = ["modal", "run"]
        if args.detached:
            cmd.append("-d")
        cmd += [os.path.abspath(__file__), "--yaml", args.config]
        if args.update is not None and len(args.update) > 0:
            update_str = "|".join(str(upd) for upd in args.update)
            cmd += ["--update", update_str]
        out = subprocess.run(cmd)
    elif args.where == "local":
        update_paths = [Path(u) for u in args.update] if args.update is not None and len(args.update) > 0 else None
        out = run_training(args.config, update_config_paths=update_paths)
    else:
        raise ValueError()
    print(out)