Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| LoRA Trainer Funcional para Hugging Face | |
| Baseado no kohya-ss sd-scripts | |
| """ | |
| import gradio as gr | |
| import os | |
| import sys | |
| import json | |
| import subprocess | |
| import shutil | |
| import zipfile | |
| import tempfile | |
| import toml | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional, Tuple, List, Dict, Any | |
| import time | |
| import threading | |
| import queue | |
| # Adicionar o diretório sd-scripts ao path | |
| sys.path.insert(0, str(Path(__file__).parent / "sd-scripts")) | |
| # Configurar logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class LoRATrainerHF: | |
| def __init__(self): | |
| self.base_dir = Path("/tmp/lora_training") | |
| self.base_dir.mkdir(exist_ok=True) | |
| self.models_dir = self.base_dir / "models" | |
| self.models_dir.mkdir(exist_ok=True) | |
| self.projects_dir = self.base_dir / "projects" | |
| self.projects_dir.mkdir(exist_ok=True) | |
| self.sd_scripts_dir = Path(__file__).parent / "sd-scripts" | |
| # URLs dos modelos | |
| self.model_urls = { | |
| "Anime (animefull-final-pruned)": "https://huggingface.co/hollowstrawberry/stable-diffusion-guide/resolve/main/models/animefull-final-pruned-fp16.safetensors", | |
| "AnyLoRA": "https://huggingface.co/Lykon/AnyLoRA/resolve/main/AnyLoRA_noVae_fp16-pruned.ckpt", | |
| "Stable Diffusion 1.5": "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors", | |
| "Waifu Diffusion 1.4": "https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt" | |
| } | |
| self.training_process = None | |
| self.training_output_queue = queue.Queue() | |
| def install_dependencies(self) -> str: | |
| """Instala as dependências necessárias""" | |
| try: | |
| logger.info("Instalando dependências...") | |
| # Lista de pacotes necessários | |
| packages = [ | |
| "torch>=2.0.0", | |
| "torchvision>=0.15.0", | |
| "diffusers>=0.21.0", | |
| "transformers>=4.25.0", | |
| "accelerate>=0.20.0", | |
| "safetensors>=0.3.0", | |
| "huggingface-hub>=0.16.0", | |
| "xformers>=0.0.20", | |
| "bitsandbytes>=0.41.0", | |
| "opencv-python>=4.7.0", | |
| "Pillow>=9.0.0", | |
| "numpy>=1.21.0", | |
| "tqdm>=4.64.0", | |
| "toml>=0.10.0", | |
| "tensorboard>=2.13.0", | |
| "wandb>=0.15.0", | |
| "scipy>=1.9.0", | |
| "matplotlib>=3.5.0", | |
| "datasets>=2.14.0", | |
| "peft>=0.5.0", | |
| "omegaconf>=2.3.0" | |
| ] | |
| # Instalar pacotes | |
| for package in packages: | |
| try: | |
| subprocess.run([ | |
| sys.executable, "-m", "pip", "install", package, "--quiet" | |
| ], check=True, capture_output=True, text=True) | |
| logger.info(f"✓ {package} instalado") | |
| except subprocess.CalledProcessError as e: | |
| logger.warning(f"⚠ Erro ao instalar {package}: {e}") | |
| return "✅ Dependências instaladas com sucesso!" | |
| except Exception as e: | |
| logger.error(f"Erro ao instalar dependências: {e}") | |
| return f"❌ Erro ao instalar dependências: {e}" | |
| def download_model(self, model_choice: str, custom_url: str = "") -> str: | |
| """Download do modelo base""" | |
| try: | |
| if custom_url.strip(): | |
| model_url = custom_url.strip() | |
| model_name = model_url.split("/")[-1] | |
| else: | |
| if model_choice not in self.model_urls: | |
| return f"❌ Modelo '{model_choice}' não encontrado" | |
| model_url = self.model_urls[model_choice] | |
| model_name = model_url.split("/")[-1] | |
| model_path = self.models_dir / model_name | |
| if model_path.exists(): | |
| return f"✅ Modelo já existe: {model_name}" | |
| logger.info(f"Baixando modelo: {model_url}") | |
| # Download usando wget | |
| result = subprocess.run([ | |
| "wget", "-O", str(model_path), model_url, "--progress=bar:force" | |
| ], capture_output=True, text=True) | |
| if result.returncode == 0: | |
| return f"✅ Modelo baixado: {model_name} ({model_path.stat().st_size // (1024*1024)} MB)" | |
| else: | |
| return f"❌ Erro no download: {result.stderr}" | |
| except Exception as e: | |
| logger.error(f"Erro ao baixar modelo: {e}") | |
| return f"❌ Erro ao baixar modelo: {e}" | |
| def process_dataset(self, dataset_zip, project_name: str) -> Tuple[str, str]: | |
| """Processa o dataset enviado""" | |
| try: | |
| if not dataset_zip: | |
| return "❌ Nenhum dataset foi enviado", "" | |
| if not project_name.strip(): | |
| return "❌ Nome do projeto é obrigatório", "" | |
| project_name = project_name.strip().replace(" ", "_") | |
| project_dir = self.projects_dir / project_name | |
| project_dir.mkdir(exist_ok=True) | |
| dataset_dir = project_dir / "dataset" | |
| if dataset_dir.exists(): | |
| shutil.rmtree(dataset_dir) | |
| dataset_dir.mkdir() | |
| # Extrair ZIP | |
| with zipfile.ZipFile(dataset_zip.name, 'r') as zip_ref: | |
| zip_ref.extractall(dataset_dir) | |
| # Analisar dataset | |
| image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff'} | |
| images = [] | |
| captions = [] | |
| for file_path in dataset_dir.rglob("*"): | |
| if file_path.suffix.lower() in image_extensions: | |
| images.append(file_path) | |
| # Procurar caption | |
| caption_path = file_path.with_suffix('.txt') | |
| if caption_path.exists(): | |
| captions.append(caption_path) | |
| info = f"✅ Dataset processado!\n" | |
| info += f"📁 Projeto: {project_name}\n" | |
| info += f"🖼️ Imagens: {len(images)}\n" | |
| info += f"📝 Captions: {len(captions)}\n" | |
| info += f"📂 Diretório: {dataset_dir}" | |
| return info, str(dataset_dir) | |
| except Exception as e: | |
| logger.error(f"Erro ao processar dataset: {e}") | |
| return f"❌ Erro ao processar dataset: {e}", "" | |
| def create_training_config(self, | |
| project_name: str, | |
| dataset_dir: str, | |
| model_choice: str, | |
| custom_model_url: str, | |
| resolution: int, | |
| batch_size: int, | |
| epochs: int, | |
| learning_rate: float, | |
| text_encoder_lr: float, | |
| network_dim: int, | |
| network_alpha: int, | |
| lora_type: str, | |
| optimizer: str, | |
| scheduler: str, | |
| flip_aug: bool, | |
| shuffle_caption: bool, | |
| keep_tokens: int, | |
| clip_skip: int, | |
| mixed_precision: str, | |
| save_every_n_epochs: int, | |
| max_train_steps: int) -> str: | |
| """Cria configuração de treinamento""" | |
| try: | |
| if not project_name.strip(): | |
| return "❌ Nome do projeto é obrigatório" | |
| project_name = project_name.strip().replace(" ", "_") | |
| project_dir = self.projects_dir / project_name | |
| project_dir.mkdir(exist_ok=True) | |
| output_dir = project_dir / "output" | |
| output_dir.mkdir(exist_ok=True) | |
| log_dir = project_dir / "logs" | |
| log_dir.mkdir(exist_ok=True) | |
| # Determinar modelo | |
| if custom_model_url.strip(): | |
| model_name = custom_model_url.strip().split("/")[-1] | |
| else: | |
| model_name = self.model_urls[model_choice].split("/")[-1] | |
| model_path = self.models_dir / model_name | |
| if not model_path.exists(): | |
| return f"❌ Modelo não encontrado: {model_name}. Faça o download primeiro." | |
| # Configuração do dataset | |
| dataset_config = { | |
| "general": { | |
| "shuffle_caption": shuffle_caption, | |
| "caption_extension": ".txt", | |
| "keep_tokens": keep_tokens, | |
| "flip_aug": flip_aug, | |
| "color_aug": False, | |
| "face_crop_aug_range": None, | |
| "random_crop": False, | |
| "debug_dataset": False | |
| }, | |
| "datasets": [{ | |
| "resolution": resolution, | |
| "batch_size": batch_size, | |
| "subsets": [{ | |
| "image_dir": str(dataset_dir), | |
| "num_repeats": 1 | |
| }] | |
| }] | |
| } | |
| # Configuração de treinamento | |
| training_config = { | |
| "model_arguments": { | |
| "pretrained_model_name_or_path": str(model_path), | |
| "v2": False, | |
| "v_parameterization": False, | |
| "clip_skip": clip_skip | |
| }, | |
| "dataset_arguments": { | |
| "dataset_config": str(project_dir / "dataset_config.toml") | |
| }, | |
| "training_arguments": { | |
| "output_dir": str(output_dir), | |
| "output_name": project_name, | |
| "save_precision": "fp16", | |
| "save_every_n_epochs": save_every_n_epochs, | |
| "max_train_epochs": epochs if max_train_steps == 0 else None, | |
| "max_train_steps": max_train_steps if max_train_steps > 0 else None, | |
| "train_batch_size": batch_size, | |
| "gradient_accumulation_steps": 1, | |
| "learning_rate": learning_rate, | |
| "text_encoder_lr": text_encoder_lr, | |
| "lr_scheduler": scheduler, | |
| "lr_warmup_steps": 0, | |
| "optimizer_type": optimizer, | |
| "mixed_precision": mixed_precision, | |
| "save_model_as": "safetensors", | |
| "seed": 42, | |
| "max_data_loader_n_workers": 2, | |
| "persistent_data_loader_workers": True, | |
| "gradient_checkpointing": True, | |
| "xformers": True, | |
| "lowram": True, | |
| "cache_latents": True, | |
| "cache_latents_to_disk": True, | |
| "logging_dir": str(log_dir), | |
| "log_with": "tensorboard" | |
| }, | |
| "network_arguments": { | |
| "network_module": "networks.lora" if lora_type == "LoRA" else "networks.dylora", | |
| "network_dim": network_dim, | |
| "network_alpha": network_alpha, | |
| "network_train_unet_only": False, | |
| "network_train_text_encoder_only": False | |
| } | |
| } | |
| # Adicionar argumentos específicos para LoCon | |
| if lora_type == "LoCon": | |
| training_config["network_arguments"]["network_module"] = "networks.lora" | |
| training_config["network_arguments"]["conv_dim"] = max(1, network_dim // 2) | |
| training_config["network_arguments"]["conv_alpha"] = max(1, network_alpha // 2) | |
| # Salvar configurações | |
| dataset_config_path = project_dir / "dataset_config.toml" | |
| training_config_path = project_dir / "training_config.toml" | |
| with open(dataset_config_path, 'w') as f: | |
| toml.dump(dataset_config, f) | |
| with open(training_config_path, 'w') as f: | |
| toml.dump(training_config, f) | |
| return f"✅ Configuração criada!\n📁 Dataset: {dataset_config_path}\n⚙️ Treinamento: {training_config_path}" | |
| except Exception as e: | |
| logger.error(f"Erro ao criar configuração: {e}") | |
| return f"❌ Erro ao criar configuração: {e}" | |
| def start_training(self, project_name: str) -> str: | |
| """Inicia o treinamento""" | |
| try: | |
| if not project_name.strip(): | |
| return "❌ Nome do projeto é obrigatório" | |
| project_name = project_name.strip().replace(" ", "_") | |
| project_dir = self.projects_dir / project_name | |
| training_config_path = project_dir / "training_config.toml" | |
| if not training_config_path.exists(): | |
| return "❌ Configuração não encontrada. Crie a configuração primeiro." | |
| # Script de treinamento | |
| train_script = self.sd_scripts_dir / "train_network.py" | |
| if not train_script.exists(): | |
| return "❌ Script de treinamento não encontrado" | |
| # Comando de treinamento | |
| cmd = [ | |
| sys.executable, | |
| str(train_script), | |
| "--config_file", str(training_config_path) | |
| ] | |
| logger.info(f"Iniciando treinamento: {' '.join(cmd)}") | |
| # Executar em thread separada | |
| def run_training(): | |
| try: | |
| process = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| bufsize=1, | |
| universal_newlines=True, | |
| cwd=str(self.sd_scripts_dir) | |
| ) | |
| self.training_process = process | |
| for line in process.stdout: | |
| self.training_output_queue.put(line.strip()) | |
| logger.info(line.strip()) | |
| process.wait() | |
| if process.returncode == 0: | |
| self.training_output_queue.put("✅ TREINAMENTO CONCLUÍDO COM SUCESSO!") | |
| else: | |
| self.training_output_queue.put(f"❌ TREINAMENTO FALHOU (código {process.returncode})") | |
| except Exception as e: | |
| self.training_output_queue.put(f"❌ ERRO NO TREINAMENTO: {e}") | |
| finally: | |
| self.training_process = None | |
| # Iniciar thread | |
| training_thread = threading.Thread(target=run_training) | |
| training_thread.daemon = True | |
| training_thread.start() | |
| return "🚀 Treinamento iniciado! Acompanhe o progresso abaixo." | |
| except Exception as e: | |
| logger.error(f"Erro ao iniciar treinamento: {e}") | |
| return f"❌ Erro ao iniciar treinamento: {e}" | |
| def get_training_output(self) -> str: | |
| """Obtém output do treinamento""" | |
| output_lines = [] | |
| try: | |
| while not self.training_output_queue.empty(): | |
| line = self.training_output_queue.get_nowait() | |
| output_lines.append(line) | |
| except queue.Empty: | |
| pass | |
| if output_lines: | |
| return "\n".join(output_lines) | |
| elif self.training_process and self.training_process.poll() is None: | |
| return "🔄 Treinamento em andamento..." | |
| else: | |
| return "⏸️ Nenhum treinamento ativo" | |
| def stop_training(self) -> str: | |
| """Para o treinamento""" | |
| try: | |
| if self.training_process and self.training_process.poll() is None: | |
| self.training_process.terminate() | |
| self.training_process.wait(timeout=10) | |
| return "⏹️ Treinamento interrompido" | |
| else: | |
| return "ℹ️ Nenhum treinamento ativo para parar" | |
| except Exception as e: | |
| return f"❌ Erro ao parar treinamento: {e}" | |
| def list_output_files(self, project_name: str) -> List[str]: | |
| """Lista arquivos de saída""" | |
| try: | |
| if not project_name.strip(): | |
| return [] | |
| project_name = project_name.strip().replace(" ", "_") | |
| project_dir = self.projects_dir / project_name | |
| output_dir = project_dir / "output" | |
| if not output_dir.exists(): | |
| return [] | |
| files = [] | |
| for file_path in output_dir.rglob("*.safetensors"): | |
| size_mb = file_path.stat().st_size // (1024 * 1024) | |
| files.append(f"{file_path.name} ({size_mb} MB)") | |
| return sorted(files, reverse=True) # Mais recentes primeiro | |
| except Exception as e: | |
| logger.error(f"Erro ao listar arquivos: {e}") | |
| return [] | |
| # Instância global | |
| trainer = LoRATrainerHF() | |
| def create_interface(): | |
| """Cria a interface Gradio""" | |
| with gr.Blocks(title="LoRA Trainer Funcional - Hugging Face", theme=gr.themes.Soft()) as interface: | |
| gr.Markdown(""" | |
| # 🎨 LoRA Trainer Funcional para Hugging Face | |
| **Treine seus próprios modelos LoRA para Stable Diffusion de forma profissional!** | |
| Esta ferramenta é baseada no kohya-ss sd-scripts e oferece treinamento real e funcional de modelos LoRA. | |
| """) | |
| # Estado para armazenar informações | |
| dataset_dir_state = gr.State("") | |
| with gr.Tab("🔧 Instalação"): | |
| gr.Markdown("### Primeiro, instale as dependências necessárias:") | |
| install_btn = gr.Button("📦 Instalar Dependências", variant="primary", size="lg") | |
| install_status = gr.Textbox(label="Status da Instalação", lines=3, interactive=False) | |
| install_btn.click( | |
| fn=trainer.install_dependencies, | |
| outputs=install_status | |
| ) | |
| with gr.Tab("📁 Configuração do Projeto"): | |
| with gr.Row(): | |
| project_name = gr.Textbox( | |
| label="Nome do Projeto", | |
| placeholder="meu_lora_anime", | |
| info="Nome único para seu projeto (sem espaços especiais)" | |
| ) | |
| gr.Markdown("### 📥 Download do Modelo Base") | |
| with gr.Row(): | |
| model_choice = gr.Dropdown( | |
| choices=list(trainer.model_urls.keys()), | |
| label="Modelo Base Pré-definido", | |
| value="Anime (animefull-final-pruned)", | |
| info="Escolha um modelo base ou use URL personalizada" | |
| ) | |
| custom_model_url = gr.Textbox( | |
| label="URL Personalizada (opcional)", | |
| placeholder="https://huggingface.co/...", | |
| info="URL direta para download de modelo personalizado" | |
| ) | |
| download_btn = gr.Button("📥 Baixar Modelo", variant="primary") | |
| download_status = gr.Textbox(label="Status do Download", lines=2, interactive=False) | |
| gr.Markdown("### 📊 Upload do Dataset") | |
| gr.Markdown(""" | |
| **Formato do Dataset:** | |
| - Crie um arquivo ZIP contendo suas imagens | |
| - Para cada imagem, inclua um arquivo .txt com o mesmo nome contendo as tags/descrições | |
| - Exemplo: `imagem1.jpg` + `imagem1.txt` | |
| """) | |
| dataset_upload = gr.File( | |
| label="Upload do Dataset (ZIP)", | |
| file_types=[".zip"] | |
| ) | |
| process_btn = gr.Button("📊 Processar Dataset", variant="primary") | |
| dataset_status = gr.Textbox(label="Status do Dataset", lines=4, interactive=False) | |
| with gr.Tab("⚙️ Parâmetros de Treinamento"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### 🖼️ Configurações de Imagem") | |
| resolution = gr.Slider( | |
| minimum=512, maximum=1024, step=64, value=512, | |
| label="Resolução", | |
| info="Resolução das imagens (512 = mais rápido, 1024 = melhor qualidade)" | |
| ) | |
| batch_size = gr.Slider( | |
| minimum=1, maximum=8, step=1, value=1, | |
| label="Batch Size", | |
| info="Imagens por lote (aumente se tiver GPU potente)" | |
| ) | |
| flip_aug = gr.Checkbox( | |
| label="Flip Augmentation", | |
| info="Espelhar imagens para aumentar dataset" | |
| ) | |
| shuffle_caption = gr.Checkbox( | |
| value=True, | |
| label="Shuffle Caption", | |
| info="Embaralhar ordem das tags" | |
| ) | |
| keep_tokens = gr.Slider( | |
| minimum=0, maximum=5, step=1, value=1, | |
| label="Keep Tokens", | |
| info="Número de tokens iniciais que não serão embaralhados" | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("#### 🎯 Configurações de Treinamento") | |
| epochs = gr.Slider( | |
| minimum=1, maximum=100, step=1, value=10, | |
| label="Épocas", | |
| info="Número de épocas de treinamento" | |
| ) | |
| max_train_steps = gr.Number( | |
| value=0, | |
| label="Max Train Steps (0 = usar épocas)", | |
| info="Número máximo de steps (deixe 0 para usar épocas)" | |
| ) | |
| save_every_n_epochs = gr.Slider( | |
| minimum=1, maximum=10, step=1, value=1, | |
| label="Salvar a cada N épocas", | |
| info="Frequência de salvamento dos checkpoints" | |
| ) | |
| mixed_precision = gr.Dropdown( | |
| choices=["fp16", "bf16", "no"], | |
| value="fp16", | |
| label="Mixed Precision", | |
| info="fp16 = mais rápido, bf16 = mais estável" | |
| ) | |
| clip_skip = gr.Slider( | |
| minimum=1, maximum=12, step=1, value=2, | |
| label="CLIP Skip", | |
| info="Camadas CLIP a pular (2 para anime, 1 para realista)" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### 📚 Learning Rate") | |
| learning_rate = gr.Number( | |
| value=1e-4, | |
| label="Learning Rate (UNet)", | |
| info="Taxa de aprendizado principal" | |
| ) | |
| text_encoder_lr = gr.Number( | |
| value=5e-5, | |
| label="Learning Rate (Text Encoder)", | |
| info="Taxa de aprendizado do text encoder" | |
| ) | |
| scheduler = gr.Dropdown( | |
| choices=["cosine", "cosine_with_restarts", "constant", "constant_with_warmup", "linear"], | |
| value="cosine_with_restarts", | |
| label="LR Scheduler", | |
| info="Algoritmo de ajuste da learning rate" | |
| ) | |
| optimizer = gr.Dropdown( | |
| choices=["AdamW8bit", "AdamW", "Lion", "SGD"], | |
| value="AdamW8bit", | |
| label="Otimizador", | |
| info="AdamW8bit = menos memória" | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("#### 🧠 Arquitetura LoRA") | |
| lora_type = gr.Radio( | |
| choices=["LoRA", "LoCon"], | |
| value="LoRA", | |
| label="Tipo de LoRA", | |
| info="LoRA = geral, LoCon = estilos artísticos" | |
| ) | |
| network_dim = gr.Slider( | |
| minimum=4, maximum=128, step=4, value=32, | |
| label="Network Dimension", | |
| info="Dimensão da rede (maior = mais detalhes, mais memória)" | |
| ) | |
| network_alpha = gr.Slider( | |
| minimum=1, maximum=128, step=1, value=16, | |
| label="Network Alpha", | |
| info="Controla a força do LoRA (geralmente dim/2)" | |
| ) | |
| with gr.Tab("🚀 Treinamento"): | |
| create_config_btn = gr.Button("📝 Criar Configuração de Treinamento", variant="primary", size="lg") | |
| config_status = gr.Textbox(label="Status da Configuração", lines=3, interactive=False) | |
| with gr.Row(): | |
| start_training_btn = gr.Button("🎯 Iniciar Treinamento", variant="primary", size="lg") | |
| stop_training_btn = gr.Button("⏹️ Parar Treinamento", variant="stop") | |
| training_output = gr.Textbox( | |
| label="Output do Treinamento", | |
| lines=15, | |
| interactive=False, | |
| info="Acompanhe o progresso do treinamento em tempo real" | |
| ) | |
| # Auto-refresh do output | |
| def update_output(): | |
| return trainer.get_training_output() | |
| with gr.Tab("📥 Download dos Resultados"): | |
| refresh_files_btn = gr.Button("🔄 Atualizar Lista de Arquivos", variant="secondary") | |
| output_files = gr.Dropdown( | |
| label="Arquivos LoRA Gerados", | |
| choices=[], | |
| info="Selecione um arquivo para download" | |
| ) | |
| download_info = gr.Markdown("ℹ️ Os arquivos LoRA estarão disponíveis após o treinamento") | |
| # Event handlers | |
| download_btn.click( | |
| fn=trainer.download_model, | |
| inputs=[model_choice, custom_model_url], | |
| outputs=download_status | |
| ) | |
| process_btn.click( | |
| fn=trainer.process_dataset, | |
| inputs=[dataset_upload, project_name], | |
| outputs=[dataset_status, dataset_dir_state] | |
| ) | |
| create_config_btn.click( | |
| fn=trainer.create_training_config, | |
| inputs=[ | |
| project_name, dataset_dir_state, model_choice, custom_model_url, | |
| resolution, batch_size, epochs, learning_rate, text_encoder_lr, | |
| network_dim, network_alpha, lora_type, optimizer, scheduler, | |
| flip_aug, shuffle_caption, keep_tokens, clip_skip, mixed_precision, | |
| save_every_n_epochs, max_train_steps | |
| ], | |
| outputs=config_status | |
| ) | |
| start_training_btn.click( | |
| fn=trainer.start_training, | |
| inputs=project_name, | |
| outputs=training_output | |
| ) | |
| stop_training_btn.click( | |
| fn=trainer.stop_training, | |
| outputs=training_output | |
| ) | |
| refresh_files_btn.click( | |
| fn=trainer.list_output_files, | |
| inputs=project_name, | |
| outputs=output_files | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| print("🚀 Iniciando LoRA Trainer Funcional...") | |
| interface = create_interface() | |
| interface.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) | |