Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import inspect | |
| from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
| from peft import LoraConfig, get_peft_model | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| # βββ 1. Read hyperparameters & mode βββββββββββββββββββββββββββββββββββββββββββ | |
| model_id = os.environ.get("BASE_MODEL", "HiDream-ai/HiDream-I1-Dev") | |
| trigger_word = os.environ.get("TRIGGER_WORD", "default-style") | |
| num_steps = int(os.environ.get("NUM_STEPS", 100)) | |
| lora_r = int(os.environ.get("LORA_R", 16)) | |
| lora_alpha = int(os.environ.get("LORA_ALPHA", 16)) | |
| LOCAL = os.environ.get("LOCAL_TRAIN", "").lower() in ("1", "true") | |
| # βββ 2. Set up directories ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if LOCAL: | |
| DATA_DIR = os.path.join(os.getcwd(), "data") | |
| OUTPUT_DIR = os.path.join(os.getcwd(), "lora-trained") | |
| LOCAL_MODEL = os.path.join(os.getcwd(), "hidream-model") | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| else: | |
| DATA_DIR = "/tmp/data" | |
| OUTPUT_DIR = "/tmp/lora-trained" | |
| CACHE_DIR = "/tmp/hidream-model" | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| print(f"π Dataset directory: {DATA_DIR}", flush=True) | |
| print(f"π₯ Preparing base model: {model_id}", flush=True) | |
| # βββ 3. Resolve model path ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_model_path(): | |
| # If local and predownloaded model exists, use it | |
| if LOCAL and os.path.isdir(LOCAL_MODEL) and os.path.isfile(os.path.join(LOCAL_MODEL, "config.json")): | |
| print(f"β Using local model at: {LOCAL_MODEL}", flush=True) | |
| return LOCAL_MODEL | |
| # Otherwise download (to ~/.cache on local, or /tmp on Spaces) | |
| download_kwargs = {} if LOCAL else {"local_dir": CACHE_DIR} | |
| path = snapshot_download(model_id, **download_kwargs) | |
| print(f"β Downloaded model to: {path}", flush=True) | |
| return path | |
| model_path = get_model_path() | |
| # βββ 4. Patch model_index.json to remove unsupported scheduler ββββββββββββββββ | |
| mi_file = os.path.join(model_path, "model_index.json") | |
| if os.path.isfile(mi_file): | |
| with open(mi_file, "r") as f: | |
| mi = json.load(f) | |
| if "pipeline" in mi and "scheduler" in mi["pipeline"]: | |
| print("π§ Removing 'scheduler' entry from model_index.json", flush=True) | |
| mi["pipeline"].pop("scheduler", None) | |
| with open(mi_file, "w") as f: | |
| json.dump(mi, f, indent=2) | |
| # βββ 5. Load & filter scheduler_config.json ββββββββββββββββββββββββββββββββββ | |
| sched_cfg_path = os.path.join(model_path, "scheduler", "scheduler_config.json") | |
| filtered_cfg = {} | |
| if os.path.isfile(sched_cfg_path): | |
| with open(sched_cfg_path, "r") as f: | |
| raw_cfg = json.load(f) | |
| sig = inspect.signature(DPMSolverMultistepScheduler.__init__) | |
| valid_keys = set(sig.parameters.keys()) - {"self", "args", "kwargs"} | |
| filtered_cfg = {k: v for k, v in raw_cfg.items() if k in valid_keys} | |
| dropped = set(raw_cfg) - set(filtered_cfg) | |
| if dropped: | |
| print(f"β οΈ Dropped unsupported scheduler keys: {dropped}", flush=True) | |
| try: | |
| scheduler = DPMSolverMultistepScheduler(**filtered_cfg) | |
| print("β Instantiated DPMSolverMultistepScheduler from config", flush=True) | |
| except Exception as e: | |
| print(f"β Failed to init scheduler from config ({e}), using defaults", flush=True) | |
| scheduler = DPMSolverMultistepScheduler() | |
| else: | |
| print("β οΈ No scheduler_config.json found; using default DPMSolverMultistepScheduler", flush=True) | |
| scheduler = DPMSolverMultistepScheduler() | |
| # βββ 6. Load the Stable Diffusion pipeline ββββββββββββββββββββββββββββββββββββ | |
| print(f"π§ Loading pipeline from: {model_path}", flush=True) | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.float16, | |
| scheduler=scheduler | |
| ).to("cuda") | |
| # βββ 7. Apply LoRA adapters βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"π§ Applying LoRA config (r={lora_r}, Ξ±={lora_alpha})", flush=True) | |
| lora_config = LoraConfig( | |
| r=lora_r, | |
| lora_alpha=lora_alpha, | |
| bias="none", | |
| task_type="CAUSAL_LM" | |
| ) | |
| pipe.unet = get_peft_model(pipe.unet, lora_config) | |
| # βββ 8. Training loop stub βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"π Starting fineβtuning for {num_steps} steps (trigger: {trigger_word})", flush=True) | |
| for step in range(num_steps): | |
| # TODO: replace this stub with your actual training code: | |
| # β’ Load batches from DATA_DIR | |
| # β’ Forward/backward pass, optimizer.step(), etc. | |
| print(f"π Step {step+1}/{num_steps}", flush=True) | |
| # βββ 9. Save the fineβtuned model βββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"πΎ Saving fineβtuned model to: {OUTPUT_DIR}", flush=True) | |
| pipe.save_pretrained(OUTPUT_DIR) | |
| print("β Training complete!", flush=True) | |