Yeu3ui commited on
Commit
dbb5a8c
Β·
verified Β·
1 Parent(s): 78a15f3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +397 -0
app.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import subprocess
4
+ import shutil
5
+ import json
6
+ import time
7
+ from pathlib import Path
8
+ import torch
9
+
10
+ # Setup directories
11
+ DATASET_DIR = Path("./datasets")
12
+ OUTPUT_DIR = Path("./output")
13
+ DATASET_DIR.mkdir(exist_ok=True)
14
+ OUTPUT_DIR.mkdir(exist_ok=True)
15
+
16
+ # Global variable to store dataset path
17
+ current_dataset_path = None
18
+
19
+ def check_gpu():
20
+ """Check if GPU is available"""
21
+ if torch.cuda.is_available():
22
+ gpu_name = torch.cuda.get_device_name(0)
23
+ return f"βœ… GPU Available: {gpu_name}"
24
+ return "⚠️ No GPU detected - training will be slow"
25
+
26
+ def upload_and_prepare_dataset(files, dataset_name, trigger_word):
27
+ """Upload images and prepare dataset"""
28
+ global current_dataset_path
29
+
30
+ if not files:
31
+ return "❌ Please upload at least one image", None, ""
32
+
33
+ if not dataset_name:
34
+ dataset_name = f"dataset_{int(time.time())}"
35
+
36
+ # Create dataset directory
37
+ dataset_path = DATASET_DIR / dataset_name
38
+ dataset_path.mkdir(exist_ok=True, parents=True)
39
+
40
+ # Save images
41
+ image_count = 0
42
+ for file in files:
43
+ if file.name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp', '.bmp')):
44
+ filename = Path(file.name).name
45
+ destination = dataset_path / filename
46
+ shutil.copy(file.name, destination)
47
+
48
+ # Create simple caption file
49
+ caption_file = destination.with_suffix('.txt')
50
+ caption_text = trigger_word if trigger_word else "a photo"
51
+ with open(caption_file, 'w') as f:
52
+ f.write(caption_text)
53
+
54
+ image_count += 1
55
+
56
+ if image_count == 0:
57
+ return "❌ No valid images found. Upload PNG, JPG, JPEG, or WEBP files.", None, ""
58
+
59
+ current_dataset_path = str(dataset_path)
60
+
61
+ status = f"βœ… Successfully uploaded {image_count} images\n"
62
+ status += f"πŸ“ Dataset: {dataset_name}\n"
63
+ if trigger_word:
64
+ status += f"🏷️ Trigger word: '{trigger_word}'\n"
65
+ status += f"πŸ’Ύ Location: {current_dataset_path}"
66
+
67
+ return status, current_dataset_path, f"Dataset ready: {dataset_name}"
68
+
69
+ def train_lora(
70
+ dataset_path,
71
+ project_name,
72
+ trigger_word,
73
+ steps,
74
+ learning_rate,
75
+ lora_rank,
76
+ resolution,
77
+ progress=gr.Progress()
78
+ ):
79
+ """Train LoRA model"""
80
+
81
+ if not dataset_path or not os.path.exists(dataset_path):
82
+ return "❌ Please upload a dataset first!", None
83
+
84
+ if not project_name:
85
+ project_name = f"lora_{int(time.time())}"
86
+
87
+ output_path = OUTPUT_DIR / project_name
88
+ output_path.mkdir(exist_ok=True, parents=True)
89
+
90
+ # Create training config
91
+ config = {
92
+ "job": "extension",
93
+ "config": {
94
+ "name": project_name,
95
+ "process": [{
96
+ "type": "sd_trainer",
97
+ "training_folder": str(output_path),
98
+ "device": "cuda:0",
99
+ "trigger_word": trigger_word or "",
100
+ "network": {
101
+ "type": "lora",
102
+ "linear": int(lora_rank),
103
+ "linear_alpha": int(lora_rank),
104
+ },
105
+ "save": {
106
+ "dtype": "float16",
107
+ "save_every": max(100, int(steps / 4)),
108
+ "max_step_saves_to_keep": 3,
109
+ },
110
+ "datasets": [{
111
+ "folder_path": dataset_path,
112
+ "caption_ext": "txt",
113
+ "caption_dropout_rate": 0.05,
114
+ "resolution": [int(resolution), int(resolution)],
115
+ }],
116
+ "train": {
117
+ "batch_size": 1,
118
+ "steps": int(steps),
119
+ "gradient_accumulation_steps": 1,
120
+ "train_unet": True,
121
+ "train_text_encoder": False,
122
+ "gradient_checkpointing": True,
123
+ "noise_scheduler": "flowmatch",
124
+ "optimizer": "adamw8bit",
125
+ "lr": float(learning_rate),
126
+ "ema_config": {
127
+ "use_ema": True,
128
+ "ema_decay": 0.99,
129
+ },
130
+ "dtype": "bf16",
131
+ },
132
+ "model": {
133
+ "name_or_path": "Tongyi-MAI/Z-Image-Base",
134
+ "is_v_pred": False,
135
+ "quantize": True,
136
+ },
137
+ "sample": {
138
+ "sampler": "flowmatch",
139
+ "sample_every": max(100, int(steps / 4)),
140
+ "width": int(resolution),
141
+ "height": int(resolution),
142
+ "prompts": [
143
+ f"{trigger_word} high quality photo" if trigger_word else "high quality photo",
144
+ f"{trigger_word} beautiful scene" if trigger_word else "beautiful scene",
145
+ ],
146
+ "neg": "",
147
+ "seed": 42,
148
+ "guidance_scale": 0.0,
149
+ "sample_steps": 9,
150
+ },
151
+ }]
152
+ }
153
+ }
154
+
155
+ # Save config
156
+ config_path = output_path / "config.json"
157
+ with open(config_path, 'w') as f:
158
+ json.dump(config, f, indent=2)
159
+
160
+ progress(0.1, desc="Installing AI Toolkit...")
161
+
162
+ # Install AI Toolkit if not exists
163
+ if not Path("./ai-toolkit").exists():
164
+ try:
165
+ subprocess.run(
166
+ ["git", "clone", "https://github.com/ostris/ai-toolkit.git"],
167
+ check=True,
168
+ capture_output=True
169
+ )
170
+ os.chdir("ai-toolkit")
171
+ subprocess.run(
172
+ ["git", "submodule", "update", "--init", "--recursive"],
173
+ check=True,
174
+ capture_output=True
175
+ )
176
+ subprocess.run(
177
+ ["pip", "install", "-q", "-r", "requirements.txt"],
178
+ check=True
179
+ )
180
+ os.chdir("..")
181
+ except Exception as e:
182
+ return f"❌ Failed to install AI Toolkit: {str(e)}", None
183
+
184
+ progress(0.3, desc="Starting training...")
185
+
186
+ # Run training
187
+ try:
188
+ result = subprocess.run(
189
+ ["python", "ai-toolkit/run.py", str(config_path)],
190
+ capture_output=True,
191
+ text=True,
192
+ timeout=3600 # 1 hour timeout
193
+ )
194
+
195
+ if result.returncode != 0:
196
+ return f"❌ Training failed:\n{result.stderr}", None
197
+
198
+ progress(0.9, desc="Training complete! Finding LoRA file...")
199
+
200
+ # Find the trained LoRA file
201
+ lora_files = list(output_path.glob("*.safetensors"))
202
+ if lora_files:
203
+ lora_file = lora_files[-1] # Get the latest one
204
+ success_msg = f"βœ… Training Complete!\n\n"
205
+ success_msg += f"πŸ“¦ LoRA saved: {lora_file.name}\n"
206
+ success_msg += f"πŸ’Ύ Size: {lora_file.stat().st_size / (1024*1024):.2f} MB\n"
207
+ success_msg += f"🏷️ Use trigger word: '{trigger_word}' in your prompts"
208
+ return success_msg, str(lora_file)
209
+ else:
210
+ return "⚠️ Training completed but no LoRA file found", None
211
+
212
+ except subprocess.TimeoutExpired:
213
+ return "❌ Training timeout (> 1 hour). Try reducing steps.", None
214
+ except Exception as e:
215
+ return f"❌ Training error: {str(e)}", None
216
+
217
+ # Gradio Interface
218
+ with gr.Blocks(title="Z-Image LoRA Trainer", theme=gr.themes.Soft()) as demo:
219
+ gr.Markdown("""
220
+ # 🎨 Z-Image LoRA Trainer
221
+
222
+ Train custom LoRA models for Z-Image-Base (6B parameter model)
223
+
224
+ **Quick Start:**
225
+ 1. Upload 10-50 images of your subject
226
+ 2. Enter a trigger word (e.g., "mycharacter", "mystyle")
227
+ 3. Click Train
228
+ 4. Download your LoRA when complete
229
+
230
+ ⚠️ **Note:** Training takes 10-30 minutes depending on steps. Don't close this tab!
231
+ """)
232
+
233
+ # GPU Status
234
+ gpu_status = gr.Textbox(label="GPU Status", value=check_gpu(), interactive=False)
235
+
236
+ with gr.Tab("πŸ“€ Upload Dataset"):
237
+ with gr.Row():
238
+ with gr.Column():
239
+ file_input = gr.Files(
240
+ label="Upload Images (10-50 recommended)",
241
+ file_types=["image"],
242
+ file_count="multiple"
243
+ )
244
+ dataset_name_input = gr.Textbox(
245
+ label="Dataset Name",
246
+ placeholder="my_dataset",
247
+ value="my_dataset"
248
+ )
249
+ trigger_word_input = gr.Textbox(
250
+ label="Trigger Word (optional but recommended)",
251
+ placeholder="e.g., mycharacter, mystyle",
252
+ info="A unique word to activate your LoRA"
253
+ )
254
+ upload_btn = gr.Button("πŸ“€ Upload Dataset", variant="primary", size="lg")
255
+
256
+ with gr.Column():
257
+ upload_status = gr.Textbox(label="Upload Status", lines=8)
258
+ dataset_path_state = gr.Textbox(label="Dataset Path", visible=False)
259
+ dataset_ready = gr.Textbox(label="Ready to Train", interactive=False)
260
+
261
+ with gr.Tab("πŸš€ Train LoRA"):
262
+ with gr.Row():
263
+ with gr.Column():
264
+ project_name_input = gr.Textbox(
265
+ label="Project Name",
266
+ placeholder="my_lora",
267
+ value="my_lora"
268
+ )
269
+
270
+ gr.Markdown("### Training Settings")
271
+
272
+ steps_input = gr.Slider(
273
+ label="Training Steps",
274
+ minimum=100,
275
+ maximum=3000,
276
+ value=1000,
277
+ step=100,
278
+ info="More steps = better quality but slower. Start with 1000."
279
+ )
280
+
281
+ learning_rate_input = gr.Slider(
282
+ label="Learning Rate",
283
+ minimum=0.00001,
284
+ maximum=0.001,
285
+ value=0.0001,
286
+ step=0.00001,
287
+ info="Default 0.0001 works well for most cases"
288
+ )
289
+
290
+ lora_rank_input = gr.Slider(
291
+ label="LoRA Rank",
292
+ minimum=4,
293
+ maximum=128,
294
+ value=16,
295
+ step=4,
296
+ info="Higher = more detail but larger file. 16 is balanced."
297
+ )
298
+
299
+ resolution_input = gr.Radio(
300
+ label="Resolution",
301
+ choices=[512, 768, 1024],
302
+ value=1024,
303
+ info="Z-Image native resolution is 1024x1024"
304
+ )
305
+
306
+ train_btn = gr.Button("πŸš€ Start Training", variant="primary", size="lg")
307
+
308
+ with gr.Column():
309
+ training_status = gr.Textbox(label="Training Status", lines=15)
310
+ lora_output = gr.File(label="Download Trained LoRA")
311
+
312
+ with gr.Tab("ℹ️ Help"):
313
+ gr.Markdown("""
314
+ ## πŸ“š How to Use
315
+
316
+ ### Step 1: Prepare Your Images
317
+ - **10-50 images** of your subject (more is better for complex subjects)
318
+ - **Consistent subject** across images
319
+ - **Good variety** in poses, angles, lighting
320
+ - **High quality** photos (clear, well-lit)
321
+
322
+ ### Step 2: Upload Dataset
323
+ - Choose a descriptive **dataset name**
324
+ - Add a **trigger word** (e.g., "sks person", "mystyle")
325
+ - Upload your images
326
+
327
+ ### Step 3: Configure Training
328
+ - **Project name**: Name for your LoRA
329
+ - **Steps**:
330
+ - 500-1000 for simple subjects
331
+ - 1000-2000 for complex subjects/styles
332
+ - **Learning rate**: Keep default (0.0001)
333
+ - **LoRA Rank**: 16 is good for most cases
334
+
335
+ ### Step 4: Train
336
+ - Click "Start Training"
337
+ - Wait 10-30 minutes (don't close tab)
338
+ - Download your LoRA when complete
339
+
340
+ ### Step 5: Use Your LoRA
341
+ - Load in ComfyUI, Automatic1111, or other Z-Image tools
342
+ - Use your trigger word in prompts
343
+ - Example: "a photo of [trigger_word] in a forest"
344
+
345
+ ## 🎯 Tips for Best Results
346
+
347
+ - **Good dataset** = good results
348
+ - **Consistent subject** across images
349
+ - **Unique trigger word** (not common words)
350
+ - **Start with 1000 steps**, adjust if needed
351
+ - **Don't overtrain** (if quality decreases, reduce steps)
352
+
353
+ ## ⚠️ Troubleshooting
354
+
355
+ **Training fails with OOM error:**
356
+ - Reduce resolution to 768 or 512
357
+ - Use fewer steps
358
+ - Upload fewer images
359
+
360
+ **LoRA doesn't look like subject:**
361
+ - Upload more images (20-30+)
362
+ - Increase steps to 1500-2000
363
+ - Ensure images are consistent
364
+
365
+ **LoRA is too strong/weak:**
366
+ - Adjust LoRA weight in your inference tool (0.5-1.5)
367
+
368
+ ## πŸ“– Resources
369
+
370
+ - **Z-Image Model**: [Tongyi-MAI/Z-Image-Base](https://huggingface.co/Tongyi-MAI/Z-Image-Base)
371
+ - **AI Toolkit**: [github.com/ostris/ai-toolkit](https://github.com/ostris/ai-toolkit)
372
+ - **Training Adapter**: [ostris/zimage_turbo_training_adapter](https://huggingface.co/ostris/zimage_turbo_training_adapter)
373
+ """)
374
+
375
+ # Event handlers
376
+ upload_btn.click(
377
+ fn=upload_and_prepare_dataset,
378
+ inputs=[file_input, dataset_name_input, trigger_word_input],
379
+ outputs=[upload_status, dataset_path_state, dataset_ready]
380
+ )
381
+
382
+ train_btn.click(
383
+ fn=train_lora,
384
+ inputs=[
385
+ dataset_path_state,
386
+ project_name_input,
387
+ trigger_word_input,
388
+ steps_input,
389
+ learning_rate_input,
390
+ lora_rank_input,
391
+ resolution_input
392
+ ],
393
+ outputs=[training_status, lora_output]
394
+ )
395
+
396
+ if __name__ == "__main__":
397
+ demo.launch()