trainer / rp_handler.py
Aloukik21's picture
Update handler to use Aloukik21/trainer cache
b31917b verified
"""
RunPod Serverless Handler - Wrapper for AI-Toolkit
Does NOT modify ai-toolkit code, only wraps it
Supports RunPod model caching via HuggingFace integration.
"""
import os
import sys
import subprocess
import traceback
import logging
import uuid
from pathlib import Path
# =============================================================================
# Environment Setup (must be before other imports)
# =============================================================================
# RunPod cache paths
RUNPOD_CACHE_BASE = "/runpod-volume/huggingface-cache"
RUNPOD_HF_CACHE = "/runpod-volume/huggingface-cache/hub"
# Check if running on RunPod with cache available
IS_RUNPOD_CACHE = os.path.exists("/runpod-volume")
if IS_RUNPOD_CACHE:
# Use RunPod's cache directory for HuggingFace downloads
os.environ["HF_HOME"] = RUNPOD_CACHE_BASE
os.environ["HUGGINGFACE_HUB_CACHE"] = RUNPOD_HF_CACHE
os.environ["TRANSFORMERS_CACHE"] = RUNPOD_HF_CACHE
os.environ["HF_DATASETS_CACHE"] = f"{RUNPOD_CACHE_BASE}/datasets"
# Performance and telemetry settings
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
os.environ["DISABLE_TELEMETRY"] = "YES"
# Get HF token from environment
HF_TOKEN = os.environ.get("HF_TOKEN", "")
if HF_TOKEN:
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
AI_TOOLKIT_DIR = os.path.join(SCRIPT_DIR, "ai-toolkit")
import runpod
import torch
import yaml
import gc
import shutil
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Track current loaded model for cleanup
CURRENT_MODEL = None
# =============================================================================
# Model Configuration
# =============================================================================
# Model configs matching ai-toolkit/config/examples exactly
MODEL_PRESETS = {
"wan21_1b": "train_lora_wan21_1b_24gb.yaml",
"wan21_14b": "train_lora_wan21_14b_24gb.yaml",
"wan22_14b": "train_lora_wan22_14b_24gb.yaml",
"qwen_image": "train_lora_qwen_image_24gb.yaml",
"qwen_image_edit": "train_lora_qwen_image_edit_32gb.yaml",
"qwen_image_edit_2509": "train_lora_qwen_image_edit_2509_32gb.yaml",
"flux_dev": "train_lora_flux_24gb.yaml",
"flux_schnell": "train_lora_flux_schnell_24gb.yaml",
}
# All models cached in single HuggingFace repo for RunPod caching
CACHE_REPO = "Aloukik21/trainer"
# Map model keys to subfolder in cache repo
MODEL_CACHE_PATHS = {
"wan21_1b": "wan21-14b", # Uses same base, different config
"wan21_14b": "wan21-14b",
"wan22_14b": "wan22-14b",
"qwen_image": "qwen-image",
"qwen_image_edit": "qwen-image", # Same base model
"qwen_image_edit_2509": "qwen-image",
"flux_dev": "flux-dev",
"flux_schnell": "flux-schnell",
}
# Original HuggingFace repos (fallback if cache not available)
MODEL_HF_REPOS = {
"wan21_1b": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
"wan21_14b": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
"wan22_14b": "ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16",
"qwen_image": "Qwen/Qwen-Image",
"qwen_image_edit": "Qwen/Qwen-Image-Edit",
"qwen_image_edit_2509": "Qwen/Qwen-Image-Edit",
"flux_dev": "black-forest-labs/FLUX.1-dev",
"flux_schnell": "black-forest-labs/FLUX.1-schnell",
}
# Accuracy Recovery Adapters path in cache repo
ARA_CACHE_PATH = "accuracy_recovery_adapters"
# =============================================================================
# Cleanup Functions
# =============================================================================
def cleanup_gpu_memory():
"""Aggressively clean up GPU memory."""
logger.info("Cleaning up GPU memory...")
# Clear PyTorch cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Force garbage collection
gc.collect()
# Clear again after GC
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info(f"GPU memory after cleanup: {get_gpu_info()}")
def cleanup_temp_files():
"""Clean up temporary training files."""
logger.info("Cleaning up temporary files...")
# Clean up generated configs (keep example configs)
config_dir = os.path.join(AI_TOOLKIT_DIR, "config")
for f in os.listdir(config_dir):
if f.endswith('.yaml') and f.startswith(('lora_', 'test_', 'my_')):
try:
os.remove(os.path.join(config_dir, f))
logger.info(f"Removed temp config: {f}")
except Exception as e:
logger.warning(f"Failed to remove {f}: {e}")
# Clean up latent cache directories in workspace
workspace_dirs = ["/workspace/dataset", "/workspace/output"]
for ws_dir in workspace_dirs:
if os.path.exists(ws_dir):
for item in os.listdir(ws_dir):
item_path = os.path.join(ws_dir, item)
if item.startswith(('_latent_cache', '_t_e_cache', '.aitk')):
try:
if os.path.isdir(item_path):
shutil.rmtree(item_path)
else:
os.remove(item_path)
logger.info(f"Removed cache: {item_path}")
except Exception as e:
logger.warning(f"Failed to remove {item_path}: {e}")
def cleanup_before_training(new_model: str):
"""Full cleanup before starting new model training."""
global CURRENT_MODEL
if CURRENT_MODEL and CURRENT_MODEL != new_model:
logger.info(f"Switching from {CURRENT_MODEL} to {new_model} - performing full cleanup")
cleanup_gpu_memory()
cleanup_temp_files()
elif CURRENT_MODEL == new_model:
logger.info(f"Same model {new_model} - light cleanup only")
cleanup_gpu_memory()
else:
logger.info(f"First training run with {new_model}")
CURRENT_MODEL = new_model
# Final memory check
gpu_info = get_gpu_info()
logger.info(f"Ready for training. GPU: {gpu_info['name']}, Free: {gpu_info['free_gb']}GB")
# =============================================================================
# Utility Functions
# =============================================================================
def get_gpu_info():
"""Get GPU information."""
if not torch.cuda.is_available():
return {"available": False}
props = torch.cuda.get_device_properties(0)
free_mem, total_mem = torch.cuda.mem_get_info(0)
return {
"available": True,
"name": props.name,
"total_gb": round(total_mem / (1024**3), 2),
"free_gb": round(free_mem / (1024**3), 2),
}
def get_environment_info():
"""Get environment information for debugging."""
return {
"is_runpod_cache": IS_RUNPOD_CACHE,
"hf_home": os.environ.get("HF_HOME", "not set"),
"hf_token_set": bool(HF_TOKEN),
"gpu": get_gpu_info(),
"ai_toolkit_dir": AI_TOOLKIT_DIR,
"cache_exists": os.path.exists(RUNPOD_HF_CACHE) if IS_RUNPOD_CACHE else False,
}
def find_cached_model(model_key: str) -> str:
"""
Find cached model path on RunPod from Aloukik21/trainer repo.
Args:
model_key: Model key (e.g., 'flux_dev', 'wan22_14b')
Returns:
Path to cached model subfolder, or original HF repo if not cached
"""
if not IS_RUNPOD_CACHE:
return MODEL_HF_REPOS.get(model_key, "")
# Check for Aloukik21/trainer cache
cache_name = CACHE_REPO.replace("/", "--")
snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots"
if snapshots_dir.exists():
snapshots = list(snapshots_dir.iterdir())
if snapshots:
# Get the subfolder for this model
subfolder = MODEL_CACHE_PATHS.get(model_key)
if subfolder:
cached_path = snapshots[0] / subfolder
if cached_path.exists():
logger.info(f"Using cached model: {model_key} -> {cached_path}")
return str(cached_path)
# Fallback to original repo
original_repo = MODEL_HF_REPOS.get(model_key, "")
logger.info(f"Model not in cache, using original: {original_repo}")
return original_repo
def find_cached_ara(adapter_name: str) -> str:
"""
Find cached accuracy recovery adapter.
Args:
adapter_name: Adapter filename (e.g., 'wan22_14b_t2i_torchao_uint4.safetensors')
Returns:
Path to cached adapter, or original HF path
"""
if not IS_RUNPOD_CACHE:
return f"ostris/accuracy_recovery_adapters/{adapter_name}"
cache_name = CACHE_REPO.replace("/", "--")
snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots"
if snapshots_dir.exists():
snapshots = list(snapshots_dir.iterdir())
if snapshots:
cached_path = snapshots[0] / ARA_CACHE_PATH / adapter_name
if cached_path.exists():
logger.info(f"Using cached ARA: {adapter_name} -> {cached_path}")
return str(cached_path)
return f"ostris/accuracy_recovery_adapters/{adapter_name}"
def check_model_cache_status(model_key: str) -> dict:
"""Check if model files are cached in Aloukik21/trainer."""
if model_key not in MODEL_CACHE_PATHS:
return {"cached": False, "reason": "unknown model"}
status = {
"model": model_key,
"cache_repo": CACHE_REPO,
"subfolder": MODEL_CACHE_PATHS.get(model_key),
}
# Check if main cache repo exists
cache_name = CACHE_REPO.replace("/", "--")
snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots"
if snapshots_dir.exists():
snapshots = list(snapshots_dir.iterdir())
if snapshots:
subfolder = MODEL_CACHE_PATHS.get(model_key)
model_path = snapshots[0] / subfolder
status["cached"] = model_path.exists()
status["path"] = str(model_path) if model_path.exists() else None
else:
status["cached"] = False
else:
status["cached"] = False
status["reason"] = "cache repo not found"
return status
# =============================================================================
# Config Loading and Training
# =============================================================================
def load_example_config(model_key):
"""Load example config from ai-toolkit."""
if model_key not in MODEL_PRESETS:
raise ValueError(f"Unknown model: {model_key}. Available: {list(MODEL_PRESETS.keys())}")
config_file = MODEL_PRESETS[model_key]
config_path = os.path.join(AI_TOOLKIT_DIR, "config", "examples", config_file)
with open(config_path, 'r') as f:
return yaml.safe_load(f)
def run_training(params):
"""Run training using ai-toolkit."""
model_key = params.get("model", "wan22_14b")
# Cleanup before starting new training
cleanup_before_training(model_key)
# Load base config from ai-toolkit examples
config = load_example_config(model_key)
# Override with user params
job_name = params.get("name", f"lora_{model_key}_{uuid.uuid4().hex[:6]}")
config["config"]["name"] = job_name
process = config["config"]["process"][0]
# Dataset
process["datasets"][0]["folder_path"] = params.get("dataset_path", "/workspace/dataset")
# Output
process["training_folder"] = params.get("output_path", "/workspace/output")
# Training params (only override if provided)
if "steps" in params:
process["train"]["steps"] = params["steps"]
if "batch_size" in params:
process["train"]["batch_size"] = params["batch_size"]
if "learning_rate" in params:
process["train"]["lr"] = params["learning_rate"]
if "lora_rank" in params:
process["network"]["linear"] = params["lora_rank"]
process["network"]["linear_alpha"] = params.get("lora_alpha", params["lora_rank"])
if "save_every" in params:
process["save"]["save_every"] = params["save_every"]
if "sample_every" in params:
process["sample"]["sample_every"] = params["sample_every"]
if "resolution" in params:
process["datasets"][0]["resolution"] = params["resolution"]
if "num_frames" in params:
process["datasets"][0]["num_frames"] = params["num_frames"]
if "sample_prompts" in params:
process["sample"]["prompts"] = params["sample_prompts"]
if "trigger_word" in params:
process["trigger_word"] = params["trigger_word"]
# Check if we should use cached model path from Aloukik21/trainer
if "model" in process:
cached_path = find_cached_model(model_key)
if cached_path:
process["model"]["name_or_path"] = cached_path
logger.info(f"Model path set to: {cached_path}")
# Save config
config_dir = os.path.join(AI_TOOLKIT_DIR, "config")
config_path = os.path.join(config_dir, f"{job_name}.yaml")
with open(config_path, 'w') as f:
yaml.dump(config, f, default_flow_style=False)
logger.info(f"Config saved: {config_path}")
logger.info(f"Starting: {job_name}")
# Run ai-toolkit
cmd = [sys.executable, os.path.join(AI_TOOLKIT_DIR, "run.py"), config_path]
logger.info(f"Command: {' '.join(cmd)}")
proc = subprocess.Popen(
cmd,
cwd=AI_TOOLKIT_DIR,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
)
for line in proc.stdout:
logger.info(line.rstrip())
proc.wait()
# Cleanup after training (success or fail)
cleanup_gpu_memory()
if proc.returncode != 0:
raise RuntimeError(f"Training failed with code {proc.returncode}")
return {
"status": "success",
"job_name": job_name,
"output_path": process["training_folder"],
"model": model_key,
}
# =============================================================================
# Handler
# =============================================================================
def handler(job):
"""RunPod handler."""
job_input = job.get("input", {})
action = job_input.get("action", "train")
logger.info(f"Action: {action}, GPU: {get_gpu_info()}")
try:
if action == "list_models":
return {"status": "success", "models": list(MODEL_PRESETS.keys())}
elif action == "status":
return {
"status": "success",
"environment": get_environment_info(),
}
elif action == "check_cache":
model_key = job_input.get("model")
if model_key:
cache_status = check_model_cache_status(model_key)
else:
cache_status = {m: check_model_cache_status(m) for m in MODEL_PRESETS.keys()}
return {"status": "success", "cache": cache_status}
elif action == "cleanup":
# Manual cleanup action
cleanup_gpu_memory()
cleanup_temp_files()
global CURRENT_MODEL
CURRENT_MODEL = None
return {
"status": "success",
"message": "Cleanup complete",
"gpu": get_gpu_info(),
}
elif action == "train":
params = job_input.get("params", {})
params["model"] = job_input.get("model", params.get("model", "wan22_14b"))
return run_training(params)
else:
return {"status": "error", "error": f"Unknown action: {action}"}
except Exception as e:
logger.error(traceback.format_exc())
return {"status": "error", "error": str(e)}
if __name__ == "__main__":
logger.info("Starting AI-Toolkit RunPod Handler")
logger.info(f"Environment: {get_environment_info()}")
runpod.serverless.start({"handler": handler})