""" RunPod Serverless Handler - Wrapper for AI-Toolkit Does NOT modify ai-toolkit code, only wraps it Supports RunPod model caching via HuggingFace integration. """ import os import sys import subprocess import traceback import logging import uuid from pathlib import Path # ============================================================================= # Environment Setup (must be before other imports) # ============================================================================= # RunPod cache paths RUNPOD_CACHE_BASE = "/runpod-volume/huggingface-cache" RUNPOD_HF_CACHE = "/runpod-volume/huggingface-cache/hub" # Check if running on RunPod with cache available IS_RUNPOD_CACHE = os.path.exists("/runpod-volume") if IS_RUNPOD_CACHE: # Use RunPod's cache directory for HuggingFace downloads os.environ["HF_HOME"] = RUNPOD_CACHE_BASE os.environ["HUGGINGFACE_HUB_CACHE"] = RUNPOD_HF_CACHE os.environ["TRANSFORMERS_CACHE"] = RUNPOD_HF_CACHE os.environ["HF_DATASETS_CACHE"] = f"{RUNPOD_CACHE_BASE}/datasets" # Performance and telemetry settings os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1" os.environ["DISABLE_TELEMETRY"] = "YES" # Get HF token from environment HF_TOKEN = os.environ.get("HF_TOKEN", "") if HF_TOKEN: os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) AI_TOOLKIT_DIR = os.path.join(SCRIPT_DIR, "ai-toolkit") import runpod import torch import yaml import gc import shutil logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Track current loaded model for cleanup CURRENT_MODEL = None # ============================================================================= # Model Configuration # ============================================================================= # Model configs matching ai-toolkit/config/examples exactly MODEL_PRESETS = { "wan21_1b": "train_lora_wan21_1b_24gb.yaml", "wan21_14b": "train_lora_wan21_14b_24gb.yaml", "wan22_14b": "train_lora_wan22_14b_24gb.yaml", "qwen_image": "train_lora_qwen_image_24gb.yaml", "qwen_image_edit": "train_lora_qwen_image_edit_32gb.yaml", "qwen_image_edit_2509": "train_lora_qwen_image_edit_2509_32gb.yaml", "flux_dev": "train_lora_flux_24gb.yaml", "flux_schnell": "train_lora_flux_schnell_24gb.yaml", } # All models cached in single HuggingFace repo for RunPod caching CACHE_REPO = "Aloukik21/trainer" # Map model keys to subfolder in cache repo MODEL_CACHE_PATHS = { "wan21_1b": "wan21-14b", # Uses same base, different config "wan21_14b": "wan21-14b", "wan22_14b": "wan22-14b", "qwen_image": "qwen-image", "qwen_image_edit": "qwen-image", # Same base model "qwen_image_edit_2509": "qwen-image", "flux_dev": "flux-dev", "flux_schnell": "flux-schnell", } # Original HuggingFace repos (fallback if cache not available) MODEL_HF_REPOS = { "wan21_1b": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", "wan21_14b": "Wan-AI/Wan2.1-T2V-14B-Diffusers", "wan22_14b": "ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16", "qwen_image": "Qwen/Qwen-Image", "qwen_image_edit": "Qwen/Qwen-Image-Edit", "qwen_image_edit_2509": "Qwen/Qwen-Image-Edit", "flux_dev": "black-forest-labs/FLUX.1-dev", "flux_schnell": "black-forest-labs/FLUX.1-schnell", } # Accuracy Recovery Adapters path in cache repo ARA_CACHE_PATH = "accuracy_recovery_adapters" # ============================================================================= # Cleanup Functions # ============================================================================= def cleanup_gpu_memory(): """Aggressively clean up GPU memory.""" logger.info("Cleaning up GPU memory...") # Clear PyTorch cache if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() # Force garbage collection gc.collect() # Clear again after GC if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info(f"GPU memory after cleanup: {get_gpu_info()}") def cleanup_temp_files(): """Clean up temporary training files.""" logger.info("Cleaning up temporary files...") # Clean up generated configs (keep example configs) config_dir = os.path.join(AI_TOOLKIT_DIR, "config") for f in os.listdir(config_dir): if f.endswith('.yaml') and f.startswith(('lora_', 'test_', 'my_')): try: os.remove(os.path.join(config_dir, f)) logger.info(f"Removed temp config: {f}") except Exception as e: logger.warning(f"Failed to remove {f}: {e}") # Clean up latent cache directories in workspace workspace_dirs = ["/workspace/dataset", "/workspace/output"] for ws_dir in workspace_dirs: if os.path.exists(ws_dir): for item in os.listdir(ws_dir): item_path = os.path.join(ws_dir, item) if item.startswith(('_latent_cache', '_t_e_cache', '.aitk')): try: if os.path.isdir(item_path): shutil.rmtree(item_path) else: os.remove(item_path) logger.info(f"Removed cache: {item_path}") except Exception as e: logger.warning(f"Failed to remove {item_path}: {e}") def cleanup_before_training(new_model: str): """Full cleanup before starting new model training.""" global CURRENT_MODEL if CURRENT_MODEL and CURRENT_MODEL != new_model: logger.info(f"Switching from {CURRENT_MODEL} to {new_model} - performing full cleanup") cleanup_gpu_memory() cleanup_temp_files() elif CURRENT_MODEL == new_model: logger.info(f"Same model {new_model} - light cleanup only") cleanup_gpu_memory() else: logger.info(f"First training run with {new_model}") CURRENT_MODEL = new_model # Final memory check gpu_info = get_gpu_info() logger.info(f"Ready for training. GPU: {gpu_info['name']}, Free: {gpu_info['free_gb']}GB") # ============================================================================= # Utility Functions # ============================================================================= def get_gpu_info(): """Get GPU information.""" if not torch.cuda.is_available(): return {"available": False} props = torch.cuda.get_device_properties(0) free_mem, total_mem = torch.cuda.mem_get_info(0) return { "available": True, "name": props.name, "total_gb": round(total_mem / (1024**3), 2), "free_gb": round(free_mem / (1024**3), 2), } def get_environment_info(): """Get environment information for debugging.""" return { "is_runpod_cache": IS_RUNPOD_CACHE, "hf_home": os.environ.get("HF_HOME", "not set"), "hf_token_set": bool(HF_TOKEN), "gpu": get_gpu_info(), "ai_toolkit_dir": AI_TOOLKIT_DIR, "cache_exists": os.path.exists(RUNPOD_HF_CACHE) if IS_RUNPOD_CACHE else False, } def find_cached_model(model_key: str) -> str: """ Find cached model path on RunPod from Aloukik21/trainer repo. Args: model_key: Model key (e.g., 'flux_dev', 'wan22_14b') Returns: Path to cached model subfolder, or original HF repo if not cached """ if not IS_RUNPOD_CACHE: return MODEL_HF_REPOS.get(model_key, "") # Check for Aloukik21/trainer cache cache_name = CACHE_REPO.replace("/", "--") snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots" if snapshots_dir.exists(): snapshots = list(snapshots_dir.iterdir()) if snapshots: # Get the subfolder for this model subfolder = MODEL_CACHE_PATHS.get(model_key) if subfolder: cached_path = snapshots[0] / subfolder if cached_path.exists(): logger.info(f"Using cached model: {model_key} -> {cached_path}") return str(cached_path) # Fallback to original repo original_repo = MODEL_HF_REPOS.get(model_key, "") logger.info(f"Model not in cache, using original: {original_repo}") return original_repo def find_cached_ara(adapter_name: str) -> str: """ Find cached accuracy recovery adapter. Args: adapter_name: Adapter filename (e.g., 'wan22_14b_t2i_torchao_uint4.safetensors') Returns: Path to cached adapter, or original HF path """ if not IS_RUNPOD_CACHE: return f"ostris/accuracy_recovery_adapters/{adapter_name}" cache_name = CACHE_REPO.replace("/", "--") snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots" if snapshots_dir.exists(): snapshots = list(snapshots_dir.iterdir()) if snapshots: cached_path = snapshots[0] / ARA_CACHE_PATH / adapter_name if cached_path.exists(): logger.info(f"Using cached ARA: {adapter_name} -> {cached_path}") return str(cached_path) return f"ostris/accuracy_recovery_adapters/{adapter_name}" def check_model_cache_status(model_key: str) -> dict: """Check if model files are cached in Aloukik21/trainer.""" if model_key not in MODEL_CACHE_PATHS: return {"cached": False, "reason": "unknown model"} status = { "model": model_key, "cache_repo": CACHE_REPO, "subfolder": MODEL_CACHE_PATHS.get(model_key), } # Check if main cache repo exists cache_name = CACHE_REPO.replace("/", "--") snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots" if snapshots_dir.exists(): snapshots = list(snapshots_dir.iterdir()) if snapshots: subfolder = MODEL_CACHE_PATHS.get(model_key) model_path = snapshots[0] / subfolder status["cached"] = model_path.exists() status["path"] = str(model_path) if model_path.exists() else None else: status["cached"] = False else: status["cached"] = False status["reason"] = "cache repo not found" return status # ============================================================================= # Config Loading and Training # ============================================================================= def load_example_config(model_key): """Load example config from ai-toolkit.""" if model_key not in MODEL_PRESETS: raise ValueError(f"Unknown model: {model_key}. Available: {list(MODEL_PRESETS.keys())}") config_file = MODEL_PRESETS[model_key] config_path = os.path.join(AI_TOOLKIT_DIR, "config", "examples", config_file) with open(config_path, 'r') as f: return yaml.safe_load(f) def run_training(params): """Run training using ai-toolkit.""" model_key = params.get("model", "wan22_14b") # Cleanup before starting new training cleanup_before_training(model_key) # Load base config from ai-toolkit examples config = load_example_config(model_key) # Override with user params job_name = params.get("name", f"lora_{model_key}_{uuid.uuid4().hex[:6]}") config["config"]["name"] = job_name process = config["config"]["process"][0] # Dataset process["datasets"][0]["folder_path"] = params.get("dataset_path", "/workspace/dataset") # Output process["training_folder"] = params.get("output_path", "/workspace/output") # Training params (only override if provided) if "steps" in params: process["train"]["steps"] = params["steps"] if "batch_size" in params: process["train"]["batch_size"] = params["batch_size"] if "learning_rate" in params: process["train"]["lr"] = params["learning_rate"] if "lora_rank" in params: process["network"]["linear"] = params["lora_rank"] process["network"]["linear_alpha"] = params.get("lora_alpha", params["lora_rank"]) if "save_every" in params: process["save"]["save_every"] = params["save_every"] if "sample_every" in params: process["sample"]["sample_every"] = params["sample_every"] if "resolution" in params: process["datasets"][0]["resolution"] = params["resolution"] if "num_frames" in params: process["datasets"][0]["num_frames"] = params["num_frames"] if "sample_prompts" in params: process["sample"]["prompts"] = params["sample_prompts"] if "trigger_word" in params: process["trigger_word"] = params["trigger_word"] # Check if we should use cached model path from Aloukik21/trainer if "model" in process: cached_path = find_cached_model(model_key) if cached_path: process["model"]["name_or_path"] = cached_path logger.info(f"Model path set to: {cached_path}") # Save config config_dir = os.path.join(AI_TOOLKIT_DIR, "config") config_path = os.path.join(config_dir, f"{job_name}.yaml") with open(config_path, 'w') as f: yaml.dump(config, f, default_flow_style=False) logger.info(f"Config saved: {config_path}") logger.info(f"Starting: {job_name}") # Run ai-toolkit cmd = [sys.executable, os.path.join(AI_TOOLKIT_DIR, "run.py"), config_path] logger.info(f"Command: {' '.join(cmd)}") proc = subprocess.Popen( cmd, cwd=AI_TOOLKIT_DIR, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, ) for line in proc.stdout: logger.info(line.rstrip()) proc.wait() # Cleanup after training (success or fail) cleanup_gpu_memory() if proc.returncode != 0: raise RuntimeError(f"Training failed with code {proc.returncode}") return { "status": "success", "job_name": job_name, "output_path": process["training_folder"], "model": model_key, } # ============================================================================= # Handler # ============================================================================= def handler(job): """RunPod handler.""" job_input = job.get("input", {}) action = job_input.get("action", "train") logger.info(f"Action: {action}, GPU: {get_gpu_info()}") try: if action == "list_models": return {"status": "success", "models": list(MODEL_PRESETS.keys())} elif action == "status": return { "status": "success", "environment": get_environment_info(), } elif action == "check_cache": model_key = job_input.get("model") if model_key: cache_status = check_model_cache_status(model_key) else: cache_status = {m: check_model_cache_status(m) for m in MODEL_PRESETS.keys()} return {"status": "success", "cache": cache_status} elif action == "cleanup": # Manual cleanup action cleanup_gpu_memory() cleanup_temp_files() global CURRENT_MODEL CURRENT_MODEL = None return { "status": "success", "message": "Cleanup complete", "gpu": get_gpu_info(), } elif action == "train": params = job_input.get("params", {}) params["model"] = job_input.get("model", params.get("model", "wan22_14b")) return run_training(params) else: return {"status": "error", "error": f"Unknown action: {action}"} except Exception as e: logger.error(traceback.format_exc()) return {"status": "error", "error": str(e)} if __name__ == "__main__": logger.info("Starting AI-Toolkit RunPod Handler") logger.info(f"Environment: {get_environment_info()}") runpod.serverless.start({"handler": handler})