import { NextRequest, NextResponse } from 'next/server'; import { spawn } from 'child_process'; import { writeFile } from 'fs/promises'; import path from 'path'; import { tmpdir } from 'os'; export async function POST(request: NextRequest) { try { const body = await request.json(); const { action, token, hardware, namespace, jobConfig, datasetRepo } = body; switch (action) { case 'checkStatus': try { if (!token || !jobConfig?.hf_job_id) { return NextResponse.json({ error: 'Token and job ID required' }, { status: 400 }); } const jobStatus = await checkHFJobStatus(token, jobConfig.hf_job_id); return NextResponse.json({ status: jobStatus }); } catch (error: any) { console.error('Job status check error:', error); return NextResponse.json({ error: error.message }, { status: 500 }); } case 'generateScript': try { const uvScript = generateUVScript({ jobConfig, datasetRepo, namespace, token: token || 'YOUR_HF_TOKEN', }); return NextResponse.json({ script: uvScript, filename: `train_${jobConfig.config.name.replace(/[^a-zA-Z0-9]/g, '_')}.py` }); } catch (error: any) { return NextResponse.json({ error: error.message }, { status: 500 }); } case 'submitJob': try { if (!token || !hardware) { return NextResponse.json({ error: 'Token and hardware required' }, { status: 400 }); } // Generate UV script const uvScript = generateUVScript({ jobConfig, datasetRepo, namespace, token, }); // Write script to temporary file const scriptPath = path.join(tmpdir(), `train_${Date.now()}.py`); await writeFile(scriptPath, uvScript); // Submit HF job using uv run const jobId = await submitHFJobUV(token, hardware, scriptPath); return NextResponse.json({ success: true, jobId, message: `Job submitted successfully with ID: ${jobId}` }); } catch (error: any) { console.error('Job submission error:', error); return NextResponse.json({ error: error.message }, { status: 500 }); } default: return NextResponse.json({ error: 'Invalid action' }, { status: 400 }); } } catch (error: any) { console.error('HF Jobs API error:', error); return NextResponse.json({ error: error.message }, { status: 500 }); } } function generateUVScript({ jobConfig, datasetRepo, namespace, token }: { jobConfig: any; datasetRepo: string; namespace: string; token: string; }) { const config = jobConfig.config; const process = config.process[0]; return `# /// script # dependencies = [ # "torch>=2.0.0", # "torchvision", # "torchao==0.10.0", # "safetensors", # "diffusers @ git+https://github.com/huggingface/diffusers@7a2b78bf0f788d311cc96b61e660a8e13e3b1e63", # "transformers==4.52.4", # "lycoris-lora==1.8.3", # "flatten_json", # "pyyaml", # "oyaml", # "tensorboard", # "kornia", # "invisible-watermark", # "einops", # "accelerate", # "toml", # "albumentations==1.4.15", # "albucore==0.0.16", # "pydantic", # "omegaconf", # "k-diffusion", # "open_clip_torch", # "timm", # "prodigyopt", # "controlnet_aux==0.0.10", # "python-dotenv", # "bitsandbytes", # "hf_transfer", # "lpips", # "pytorch_fid", # "optimum-quanto==0.2.4", # "sentencepiece", # "huggingface_hub", # "peft", # "python-slugify", # "opencv-python-headless", # "pytorch-wavelets==1.3.0", # "matplotlib==3.10.1", # "setuptools==69.5.1", # "datasets==4.0.0", # "pyarrow==20.0.0", # "pillow", # "ftfy", # ] # /// import os import sys import subprocess import argparse import oyaml as yaml from datasets import load_dataset from huggingface_hub import HfApi, create_repo, upload_folder, snapshot_download import tempfile import shutil import glob from PIL import Image def setup_ai_toolkit(): """Clone and setup ai-toolkit repository""" repo_dir = "ai-toolkit" if not os.path.exists(repo_dir): print("Cloning ai-toolkit repository...") subprocess.run( ["git", "clone", "https://github.com/ostris/ai-toolkit.git", repo_dir], check=True ) sys.path.insert(0, os.path.abspath(repo_dir)) return repo_dir def download_dataset(dataset_repo: str, local_path: str): """Download dataset from HF Hub as files""" print(f"Downloading dataset from {dataset_repo}...") # Create local dataset directory os.makedirs(local_path, exist_ok=True) # Use snapshot_download to get the dataset files directly from huggingface_hub import snapshot_download try: # First try to download as a structured dataset dataset = load_dataset(dataset_repo, split="train") # Download images and captions from structured dataset for i, item in enumerate(dataset): # Save image if "image" in item: image_path = os.path.join(local_path, f"image_{i:06d}.jpg") image = item["image"] # Convert RGBA to RGB if necessary (for JPEG compatibility) if image.mode == 'RGBA': # Create a white background and paste the RGBA image on it background = Image.new('RGB', image.size, (255, 255, 255)) background.paste(image, mask=image.split()[-1]) # Use alpha channel as mask image = background elif image.mode not in ['RGB', 'L']: # Convert any other mode to RGB image = image.convert('RGB') image.save(image_path, 'JPEG') # Save caption if "text" in item: caption_path = os.path.join(local_path, f"image_{i:06d}.txt") with open(caption_path, "w", encoding="utf-8") as f: f.write(item["text"]) print(f"Downloaded {len(dataset)} items to {local_path}") except Exception as e: print(f"Failed to load as structured dataset: {e}") print("Attempting to download raw files...") # Download the dataset repository as files temp_repo_path = snapshot_download(repo_id=dataset_repo, repo_type="dataset") # Copy all image and text files to the local path import glob import shutil print(f"Downloaded repo to: {temp_repo_path}") print(f"Contents: {os.listdir(temp_repo_path)}") # Find all image files image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.webp', '*.bmp', '*.JPG', '*.JPEG', '*.PNG'] image_files = [] for ext in image_extensions: pattern = os.path.join(temp_repo_path, "**", ext) found_files = glob.glob(pattern, recursive=True) image_files.extend(found_files) print(f"Pattern {pattern} found {len(found_files)} files") # Find all text files text_files = glob.glob(os.path.join(temp_repo_path, "**", "*.txt"), recursive=True) print(f"Found {len(image_files)} image files and {len(text_files)} text files") # Copy image files for i, img_file in enumerate(image_files): dest_path = os.path.join(local_path, f"image_{i:06d}.jpg") # Load and convert image if needed try: with Image.open(img_file) as image: if image.mode == 'RGBA': background = Image.new('RGB', image.size, (255, 255, 255)) background.paste(image, mask=image.split()[-1]) image = background elif image.mode not in ['RGB', 'L']: image = image.convert('RGB') image.save(dest_path, 'JPEG') except Exception as img_error: print(f"Error processing image {img_file}: {img_error}") continue # Copy text files (captions) for i, txt_file in enumerate(text_files[:len(image_files)]): # Match number of images dest_path = os.path.join(local_path, f"image_{i:06d}.txt") try: shutil.copy2(txt_file, dest_path) except Exception as txt_error: print(f"Error copying text file {txt_file}: {txt_error}") continue print(f"Downloaded {len(image_files)} images and {len(text_files)} captions to {local_path}") def create_config(dataset_path: str, output_path: str): """Create training configuration""" import json # Load config from JSON string and fix boolean/null values for Python config_str = """${JSON.stringify(jobConfig, null, 2)}""" config_str = config_str.replace('true', 'True').replace('false', 'False').replace('null', 'None') config = eval(config_str) # Update paths for cloud environment config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_path config["config"]["process"][0]["training_folder"] = output_path # Remove sqlite_db_path as it's not needed for cloud training if "sqlite_db_path" in config["config"]["process"][0]: del config["config"]["process"][0]["sqlite_db_path"] # Also change trainer type from ui_trainer to standard trainer to avoid UI dependencies if config["config"]["process"][0]["type"] == "ui_trainer": config["config"]["process"][0]["type"] = "sd_trainer" return config def upload_results(output_path: str, model_name: str, namespace: str, token: str, config: dict): """Upload trained model to HF Hub with README generation and proper file organization""" import tempfile import shutil import glob import re import yaml from datetime import datetime from huggingface_hub import create_repo, upload_file, HfApi try: repo_id = f"{namespace}/{model_name}" # Create repository create_repo(repo_id=repo_id, token=token, exist_ok=True) print(f"Uploading model to {repo_id}...") # Create temporary directory for organized upload with tempfile.TemporaryDirectory() as temp_upload_dir: api = HfApi() # 1. Find and upload model files to root directory safetensors_files = glob.glob(os.path.join(output_path, "**", "*.safetensors"), recursive=True) json_files = glob.glob(os.path.join(output_path, "**", "*.json"), recursive=True) txt_files = glob.glob(os.path.join(output_path, "**", "*.txt"), recursive=True) uploaded_files = [] # Upload .safetensors files to root for file_path in safetensors_files: filename = os.path.basename(file_path) print(f"Uploading {filename} to repository root...") api.upload_file( path_or_fileobj=file_path, path_in_repo=filename, repo_id=repo_id, token=token ) uploaded_files.append(filename) # Upload relevant JSON config files to root (skip metadata.json and other internal files) config_files_uploaded = [] for file_path in json_files: filename = os.path.basename(file_path) # Only upload important config files, skip internal metadata if any(keyword in filename.lower() for keyword in ['config', 'adapter', 'lora', 'model']): print(f"Uploading {filename} to repository root...") api.upload_file( path_or_fileobj=file_path, path_in_repo=filename, repo_id=repo_id, token=token ) uploaded_files.append(filename) config_files_uploaded.append(filename) # 2. Handle sample images samples_uploaded = [] samples_dir = os.path.join(output_path, "samples") if os.path.isdir(samples_dir): print("Uploading sample images...") # Create samples directory in repo for filename in os.listdir(samples_dir): if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')): file_path = os.path.join(samples_dir, filename) repo_path = f"samples/{filename}" api.upload_file( path_or_fileobj=file_path, path_in_repo=repo_path, repo_id=repo_id, token=token ) samples_uploaded.append(repo_path) # 3. Generate and upload README.md readme_content = generate_model_card_readme( repo_id=repo_id, config=config, model_name=model_name, samples_dir=samples_dir if os.path.isdir(samples_dir) else None, uploaded_files=uploaded_files ) # Create README.md file and upload to root readme_path = os.path.join(temp_upload_dir, "README.md") with open(readme_path, "w", encoding="utf-8") as f: f.write(readme_content) print("Uploading README.md to repository root...") api.upload_file( path_or_fileobj=readme_path, path_in_repo="README.md", repo_id=repo_id, token=token ) print(f"Model uploaded successfully to https://huggingface.co/{repo_id}") print(f"Files uploaded: {len(uploaded_files)} model files, {len(samples_uploaded)} samples, README.md") except Exception as e: print(f"Failed to upload model: {e}") raise e def generate_model_card_readme(repo_id: str, config: dict, model_name: str, samples_dir: str = None, uploaded_files: list = None) -> str: """Generate README.md content for the model card based on AI Toolkit's implementation""" import re import yaml import os try: # Extract configuration details process_config = config.get("config", {}).get("process", [{}])[0] model_config = process_config.get("model", {}) train_config = process_config.get("train", {}) sample_config = process_config.get("sample", {}) # Gather model info base_model = model_config.get("name_or_path", "unknown") trigger_word = process_config.get("trigger_word") arch = model_config.get("arch", "") # Determine license based on base model if "FLUX.1-schnell" in base_model: license_info = {"license": "apache-2.0"} elif "FLUX.1-dev" in base_model: license_info = { "license": "other", "license_name": "flux-1-dev-non-commercial-license", "license_link": "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md" } else: license_info = {"license": "creativeml-openrail-m"} # Generate tags based on model architecture tags = ["text-to-image"] if "xl" in arch.lower(): tags.append("stable-diffusion-xl") if "flux" in arch.lower(): tags.append("flux") if "lumina" in arch.lower(): tags.append("lumina2") if "sd3" in arch.lower() or "v3" in arch.lower(): tags.append("sd3") # Add LoRA-specific tags tags.extend(["lora", "diffusers", "template:sd-lora", "ai-toolkit"]) # Generate widgets from sample images and prompts widgets = [] if samples_dir and os.path.isdir(samples_dir): sample_prompts = sample_config.get("samples", []) if not sample_prompts: # Fallback to old format sample_prompts = [{"prompt": p} for p in sample_config.get("prompts", [])] # Get sample image files sample_files = [] if os.path.isdir(samples_dir): for filename in os.listdir(samples_dir): if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')): # Parse filename pattern: timestamp__steps_index.jpg match = re.search(r"__(\d+)_(\d+)\.jpg$", filename) if match: steps, index = int(match.group(1)), int(match.group(2)) # Only use samples from final training step final_steps = train_config.get("steps", 1000) if steps == final_steps: sample_files.append((index, f"samples/{filename}")) # Sort by index and create widgets sample_files.sort(key=lambda x: x[0]) for i, prompt_obj in enumerate(sample_prompts): prompt = prompt_obj.get("prompt", "") if isinstance(prompt_obj, dict) else str(prompt_obj) if i < len(sample_files): _, image_path = sample_files[i] widgets.append({ "text": prompt, "output": {"url": image_path} }) # Determine torch dtype based on model dtype = "torch.bfloat16" if "flux" in arch.lower() else "torch.float16" # Find the main safetensors file for usage example main_safetensors = f"{model_name}.safetensors" if uploaded_files: safetensors_files = [f for f in uploaded_files if f.endswith('.safetensors')] if safetensors_files: main_safetensors = safetensors_files[0] # Construct YAML frontmatter frontmatter = { "tags": tags, "base_model": base_model, **license_info } if widgets: frontmatter["widget"] = widgets if trigger_word: frontmatter["instance_prompt"] = trigger_word # Get first prompt for usage example usage_prompt = trigger_word or "a beautiful landscape" if widgets: usage_prompt = widgets[0]["text"] elif trigger_word: usage_prompt = trigger_word # Construct README content trigger_section = f"You should use \`{trigger_word}\` to trigger the image generation." if trigger_word else "No trigger words defined." # Build YAML frontmatter string frontmatter_yaml = yaml.dump(frontmatter, default_flow_style=False, allow_unicode=True, sort_keys=False).strip() readme_content = f"""--- {frontmatter_yaml} --- # {model_name} Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) ## Trigger words {trigger_section} ## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, etc. Weights for this model are available in Safetensors format. [Download]({repo_id}/tree/main) them in the Files & versions tab. ## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) \`\`\`py from diffusers import AutoPipelineForText2Image import torch pipeline = AutoPipelineForText2Image.from_pretrained('{base_model}', torch_dtype={dtype}).to('cuda') pipeline.load_lora_weights('{repo_id}', weight_name='{main_safetensors}') image = pipeline('{usage_prompt}').images[0] image.save("my_image.png") \`\`\` For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) """ return readme_content except Exception as e: print(f"Error generating README: {e}") # Fallback simple README return f"""# {model_name} Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) ## Download model Weights for this model are available in Safetensors format. [Download]({repo_id}/tree/main) them in the Files & versions tab. """ def main(): # Setup environment - token comes from HF Jobs secrets if "HF_TOKEN" not in os.environ: raise ValueError("HF_TOKEN environment variable not set") # Install system dependencies for headless operation print("Installing system dependencies...") try: subprocess.run(["apt-get", "update"], check=True, capture_output=True) subprocess.run([ "apt-get", "install", "-y", "libgl1-mesa-glx", "libglib2.0-0", "libsm6", "libxext6", "libxrender-dev", "libgomp1", "ffmpeg" ], check=True, capture_output=True) print("System dependencies installed successfully") except subprocess.CalledProcessError as e: print(f"Failed to install system dependencies: {e}") print("Continuing without system dependencies...") # Setup ai-toolkit toolkit_dir = setup_ai_toolkit() # Create temporary directories with tempfile.TemporaryDirectory() as temp_dir: dataset_path = os.path.join(temp_dir, "dataset") output_path = os.path.join(temp_dir, "output") # Download dataset download_dataset("${datasetRepo}", dataset_path) # Create config config = create_config(dataset_path, output_path) config_path = os.path.join(temp_dir, "config.yaml") with open(config_path, "w") as f: yaml.dump(config, f, default_flow_style=False) # Run training print("Starting training...") os.chdir(toolkit_dir) subprocess.run([ sys.executable, "run.py", config_path ], check=True) print("Training completed!") # Upload results model_name = f"${jobConfig.config.name}-lora" upload_results(output_path, model_name, "${namespace}", os.environ["HF_TOKEN"], config) if __name__ == "__main__": main() `; } async function submitHFJobUV(token: string, hardware: string, scriptPath: string): Promise { return new Promise((resolve, reject) => { // Ensure token is available if (!token) { reject(new Error('HF_TOKEN is required')); return; } console.log('Setting up environment with HF_TOKEN for job submission'); console.log(`Command: hf jobs uv run --flavor ${hardware} --timeout 5h --secrets HF_TOKEN --detach ${scriptPath}`); // Use hf jobs uv run command with timeout and detach to get job ID const childProcess = spawn('hf', [ 'jobs', 'uv', 'run', '--flavor', hardware, '--timeout', '5h', '--secrets', 'HF_TOKEN', '--detach', scriptPath ], { env: { ...process.env, HF_TOKEN: token } }); let output = ''; let error = ''; childProcess.stdout.on('data', (data) => { const text = data.toString(); output += text; console.log('HF Jobs stdout:', text); }); childProcess.stderr.on('data', (data) => { const text = data.toString(); error += text; console.log('HF Jobs stderr:', text); }); childProcess.on('close', (code) => { console.log('HF Jobs process closed with code:', code); console.log('Full output:', output); console.log('Full error:', error); if (code === 0) { // With --detach flag, the output should be just the job ID const fullText = (output + ' ' + error).trim(); // Updated patterns to handle variable-length hex job IDs (16-24+ characters) const jobIdPatterns = [ /Job started with ID:\s*([a-f0-9]{16,})/i, // "Job started with ID: 68b26b73767540db9fc726ac" /job\s+([a-f0-9]{16,})/i, // "job 68b26b73767540db9fc726ac" /Job ID:\s*([a-f0-9]{16,})/i, // "Job ID: 68b26b73767540db9fc726ac" /created\s+job\s+([a-f0-9]{16,})/i, // "created job 68b26b73767540db9fc726ac" /submitted.*?job\s+([a-f0-9]{16,})/i, // "submitted ... job 68b26b73767540db9fc726ac" /https:\/\/huggingface\.co\/jobs\/[^\/]+\/([a-f0-9]{16,})/i, // URL pattern /([a-f0-9]{20,})/i, // Fallback: any 20+ char hex string ]; let jobId = 'unknown'; for (const pattern of jobIdPatterns) { const match = fullText.match(pattern); if (match && match[1] && match[1] !== 'started') { jobId = match[1]; console.log(`Extracted job ID using pattern: ${pattern.toString()} -> ${jobId}`); break; } } resolve(jobId); } else { reject(new Error(error || output || 'Failed to submit job')); } }); childProcess.on('error', (err) => { console.error('HF Jobs process error:', err); reject(new Error(`Process error: ${err.message}`)); }); }); } async function checkHFJobStatus(token: string, jobId: string): Promise { return new Promise((resolve, reject) => { console.log(`Checking HF Job status for: ${jobId}`); const childProcess = spawn('hf', [ 'jobs', 'inspect', jobId ], { env: { ...process.env, HF_TOKEN: token } }); let output = ''; let error = ''; childProcess.stdout.on('data', (data) => { const text = data.toString(); output += text; }); childProcess.stderr.on('data', (data) => { const text = data.toString(); error += text; }); childProcess.on('close', (code) => { if (code === 0) { try { // Parse the JSON output from hf jobs inspect const jobInfo = JSON.parse(output); if (Array.isArray(jobInfo) && jobInfo.length > 0) { const job = jobInfo[0]; resolve({ id: job.id, status: job.status?.stage || 'UNKNOWN', message: job.status?.message, created_at: job.created_at, flavor: job.flavor, url: job.url, }); } else { reject(new Error('Invalid job info response')); } } catch (parseError: any) { console.error('Failed to parse job status:', parseError, output); reject(new Error('Failed to parse job status')); } } else { reject(new Error(error || output || 'Failed to check job status')); } }); childProcess.on('error', (err) => { console.error('HF Jobs inspect process error:', err); reject(new Error(`Process error: ${err.message}`)); }); }); }