Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import { NextRequest, NextResponse } from 'next/server'; | |
| import { spawn } from 'child_process'; | |
| import { writeFile, readFile, unlink } from 'fs/promises'; | |
| import path from 'path'; | |
| import { tmpdir } from 'os'; | |
| export async function POST(request: NextRequest) { | |
| try { | |
| const body = await request.json(); | |
| const { action, token, hardware, namespace, jobConfig, datasetRepo, participateHackathon } = body; | |
| switch (action) { | |
| case 'checkCapacity': | |
| try { | |
| if (!token) { | |
| return NextResponse.json({ error: 'Token required' }, { status: 400 }); | |
| } | |
| const capacityStatus = await checkHFJobsCapacity(token); | |
| return NextResponse.json(capacityStatus); | |
| } catch (error: any) { | |
| console.error('Capacity check error:', error); | |
| return NextResponse.json({ error: error.message }, { status: 500 }); | |
| } | |
| case 'checkStatus': | |
| try { | |
| if (!token || !jobConfig?.hf_job_id) { | |
| return NextResponse.json({ error: 'Token and job ID required' }, { status: 400 }); | |
| } | |
| const jobNamespaceOverride = jobConfig?.hf_job_namespace; | |
| const jobStatus = await checkHFJobStatus(token, jobConfig.hf_job_id, jobNamespaceOverride); | |
| return NextResponse.json({ status: jobStatus }); | |
| } catch (error: any) { | |
| console.error('Job status check error:', error); | |
| return NextResponse.json({ error: error.message }, { status: 500 }); | |
| } | |
| case 'generateScript': | |
| try { | |
| const uvScript = generateUVScript({ | |
| jobConfig, | |
| datasetRepo, | |
| namespace, | |
| token: token || 'YOUR_HF_TOKEN', | |
| }); | |
| return NextResponse.json({ | |
| script: uvScript, | |
| filename: `train_${jobConfig.config.name.replace(/[^a-zA-Z0-9]/g, '_')}.py` | |
| }); | |
| } catch (error: any) { | |
| return NextResponse.json({ error: error.message }, { status: 500 }); | |
| } | |
| case 'submitJob': | |
| try { | |
| if (!token || !hardware) { | |
| return NextResponse.json({ error: 'Token and hardware required' }, { status: 400 }); | |
| } | |
| // Generate UV script | |
| const uvScript = generateUVScript({ | |
| jobConfig, | |
| datasetRepo, | |
| namespace, | |
| token, | |
| }); | |
| // Write script to temporary file | |
| const scriptPath = path.join(tmpdir(), `train_${Date.now()}.py`); | |
| await writeFile(scriptPath, uvScript); | |
| // Submit HF job using uv run | |
| const namespaceOverride = participateHackathon ? 'lora-training-frenzi' : undefined; | |
| const jobId = await submitHFJobUV( | |
| token, | |
| hardware, | |
| scriptPath, | |
| namespaceOverride | |
| ); | |
| const jobNamespace = namespaceOverride ?? namespace; | |
| return NextResponse.json({ | |
| success: true, | |
| jobId, | |
| jobNamespace, | |
| message: `Job submitted successfully with ID: ${jobId}` | |
| }); | |
| } catch (error: any) { | |
| console.error('Job submission error:', error); | |
| return NextResponse.json({ error: error.message }, { status: 500 }); | |
| } | |
| default: | |
| return NextResponse.json({ error: 'Invalid action' }, { status: 400 }); | |
| } | |
| } catch (error: any) { | |
| console.error('HF Jobs API error:', error); | |
| return NextResponse.json({ error: error.message }, { status: 500 }); | |
| } | |
| } | |
| function generateUVScript({ jobConfig, datasetRepo, namespace, token }: { | |
| jobConfig: any; | |
| datasetRepo: string; | |
| namespace: string; | |
| token: string; | |
| }) { | |
| const config = jobConfig.config; | |
| const process = config.process[0]; | |
| return `# /// script | |
| # dependencies = [ | |
| # "torch>=2.0.0", | |
| # "torchvision", | |
| # "torchaudio", | |
| # "torchao==0.10.0", | |
| # "safetensors", | |
| # "diffusers @ git+https://github.com/huggingface/diffusers", | |
| # "transformers==4.52.4", | |
| # "lycoris-lora==1.8.3", | |
| # "flatten_json", | |
| # "pyyaml", | |
| # "oyaml", | |
| # "tensorboard", | |
| # "kornia", | |
| # "invisible-watermark", | |
| # "einops", | |
| # "accelerate", | |
| # "toml", | |
| # "albumentations==1.4.15", | |
| # "albucore==0.0.16", | |
| # "pydantic", | |
| # "omegaconf", | |
| # "k-diffusion", | |
| # "open_clip_torch", | |
| # "timm", | |
| # "prodigyopt", | |
| # "controlnet_aux==0.0.10", | |
| # "python-dotenv", | |
| # "bitsandbytes", | |
| # "hf_transfer", | |
| # "lpips", | |
| # "pytorch_fid", | |
| # "optimum-quanto==0.2.4", | |
| # "sentencepiece", | |
| # "huggingface_hub", | |
| # "peft", | |
| # "python-slugify", | |
| # "opencv-python-headless", | |
| # "pytorch-wavelets==1.3.0", | |
| # "matplotlib==3.10.1", | |
| # "setuptools==69.5.1", | |
| # "datasets==4.0.0", | |
| # "pyarrow==20.0.0", | |
| # "pillow", | |
| # "ftfy", | |
| # ] | |
| # /// | |
| import os | |
| import sys | |
| import subprocess | |
| import argparse | |
| import re | |
| import oyaml as yaml | |
| from datasets import load_dataset | |
| from huggingface_hub import HfApi, create_repo, upload_folder, snapshot_download | |
| import tempfile | |
| import shutil | |
| import glob | |
| from PIL import Image | |
| def setup_ai_toolkit(): | |
| """Clone and setup ai-toolkit repository""" | |
| repo_dir = "ai-toolkit" | |
| if not os.path.exists(repo_dir): | |
| print("Cloning ai-toolkit repository...") | |
| subprocess.run( | |
| ["git", "clone", "https://github.com/ostris/ai-toolkit.git", repo_dir], | |
| check=True | |
| ) | |
| sys.path.insert(0, os.path.abspath(repo_dir)) | |
| return repo_dir | |
| def find_local_dataset_source(dataset_repo: str): | |
| if not dataset_repo: | |
| return None | |
| repo_stripped = dataset_repo.strip() | |
| candidates = [] | |
| if os.path.isabs(repo_stripped): | |
| candidates.append(repo_stripped) | |
| else: | |
| candidates.append(repo_stripped) | |
| candidates.append(os.path.abspath(repo_stripped)) | |
| normalized = normalize_repo_id(repo_stripped) | |
| if normalized: | |
| candidates.append(os.path.join("/datasets", normalized)) | |
| if repo_stripped.startswith("/datasets/") and repo_stripped not in candidates: | |
| candidates.append(repo_stripped) | |
| seen = set() | |
| for candidate in candidates: | |
| if not candidate or candidate in seen: | |
| continue | |
| seen.add(candidate) | |
| if os.path.exists(candidate): | |
| return candidate | |
| return None | |
| def normalize_repo_id(dataset_repo: str) -> str: | |
| repo_id = dataset_repo.strip() | |
| if repo_id.startswith("/datasets/"): | |
| repo_id = repo_id[len("/datasets/"):] | |
| elif repo_id.startswith("datasets/"): | |
| repo_id = repo_id[len("datasets/"):] | |
| return repo_id.strip("/") | |
| def copy_dataset_files(source_dir: str, local_path: str): | |
| print(f"Collecting data files from {source_dir}") | |
| image_exts = {'.jpg', '.jpeg', '.png', '.webp', '.bmp'} | |
| video_exts = {'.mp4', '.avi', '.mov', '.webm', '.mkv', '.wmv', '.m4v', '.flv'} | |
| copied_images = 0 | |
| copied_videos = 0 | |
| copied_captions = 0 | |
| for root, _, files in os.walk(source_dir): | |
| for file_name in files: | |
| ext = os.path.splitext(file_name)[1].lower() | |
| src_path = os.path.join(root, file_name) | |
| rel_path = os.path.relpath(src_path, source_dir) | |
| dest_path = os.path.join(local_path, rel_path) | |
| dest_dir = os.path.dirname(dest_path) | |
| if dest_dir and not os.path.exists(dest_dir): | |
| os.makedirs(dest_dir, exist_ok=True) | |
| if ext in image_exts: | |
| try: | |
| shutil.copy2(src_path, dest_path) | |
| copied_images += 1 | |
| except Exception as img_error: | |
| print(f"Error copying image {src_path}: {img_error}") | |
| elif ext in video_exts: | |
| try: | |
| shutil.copy2(src_path, dest_path) | |
| copied_videos += 1 | |
| except Exception as vid_error: | |
| print(f"Error copying video {src_path}: {vid_error}") | |
| elif ext == '.txt': | |
| try: | |
| shutil.copy2(src_path, dest_path) | |
| copied_captions += 1 | |
| except Exception as txt_error: | |
| print(f"Error copying text file {src_path}: {txt_error}") | |
| else: | |
| try: | |
| shutil.copy2(src_path, dest_path) | |
| except Exception as other_error: | |
| print(f"Error copying file {src_path}: {other_error}") | |
| total_media = copied_images + copied_videos | |
| print( | |
| f"Prepared {copied_images} images, {copied_videos} videos, and {copied_captions} captions in {local_path}" | |
| ) | |
| return total_media, copied_captions | |
| def download_dataset(dataset_repo: str, local_path: str): | |
| """Download dataset from HF Hub as files""" | |
| print(f"Downloading dataset from {dataset_repo}...") | |
| os.makedirs(local_path, exist_ok=True) | |
| local_source = find_local_dataset_source(dataset_repo) | |
| if local_source: | |
| print(f"Found local dataset at {local_source}") | |
| media_copied, _ = copy_dataset_files(local_source, local_path) | |
| if media_copied > 0: | |
| return | |
| print("Local dataset did not contain media files, falling back to remote download") | |
| repo_id = normalize_repo_id(dataset_repo) | |
| if repo_id: | |
| try: | |
| print(f"Attempting snapshot download for dataset {repo_id}") | |
| temp_repo_path = snapshot_download(repo_id=repo_id, repo_type="dataset") | |
| print(f"Downloaded repo to: {temp_repo_path}") | |
| print(f"Contents: {os.listdir(temp_repo_path)}") | |
| media_copied, _ = copy_dataset_files(temp_repo_path, local_path) | |
| if media_copied > 0: | |
| return | |
| print("Snapshot download did not contain media files, attempting structured dataset load") | |
| except Exception as snapshot_error: | |
| print(f"Snapshot download failed: {snapshot_error}") | |
| if not repo_id: | |
| raise ValueError("Dataset repository ID is required when no local dataset is available") | |
| try: | |
| dataset = load_dataset(repo_id, split="train") | |
| images_saved = 0 | |
| captions_saved = 0 | |
| for i, item in enumerate(dataset): | |
| if "image" in item and item["image"] is not None: | |
| image_path = os.path.join(local_path, f"image_{i:06d}.jpg") | |
| image = item["image"] | |
| if image.mode == 'RGBA': | |
| background = Image.new('RGB', image.size, (255, 255, 255)) | |
| background.paste(image, mask=image.split()[-1]) | |
| image = background | |
| elif image.mode not in ['RGB', 'L']: | |
| image = image.convert('RGB') | |
| image.save(image_path, 'JPEG') | |
| images_saved += 1 | |
| if "text" in item and item["text"] is not None: | |
| caption_path = os.path.join(local_path, f"image_{i:06d}.txt") | |
| with open(caption_path, "w", encoding="utf-8") as f: | |
| f.write(item["text"]) | |
| captions_saved += 1 | |
| if images_saved == 0: | |
| raise ValueError(f"Structured dataset load completed but produced 0 images for {repo_id}") | |
| print(f"Downloaded {images_saved} items to {local_path}") | |
| except Exception as e: | |
| print(f"Failed to load as structured dataset: {e}") | |
| raise | |
| def create_config(dataset_path: str, output_path: str): | |
| """Create training configuration""" | |
| import json | |
| # Load config from JSON string and fix boolean/null values for Python | |
| config_str = """${JSON.stringify(jobConfig, null, 2)}""" | |
| config_str = config_str.replace('true', 'True').replace('false', 'False').replace('null', 'None') | |
| config = eval(config_str) | |
| def resolve_manifest_value(value): | |
| if value is None: | |
| return None | |
| if isinstance(value, list): | |
| resolved_list = [resolve_manifest_value(v) for v in value] | |
| return [v for v in resolved_list if v is not None] | |
| if not isinstance(value, str) or value.strip() == "": | |
| return None | |
| normalized = value.replace("\\\\", "/") | |
| parts = [part for part in normalized.split("/") if part not in ("", ".")] | |
| return os.path.normpath(os.path.join(dataset_path, *parts)) | |
| manifest_path = os.path.join(dataset_path, "manifest.json") | |
| manifest_data = None | |
| if os.path.isfile(manifest_path): | |
| try: | |
| with open(manifest_path, "r", encoding="utf-8") as manifest_file: | |
| manifest_data = json.load(manifest_file) | |
| except Exception as manifest_error: | |
| print(f"Failed to load dataset manifest: {manifest_error}") | |
| manifest_data = None | |
| process_config = config["config"]["process"][0] | |
| datasets_config = process_config.get("datasets", []) | |
| if manifest_data and isinstance(manifest_data, dict) and "datasets" in manifest_data: | |
| manifest_datasets = manifest_data.get("datasets", []) | |
| for idx, dataset_cfg in enumerate(datasets_config): | |
| manifest_entry = manifest_datasets[idx] if idx < len(manifest_datasets) else {} | |
| if isinstance(manifest_entry, dict): | |
| for key, value in manifest_entry.items(): | |
| resolved_value = resolve_manifest_value(value) | |
| if resolved_value is not None and resolved_value != []: | |
| dataset_cfg[key] = resolved_value | |
| if key == "folder_path": | |
| dataset_cfg["dataset_path"] = resolved_value | |
| if "folder_path" not in dataset_cfg or not dataset_cfg["folder_path"]: | |
| dataset_cfg["folder_path"] = dataset_path | |
| dataset_cfg["dataset_path"] = dataset_path | |
| else: | |
| for dataset_cfg in datasets_config: | |
| dataset_cfg["folder_path"] = dataset_path | |
| dataset_cfg["dataset_path"] = dataset_path | |
| samples_config = process_config.get("sample", {}).get("samples", []) | |
| if manifest_data and isinstance(manifest_data, dict): | |
| manifest_samples = manifest_data.get("samples", []) | |
| for sample_entry in manifest_samples: | |
| if not isinstance(sample_entry, dict): | |
| continue | |
| index = sample_entry.get("index") | |
| ctrl_img_rel = sample_entry.get("ctrl_img") | |
| if ( | |
| isinstance(index, int) | |
| and 0 <= index < len(samples_config) | |
| and ctrl_img_rel is not None | |
| ): | |
| resolved_ctrl_img = resolve_manifest_value(ctrl_img_rel) | |
| if resolved_ctrl_img: | |
| samples_config[index]["ctrl_img"] = resolved_ctrl_img | |
| # Update training folder for cloud environment | |
| process_config["training_folder"] = output_path | |
| # Remove sqlite_db_path as it's not needed for cloud training | |
| if "sqlite_db_path" in process_config: | |
| del process_config["sqlite_db_path"] | |
| # Also change trainer type from ui_trainer to standard trainer to avoid UI dependencies | |
| if process_config["type"] == "ui_trainer": | |
| process_config["type"] = "sd_trainer" | |
| return config | |
| def upload_results(output_path: str, model_name: str, namespace: str, token: str, config: dict): | |
| """Upload trained model to HF Hub with README generation and proper file organization""" | |
| import tempfile | |
| import shutil | |
| import glob | |
| from datetime import datetime | |
| from huggingface_hub import create_repo, upload_file, HfApi | |
| from collections import deque | |
| try: | |
| repo_id = f"{namespace}/{model_name}" | |
| # Create repository | |
| create_repo(repo_id=repo_id, token=token, exist_ok=True) | |
| print(f"Uploading model to {repo_id}...") | |
| # Create temporary directory for organized upload | |
| with tempfile.TemporaryDirectory() as temp_upload_dir: | |
| api = HfApi() | |
| # 1. Find and upload model files to root directory | |
| safetensors_files = glob.glob(os.path.join(output_path, "**", "*.safetensors"), recursive=True) | |
| json_files = glob.glob(os.path.join(output_path, "**", "*.json"), recursive=True) | |
| txt_files = glob.glob(os.path.join(output_path, "**", "*.txt"), recursive=True) | |
| uploaded_files = [] | |
| # Upload .safetensors files to root | |
| for file_path in safetensors_files: | |
| filename = os.path.basename(file_path) | |
| print(f"Uploading {filename} to repository root...") | |
| api.upload_file( | |
| path_or_fileobj=file_path, | |
| path_in_repo=filename, | |
| repo_id=repo_id, | |
| token=token | |
| ) | |
| uploaded_files.append(filename) | |
| # Upload relevant JSON config files to root (skip metadata.json and other internal files) | |
| config_files_uploaded = [] | |
| for file_path in json_files: | |
| filename = os.path.basename(file_path) | |
| # Only upload important config files, skip internal metadata | |
| if any(keyword in filename.lower() for keyword in ['config', 'adapter', 'lora', 'model']): | |
| print(f"Uploading {filename} to repository root...") | |
| api.upload_file( | |
| path_or_fileobj=file_path, | |
| path_in_repo=filename, | |
| repo_id=repo_id, | |
| token=token | |
| ) | |
| uploaded_files.append(filename) | |
| config_files_uploaded.append(filename) | |
| def prepare_sample_metadata(samples_directory: str, sample_conf: dict): | |
| if not samples_directory or not os.path.isdir(samples_directory): | |
| return [], [] | |
| allowed_ext = {'.jpg', '.jpeg', '.png', '.webp'} | |
| image_records = [] | |
| for root, _, files in os.walk(samples_directory): | |
| for filename in files: | |
| ext = os.path.splitext(filename)[1].lower() | |
| if ext not in allowed_ext: | |
| continue | |
| abs_path = os.path.join(root, filename) | |
| try: | |
| mtime = os.path.getmtime(abs_path) | |
| except Exception: | |
| mtime = 0 | |
| image_records.append((abs_path, mtime)) | |
| if not image_records: | |
| return [], [] | |
| image_records.sort(key=lambda item: (-item[1], item[0])) | |
| image_queue = deque(image_records) | |
| samples_list = sample_conf.get("samples", []) if sample_conf else [] | |
| if not samples_list: | |
| legacy = sample_conf.get("prompts", []) if sample_conf else [] | |
| samples_list = [{"prompt": prompt} for prompt in legacy if prompt] | |
| curated_samples = [] | |
| for sample in samples_list: | |
| prompt = None | |
| if isinstance(sample, dict): | |
| prompt = sample.get("prompt") | |
| elif isinstance(sample, str): | |
| prompt = sample | |
| if not prompt: | |
| continue | |
| if not image_queue: | |
| break | |
| image_path, _ = image_queue.popleft() | |
| repo_rel_path = f"images/{os.path.basename(image_path)}" | |
| curated_samples.append({ | |
| "prompt": prompt, | |
| "local_path": image_path, | |
| "repo_path": repo_rel_path, | |
| }) | |
| all_files = [record[0] for record in image_records] | |
| return curated_samples, all_files | |
| samples_dir = os.path.join(output_path, "samples") | |
| sample_config = config.get("config", {}).get("process", [{}])[0].get("sample", {}) | |
| curated_samples, sample_files = prepare_sample_metadata(samples_dir, sample_config) | |
| samples_uploaded = [] | |
| if sample_files: | |
| print("Uploading sample images...") | |
| for file_path in sample_files: | |
| if not os.path.isfile(file_path): | |
| continue | |
| filename = os.path.basename(file_path) | |
| repo_path = f"images/{filename}" | |
| api.upload_file( | |
| path_or_fileobj=file_path, | |
| path_in_repo=repo_path, | |
| repo_id=repo_id, | |
| token=token | |
| ) | |
| samples_uploaded.append(repo_path) | |
| # 3. Generate and upload README.md | |
| readme_content = generate_model_card_readme( | |
| repo_id=repo_id, | |
| config=config, | |
| model_name=model_name, | |
| curated_samples=curated_samples, | |
| uploaded_files=uploaded_files | |
| ) | |
| # Create README.md file and upload to root | |
| readme_path = os.path.join(temp_upload_dir, "README.md") | |
| with open(readme_path, "w", encoding="utf-8") as f: | |
| f.write(readme_content) | |
| print("Uploading README.md to repository root...") | |
| api.upload_file( | |
| path_or_fileobj=readme_path, | |
| path_in_repo="README.md", | |
| repo_id=repo_id, | |
| token=token | |
| ) | |
| print(f"Model uploaded successfully to https://huggingface.co/{repo_id}") | |
| print(f"Files uploaded: {len(uploaded_files)} model files, {len(samples_uploaded)} samples, README.md") | |
| except Exception as e: | |
| print(f"Failed to upload model: {e}") | |
| raise e | |
| def generate_model_card_readme(repo_id: str, config: dict, model_name: str, curated_samples: list = None, uploaded_files: list = None) -> str: | |
| """Generate README.md content for the model card based on AI Toolkit's implementation""" | |
| import yaml | |
| import os | |
| try: | |
| # Extract configuration details | |
| process_config = config.get("config", {}).get("process", [{}])[0] | |
| model_config = process_config.get("model", {}) | |
| train_config = process_config.get("train", {}) | |
| sample_config = process_config.get("sample", {}) | |
| # Gather model info | |
| base_model = model_config.get("name_or_path", "unknown") | |
| trigger_word = process_config.get("trigger_word") | |
| arch = model_config.get("arch", "") | |
| # Determine license based on base model | |
| if "FLUX.1-schnell" in base_model: | |
| license_info = {"license": "apache-2.0"} | |
| elif "FLUX.1-dev" in base_model: | |
| license_info = { | |
| "license": "other", | |
| "license_name": "flux-1-dev-non-commercial-license", | |
| "license_link": "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md" | |
| } | |
| else: | |
| license_info = {"license": "creativeml-openrail-m"} | |
| # Generate tags based on model architecture group | |
| tags = [] | |
| lower_arch = (arch or "").lower() | |
| lower_model_name = (model_config.get("name_or_path", "") or "").lower() | |
| base_model_lower = (base_model or "").lower() | |
| # Define model groups based on the frontend options.ts structure | |
| # Group: 'image' -> text-to-image | |
| # Group: 'instruction' -> image-to-image | |
| # Group: 'video' -> check for i2v in arch name for image-to-video vs text-to-video | |
| image_arches = { | |
| 'flux', 'flex1', 'flex2', 'chroma', 'lumina2', | |
| 'qwen_image', 'hidream', 'sdxl', 'sd15', 'omnigen2' | |
| } | |
| instruction_arches = { | |
| 'flux_kontext', 'qwen_image_edit', 'qwen_image_edit_plus', 'hidream_e1' | |
| } | |
| video_arches = { | |
| 'wan21:1b', 'wan21_i2v:14b480p', 'wan21_i2v:14b', 'wan21:14b', | |
| 'wan22_14b:t2v', 'wan22_14b_i2v', 'wan22_5b' | |
| } | |
| # Determine the task type based on architecture group | |
| if lower_arch in instruction_arches: | |
| tags.append("image-to-image") | |
| elif lower_arch in video_arches: | |
| # Video models: check if i2v is in the architecture name | |
| is_i2v = 'i2v' in lower_arch | |
| tags.append("image-to-video" if is_i2v else "text-to-video") | |
| elif lower_arch in image_arches: | |
| tags.append("text-to-image") | |
| else: | |
| # Fallback to text-to-image for unknown architectures | |
| tags.append("text-to-image") | |
| if "xl" in lower_arch: | |
| tags.append("stable-diffusion-xl") | |
| if "flux" in lower_arch: | |
| tags.append("flux") | |
| if "lumina" in lower_arch: | |
| tags.append("lumina2") | |
| if "sd3" in lower_arch or "v3" in lower_arch: | |
| tags.append("sd3") | |
| # Add LoRA-specific tags | |
| tags.extend(["lora", "diffusers", "template:sd-lora", "ai-toolkit"]) | |
| # Generate widgets and gallery section from sample images | |
| curated_samples = curated_samples or [] | |
| widgets = [] | |
| prompt_bullets = [] | |
| for sample in curated_samples: | |
| prompt_text = str(sample.get("prompt", "")).strip() | |
| repo_path = sample.get("repo_path") | |
| if not prompt_text or not repo_path: | |
| continue | |
| widgets.append({ | |
| "text": prompt_text, | |
| "output": {"url": repo_path} | |
| }) | |
| prompt_bullets.append(f"- {prompt_text}") | |
| gallery_section = "" | |
| if prompt_bullets: | |
| gallery_section = "<Gallery />\\n\\n" + "### Prompts\\n\\n" + "\\n".join(prompt_bullets) + "\\n\\n" | |
| # Determine torch dtype based on model | |
| dtype = "torch.bfloat16" | |
| try: | |
| arch_lower = arch.lower() | |
| except AttributeError: | |
| arch_lower = "" | |
| if "sd15" in arch_lower or "sdxl" in arch_lower: | |
| dtype = "torch.float16" | |
| # Find the main safetensors file for usage example | |
| main_safetensors = f"{model_name}.safetensors" | |
| if uploaded_files: | |
| safetensors_files = [f for f in uploaded_files if f.endswith('.safetensors')] | |
| if safetensors_files: | |
| preferred_name = f"{model_name}.safetensors" | |
| exact_match = next( | |
| ( | |
| f | |
| for f in safetensors_files | |
| if os.path.basename(f) == preferred_name or f == preferred_name | |
| ), | |
| None, | |
| ) | |
| if exact_match: | |
| main_safetensors = exact_match | |
| else: | |
| def extract_step(filename: str) -> int: | |
| match = re.search(r"_(\d+)\.safetensors$", os.path.basename(filename)) | |
| return int(match.group(1)) if match else -1 | |
| safetensors_files.sort( | |
| key=lambda f: (extract_step(f), f), | |
| reverse=True, | |
| ) | |
| main_safetensors = safetensors_files[0] | |
| # Construct YAML frontmatter | |
| frontmatter = { | |
| "tags": tags, | |
| "base_model": base_model, | |
| **license_info | |
| } | |
| if widgets: | |
| frontmatter["widget"] = widgets | |
| inference_params = {} | |
| sample_width = sample_config.get("width") if isinstance(sample_config, dict) else None | |
| sample_height = sample_config.get("height") if isinstance(sample_config, dict) else None | |
| if sample_width: | |
| inference_params["width"] = sample_width | |
| if sample_height: | |
| inference_params["height"] = sample_height | |
| if inference_params: | |
| frontmatter["inference"] = {"parameters": inference_params} | |
| if trigger_word: | |
| frontmatter["instance_prompt"] = trigger_word | |
| # Get first prompt for usage example | |
| usage_prompt = trigger_word or "a beautiful landscape" | |
| if widgets: | |
| usage_prompt = widgets[0]["text"] | |
| elif trigger_word: | |
| usage_prompt = trigger_word | |
| # Construct README content | |
| trigger_section = f"You should use \`{trigger_word}\` to trigger the image generation." if trigger_word else "No trigger words defined." | |
| # Build YAML frontmatter string | |
| frontmatter_yaml = yaml.dump(frontmatter, default_flow_style=False, allow_unicode=True, sort_keys=False).strip() | |
| readme_content = f"""--- | |
| {frontmatter_yaml} | |
| --- | |
| # {model_name} | |
| Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) | |
| {gallery_section} | |
| ## Trigger words | |
| {trigger_section} | |
| ## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, etc. | |
| Weights for this model are available in Safetensors format. | |
| [Download]({repo_id}/tree/main) them in the Files & versions tab. | |
| ## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) | |
| \`\`\`py | |
| from diffusers import AutoPipelineForText2Image | |
| import torch | |
| pipeline = AutoPipelineForText2Image.from_pretrained('{base_model}', torch_dtype={dtype}).to('cuda') | |
| pipeline.load_lora_weights('{repo_id}', weight_name='{main_safetensors}') | |
| image = pipeline('{usage_prompt}').images[0] | |
| image.save("my_image.png") | |
| \`\`\` | |
| For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) | |
| """ | |
| return readme_content | |
| except Exception as e: | |
| print(f"Error generating README: {e}") | |
| # Fallback simple README | |
| return f"""# {model_name} | |
| Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) | |
| ## Download model | |
| Weights for this model are available in Safetensors format. | |
| [Download]({repo_id}/tree/main) them in the Files & versions tab. | |
| """ | |
| def main(): | |
| # Setup environment - token comes from HF Jobs secrets | |
| if "HF_TOKEN" not in os.environ: | |
| raise ValueError("HF_TOKEN environment variable not set") | |
| # Install system dependencies for headless operation | |
| print("Installing system dependencies...") | |
| try: | |
| subprocess.run(["apt-get", "update"], check=True, capture_output=True) | |
| subprocess.run([ | |
| "apt-get", "install", "-y", | |
| "libgl1-mesa-glx", | |
| "libglib2.0-0", | |
| "libsm6", | |
| "libxext6", | |
| "libxrender-dev", | |
| "libgomp1", | |
| "ffmpeg" | |
| ], check=True, capture_output=True) | |
| print("System dependencies installed successfully") | |
| except subprocess.CalledProcessError as e: | |
| print(f"Failed to install system dependencies: {e}") | |
| print("Continuing without system dependencies...") | |
| # Setup ai-toolkit | |
| toolkit_dir = setup_ai_toolkit() | |
| # Create temporary directories | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| dataset_path = os.path.join(temp_dir, "dataset") | |
| output_path = os.path.join(temp_dir, "output") | |
| # Download dataset | |
| download_dataset("${datasetRepo}", dataset_path) | |
| # Create config | |
| config = create_config(dataset_path, output_path) | |
| config_path = os.path.join(temp_dir, "config.yaml") | |
| with open(config_path, "w") as f: | |
| yaml.dump(config, f, default_flow_style=False) | |
| # Run training | |
| print("Starting training...") | |
| os.chdir(toolkit_dir) | |
| subprocess.run([ | |
| sys.executable, "run.py", | |
| config_path | |
| ], check=True) | |
| print("Training completed!") | |
| # Upload results | |
| model_name = f"${jobConfig.config.name}-lora" | |
| upload_results(output_path, model_name, "${namespace}", os.environ["HF_TOKEN"], config) | |
| if __name__ == "__main__": | |
| main() | |
| `; | |
| } | |
| async function submitHFJobUV(token: string, hardware: string, scriptPath: string, namespaceOverride?: string): Promise<string> { | |
| return new Promise((resolve, reject) => { | |
| // Ensure token is available | |
| if (!token) { | |
| reject(new Error('HF_TOKEN is required')); | |
| return; | |
| } | |
| console.log('Setting up environment with HF_TOKEN for job submission'); | |
| const namespaceArgs = namespaceOverride ? ` --namespace ${namespaceOverride}` : ''; | |
| console.log(`Command: hf jobs uv run --flavor ${hardware} --timeout 5h --secrets HF_TOKEN --detach${namespaceArgs} ${scriptPath}`); | |
| // Use hf jobs uv run command with timeout and detach to get job ID | |
| const args = [ | |
| 'jobs', 'uv', 'run', | |
| '--flavor', hardware, | |
| '--timeout', '5h', | |
| '--secrets', 'HF_TOKEN', | |
| '--detach' | |
| ]; | |
| if (namespaceOverride) { | |
| args.push('--namespace', namespaceOverride); | |
| } | |
| args.push(scriptPath); | |
| const childProcess = spawn('hf', args, { | |
| env: { | |
| ...process.env, | |
| HF_TOKEN: token | |
| } | |
| }); | |
| let output = ''; | |
| let error = ''; | |
| childProcess.stdout.on('data', (data) => { | |
| const text = data.toString(); | |
| output += text; | |
| console.log('HF Jobs stdout:', text); | |
| }); | |
| childProcess.stderr.on('data', (data) => { | |
| const text = data.toString(); | |
| error += text; | |
| console.log('HF Jobs stderr:', text); | |
| }); | |
| childProcess.on('close', (code) => { | |
| console.log('HF Jobs process closed with code:', code); | |
| console.log('Full output:', output); | |
| console.log('Full error:', error); | |
| if (code === 0) { | |
| // With --detach flag, the output should be just the job ID | |
| const fullText = (output + ' ' + error).trim(); | |
| // Updated patterns to handle variable-length hex job IDs (16-24+ characters) | |
| const jobIdPatterns = [ | |
| /Job started with ID:\s*([a-f0-9]{16,})/i, // "Job started with ID: 68b26b73767540db9fc726ac" | |
| /job\s+([a-f0-9]{16,})/i, // "job 68b26b73767540db9fc726ac" | |
| /Job ID:\s*([a-f0-9]{16,})/i, // "Job ID: 68b26b73767540db9fc726ac" | |
| /created\s+job\s+([a-f0-9]{16,})/i, // "created job 68b26b73767540db9fc726ac" | |
| /submitted.*?job\s+([a-f0-9]{16,})/i, // "submitted ... job 68b26b73767540db9fc726ac" | |
| /https:\/\/huggingface\.co\/jobs\/[^\/]+\/([a-f0-9]{16,})/i, // URL pattern | |
| /([a-f0-9]{20,})/i, // Fallback: any 20+ char hex string | |
| ]; | |
| let jobId = 'unknown'; | |
| for (const pattern of jobIdPatterns) { | |
| const match = fullText.match(pattern); | |
| if (match && match[1] && match[1] !== 'started') { | |
| jobId = match[1]; | |
| console.log(`Extracted job ID using pattern: ${pattern.toString()} -> ${jobId}`); | |
| break; | |
| } | |
| } | |
| resolve(jobId); | |
| } else { | |
| reject(new Error(error || output || 'Failed to submit job')); | |
| } | |
| }); | |
| childProcess.on('error', (err) => { | |
| console.error('HF Jobs process error:', err); | |
| reject(new Error(`Process error: ${err.message}`)); | |
| }); | |
| }); | |
| } | |
| async function checkHFJobStatus(token: string, jobId: string, jobNamespace?: string): Promise<any> { | |
| return new Promise((resolve, reject) => { | |
| console.log(`Checking HF Job status for: ${jobId}`); | |
| const args = ['jobs', 'inspect']; | |
| if (jobNamespace) { | |
| console.log(`Using namespace override for status check: ${jobNamespace}`); | |
| args.push('--namespace', jobNamespace); | |
| } | |
| args.push(jobId); | |
| const childProcess = spawn('hf', args, { | |
| env: { | |
| ...process.env, | |
| HF_TOKEN: token | |
| } | |
| }); | |
| let output = ''; | |
| let error = ''; | |
| childProcess.stdout.on('data', (data) => { | |
| const text = data.toString(); | |
| output += text; | |
| }); | |
| childProcess.stderr.on('data', (data) => { | |
| const text = data.toString(); | |
| error += text; | |
| }); | |
| childProcess.on('close', (code) => { | |
| if (code === 0) { | |
| try { | |
| // Parse the JSON output from hf jobs inspect | |
| const jobInfo = JSON.parse(output); | |
| if (Array.isArray(jobInfo) && jobInfo.length > 0) { | |
| const job = jobInfo[0]; | |
| resolve({ | |
| id: job.id, | |
| status: job.status?.stage || 'UNKNOWN', | |
| message: job.status?.message, | |
| created_at: job.created_at, | |
| flavor: job.flavor, | |
| url: job.url, | |
| }); | |
| } else { | |
| reject(new Error('Invalid job info response')); | |
| } | |
| } catch (parseError: any) { | |
| console.error('Failed to parse job status:', parseError, output); | |
| reject(new Error('Failed to parse job status')); | |
| } | |
| } else { | |
| reject(new Error(error || output || 'Failed to check job status')); | |
| } | |
| }); | |
| childProcess.on('error', (err) => { | |
| console.error('HF Jobs inspect process error:', err); | |
| reject(new Error(`Process error: ${err.message}`)); | |
| }); | |
| }); | |
| } | |
| async function checkHFJobsCapacity(token: string): Promise<any> { | |
| try { | |
| console.log('Checking HF Jobs capacity for namespace: lora-training-frenzi via API'); | |
| // Use HuggingFace API directly instead of CLI to avoid TTY issues | |
| const response = await fetch('https://huggingface.co/api/jobs/lora-training-frenzi', { | |
| headers: { | |
| 'Authorization': `Bearer ${token}`, | |
| }, | |
| }); | |
| if (!response.ok) { | |
| throw new Error(`API request failed: ${response.status} ${response.statusText}`); | |
| } | |
| const jobs = await response.json(); | |
| console.log(`Fetched ${jobs.length} total jobs from API`); | |
| // Count jobs with status RUNNING | |
| let runningCount = 0; | |
| for (const job of jobs) { | |
| const status = job.status?.stage || job.status; | |
| if (status === 'RUNNING') { | |
| runningCount++; | |
| } | |
| } | |
| const atCapacity = runningCount >= 32; | |
| console.log(`\n=== FINAL COUNT ===`); | |
| console.log(`Found ${runningCount} RUNNING jobs. At capacity: ${atCapacity}`); | |
| console.log(`==================\n`); | |
| return { | |
| runningJobs: runningCount, | |
| atCapacity, | |
| capacityLimit: 32, | |
| }; | |
| } catch (error: any) { | |
| console.error('Failed to check capacity via API:', error); | |
| throw new Error(`Failed to check capacity: ${error.message}`); | |
| } | |
| } | |