Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import os | |
| import subprocess | |
| from pathlib import Path | |
| from typing import Optional | |
| import uuid | |
| import sys | |
| sys.path.append(str(Path(__file__).parent.parent)) | |
| import fal | |
| import modal | |
| from wandml.utils.storage import download_http, upload_gcs | |
| from wandml.services.fal.utils import get_commit_hash, get_requirements, install_wandml | |
| from wandml import WandAuth | |
| from qwenimage.training import run_training | |
| REQUIREMENTS_PATH = os.path.abspath("requirements.txt") | |
| WAND_REQUIREMENTS_PATH = os.path.abspath("scripts/wand_requirements.txt") | |
| local_modules = ["qwenimage","wandml","scripts"] | |
| ## Fal zone | |
| def run_training_on_fal(**kwargs): | |
| install_wandml(commit_hash=kwargs["commit_hash"]) | |
| cfg_dest = Path("/tmp") / kwargs["yaml_file_url"].split("/")[-1] | |
| cfg_downloaded = download_http(kwargs["yaml_file_url"], cfg_dest) | |
| if cfg_downloaded is None: | |
| raise RuntimeError("Failed to download training config file") | |
| config_path = cfg_dest | |
| update_paths = [] | |
| if "update_yaml_file_urls" in kwargs and kwargs["update_yaml_file_urls"] is not None: | |
| for idx, url in enumerate(kwargs["update_yaml_file_urls"]): | |
| upd_dest = Path("/tmp") / f"update_{idx}_{url.split('/')[-1]}" | |
| upd_downloaded = download_http(url, upd_dest) | |
| if upd_downloaded is None: | |
| raise RuntimeError(f"Failed to download update config file {url}") | |
| update_paths.append(upd_dest) | |
| return run_training(config_path, update_config_paths=update_paths if update_paths else None) | |
| ## End Fal zone | |
| ## Modal zone | |
| modalapp = modal.App("next-stroke") | |
| modalapp.image = ( | |
| modal.Image.debian_slim(python_version="3.10") | |
| .apt_install("git", "ffmpeg", "libsm6", "libxext6") | |
| .pip_install_from_requirements(REQUIREMENTS_PATH) | |
| .pip_install_from_requirements(WAND_REQUIREMENTS_PATH) | |
| .add_local_python_source(*local_modules) | |
| ) | |
| def run_training_on_modal(yaml_file_url: str, update_yaml_file_urls: Optional[list[str]] = None): | |
| config_path = Path("/tmp")/yaml_file_url.split("/")[-1] | |
| download_http(yaml_file_url, config_path) | |
| update_paths = [] | |
| if update_yaml_file_urls is not None: | |
| for idx, url in enumerate(update_yaml_file_urls): | |
| update_path = Path("/tmp")/f"update_{idx}_{url.split('/')[-1]}" | |
| download_http(url, update_path) | |
| update_paths.append(update_path) | |
| return run_training(config_path, update_config_paths=update_paths if update_paths else None) | |
| def run_modal_local(yaml: str, update: Optional[str] = None): | |
| WandAuth(ignore=True) | |
| if not yaml.startswith("http"): | |
| yamlp = Path(yaml) | |
| name = yamlp.stem + str(uuid.uuid4())[:8] + yamlp.suffix | |
| yaml_file_url: str = upload_gcs(yaml, "wand-finetune", name, public=True) # pyright: ignore | |
| else: | |
| yaml_file_url = yaml | |
| update_urls: Optional[list[str]] = None | |
| if update is not None and len(update) > 0: | |
| update_list = update.split("|") | |
| update_urls = [] | |
| for upd in update_list: | |
| if not upd.startswith("http"): | |
| up = Path(upd) | |
| uname = up.stem + str(uuid.uuid4())[:8] + up.suffix | |
| update_url = upload_gcs(str(up), "wand-finetune", uname, public=True) # pyright: ignore | |
| if update_url is None: | |
| raise RuntimeError(f"Failed to upload {upd} to GCS") | |
| update_urls.append(update_url) | |
| else: | |
| update_urls.append(upd) | |
| return run_training_on_modal.remote(yaml_file_url, update_yaml_file_urls=update_urls) | |
| ## End modal zone | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Run training.") | |
| parser.add_argument("config",type=str,help="Path or Url to YAML configuration file") | |
| parser.add_argument("--update", type=str, action="append", help="Optional secondary YAML with overrides (path or URL). Can be specified multiple times.") | |
| parser.add_argument("--where", choices=["local", "fal", "modal"]) | |
| parser.add_argument("-d", "--detached", action="store_true", default=False, help="Run Modal in detached mode (-d). Only valid when --where modal") | |
| args = parser.parse_args() | |
| if args.detached and args.where != "modal": | |
| parser.error("--detached is only valid when --where modal") | |
| if args.where == "local" and args.config.startswith("http"): | |
| local_path = Path("/tmp") / args.config.split("/")[-1] | |
| download_http(args.config, local_path) | |
| args.config = local_path | |
| elif args.where != "local" and not args.config.startswith("http"): | |
| yamlp = Path(args.config) | |
| name = yamlp.stem + str(uuid.uuid4())[:8] + yamlp.suffix | |
| yaml_file_url = upload_gcs(args.config, "wand-finetune", name, public=True) | |
| if yaml_file_url is None: | |
| raise RuntimeError(f"Failed to upload {args.config} to GCS") | |
| args.config = yaml_file_url | |
| # Handle update paths/urls depending on where | |
| if args.update is not None and len(args.update) > 0: | |
| processed_updates = [] | |
| for upd in args.update: | |
| if args.where == "local" and upd.startswith("http"): | |
| up_local_path = Path("/tmp") / upd.split("/")[-1] | |
| download_http(upd, up_local_path) | |
| processed_updates.append(up_local_path) | |
| elif args.where != "local" and not upd.startswith("http"): | |
| up = Path(upd) | |
| uname = up.stem + str(uuid.uuid4())[:8] + up.suffix | |
| up_url = upload_gcs(str(up), "wand-finetune", uname, public=True) | |
| if up_url is None: | |
| raise RuntimeError(f"Failed to upload {upd} to GCS") | |
| processed_updates.append(up_url) | |
| else: | |
| processed_updates.append(upd) | |
| args.update = processed_updates | |
| return args | |
| if __name__ == "__main__": | |
| WandAuth() | |
| args = parse_args() | |
| if args.where == "fal": | |
| out = run_training_on_fal( | |
| yaml_file_url=args.config, | |
| commit_hash=get_commit_hash(), | |
| update_yaml_file_urls=args.update, | |
| ) | |
| elif args.where == "modal": | |
| cmd = ["modal", "run"] | |
| if args.detached: | |
| cmd.append("-d") | |
| cmd += [os.path.abspath(__file__), "--yaml", args.config] | |
| if args.update is not None and len(args.update) > 0: | |
| update_str = "|".join(str(upd) for upd in args.update) | |
| cmd += ["--update", update_str] | |
| out = subprocess.run(cmd) | |
| elif args.where == "local": | |
| update_paths = [Path(u) for u in args.update] if args.update is not None and len(args.update) > 0 else None | |
| out = run_training(args.config, update_config_paths=update_paths) | |
| else: | |
| raise ValueError() | |
| print(out) | |