apolinario's picture
Attempt to use the API
e306cd2
import { NextRequest, NextResponse } from 'next/server';
import { spawn } from 'child_process';
import { writeFile, readFile, unlink } from 'fs/promises';
import path from 'path';
import { tmpdir } from 'os';
export async function POST(request: NextRequest) {
try {
const body = await request.json();
const { action, token, hardware, namespace, jobConfig, datasetRepo, participateHackathon } = body;
switch (action) {
case 'checkCapacity':
try {
if (!token) {
return NextResponse.json({ error: 'Token required' }, { status: 400 });
}
const capacityStatus = await checkHFJobsCapacity(token);
return NextResponse.json(capacityStatus);
} catch (error: any) {
console.error('Capacity check error:', error);
return NextResponse.json({ error: error.message }, { status: 500 });
}
case 'checkStatus':
try {
if (!token || !jobConfig?.hf_job_id) {
return NextResponse.json({ error: 'Token and job ID required' }, { status: 400 });
}
const jobNamespaceOverride = jobConfig?.hf_job_namespace;
const jobStatus = await checkHFJobStatus(token, jobConfig.hf_job_id, jobNamespaceOverride);
return NextResponse.json({ status: jobStatus });
} catch (error: any) {
console.error('Job status check error:', error);
return NextResponse.json({ error: error.message }, { status: 500 });
}
case 'generateScript':
try {
const uvScript = generateUVScript({
jobConfig,
datasetRepo,
namespace,
token: token || 'YOUR_HF_TOKEN',
});
return NextResponse.json({
script: uvScript,
filename: `train_${jobConfig.config.name.replace(/[^a-zA-Z0-9]/g, '_')}.py`
});
} catch (error: any) {
return NextResponse.json({ error: error.message }, { status: 500 });
}
case 'submitJob':
try {
if (!token || !hardware) {
return NextResponse.json({ error: 'Token and hardware required' }, { status: 400 });
}
// Generate UV script
const uvScript = generateUVScript({
jobConfig,
datasetRepo,
namespace,
token,
});
// Write script to temporary file
const scriptPath = path.join(tmpdir(), `train_${Date.now()}.py`);
await writeFile(scriptPath, uvScript);
// Submit HF job using uv run
const namespaceOverride = participateHackathon ? 'lora-training-frenzi' : undefined;
const jobId = await submitHFJobUV(
token,
hardware,
scriptPath,
namespaceOverride
);
const jobNamespace = namespaceOverride ?? namespace;
return NextResponse.json({
success: true,
jobId,
jobNamespace,
message: `Job submitted successfully with ID: ${jobId}`
});
} catch (error: any) {
console.error('Job submission error:', error);
return NextResponse.json({ error: error.message }, { status: 500 });
}
default:
return NextResponse.json({ error: 'Invalid action' }, { status: 400 });
}
} catch (error: any) {
console.error('HF Jobs API error:', error);
return NextResponse.json({ error: error.message }, { status: 500 });
}
}
function generateUVScript({ jobConfig, datasetRepo, namespace, token }: {
jobConfig: any;
datasetRepo: string;
namespace: string;
token: string;
}) {
const config = jobConfig.config;
const process = config.process[0];
return `# /// script
# dependencies = [
# "torch>=2.0.0",
# "torchvision",
# "torchaudio",
# "torchao==0.10.0",
# "safetensors",
# "diffusers @ git+https://github.com/huggingface/diffusers",
# "transformers==4.52.4",
# "lycoris-lora==1.8.3",
# "flatten_json",
# "pyyaml",
# "oyaml",
# "tensorboard",
# "kornia",
# "invisible-watermark",
# "einops",
# "accelerate",
# "toml",
# "albumentations==1.4.15",
# "albucore==0.0.16",
# "pydantic",
# "omegaconf",
# "k-diffusion",
# "open_clip_torch",
# "timm",
# "prodigyopt",
# "controlnet_aux==0.0.10",
# "python-dotenv",
# "bitsandbytes",
# "hf_transfer",
# "lpips",
# "pytorch_fid",
# "optimum-quanto==0.2.4",
# "sentencepiece",
# "huggingface_hub",
# "peft",
# "python-slugify",
# "opencv-python-headless",
# "pytorch-wavelets==1.3.0",
# "matplotlib==3.10.1",
# "setuptools==69.5.1",
# "datasets==4.0.0",
# "pyarrow==20.0.0",
# "pillow",
# "ftfy",
# ]
# ///
import os
import sys
import subprocess
import argparse
import re
import oyaml as yaml
from datasets import load_dataset
from huggingface_hub import HfApi, create_repo, upload_folder, snapshot_download
import tempfile
import shutil
import glob
from PIL import Image
def setup_ai_toolkit():
"""Clone and setup ai-toolkit repository"""
repo_dir = "ai-toolkit"
if not os.path.exists(repo_dir):
print("Cloning ai-toolkit repository...")
subprocess.run(
["git", "clone", "https://github.com/ostris/ai-toolkit.git", repo_dir],
check=True
)
sys.path.insert(0, os.path.abspath(repo_dir))
return repo_dir
def find_local_dataset_source(dataset_repo: str):
if not dataset_repo:
return None
repo_stripped = dataset_repo.strip()
candidates = []
if os.path.isabs(repo_stripped):
candidates.append(repo_stripped)
else:
candidates.append(repo_stripped)
candidates.append(os.path.abspath(repo_stripped))
normalized = normalize_repo_id(repo_stripped)
if normalized:
candidates.append(os.path.join("/datasets", normalized))
if repo_stripped.startswith("/datasets/") and repo_stripped not in candidates:
candidates.append(repo_stripped)
seen = set()
for candidate in candidates:
if not candidate or candidate in seen:
continue
seen.add(candidate)
if os.path.exists(candidate):
return candidate
return None
def normalize_repo_id(dataset_repo: str) -> str:
repo_id = dataset_repo.strip()
if repo_id.startswith("/datasets/"):
repo_id = repo_id[len("/datasets/"):]
elif repo_id.startswith("datasets/"):
repo_id = repo_id[len("datasets/"):]
return repo_id.strip("/")
def copy_dataset_files(source_dir: str, local_path: str):
print(f"Collecting data files from {source_dir}")
image_exts = {'.jpg', '.jpeg', '.png', '.webp', '.bmp'}
video_exts = {'.mp4', '.avi', '.mov', '.webm', '.mkv', '.wmv', '.m4v', '.flv'}
copied_images = 0
copied_videos = 0
copied_captions = 0
for root, _, files in os.walk(source_dir):
for file_name in files:
ext = os.path.splitext(file_name)[1].lower()
src_path = os.path.join(root, file_name)
rel_path = os.path.relpath(src_path, source_dir)
dest_path = os.path.join(local_path, rel_path)
dest_dir = os.path.dirname(dest_path)
if dest_dir and not os.path.exists(dest_dir):
os.makedirs(dest_dir, exist_ok=True)
if ext in image_exts:
try:
shutil.copy2(src_path, dest_path)
copied_images += 1
except Exception as img_error:
print(f"Error copying image {src_path}: {img_error}")
elif ext in video_exts:
try:
shutil.copy2(src_path, dest_path)
copied_videos += 1
except Exception as vid_error:
print(f"Error copying video {src_path}: {vid_error}")
elif ext == '.txt':
try:
shutil.copy2(src_path, dest_path)
copied_captions += 1
except Exception as txt_error:
print(f"Error copying text file {src_path}: {txt_error}")
else:
try:
shutil.copy2(src_path, dest_path)
except Exception as other_error:
print(f"Error copying file {src_path}: {other_error}")
total_media = copied_images + copied_videos
print(
f"Prepared {copied_images} images, {copied_videos} videos, and {copied_captions} captions in {local_path}"
)
return total_media, copied_captions
def download_dataset(dataset_repo: str, local_path: str):
"""Download dataset from HF Hub as files"""
print(f"Downloading dataset from {dataset_repo}...")
os.makedirs(local_path, exist_ok=True)
local_source = find_local_dataset_source(dataset_repo)
if local_source:
print(f"Found local dataset at {local_source}")
media_copied, _ = copy_dataset_files(local_source, local_path)
if media_copied > 0:
return
print("Local dataset did not contain media files, falling back to remote download")
repo_id = normalize_repo_id(dataset_repo)
if repo_id:
try:
print(f"Attempting snapshot download for dataset {repo_id}")
temp_repo_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
print(f"Downloaded repo to: {temp_repo_path}")
print(f"Contents: {os.listdir(temp_repo_path)}")
media_copied, _ = copy_dataset_files(temp_repo_path, local_path)
if media_copied > 0:
return
print("Snapshot download did not contain media files, attempting structured dataset load")
except Exception as snapshot_error:
print(f"Snapshot download failed: {snapshot_error}")
if not repo_id:
raise ValueError("Dataset repository ID is required when no local dataset is available")
try:
dataset = load_dataset(repo_id, split="train")
images_saved = 0
captions_saved = 0
for i, item in enumerate(dataset):
if "image" in item and item["image"] is not None:
image_path = os.path.join(local_path, f"image_{i:06d}.jpg")
image = item["image"]
if image.mode == 'RGBA':
background = Image.new('RGB', image.size, (255, 255, 255))
background.paste(image, mask=image.split()[-1])
image = background
elif image.mode not in ['RGB', 'L']:
image = image.convert('RGB')
image.save(image_path, 'JPEG')
images_saved += 1
if "text" in item and item["text"] is not None:
caption_path = os.path.join(local_path, f"image_{i:06d}.txt")
with open(caption_path, "w", encoding="utf-8") as f:
f.write(item["text"])
captions_saved += 1
if images_saved == 0:
raise ValueError(f"Structured dataset load completed but produced 0 images for {repo_id}")
print(f"Downloaded {images_saved} items to {local_path}")
except Exception as e:
print(f"Failed to load as structured dataset: {e}")
raise
def create_config(dataset_path: str, output_path: str):
"""Create training configuration"""
import json
# Load config from JSON string and fix boolean/null values for Python
config_str = """${JSON.stringify(jobConfig, null, 2)}"""
config_str = config_str.replace('true', 'True').replace('false', 'False').replace('null', 'None')
config = eval(config_str)
def resolve_manifest_value(value):
if value is None:
return None
if isinstance(value, list):
resolved_list = [resolve_manifest_value(v) for v in value]
return [v for v in resolved_list if v is not None]
if not isinstance(value, str) or value.strip() == "":
return None
normalized = value.replace("\\\\", "/")
parts = [part for part in normalized.split("/") if part not in ("", ".")]
return os.path.normpath(os.path.join(dataset_path, *parts))
manifest_path = os.path.join(dataset_path, "manifest.json")
manifest_data = None
if os.path.isfile(manifest_path):
try:
with open(manifest_path, "r", encoding="utf-8") as manifest_file:
manifest_data = json.load(manifest_file)
except Exception as manifest_error:
print(f"Failed to load dataset manifest: {manifest_error}")
manifest_data = None
process_config = config["config"]["process"][0]
datasets_config = process_config.get("datasets", [])
if manifest_data and isinstance(manifest_data, dict) and "datasets" in manifest_data:
manifest_datasets = manifest_data.get("datasets", [])
for idx, dataset_cfg in enumerate(datasets_config):
manifest_entry = manifest_datasets[idx] if idx < len(manifest_datasets) else {}
if isinstance(manifest_entry, dict):
for key, value in manifest_entry.items():
resolved_value = resolve_manifest_value(value)
if resolved_value is not None and resolved_value != []:
dataset_cfg[key] = resolved_value
if key == "folder_path":
dataset_cfg["dataset_path"] = resolved_value
if "folder_path" not in dataset_cfg or not dataset_cfg["folder_path"]:
dataset_cfg["folder_path"] = dataset_path
dataset_cfg["dataset_path"] = dataset_path
else:
for dataset_cfg in datasets_config:
dataset_cfg["folder_path"] = dataset_path
dataset_cfg["dataset_path"] = dataset_path
samples_config = process_config.get("sample", {}).get("samples", [])
if manifest_data and isinstance(manifest_data, dict):
manifest_samples = manifest_data.get("samples", [])
for sample_entry in manifest_samples:
if not isinstance(sample_entry, dict):
continue
index = sample_entry.get("index")
ctrl_img_rel = sample_entry.get("ctrl_img")
if (
isinstance(index, int)
and 0 <= index < len(samples_config)
and ctrl_img_rel is not None
):
resolved_ctrl_img = resolve_manifest_value(ctrl_img_rel)
if resolved_ctrl_img:
samples_config[index]["ctrl_img"] = resolved_ctrl_img
# Update training folder for cloud environment
process_config["training_folder"] = output_path
# Remove sqlite_db_path as it's not needed for cloud training
if "sqlite_db_path" in process_config:
del process_config["sqlite_db_path"]
# Also change trainer type from ui_trainer to standard trainer to avoid UI dependencies
if process_config["type"] == "ui_trainer":
process_config["type"] = "sd_trainer"
return config
def upload_results(output_path: str, model_name: str, namespace: str, token: str, config: dict):
"""Upload trained model to HF Hub with README generation and proper file organization"""
import tempfile
import shutil
import glob
from datetime import datetime
from huggingface_hub import create_repo, upload_file, HfApi
from collections import deque
try:
repo_id = f"{namespace}/{model_name}"
# Create repository
create_repo(repo_id=repo_id, token=token, exist_ok=True)
print(f"Uploading model to {repo_id}...")
# Create temporary directory for organized upload
with tempfile.TemporaryDirectory() as temp_upload_dir:
api = HfApi()
# 1. Find and upload model files to root directory
safetensors_files = glob.glob(os.path.join(output_path, "**", "*.safetensors"), recursive=True)
json_files = glob.glob(os.path.join(output_path, "**", "*.json"), recursive=True)
txt_files = glob.glob(os.path.join(output_path, "**", "*.txt"), recursive=True)
uploaded_files = []
# Upload .safetensors files to root
for file_path in safetensors_files:
filename = os.path.basename(file_path)
print(f"Uploading {filename} to repository root...")
api.upload_file(
path_or_fileobj=file_path,
path_in_repo=filename,
repo_id=repo_id,
token=token
)
uploaded_files.append(filename)
# Upload relevant JSON config files to root (skip metadata.json and other internal files)
config_files_uploaded = []
for file_path in json_files:
filename = os.path.basename(file_path)
# Only upload important config files, skip internal metadata
if any(keyword in filename.lower() for keyword in ['config', 'adapter', 'lora', 'model']):
print(f"Uploading {filename} to repository root...")
api.upload_file(
path_or_fileobj=file_path,
path_in_repo=filename,
repo_id=repo_id,
token=token
)
uploaded_files.append(filename)
config_files_uploaded.append(filename)
def prepare_sample_metadata(samples_directory: str, sample_conf: dict):
if not samples_directory or not os.path.isdir(samples_directory):
return [], []
allowed_ext = {'.jpg', '.jpeg', '.png', '.webp'}
image_records = []
for root, _, files in os.walk(samples_directory):
for filename in files:
ext = os.path.splitext(filename)[1].lower()
if ext not in allowed_ext:
continue
abs_path = os.path.join(root, filename)
try:
mtime = os.path.getmtime(abs_path)
except Exception:
mtime = 0
image_records.append((abs_path, mtime))
if not image_records:
return [], []
image_records.sort(key=lambda item: (-item[1], item[0]))
image_queue = deque(image_records)
samples_list = sample_conf.get("samples", []) if sample_conf else []
if not samples_list:
legacy = sample_conf.get("prompts", []) if sample_conf else []
samples_list = [{"prompt": prompt} for prompt in legacy if prompt]
curated_samples = []
for sample in samples_list:
prompt = None
if isinstance(sample, dict):
prompt = sample.get("prompt")
elif isinstance(sample, str):
prompt = sample
if not prompt:
continue
if not image_queue:
break
image_path, _ = image_queue.popleft()
repo_rel_path = f"images/{os.path.basename(image_path)}"
curated_samples.append({
"prompt": prompt,
"local_path": image_path,
"repo_path": repo_rel_path,
})
all_files = [record[0] for record in image_records]
return curated_samples, all_files
samples_dir = os.path.join(output_path, "samples")
sample_config = config.get("config", {}).get("process", [{}])[0].get("sample", {})
curated_samples, sample_files = prepare_sample_metadata(samples_dir, sample_config)
samples_uploaded = []
if sample_files:
print("Uploading sample images...")
for file_path in sample_files:
if not os.path.isfile(file_path):
continue
filename = os.path.basename(file_path)
repo_path = f"images/{filename}"
api.upload_file(
path_or_fileobj=file_path,
path_in_repo=repo_path,
repo_id=repo_id,
token=token
)
samples_uploaded.append(repo_path)
# 3. Generate and upload README.md
readme_content = generate_model_card_readme(
repo_id=repo_id,
config=config,
model_name=model_name,
curated_samples=curated_samples,
uploaded_files=uploaded_files
)
# Create README.md file and upload to root
readme_path = os.path.join(temp_upload_dir, "README.md")
with open(readme_path, "w", encoding="utf-8") as f:
f.write(readme_content)
print("Uploading README.md to repository root...")
api.upload_file(
path_or_fileobj=readme_path,
path_in_repo="README.md",
repo_id=repo_id,
token=token
)
print(f"Model uploaded successfully to https://huggingface.co/{repo_id}")
print(f"Files uploaded: {len(uploaded_files)} model files, {len(samples_uploaded)} samples, README.md")
except Exception as e:
print(f"Failed to upload model: {e}")
raise e
def generate_model_card_readme(repo_id: str, config: dict, model_name: str, curated_samples: list = None, uploaded_files: list = None) -> str:
"""Generate README.md content for the model card based on AI Toolkit's implementation"""
import yaml
import os
try:
# Extract configuration details
process_config = config.get("config", {}).get("process", [{}])[0]
model_config = process_config.get("model", {})
train_config = process_config.get("train", {})
sample_config = process_config.get("sample", {})
# Gather model info
base_model = model_config.get("name_or_path", "unknown")
trigger_word = process_config.get("trigger_word")
arch = model_config.get("arch", "")
# Determine license based on base model
if "FLUX.1-schnell" in base_model:
license_info = {"license": "apache-2.0"}
elif "FLUX.1-dev" in base_model:
license_info = {
"license": "other",
"license_name": "flux-1-dev-non-commercial-license",
"license_link": "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md"
}
else:
license_info = {"license": "creativeml-openrail-m"}
# Generate tags based on model architecture group
tags = []
lower_arch = (arch or "").lower()
lower_model_name = (model_config.get("name_or_path", "") or "").lower()
base_model_lower = (base_model or "").lower()
# Define model groups based on the frontend options.ts structure
# Group: 'image' -> text-to-image
# Group: 'instruction' -> image-to-image
# Group: 'video' -> check for i2v in arch name for image-to-video vs text-to-video
image_arches = {
'flux', 'flex1', 'flex2', 'chroma', 'lumina2',
'qwen_image', 'hidream', 'sdxl', 'sd15', 'omnigen2'
}
instruction_arches = {
'flux_kontext', 'qwen_image_edit', 'qwen_image_edit_plus', 'hidream_e1'
}
video_arches = {
'wan21:1b', 'wan21_i2v:14b480p', 'wan21_i2v:14b', 'wan21:14b',
'wan22_14b:t2v', 'wan22_14b_i2v', 'wan22_5b'
}
# Determine the task type based on architecture group
if lower_arch in instruction_arches:
tags.append("image-to-image")
elif lower_arch in video_arches:
# Video models: check if i2v is in the architecture name
is_i2v = 'i2v' in lower_arch
tags.append("image-to-video" if is_i2v else "text-to-video")
elif lower_arch in image_arches:
tags.append("text-to-image")
else:
# Fallback to text-to-image for unknown architectures
tags.append("text-to-image")
if "xl" in lower_arch:
tags.append("stable-diffusion-xl")
if "flux" in lower_arch:
tags.append("flux")
if "lumina" in lower_arch:
tags.append("lumina2")
if "sd3" in lower_arch or "v3" in lower_arch:
tags.append("sd3")
# Add LoRA-specific tags
tags.extend(["lora", "diffusers", "template:sd-lora", "ai-toolkit"])
# Generate widgets and gallery section from sample images
curated_samples = curated_samples or []
widgets = []
prompt_bullets = []
for sample in curated_samples:
prompt_text = str(sample.get("prompt", "")).strip()
repo_path = sample.get("repo_path")
if not prompt_text or not repo_path:
continue
widgets.append({
"text": prompt_text,
"output": {"url": repo_path}
})
prompt_bullets.append(f"- {prompt_text}")
gallery_section = ""
if prompt_bullets:
gallery_section = "<Gallery />\\n\\n" + "### Prompts\\n\\n" + "\\n".join(prompt_bullets) + "\\n\\n"
# Determine torch dtype based on model
dtype = "torch.bfloat16"
try:
arch_lower = arch.lower()
except AttributeError:
arch_lower = ""
if "sd15" in arch_lower or "sdxl" in arch_lower:
dtype = "torch.float16"
# Find the main safetensors file for usage example
main_safetensors = f"{model_name}.safetensors"
if uploaded_files:
safetensors_files = [f for f in uploaded_files if f.endswith('.safetensors')]
if safetensors_files:
preferred_name = f"{model_name}.safetensors"
exact_match = next(
(
f
for f in safetensors_files
if os.path.basename(f) == preferred_name or f == preferred_name
),
None,
)
if exact_match:
main_safetensors = exact_match
else:
def extract_step(filename: str) -> int:
match = re.search(r"_(\d+)\.safetensors$", os.path.basename(filename))
return int(match.group(1)) if match else -1
safetensors_files.sort(
key=lambda f: (extract_step(f), f),
reverse=True,
)
main_safetensors = safetensors_files[0]
# Construct YAML frontmatter
frontmatter = {
"tags": tags,
"base_model": base_model,
**license_info
}
if widgets:
frontmatter["widget"] = widgets
inference_params = {}
sample_width = sample_config.get("width") if isinstance(sample_config, dict) else None
sample_height = sample_config.get("height") if isinstance(sample_config, dict) else None
if sample_width:
inference_params["width"] = sample_width
if sample_height:
inference_params["height"] = sample_height
if inference_params:
frontmatter["inference"] = {"parameters": inference_params}
if trigger_word:
frontmatter["instance_prompt"] = trigger_word
# Get first prompt for usage example
usage_prompt = trigger_word or "a beautiful landscape"
if widgets:
usage_prompt = widgets[0]["text"]
elif trigger_word:
usage_prompt = trigger_word
# Construct README content
trigger_section = f"You should use \`{trigger_word}\` to trigger the image generation." if trigger_word else "No trigger words defined."
# Build YAML frontmatter string
frontmatter_yaml = yaml.dump(frontmatter, default_flow_style=False, allow_unicode=True, sort_keys=False).strip()
readme_content = f"""---
{frontmatter_yaml}
---
# {model_name}
Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit)
{gallery_section}
## Trigger words
{trigger_section}
## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, etc.
Weights for this model are available in Safetensors format.
[Download]({repo_id}/tree/main) them in the Files & versions tab.
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
\`\`\`py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained('{base_model}', torch_dtype={dtype}).to('cuda')
pipeline.load_lora_weights('{repo_id}', weight_name='{main_safetensors}')
image = pipeline('{usage_prompt}').images[0]
image.save("my_image.png")
\`\`\`
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
"""
return readme_content
except Exception as e:
print(f"Error generating README: {e}")
# Fallback simple README
return f"""# {model_name}
Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit)
## Download model
Weights for this model are available in Safetensors format.
[Download]({repo_id}/tree/main) them in the Files & versions tab.
"""
def main():
# Setup environment - token comes from HF Jobs secrets
if "HF_TOKEN" not in os.environ:
raise ValueError("HF_TOKEN environment variable not set")
# Install system dependencies for headless operation
print("Installing system dependencies...")
try:
subprocess.run(["apt-get", "update"], check=True, capture_output=True)
subprocess.run([
"apt-get", "install", "-y",
"libgl1-mesa-glx",
"libglib2.0-0",
"libsm6",
"libxext6",
"libxrender-dev",
"libgomp1",
"ffmpeg"
], check=True, capture_output=True)
print("System dependencies installed successfully")
except subprocess.CalledProcessError as e:
print(f"Failed to install system dependencies: {e}")
print("Continuing without system dependencies...")
# Setup ai-toolkit
toolkit_dir = setup_ai_toolkit()
# Create temporary directories
with tempfile.TemporaryDirectory() as temp_dir:
dataset_path = os.path.join(temp_dir, "dataset")
output_path = os.path.join(temp_dir, "output")
# Download dataset
download_dataset("${datasetRepo}", dataset_path)
# Create config
config = create_config(dataset_path, output_path)
config_path = os.path.join(temp_dir, "config.yaml")
with open(config_path, "w") as f:
yaml.dump(config, f, default_flow_style=False)
# Run training
print("Starting training...")
os.chdir(toolkit_dir)
subprocess.run([
sys.executable, "run.py",
config_path
], check=True)
print("Training completed!")
# Upload results
model_name = f"${jobConfig.config.name}-lora"
upload_results(output_path, model_name, "${namespace}", os.environ["HF_TOKEN"], config)
if __name__ == "__main__":
main()
`;
}
async function submitHFJobUV(token: string, hardware: string, scriptPath: string, namespaceOverride?: string): Promise<string> {
return new Promise((resolve, reject) => {
// Ensure token is available
if (!token) {
reject(new Error('HF_TOKEN is required'));
return;
}
console.log('Setting up environment with HF_TOKEN for job submission');
const namespaceArgs = namespaceOverride ? ` --namespace ${namespaceOverride}` : '';
console.log(`Command: hf jobs uv run --flavor ${hardware} --timeout 5h --secrets HF_TOKEN --detach${namespaceArgs} ${scriptPath}`);
// Use hf jobs uv run command with timeout and detach to get job ID
const args = [
'jobs', 'uv', 'run',
'--flavor', hardware,
'--timeout', '5h',
'--secrets', 'HF_TOKEN',
'--detach'
];
if (namespaceOverride) {
args.push('--namespace', namespaceOverride);
}
args.push(scriptPath);
const childProcess = spawn('hf', args, {
env: {
...process.env,
HF_TOKEN: token
}
});
let output = '';
let error = '';
childProcess.stdout.on('data', (data) => {
const text = data.toString();
output += text;
console.log('HF Jobs stdout:', text);
});
childProcess.stderr.on('data', (data) => {
const text = data.toString();
error += text;
console.log('HF Jobs stderr:', text);
});
childProcess.on('close', (code) => {
console.log('HF Jobs process closed with code:', code);
console.log('Full output:', output);
console.log('Full error:', error);
if (code === 0) {
// With --detach flag, the output should be just the job ID
const fullText = (output + ' ' + error).trim();
// Updated patterns to handle variable-length hex job IDs (16-24+ characters)
const jobIdPatterns = [
/Job started with ID:\s*([a-f0-9]{16,})/i, // "Job started with ID: 68b26b73767540db9fc726ac"
/job\s+([a-f0-9]{16,})/i, // "job 68b26b73767540db9fc726ac"
/Job ID:\s*([a-f0-9]{16,})/i, // "Job ID: 68b26b73767540db9fc726ac"
/created\s+job\s+([a-f0-9]{16,})/i, // "created job 68b26b73767540db9fc726ac"
/submitted.*?job\s+([a-f0-9]{16,})/i, // "submitted ... job 68b26b73767540db9fc726ac"
/https:\/\/huggingface\.co\/jobs\/[^\/]+\/([a-f0-9]{16,})/i, // URL pattern
/([a-f0-9]{20,})/i, // Fallback: any 20+ char hex string
];
let jobId = 'unknown';
for (const pattern of jobIdPatterns) {
const match = fullText.match(pattern);
if (match && match[1] && match[1] !== 'started') {
jobId = match[1];
console.log(`Extracted job ID using pattern: ${pattern.toString()} -> ${jobId}`);
break;
}
}
resolve(jobId);
} else {
reject(new Error(error || output || 'Failed to submit job'));
}
});
childProcess.on('error', (err) => {
console.error('HF Jobs process error:', err);
reject(new Error(`Process error: ${err.message}`));
});
});
}
async function checkHFJobStatus(token: string, jobId: string, jobNamespace?: string): Promise<any> {
return new Promise((resolve, reject) => {
console.log(`Checking HF Job status for: ${jobId}`);
const args = ['jobs', 'inspect'];
if (jobNamespace) {
console.log(`Using namespace override for status check: ${jobNamespace}`);
args.push('--namespace', jobNamespace);
}
args.push(jobId);
const childProcess = spawn('hf', args, {
env: {
...process.env,
HF_TOKEN: token
}
});
let output = '';
let error = '';
childProcess.stdout.on('data', (data) => {
const text = data.toString();
output += text;
});
childProcess.stderr.on('data', (data) => {
const text = data.toString();
error += text;
});
childProcess.on('close', (code) => {
if (code === 0) {
try {
// Parse the JSON output from hf jobs inspect
const jobInfo = JSON.parse(output);
if (Array.isArray(jobInfo) && jobInfo.length > 0) {
const job = jobInfo[0];
resolve({
id: job.id,
status: job.status?.stage || 'UNKNOWN',
message: job.status?.message,
created_at: job.created_at,
flavor: job.flavor,
url: job.url,
});
} else {
reject(new Error('Invalid job info response'));
}
} catch (parseError: any) {
console.error('Failed to parse job status:', parseError, output);
reject(new Error('Failed to parse job status'));
}
} else {
reject(new Error(error || output || 'Failed to check job status'));
}
});
childProcess.on('error', (err) => {
console.error('HF Jobs inspect process error:', err);
reject(new Error(`Process error: ${err.message}`));
});
});
}
async function checkHFJobsCapacity(token: string): Promise<any> {
try {
console.log('Checking HF Jobs capacity for namespace: lora-training-frenzi via API');
// Use HuggingFace API directly instead of CLI to avoid TTY issues
const response = await fetch('https://huggingface.co/api/jobs/lora-training-frenzi', {
headers: {
'Authorization': `Bearer ${token}`,
},
});
if (!response.ok) {
throw new Error(`API request failed: ${response.status} ${response.statusText}`);
}
const jobs = await response.json();
console.log(`Fetched ${jobs.length} total jobs from API`);
// Count jobs with status RUNNING
let runningCount = 0;
for (const job of jobs) {
const status = job.status?.stage || job.status;
if (status === 'RUNNING') {
runningCount++;
}
}
const atCapacity = runningCount >= 32;
console.log(`\n=== FINAL COUNT ===`);
console.log(`Found ${runningCount} RUNNING jobs. At capacity: ${atCapacity}`);
console.log(`==================\n`);
return {
runningJobs: runningCount,
atCapacity,
capacityLimit: 32,
};
} catch (error: any) {
console.error('Failed to check capacity via API:', error);
throw new Error(`Failed to check capacity: ${error.message}`);
}
}