import gradio as gr import os import subprocess import shutil import json import time from pathlib import Path import torch # Setup directories DATASET_DIR = Path("./datasets") OUTPUT_DIR = Path("./output") DATASET_DIR.mkdir(exist_ok=True) OUTPUT_DIR.mkdir(exist_ok=True) # Global variable to store dataset path current_dataset_path = None def check_gpu(): """Check if GPU is available""" if torch.cuda.is_available(): gpu_name = torch.cuda.get_device_name(0) return f"✅ GPU Available: {gpu_name}" return "⚠️ No GPU detected - training will be slow" def upload_and_prepare_dataset(files, dataset_name, trigger_word): """Upload images and prepare dataset""" global current_dataset_path if not files: return "❌ Please upload at least one image", None, "" if not dataset_name: dataset_name = f"dataset_{int(time.time())}" # Create dataset directory dataset_path = DATASET_DIR / dataset_name dataset_path.mkdir(exist_ok=True, parents=True) # Save images image_count = 0 for file in files: if file.name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp', '.bmp')): filename = Path(file.name).name destination = dataset_path / filename shutil.copy(file.name, destination) # Create simple caption file caption_file = destination.with_suffix('.txt') caption_text = trigger_word if trigger_word else "a photo" with open(caption_file, 'w') as f: f.write(caption_text) image_count += 1 if image_count == 0: return "❌ No valid images found. Upload PNG, JPG, JPEG, or WEBP files.", None, "" current_dataset_path = str(dataset_path) status = f"✅ Successfully uploaded {image_count} images\n" status += f"📁 Dataset: {dataset_name}\n" if trigger_word: status += f"🏷️ Trigger word: '{trigger_word}'\n" status += f"💾 Location: {current_dataset_path}" return status, current_dataset_path, f"Dataset ready: {dataset_name}" def train_lora( dataset_path, project_name, trigger_word, steps, learning_rate, lora_rank, resolution, progress=gr.Progress() ): """Train LoRA model""" if not dataset_path or not os.path.exists(dataset_path): return "❌ Please upload a dataset first!", None if not project_name: project_name = f"lora_{int(time.time())}" output_path = OUTPUT_DIR / project_name output_path.mkdir(exist_ok=True, parents=True) # Create training config config = { "job": "extension", "config": { "name": project_name, "process": [{ "type": "sd_trainer", "training_folder": str(output_path), "device": "cuda:0", "trigger_word": trigger_word or "", "network": { "type": "lora", "linear": int(lora_rank), "linear_alpha": int(lora_rank), }, "save": { "dtype": "float16", "save_every": max(100, int(steps / 4)), "max_step_saves_to_keep": 3, }, "datasets": [{ "folder_path": dataset_path, "caption_ext": "txt", "caption_dropout_rate": 0.05, "resolution": [int(resolution), int(resolution)], }], "train": { "batch_size": 1, "steps": int(steps), "gradient_accumulation_steps": 1, "train_unet": True, "train_text_encoder": False, "gradient_checkpointing": True, "noise_scheduler": "flowmatch", "optimizer": "adamw8bit", "lr": float(learning_rate), "ema_config": { "use_ema": True, "ema_decay": 0.99, }, "dtype": "bf16", }, "model": { "name_or_path": "Tongyi-MAI/Z-Image-Base", "is_v_pred": False, "quantize": True, }, "sample": { "sampler": "flowmatch", "sample_every": max(100, int(steps / 4)), "width": int(resolution), "height": int(resolution), "prompts": [ f"{trigger_word} high quality photo" if trigger_word else "high quality photo", f"{trigger_word} beautiful scene" if trigger_word else "beautiful scene", ], "neg": "", "seed": 42, "guidance_scale": 0.0, "sample_steps": 9, }, }] } } # Save config config_path = output_path / "config.json" with open(config_path, 'w') as f: json.dump(config, f, indent=2) progress(0.1, desc="Installing AI Toolkit...") # Install AI Toolkit if not exists if not Path("./ai-toolkit").exists(): try: subprocess.run( ["git", "clone", "https://github.com/ostris/ai-toolkit.git"], check=True, capture_output=True ) os.chdir("ai-toolkit") subprocess.run( ["git", "submodule", "update", "--init", "--recursive"], check=True, capture_output=True ) subprocess.run( ["pip", "install", "-q", "-r", "requirements.txt"], check=True ) os.chdir("..") except Exception as e: return f"❌ Failed to install AI Toolkit: {str(e)}", None progress(0.3, desc="Starting training...") # Run training try: result = subprocess.run( ["python", "ai-toolkit/run.py", str(config_path)], capture_output=True, text=True, timeout=3600 # 1 hour timeout ) if result.returncode != 0: return f"❌ Training failed:\n{result.stderr}", None progress(0.9, desc="Training complete! Finding LoRA file...") # Find the trained LoRA file lora_files = list(output_path.glob("*.safetensors")) if lora_files: lora_file = lora_files[-1] # Get the latest one success_msg = f"✅ Training Complete!\n\n" success_msg += f"📦 LoRA saved: {lora_file.name}\n" success_msg += f"💾 Size: {lora_file.stat().st_size / (1024*1024):.2f} MB\n" success_msg += f"🏷️ Use trigger word: '{trigger_word}' in your prompts" return success_msg, str(lora_file) else: return "⚠️ Training completed but no LoRA file found", None except subprocess.TimeoutExpired: return "❌ Training timeout (> 1 hour). Try reducing steps.", None except Exception as e: return f"❌ Training error: {str(e)}", None # Gradio Interface with gr.Blocks(title="Z-Image LoRA Trainer", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎨 Z-Image LoRA Trainer Train custom LoRA models for Z-Image-Base (6B parameter model) **Quick Start:** 1. Upload 10-50 images of your subject 2. Enter a trigger word (e.g., "mycharacter", "mystyle") 3. Click Train 4. Download your LoRA when complete ⚠️ **Note:** Training takes 10-30 minutes depending on steps. Don't close this tab! """) # GPU Status gpu_status = gr.Textbox(label="GPU Status", value=check_gpu(), interactive=False) with gr.Tab("📤 Upload Dataset"): with gr.Row(): with gr.Column(): file_input = gr.Files( label="Upload Images (10-50 recommended)", file_types=["image"], file_count="multiple" ) dataset_name_input = gr.Textbox( label="Dataset Name", placeholder="my_dataset", value="my_dataset" ) trigger_word_input = gr.Textbox( label="Trigger Word (optional but recommended)", placeholder="e.g., mycharacter, mystyle", info="A unique word to activate your LoRA" ) upload_btn = gr.Button("📤 Upload Dataset", variant="primary", size="lg") with gr.Column(): upload_status = gr.Textbox(label="Upload Status", lines=8) dataset_path_state = gr.Textbox(label="Dataset Path", visible=False) dataset_ready = gr.Textbox(label="Ready to Train", interactive=False) with gr.Tab("🚀 Train LoRA"): with gr.Row(): with gr.Column(): project_name_input = gr.Textbox( label="Project Name", placeholder="my_lora", value="my_lora" ) gr.Markdown("### Training Settings") steps_input = gr.Slider( label="Training Steps", minimum=100, maximum=3000, value=1000, step=100, info="More steps = better quality but slower. Start with 1000." ) learning_rate_input = gr.Slider( label="Learning Rate", minimum=0.00001, maximum=0.001, value=0.0001, step=0.00001, info="Default 0.0001 works well for most cases" ) lora_rank_input = gr.Slider( label="LoRA Rank", minimum=4, maximum=128, value=16, step=4, info="Higher = more detail but larger file. 16 is balanced." ) resolution_input = gr.Radio( label="Resolution", choices=[512, 768, 1024], value=1024, info="Z-Image native resolution is 1024x1024" ) train_btn = gr.Button("🚀 Start Training", variant="primary", size="lg") with gr.Column(): training_status = gr.Textbox(label="Training Status", lines=15) lora_output = gr.File(label="Download Trained LoRA") with gr.Tab("ℹ️ Help"): gr.Markdown(""" ## 📚 How to Use ### Step 1: Prepare Your Images - **10-50 images** of your subject (more is better for complex subjects) - **Consistent subject** across images - **Good variety** in poses, angles, lighting - **High quality** photos (clear, well-lit) ### Step 2: Upload Dataset - Choose a descriptive **dataset name** - Add a **trigger word** (e.g., "sks person", "mystyle") - Upload your images ### Step 3: Configure Training - **Project name**: Name for your LoRA - **Steps**: - 500-1000 for simple subjects - 1000-2000 for complex subjects/styles - **Learning rate**: Keep default (0.0001) - **LoRA Rank**: 16 is good for most cases ### Step 4: Train - Click "Start Training" - Wait 10-30 minutes (don't close tab) - Download your LoRA when complete ### Step 5: Use Your LoRA - Load in ComfyUI, Automatic1111, or other Z-Image tools - Use your trigger word in prompts - Example: "a photo of [trigger_word] in a forest" ## 🎯 Tips for Best Results - **Good dataset** = good results - **Consistent subject** across images - **Unique trigger word** (not common words) - **Start with 1000 steps**, adjust if needed - **Don't overtrain** (if quality decreases, reduce steps) ## ⚠️ Troubleshooting **Training fails with OOM error:** - Reduce resolution to 768 or 512 - Use fewer steps - Upload fewer images **LoRA doesn't look like subject:** - Upload more images (20-30+) - Increase steps to 1500-2000 - Ensure images are consistent **LoRA is too strong/weak:** - Adjust LoRA weight in your inference tool (0.5-1.5) ## 📖 Resources - **Z-Image Model**: [Tongyi-MAI/Z-Image-Base](https://huggingface.co/Tongyi-MAI/Z-Image-Base) - **AI Toolkit**: [github.com/ostris/ai-toolkit](https://github.com/ostris/ai-toolkit) - **Training Adapter**: [ostris/zimage_turbo_training_adapter](https://huggingface.co/ostris/zimage_turbo_training_adapter) """) # Event handlers upload_btn.click( fn=upload_and_prepare_dataset, inputs=[file_input, dataset_name_input, trigger_word_input], outputs=[upload_status, dataset_path_state, dataset_ready] ) train_btn.click( fn=train_lora, inputs=[ dataset_path_state, project_name_input, trigger_word_input, steps_input, learning_rate_input, lora_rank_input, resolution_input ], outputs=[training_status, lora_output] ) if __name__ == "__main__": demo.launch()