Spaces:
Running
Running
| """ | |
| ACE-Step 1.5 Custom Edition - Main Application | |
| A comprehensive music generation system with three main interfaces: | |
| 1. Standard ACE-Step GUI | |
| 2. Custom Timeline-based Workflow | |
| 3. LoRA Training Studio | |
| """ | |
| import gradio as gr | |
| import pandas as pd | |
| import torch | |
| import numpy as np | |
| from pathlib import Path | |
| import json | |
| from typing import Optional, List, Tuple | |
| try: | |
| import spaces | |
| except ImportError: | |
| # Local dev — make @spaces.GPU a no-op | |
| class _Spaces: | |
| def GPU(self, fn=None, **kwargs): | |
| return fn if fn else lambda f: f | |
| spaces = _Spaces() | |
| from src.ace_step_engine import ACEStepEngine | |
| from src.timeline_manager import TimelineManager | |
| from src.lora_trainer import download_hf_dataset, upload_dataset_json_to_hf | |
| from src.audio_processor import AudioProcessor | |
| from src.utils import setup_logging, load_config | |
| from acestep.training.dataset_builder import DatasetBuilder | |
| from acestep.training.configs import LoRAConfig, TrainingConfig | |
| from acestep.training.trainer import LoRATrainer as FabricLoRATrainer | |
| # Setup | |
| logger = setup_logging() | |
| config = load_config() | |
| # Lazy initialize components (will be initialized on first use) | |
| ace_engine = None | |
| timeline_manager = None | |
| dataset_builder = None | |
| audio_processor = None | |
| # Module-level mutable dict for training stop signal | |
| # (gr.State is not shared between concurrent Gradio calls) | |
| _training_control = {"should_stop": False} | |
| def get_ace_engine(): | |
| """Lazy-load ACE-Step engine.""" | |
| global ace_engine | |
| if ace_engine is None: | |
| ace_engine = ACEStepEngine(config) | |
| return ace_engine | |
| def get_timeline_manager(): | |
| """Lazy-load timeline manager.""" | |
| global timeline_manager | |
| if timeline_manager is None: | |
| timeline_manager = TimelineManager(config) | |
| return timeline_manager | |
| def get_dataset_builder(): | |
| """Lazy-load dataset builder.""" | |
| global dataset_builder | |
| if dataset_builder is None: | |
| dataset_builder = DatasetBuilder() | |
| return dataset_builder | |
| def get_audio_processor(): | |
| """Lazy-load audio processor.""" | |
| global audio_processor | |
| if audio_processor is None: | |
| audio_processor = AudioProcessor(config) | |
| return audio_processor | |
| # ==================== TAB 1: STANDARD ACE-STEP GUI ==================== | |
| def standard_generate( | |
| prompt: str, | |
| lyrics: str, | |
| duration: int, | |
| temperature: float, | |
| top_p: float, | |
| seed: int, | |
| style: str, | |
| use_lora: bool, | |
| lora_path: Optional[str] = None | |
| ) -> Tuple[str, str]: | |
| """Standard ACE-Step generation with all original features.""" | |
| try: | |
| logger.info(f"Standard generation: {prompt[:50]}...") | |
| # Get engine instance | |
| engine = get_ace_engine() | |
| # Generate audio | |
| audio_path = engine.generate( | |
| prompt=prompt, | |
| lyrics=lyrics, | |
| duration=duration, | |
| temperature=temperature, | |
| top_p=top_p, | |
| seed=seed, | |
| style=style, | |
| lora_path=lora_path if use_lora else None | |
| ) | |
| info = f"✅ Generated {duration}s audio successfully" | |
| return audio_path, info | |
| except Exception as e: | |
| logger.error(f"Standard generation failed: {e}") | |
| return None, f"❌ Error: {str(e)}" | |
| def standard_variation(audio_path: str, variation_strength: float) -> Tuple[str, str]: | |
| """Generate variation of existing audio.""" | |
| try: | |
| result = get_ace_engine().generate_variation(audio_path, variation_strength) | |
| return result, "✅ Variation generated" | |
| except Exception as e: | |
| return None, f"❌ Error: {str(e)}" | |
| def standard_repaint( | |
| audio_path: str, | |
| start_time: float, | |
| end_time: float, | |
| new_prompt: str | |
| ) -> Tuple[str, str]: | |
| """Repaint specific section of audio.""" | |
| try: | |
| result = get_ace_engine().repaint(audio_path, start_time, end_time, new_prompt) | |
| return result, f"✅ Repainted {start_time}s-{end_time}s" | |
| except Exception as e: | |
| return None, f"❌ Error: {str(e)}" | |
| def standard_lyric_edit( | |
| audio_path: str, | |
| new_lyrics: str | |
| ) -> Tuple[str, str]: | |
| """Edit lyrics while maintaining music.""" | |
| try: | |
| result = get_ace_engine().edit_lyrics(audio_path, new_lyrics) | |
| return result, "✅ Lyrics edited" | |
| except Exception as e: | |
| return None, f"❌ Error: {str(e)}" | |
| # ==================== TAB 2: CUSTOM TIMELINE WORKFLOW ==================== | |
| def timeline_generate( | |
| prompt: str, | |
| lyrics: str, | |
| context_length: int, | |
| style: str, | |
| temperature: float, | |
| seed: int, | |
| session_state: dict | |
| ) -> Tuple[str, str, str, dict]: | |
| """ | |
| Generate 32-second clip with 2s lead-in, 28s main, 2s lead-out. | |
| Blends with previous clips based on context_length. | |
| """ | |
| try: | |
| # Initialize session state if None | |
| if session_state is None: | |
| session_state = {"timeline_id": None, "total_clips": 0} | |
| logger.info(f"Timeline generation with {context_length}s context") | |
| # Get managers | |
| tm = get_timeline_manager() | |
| engine = get_ace_engine() | |
| ap = get_audio_processor() | |
| # Get context from timeline | |
| context_audio = tm.get_context( | |
| session_state.get("timeline_id"), | |
| context_length | |
| ) | |
| # Generate 32s clip | |
| clip = engine.generate_clip( | |
| prompt=prompt, | |
| lyrics=lyrics, | |
| duration=32, | |
| context_audio=context_audio, | |
| style=style, | |
| temperature=temperature, | |
| seed=seed | |
| ) | |
| # Blend with timeline (2s lead-in and lead-out) | |
| blended_clip = ap.blend_clip( | |
| clip, | |
| tm.get_last_clip(session_state.get("timeline_id")), | |
| lead_in=2.0, | |
| lead_out=2.0 | |
| ) | |
| # Add to timeline | |
| timeline_id = tm.add_clip( | |
| session_state.get("timeline_id"), | |
| blended_clip, | |
| metadata={ | |
| "prompt": prompt, | |
| "lyrics": lyrics, | |
| "context_length": context_length | |
| } | |
| ) | |
| # Update session | |
| session_state["timeline_id"] = timeline_id | |
| session_state["total_clips"] = session_state.get("total_clips", 0) + 1 | |
| # Get full timeline audio | |
| full_audio = tm.export_timeline(timeline_id) | |
| # Get timeline visualization | |
| timeline_viz = tm.visualize_timeline(timeline_id) | |
| info = f"✅ Clip {session_state['total_clips']} added • Total: {tm.get_duration(timeline_id):.1f}s" | |
| return blended_clip, full_audio, timeline_viz, session_state, info | |
| except Exception as e: | |
| logger.error(f"Timeline generation failed: {e}") | |
| return None, None, None, session_state, f"❌ Error: {str(e)}" | |
| def timeline_extend( | |
| prompt: str, | |
| lyrics: str, | |
| context_length: int, | |
| session_state: dict | |
| ) -> Tuple[str, str, str, dict]: | |
| """Extend current timeline with new generation.""" | |
| return timeline_generate( | |
| prompt, lyrics, context_length, "auto", 0.7, -1, session_state | |
| ) | |
| def timeline_inpaint( | |
| start_time: float, | |
| end_time: float, | |
| new_prompt: str, | |
| session_state: dict | |
| ) -> Tuple[str, str, dict]: | |
| """Inpaint specific region in timeline.""" | |
| try: | |
| # Initialize session state if None | |
| if session_state is None: | |
| session_state = {"timeline_id": None, "total_clips": 0} | |
| tm = get_timeline_manager() | |
| timeline_id = session_state.get("timeline_id") | |
| result = tm.inpaint_region( | |
| timeline_id, | |
| start_time, | |
| end_time, | |
| new_prompt | |
| ) | |
| full_audio = tm.export_timeline(timeline_id) | |
| timeline_viz = tm.visualize_timeline(timeline_id) | |
| info = f"✅ Inpainted {start_time:.1f}s-{end_time:.1f}s" | |
| return full_audio, timeline_viz, session_state, info | |
| except Exception as e: | |
| return None, None, session_state, f"❌ Error: {str(e)}" | |
| def timeline_reset(session_state: dict) -> Tuple[None, None, str, dict]: | |
| """Reset timeline to start fresh.""" | |
| # Initialize session state if None | |
| if session_state is None: | |
| session_state = {"timeline_id": None, "total_clips": 0} | |
| elif session_state.get("timeline_id"): | |
| get_timeline_manager().delete_timeline(session_state["timeline_id"]) | |
| session_state = {"timeline_id": None, "total_clips": 0} | |
| return None, None, "Timeline cleared", session_state | |
| # ==================== TAB 3: LORA TRAINING STUDIO ==================== | |
| DATAFRAME_HEADERS = ["#", "Filename", "Duration", "Lyrics", "Labeled", "BPM", "Key", "Caption"] | |
| def _build_progress_summary(): | |
| """Build a one-line progress summary from current dataset builder state.""" | |
| builder = get_dataset_builder() | |
| total = builder.get_sample_count() | |
| labeled = builder.get_labeled_count() | |
| preprocessed = builder.get_preprocessed_count() | |
| remaining = total - labeled | |
| return f"Total: {total} | Labeled: {labeled} | Preprocessed: {preprocessed} | Remaining: {remaining}" | |
| def _build_review_dataframe(): | |
| """Build editable dataframe rows from current dataset builder state.""" | |
| builder = get_dataset_builder() | |
| return builder.get_samples_dataframe_data() | |
| def lora_download_hf(dataset_id, custom_tag, max_files, hf_offset, training_state): | |
| """Download HuggingFace dataset batch, restore labels from HF repo, and scan.""" | |
| try: | |
| if not dataset_id or not dataset_id.strip(): | |
| return "Enter a dataset ID (e.g. username/dataset-name)", training_state, int(hf_offset or 0), _build_progress_summary() | |
| offset_val = int(hf_offset or 0) | |
| max_files_val = int(max_files) | |
| local_dir, dl_status = download_hf_dataset( | |
| dataset_id.strip(), | |
| max_files=max_files_val, | |
| offset=offset_val, | |
| ) | |
| if not local_dir: | |
| return f"Download failed: {dl_status}", training_state, offset_val, _build_progress_summary() | |
| builder = get_dataset_builder() | |
| # Set trigger word for LoRA training | |
| tag = custom_tag.strip() if custom_tag else "" | |
| if tag: | |
| builder.set_custom_tag(tag) | |
| # Restore labels/flags from dataset.json pulled from HF repo | |
| dataset_json_path = str(Path(local_dir) / "dataset.json") | |
| if Path(dataset_json_path).exists(): | |
| builder.load_dataset(dataset_json_path) | |
| dl_status += " | Restored labels from HF repo" | |
| # Scan directory — skips already-tracked files via existing_paths check | |
| samples, scan_status = builder.scan_directory(local_dir) | |
| training_state = training_state or {} | |
| training_state["audio_dir"] = local_dir | |
| training_state["dataset_id"] = dataset_id.strip() | |
| training_state["dataset_path"] = dataset_json_path | |
| next_offset = offset_val + max_files_val | |
| return f"{dl_status} | {scan_status}", training_state, next_offset, _build_progress_summary() | |
| except Exception as e: | |
| logger.error(f"HF download failed: {e}") | |
| return f"Error: {e}", training_state or {}, int(hf_offset or 0), _build_progress_summary() | |
| def lora_save_dataset_to_json(training_state): | |
| """Explicitly save the current dataset to JSON.""" | |
| try: | |
| builder = get_dataset_builder() | |
| if builder.get_sample_count() == 0: | |
| return "No samples to save" | |
| training_state = training_state or {} | |
| dataset_path = training_state.get("dataset_path") | |
| if not dataset_path: | |
| audio_dir = training_state.get("audio_dir", "lora_training") | |
| dataset_path = str(Path(audio_dir) / "dataset.json") | |
| training_state["dataset_path"] = dataset_path | |
| return builder.save_dataset(dataset_path) | |
| except Exception as e: | |
| logger.error(f"Save dataset failed: {e}") | |
| return f"Error: {e}" | |
| def lora_auto_label(label_batch_size, training_state, progress=gr.Progress()): | |
| """Auto-label unlabeled samples in batches using LLM analysis, then auto-save.""" | |
| try: | |
| builder = get_dataset_builder() | |
| if builder.get_sample_count() == 0: | |
| return [], "No samples loaded. Upload files or download a dataset first.", training_state, _build_progress_summary() | |
| engine = get_ace_engine() | |
| if not engine.is_initialized(): | |
| return [], "ACE-Step engine not initialized. Models may still be loading.", training_state, _build_progress_summary() | |
| def progress_callback(msg): | |
| progress(0, desc=msg) | |
| samples, status = builder.label_all_samples( | |
| dit_handler=engine.dit_handler, | |
| llm_handler=engine.llm_handler, | |
| only_unlabeled=True, | |
| max_count=int(label_batch_size), | |
| progress_callback=progress_callback, | |
| ) | |
| training_state = training_state or {} | |
| dataset_path = training_state.get("dataset_path") | |
| if not dataset_path: | |
| audio_dir = training_state.get("audio_dir", "lora_training") | |
| dataset_path = str(Path(audio_dir) / "dataset.json") | |
| training_state["dataset_path"] = dataset_path | |
| save_status = builder.save_dataset(dataset_path) | |
| status += f"\n{save_status}" | |
| # Sync to HF repo so labels persist across sessions | |
| dataset_id = training_state.get("dataset_id") | |
| if dataset_id: | |
| hf_status = upload_dataset_json_to_hf(dataset_id, dataset_path) | |
| status += f"\n{hf_status}" | |
| return _build_review_dataframe(), status, training_state, _build_progress_summary() | |
| except Exception as e: | |
| logger.error(f"Auto-label failed: {e}") | |
| return [], f"Error: {e}", training_state or {}, _build_progress_summary() | |
| def lora_save_edits(df_data, training_state): | |
| """Save user edits from the review dataframe back to samples.""" | |
| try: | |
| builder = get_dataset_builder() | |
| if df_data is None: | |
| return "No data to save" | |
| if isinstance(df_data, pd.DataFrame): | |
| if df_data.empty: | |
| return "No data to save" | |
| rows = df_data.values.tolist() | |
| elif isinstance(df_data, list): | |
| if len(df_data) == 0: | |
| return "No data to save" | |
| rows = df_data | |
| else: | |
| return "No data to save" | |
| updated = 0 | |
| for row in rows: | |
| idx = int(row[0]) | |
| updates = {} | |
| # Map editable columns back to sample fields | |
| bpm_val = row[5] | |
| if bpm_val and bpm_val != "-": | |
| try: | |
| updates["bpm"] = int(bpm_val) | |
| except (ValueError, TypeError): | |
| pass | |
| key_val = row[6] | |
| if key_val and key_val != "-": | |
| updates["keyscale"] = str(key_val) | |
| caption_val = row[7] | |
| if caption_val and caption_val != "-": | |
| updates["caption"] = str(caption_val) | |
| if updates: | |
| builder.update_sample(idx, **updates) | |
| updated += 1 | |
| return f"Updated {updated} samples" | |
| except Exception as e: | |
| logger.error(f"Save edits failed: {e}") | |
| return f"Error: {e}" | |
| def lora_preprocess(preprocess_batch_size, training_state, progress=gr.Progress()): | |
| """Preprocess labeled samples to training tensors in batches.""" | |
| try: | |
| builder = get_dataset_builder() | |
| if builder.get_labeled_count() == 0: | |
| return "No labeled samples. Run auto-label first.", _build_progress_summary() | |
| engine = get_ace_engine() | |
| if not engine.is_initialized(): | |
| return "ACE-Step engine not initialized.", _build_progress_summary() | |
| tensor_dir = str(Path("lora_training") / "tensors") | |
| def progress_callback(msg): | |
| progress(0, desc=msg) | |
| output_paths, status = builder.preprocess_to_tensors( | |
| dit_handler=engine.dit_handler, | |
| output_dir=tensor_dir, | |
| max_count=int(preprocess_batch_size), | |
| progress_callback=progress_callback, | |
| ) | |
| training_state = training_state or {} | |
| training_state["tensor_dir"] = tensor_dir | |
| # Auto-save so preprocessed flags persist across sessions | |
| dataset_path = training_state.get("dataset_path") | |
| if not dataset_path: | |
| audio_dir = training_state.get("audio_dir", "lora_training") | |
| dataset_path = str(Path(audio_dir) / "dataset.json") | |
| training_state["dataset_path"] = dataset_path | |
| save_status = builder.save_dataset(dataset_path) | |
| status += f"\n{save_status}" | |
| # Sync to HF repo so preprocessed flags persist across sessions | |
| dataset_id = training_state.get("dataset_id") | |
| if dataset_id: | |
| hf_status = upload_dataset_json_to_hf(dataset_id, dataset_path) | |
| status += f"\n{hf_status}" | |
| return status, _build_progress_summary() | |
| except Exception as e: | |
| logger.error(f"Preprocess failed: {e}") | |
| return f"Error: {e}", _build_progress_summary() | |
| def lora_train_real( | |
| lr, batch_size, epochs, rank, alpha, | |
| grad_accum, model_name, training_state, | |
| progress=gr.Progress(), | |
| ): | |
| """Train LoRA using the real Fabric-based trainer.""" | |
| try: | |
| training_state = training_state or {} | |
| tensor_dir = training_state.get("tensor_dir", "") | |
| if not tensor_dir or not Path(tensor_dir).exists(): | |
| return "", "No preprocessed tensors found. Run preprocessing first." | |
| engine = get_ace_engine() | |
| if not engine.is_initialized(): | |
| return "", "ACE-Step engine not initialized." | |
| lora_cfg = LoRAConfig(r=int(rank), alpha=int(alpha)) | |
| output_dir = str(Path("lora_training") / "models" / (model_name or "lora_model")) | |
| train_cfg = TrainingConfig( | |
| learning_rate=float(lr), | |
| batch_size=int(batch_size), | |
| max_epochs=int(epochs), | |
| gradient_accumulation_steps=int(grad_accum), | |
| output_dir=output_dir, | |
| ) | |
| trainer = FabricLoRATrainer( | |
| dit_handler=engine.dit_handler, | |
| lora_config=lora_cfg, | |
| training_config=train_cfg, | |
| ) | |
| _training_control["should_stop"] = False | |
| last_msg = "" | |
| for step, loss, message in trainer.train_from_preprocessed( | |
| tensor_dir=tensor_dir, | |
| training_state=_training_control, | |
| ): | |
| last_msg = f"Step {step} | Loss: {loss:.4f} | {message}" | |
| progress(0, desc=last_msg) | |
| if _training_control.get("should_stop"): | |
| trainer.stop() | |
| last_msg = f"Training stopped at step {step} (loss: {loss:.4f})" | |
| break | |
| final_path = str(Path(output_dir) / "final") | |
| return final_path, last_msg | |
| except Exception as e: | |
| logger.error(f"Training failed: {e}") | |
| return "", f"Error: {e}" | |
| def lora_stop_training(): | |
| """Signal the training loop to stop.""" | |
| _training_control["should_stop"] = True | |
| return "Stop signal sent. Training will stop after current step." | |
| def lora_download_model(model_path): | |
| """Zip the LoRA model directory and return the zip for Gradio file download.""" | |
| import shutil | |
| if not model_path or not Path(model_path).exists(): | |
| return None | |
| path = Path(model_path) | |
| if path.is_dir(): | |
| zip_path = path.parent / path.name | |
| shutil.make_archive(str(zip_path), "zip", root_dir=str(path.parent), base_dir=path.name) | |
| return str(zip_path) + ".zip" | |
| return model_path | |
| # ==================== GRADIO UI ==================== | |
| def create_ui(): | |
| """Create the three-tab Gradio interface.""" | |
| with gr.Blocks(title="ACE-Step 1.5 Custom Edition", theme=gr.themes.Soft()) as app: | |
| gr.Markdown(""" | |
| # 🎵 ACE-Step 1.5 Custom Edition | |
| **Three powerful interfaces for music generation and training** | |
| Models will download automatically on first use (~7GB from HuggingFace) | |
| """) | |
| with gr.Tabs(): | |
| # ============ TAB 1: STANDARD ACE-STEP ============ | |
| with gr.Tab("🎼 Standard ACE-Step"): | |
| gr.Markdown("### Full-featured standard ACE-Step 1.5 interface") | |
| with gr.Row(): | |
| with gr.Column(): | |
| std_prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe the music style, mood, instruments...", | |
| lines=3 | |
| ) | |
| std_lyrics = gr.Textbox( | |
| label="Lyrics (optional)", | |
| placeholder="Enter lyrics here...", | |
| lines=5 | |
| ) | |
| with gr.Row(): | |
| std_duration = gr.Slider( | |
| minimum=10, maximum=240, value=30, step=10, | |
| label="Duration (seconds)" | |
| ) | |
| std_style = gr.Dropdown( | |
| choices=["auto", "pop", "rock", "jazz", "classical", "electronic", "hip-hop"], | |
| value="auto", | |
| label="Style" | |
| ) | |
| with gr.Row(): | |
| std_temperature = gr.Slider( | |
| minimum=0.1, maximum=1.5, value=0.7, step=0.1, | |
| label="Temperature" | |
| ) | |
| std_top_p = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.9, step=0.05, | |
| label="Top P" | |
| ) | |
| std_seed = gr.Number(label="Seed (-1 for random)", value=-1) | |
| with gr.Row(): | |
| std_use_lora = gr.Checkbox(label="Use LoRA", value=False) | |
| std_lora_path = gr.Textbox( | |
| label="LoRA Path", | |
| placeholder="Path to LoRA model (if using)" | |
| ) | |
| std_generate_btn = gr.Button("🎵 Generate", variant="primary", size="lg") | |
| with gr.Column(): | |
| gr.Markdown("### Audio Input (Optional)") | |
| gr.Markdown("*Upload audio file or record to use as style guidance*") | |
| std_audio_input = gr.Audio( | |
| label="Style Reference Audio", | |
| type="filepath" | |
| ) | |
| gr.Markdown("### Generated Output") | |
| std_audio_out = gr.Audio(label="Generated Audio") | |
| std_info = gr.Textbox(label="Status", lines=2) | |
| gr.Markdown("### Advanced Controls") | |
| with gr.Accordion("🔄 Generate Variation", open=False): | |
| std_var_strength = gr.Slider(0.1, 1.0, 0.5, label="Variation Strength") | |
| std_var_btn = gr.Button("Generate Variation") | |
| with gr.Accordion("🎨 Repaint Section", open=False): | |
| std_repaint_start = gr.Number(label="Start Time (s)", value=0) | |
| std_repaint_end = gr.Number(label="End Time (s)", value=10) | |
| std_repaint_prompt = gr.Textbox(label="New Prompt", lines=2) | |
| std_repaint_btn = gr.Button("Repaint") | |
| with gr.Accordion("✏️ Edit Lyrics", open=False): | |
| std_edit_lyrics = gr.Textbox(label="New Lyrics", lines=4) | |
| std_edit_btn = gr.Button("Edit Lyrics") | |
| # Event handlers | |
| std_generate_btn.click( | |
| fn=standard_generate, | |
| inputs=[std_prompt, std_lyrics, std_duration, std_temperature, | |
| std_top_p, std_seed, std_style, std_use_lora, std_lora_path], | |
| outputs=[std_audio_out, std_info] | |
| ) | |
| std_var_btn.click( | |
| fn=standard_variation, | |
| inputs=[std_audio_out, std_var_strength], | |
| outputs=[std_audio_out, std_info] | |
| ) | |
| std_repaint_btn.click( | |
| fn=standard_repaint, | |
| inputs=[std_audio_out, std_repaint_start, std_repaint_end, std_repaint_prompt], | |
| outputs=[std_audio_out, std_info] | |
| ) | |
| std_edit_btn.click( | |
| fn=standard_lyric_edit, | |
| inputs=[std_audio_out, std_edit_lyrics], | |
| outputs=[std_audio_out, std_info] | |
| ) | |
| # ============ TAB 2: CUSTOM TIMELINE ============ | |
| with gr.Tab("⏱️ Timeline Workflow"): | |
| gr.Markdown(""" | |
| ### Custom Timeline-based Generation | |
| Generate 32-second clips that seamlessly blend together on a master timeline. | |
| """) | |
| # Session state for timeline | |
| timeline_state = gr.State(value=None) | |
| with gr.Row(): | |
| with gr.Column(): | |
| tl_prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe this section...", | |
| lines=3 | |
| ) | |
| tl_lyrics = gr.Textbox( | |
| label="Lyrics for this clip", | |
| placeholder="Enter lyrics for this 32s section...", | |
| lines=4 | |
| ) | |
| gr.Markdown("*How far back to reference for style guidance*") | |
| tl_context_length = gr.Slider( | |
| minimum=0, maximum=120, value=30, step=10, | |
| label="Context Length (seconds)" | |
| ) | |
| with gr.Row(): | |
| tl_style = gr.Dropdown( | |
| choices=["auto", "pop", "rock", "jazz", "electronic"], | |
| value="auto", | |
| label="Style" | |
| ) | |
| tl_temperature = gr.Slider( | |
| minimum=0.5, maximum=1.0, value=0.7, step=0.05, | |
| label="Temperature" | |
| ) | |
| tl_seed = gr.Number(label="Seed (-1 for random)", value=-1) | |
| with gr.Row(): | |
| tl_generate_btn = gr.Button("🎵 Generate Clip", variant="primary", size="lg") | |
| tl_extend_btn = gr.Button("➕ Extend", size="lg") | |
| tl_reset_btn = gr.Button("🔄 Reset Timeline", variant="secondary") | |
| tl_info = gr.Textbox(label="Status", lines=2) | |
| with gr.Column(): | |
| tl_clip_audio = gr.Audio(label="Latest Clip") | |
| tl_full_audio = gr.Audio(label="Full Timeline") | |
| tl_timeline_viz = gr.Image(label="Timeline Visualization") | |
| with gr.Accordion("🎨 Inpaint Timeline Region", open=False): | |
| tl_inpaint_start = gr.Number(label="Start Time (s)", value=0) | |
| tl_inpaint_end = gr.Number(label="End Time (s)", value=10) | |
| tl_inpaint_prompt = gr.Textbox(label="New Prompt", lines=2) | |
| tl_inpaint_btn = gr.Button("Inpaint Region") | |
| # Event handlers | |
| tl_generate_btn.click( | |
| fn=timeline_generate, | |
| inputs=[tl_prompt, tl_lyrics, tl_context_length, tl_style, | |
| tl_temperature, tl_seed, timeline_state], | |
| outputs=[tl_clip_audio, tl_full_audio, tl_timeline_viz, timeline_state, tl_info] | |
| ) | |
| tl_extend_btn.click( | |
| fn=timeline_extend, | |
| inputs=[tl_prompt, tl_lyrics, tl_context_length, timeline_state], | |
| outputs=[tl_clip_audio, tl_full_audio, tl_timeline_viz, timeline_state, tl_info] | |
| ) | |
| tl_reset_btn.click( | |
| fn=timeline_reset, | |
| inputs=[timeline_state], | |
| outputs=[tl_clip_audio, tl_full_audio, tl_info, timeline_state] | |
| ) | |
| tl_inpaint_btn.click( | |
| fn=timeline_inpaint, | |
| inputs=[tl_inpaint_start, tl_inpaint_end, tl_inpaint_prompt, timeline_state], | |
| outputs=[tl_full_audio, tl_timeline_viz, timeline_state, tl_info] | |
| ) | |
| # ============ TAB 3: LORA TRAINING STUDIO ============ | |
| with gr.Tab("🎓 LoRA Training Studio"): | |
| gr.Markdown(""" | |
| ### Train Custom LoRA Models | |
| Step-by-step wizard: provide audio data, auto-label with LLM, preprocess, and train. | |
| """) | |
| training_state = gr.State(value={}) | |
| lora_progress = gr.Textbox( | |
| label="Progress", | |
| value="Total: 0 | Labeled: 0 | Preprocessed: 0 | Remaining: 0", | |
| interactive=False, | |
| ) | |
| with gr.Tabs(): | |
| # ---------- Sub-tab 1: Data Source ---------- | |
| with gr.Tab("1. Data Source"): | |
| gr.Markdown( | |
| "Download audio from a HuggingFace dataset repo. " | |
| "Labels and progress are synced back to the repo automatically." | |
| ) | |
| lora_hf_id = gr.Textbox( | |
| label="Dataset ID", | |
| placeholder="username/dataset-name", | |
| ) | |
| lora_custom_tag = gr.Textbox( | |
| label="Custom Tag (trigger word for LoRA)", | |
| placeholder="lofi, synthwave, jazz-piano…", | |
| ) | |
| with gr.Row(): | |
| lora_hf_max = gr.Slider( | |
| minimum=1, maximum=500, value=50, step=1, | |
| label="Batch size", | |
| ) | |
| lora_hf_offset = gr.Number( | |
| label="Offset (auto-increments)", | |
| value=0, | |
| precision=0, | |
| ) | |
| lora_hf_btn = gr.Button( | |
| "Download Batch & Scan", variant="primary" | |
| ) | |
| lora_source_status = gr.Textbox( | |
| label="Status", lines=2, interactive=False | |
| ) | |
| # ---------- Sub-tab 2: Label & Review ---------- | |
| with gr.Tab("2. Label & Review"): | |
| gr.Markdown( | |
| "Auto-label samples using the LLM, then review and edit metadata." | |
| ) | |
| lora_label_batch_size = gr.Slider( | |
| minimum=1, maximum=500, value=50, step=1, | |
| label="Label batch size (samples per run)", | |
| ) | |
| lora_label_btn = gr.Button( | |
| "Label Batch (+ auto-save)", | |
| variant="primary", | |
| ) | |
| lora_label_status = gr.Textbox( | |
| label="Label Status", lines=3, interactive=False | |
| ) | |
| lora_review_df = gr.Dataframe( | |
| headers=DATAFRAME_HEADERS, | |
| label="Sample Review (editable: BPM, Key, Caption)", | |
| interactive=True, | |
| wrap=True, | |
| ) | |
| with gr.Row(): | |
| lora_save_btn = gr.Button("Save Edits") | |
| lora_save_dataset_btn = gr.Button( | |
| "Save Dataset to JSON", variant="secondary" | |
| ) | |
| lora_save_status = gr.Textbox( | |
| label="Save Status", interactive=False | |
| ) | |
| # ---------- Sub-tab 3: Preprocess ---------- | |
| with gr.Tab("3. Preprocess"): | |
| gr.Markdown( | |
| "Encode audio through VAE and text encoders to create training tensors." | |
| ) | |
| lora_preprocess_batch_size = gr.Slider( | |
| minimum=1, maximum=500, value=50, step=1, | |
| label="Preprocess batch size (samples per run)", | |
| ) | |
| lora_preprocess_btn = gr.Button( | |
| "Preprocess Batch (+ auto-save)", variant="primary" | |
| ) | |
| lora_preprocess_status = gr.Textbox( | |
| label="Preprocess Status", lines=3, interactive=False | |
| ) | |
| # ---------- Sub-tab 4: Train ---------- | |
| with gr.Tab("4. Train"): | |
| gr.Markdown("Configure and run LoRA training.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| lora_model_name = gr.Textbox( | |
| label="Model Name", | |
| value="my_lora", | |
| placeholder="my_lora", | |
| ) | |
| with gr.Row(): | |
| lora_lr = gr.Number( | |
| label="Learning Rate", value=1e-4 | |
| ) | |
| lora_batch_size = gr.Slider( | |
| minimum=1, maximum=8, value=1, step=1, | |
| label="Batch Size", | |
| ) | |
| with gr.Row(): | |
| lora_epochs = gr.Slider( | |
| minimum=1, maximum=500, value=100, step=1, | |
| label="Epochs", | |
| ) | |
| lora_grad_accum = gr.Slider( | |
| minimum=1, maximum=16, value=4, step=1, | |
| label="Gradient Accumulation", | |
| ) | |
| with gr.Row(): | |
| lora_rank = gr.Slider( | |
| minimum=4, maximum=128, value=8, step=4, | |
| label="LoRA Rank", | |
| ) | |
| lora_alpha = gr.Slider( | |
| minimum=4, maximum=128, value=16, step=4, | |
| label="LoRA Alpha", | |
| ) | |
| with gr.Row(): | |
| lora_train_btn = gr.Button( | |
| "Start Training", | |
| variant="primary", | |
| size="lg", | |
| ) | |
| lora_stop_btn = gr.Button( | |
| "Stop Training", | |
| variant="stop", | |
| size="lg", | |
| ) | |
| with gr.Column(): | |
| lora_train_status = gr.Textbox( | |
| label="Training Status", | |
| lines=4, | |
| interactive=False, | |
| ) | |
| lora_model_path = gr.Textbox( | |
| label="Model Path", | |
| interactive=False, | |
| ) | |
| lora_dl_btn = gr.Button("Download Model") | |
| lora_dl_file = gr.File(label="Download") | |
| gr.Markdown(""" | |
| #### Tips | |
| - Upload 10+ audio samples for best results | |
| - Keep samples consistent in style/quality | |
| - Higher rank = more capacity but slower training | |
| - Default settings (rank=8, lr=1e-4, 100 epochs) are a good starting point | |
| """) | |
| # ---------- Event handlers ---------- | |
| # Data Source | |
| lora_hf_btn.click( | |
| fn=lora_download_hf, | |
| inputs=[lora_hf_id, lora_custom_tag, lora_hf_max, lora_hf_offset, training_state], | |
| outputs=[lora_source_status, training_state, lora_hf_offset, lora_progress], | |
| ) | |
| # Label & Review | |
| lora_label_btn.click( | |
| fn=lora_auto_label, | |
| inputs=[lora_label_batch_size, training_state], | |
| outputs=[lora_review_df, lora_label_status, training_state, lora_progress], | |
| ) | |
| lora_save_btn.click( | |
| fn=lora_save_edits, | |
| inputs=[lora_review_df, training_state], | |
| outputs=[lora_save_status], | |
| ) | |
| lora_save_dataset_btn.click( | |
| fn=lora_save_dataset_to_json, | |
| inputs=[training_state], | |
| outputs=[lora_save_status], | |
| ) | |
| # Preprocess | |
| lora_preprocess_btn.click( | |
| fn=lora_preprocess, | |
| inputs=[lora_preprocess_batch_size, training_state], | |
| outputs=[lora_preprocess_status, lora_progress], | |
| ) | |
| # Train | |
| lora_train_btn.click( | |
| fn=lora_train_real, | |
| inputs=[ | |
| lora_lr, lora_batch_size, lora_epochs, | |
| lora_rank, lora_alpha, lora_grad_accum, | |
| lora_model_name, training_state, | |
| ], | |
| outputs=[lora_model_path, lora_train_status], | |
| ) | |
| lora_stop_btn.click( | |
| fn=lora_stop_training, | |
| inputs=[], | |
| outputs=[lora_train_status], | |
| ) | |
| lora_dl_btn.click( | |
| fn=lora_download_model, | |
| inputs=[lora_model_path], | |
| outputs=[lora_dl_file], | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### About | |
| ACE-Step 1.5 Custom Edition by Gamahea | Based on [ACE-Step](https://ace-step.github.io/) | |
| """) | |
| return app | |
| # ==================== MAIN ==================== | |
| if __name__ == "__main__": | |
| logger.info("Starting ACE-Step 1.5 Custom Edition...") | |
| try: | |
| # Create and launch app | |
| app = create_ui() | |
| # Monkey patch the get_api_info method to prevent JSON schema errors | |
| original_get_api_info = app.get_api_info | |
| def safe_get_api_info(*args, **kwargs): | |
| """Patched get_api_info that returns minimal info to avoid schema errors""" | |
| try: | |
| return original_get_api_info(*args, **kwargs) | |
| except (TypeError, AttributeError, KeyError) as e: | |
| logger.warning(f"API info generation failed, returning minimal info: {e}") | |
| return { | |
| "named_endpoints": {}, | |
| "unnamed_endpoints": {} | |
| } | |
| app.get_api_info = safe_get_api_info | |
| logger.info("✓ Patched get_api_info method") | |
| # Launch the app | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to launch app: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise | |