Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import random | |
| import re | |
| from typing import Optional | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from miocodec import MioCodecModel | |
| from text import normalize_text | |
| # --------------------------------------------------------------------------- | |
| # Configuration | |
| # --------------------------------------------------------------------------- | |
| MODEL_REPO = os.environ.get("MODEL_REPO", "Aratako/MioTTS-1.7B") | |
| CODEC_REPO = os.environ.get("CODEC_REPO", "Aratako/MioCodec-25Hz-24kHz") | |
| # Global variables for lazy loading | |
| _model = None | |
| _tokenizer = None | |
| _codec = None | |
| # Presets directory | |
| PRESETS_DIR = "presets" | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| TOKEN_PATTERN = re.compile(r"<\|s_(\d+)\|>") | |
| def seed_everything(seed: Optional[int]) -> int: | |
| if seed is None: | |
| seed = random.SystemRandom().randint(0, 2**31 - 1) | |
| print(f"[Info] No seed provided; using random seed {seed}") | |
| os.environ["PYTHONHASHSEED"] = str(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cudnn.deterministic = True | |
| return seed | |
| def parse_speech_tokens(text: str) -> list[int]: | |
| tokens = [int(value) for value in TOKEN_PATTERN.findall(text)] | |
| if not tokens: | |
| raise ValueError("No speech tokens found in LLM output.") | |
| return tokens | |
| # --------------------------------------------------------------------------- | |
| # Model Loading | |
| # --------------------------------------------------------------------------- | |
| def load_models(): | |
| global _model, _tokenizer, _codec | |
| if _model is not None: | |
| return | |
| print(f"[Info] Loading LLM from {MODEL_REPO}...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_REPO, | |
| torch_dtype=torch.bfloat16, | |
| ).to(device) | |
| _model.eval() | |
| print(f"[Info] Loading codec from {CODEC_REPO}...") | |
| _codec = MioCodecModel.from_pretrained(CODEC_REPO) | |
| _codec = _codec.eval().to(device) | |
| print("[Info] Models loaded successfully.") | |
| def get_preset_list() -> list[str]: | |
| if not os.path.exists(PRESETS_DIR): | |
| return [] | |
| presets = [] | |
| for f in os.listdir(PRESETS_DIR): | |
| if f.endswith(".pt"): | |
| presets.append(f[:-3]) | |
| return sorted(presets) | |
| def load_preset_embedding(preset_id: str) -> torch.Tensor: | |
| path = os.path.join(PRESETS_DIR, f"{preset_id}.pt") | |
| if not os.path.exists(path): | |
| raise FileNotFoundError(f"Preset '{preset_id}' not found.") | |
| embedding = torch.load(path, map_location="cpu", weights_only=True) | |
| if isinstance(embedding, dict): | |
| embedding = embedding.get("global_embedding", embedding) | |
| return embedding.squeeze() | |
| # --------------------------------------------------------------------------- | |
| # GPU-decorated Inference Functions | |
| # --------------------------------------------------------------------------- | |
| def run_inference_gpu( | |
| target_text: str, | |
| reference_mode: str, | |
| reference_audio: Optional[tuple[int, np.ndarray]], | |
| preset_id: Optional[str], | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| repetition_penalty: float, | |
| max_tokens: int, | |
| seed: Optional[int], | |
| num_samples: int = 1, | |
| ) -> list[tuple[int, np.ndarray]]: | |
| load_models() | |
| used_seed = seed_everything(None if seed is None else int(seed)) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Normalize text | |
| normalized_text = normalize_text(target_text) | |
| print(f"[Info] Normalized text: {normalized_text}") | |
| # Prepare reference | |
| reference_waveform = None | |
| global_embedding = None | |
| if reference_mode == "upload" and reference_audio is not None: | |
| sr, audio = reference_audio | |
| # Convert to tensor | |
| if audio.ndim == 1: | |
| audio_tensor = torch.from_numpy(audio).float() | |
| else: | |
| audio_tensor = torch.from_numpy(audio.mean(axis=1)).float() | |
| # Resample if needed | |
| codec_sr = _codec.config.sample_rate | |
| if sr != codec_sr: | |
| import torchaudio | |
| audio_tensor = audio_tensor.unsqueeze(0) | |
| resampler = torchaudio.transforms.Resample(sr, codec_sr) | |
| audio_tensor = resampler(audio_tensor).squeeze(0) | |
| # Trim to max 20 seconds | |
| max_samples = int(codec_sr * 20) | |
| if audio_tensor.shape[0] > max_samples: | |
| audio_tensor = audio_tensor[:max_samples] | |
| print(f"[Info] Reference audio trimmed to 20 seconds") | |
| reference_waveform = audio_tensor.to(device) | |
| elif reference_mode == "preset" and preset_id: | |
| global_embedding = load_preset_embedding(preset_id).to(device) | |
| else: | |
| raise ValueError("Either reference audio or preset must be provided.") | |
| # Tokenize input | |
| messages = [{"role": "user", "content": normalized_text}] | |
| input_text = _tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = _tokenizer(input_text, return_tensors="pt").to(device) | |
| # Remove token_type_ids if present (not used by this model) | |
| inputs.pop("token_type_ids", None) | |
| # Generate (batch) | |
| with torch.no_grad(): | |
| outputs = _model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=True, | |
| pad_token_id=_tokenizer.eos_token_id, | |
| num_return_sequences=num_samples, | |
| ) | |
| # Parse all generated sequences | |
| tokens_list = [] | |
| for i in range(outputs.shape[0]): | |
| generated_text = _tokenizer.decode(outputs[i], skip_special_tokens=False) | |
| generated_part = generated_text[len(input_text):] | |
| try: | |
| speech_tokens = parse_speech_tokens(generated_part) | |
| tokens_list.append(speech_tokens) | |
| except ValueError as e: | |
| print(f"[Warning] Sample {i + 1}: {e}") | |
| if not tokens_list: | |
| raise ValueError("No valid speech tokens generated.") | |
| # Decode audio (batch) | |
| results = [] | |
| sample_rate = _codec.config.sample_rate | |
| with torch.no_grad(): | |
| # Prepare batch tokens | |
| max_len = max(len(t) for t in tokens_list) | |
| batch_tokens = torch.zeros((len(tokens_list), max_len), dtype=torch.long, device=device) | |
| content_lengths = [] | |
| for i, tokens in enumerate(tokens_list): | |
| batch_tokens[i, :len(tokens)] = torch.tensor(tokens, dtype=torch.long) | |
| content_lengths.append(len(tokens)) | |
| # Get global embeddings | |
| if reference_waveform is not None: | |
| # Extract global embedding from reference waveform | |
| ref_features = _codec.encode(reference_waveform, return_content=False, return_global=True) | |
| global_embeddings = ref_features.global_embedding.unsqueeze(0).expand(len(tokens_list), -1) | |
| else: | |
| global_embeddings = global_embedding.unsqueeze(0).expand(len(tokens_list), -1) | |
| # Batch decode | |
| audio_batch, audio_lengths = _codec.decode_batch( | |
| global_embeddings=global_embeddings, | |
| content_token_indices=batch_tokens, | |
| content_lengths=content_lengths, | |
| ) | |
| for i in range(len(tokens_list)): | |
| audio_len = int(audio_lengths[i]) | |
| audio_np = audio_batch[i, :audio_len].cpu().numpy() | |
| results.append((sample_rate, audio_np)) | |
| print(f"[Info] Seed used: {used_seed}") | |
| return results | |
| # Load models at startup | |
| load_models() | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| MAX_NUM_SAMPLES = 32 | |
| def gradio_inference( | |
| target_text: str, | |
| reference_mode: str, | |
| reference_audio: Optional[tuple[int, np.ndarray]], | |
| preset_id: Optional[str], | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| repetition_penalty: float, | |
| max_tokens: int, | |
| seed: str, | |
| num_samples: int, | |
| ): | |
| if not target_text.strip(): | |
| outputs = [gr.update(value=None, visible=False) for _ in range(MAX_NUM_SAMPLES)] | |
| return outputs | |
| seed_val = None | |
| if seed.strip() not in {"", "None", "none"}: | |
| seed_val = int(float(seed)) | |
| try: | |
| results = run_inference_gpu( | |
| target_text=target_text, | |
| reference_mode=reference_mode, | |
| reference_audio=reference_audio, | |
| preset_id=preset_id, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| max_tokens=max_tokens, | |
| seed=seed_val, | |
| num_samples=int(num_samples), | |
| ) | |
| except Exception as e: | |
| print(f"[Error] {e}") | |
| raise gr.Error(str(e)) | |
| outputs = [] | |
| for i in range(MAX_NUM_SAMPLES): | |
| if i < len(results): | |
| outputs.append(gr.update(value=results[i], visible=True)) | |
| else: | |
| outputs.append(gr.update(value=None, visible=False)) | |
| return outputs | |
| def build_demo(): | |
| presets = get_preset_list() | |
| MODEL_LINK = f"https://huggingface.co/{MODEL_REPO}" | |
| GITHUB_REPO = "https://github.com/Aratako/MioTTS-Inference" | |
| title = "# MioTTS-0.1B Demo" | |
| description = f""" | |
| - **Model**: [{MODEL_REPO}]({MODEL_LINK}) | |
| - For faster and more efficient inference, see [MioTTS-Inference]({GITHUB_REPO}) | |
| **Usage:** | |
| - Select a preset voice OR upload your own reference audio (max 20 seconds) | |
| - Enter text to synthesize | |
| - Adjust generation parameters as needed | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| reference_mode = gr.Radio( | |
| choices=["preset", "upload"], | |
| value="preset", | |
| label="Reference Mode", | |
| ) | |
| preset_id = gr.Dropdown( | |
| choices=presets, | |
| value=presets[0] if presets else None, | |
| label="Preset Voice", | |
| allow_custom_value=False, | |
| visible=True, | |
| ) | |
| reference_audio = gr.Audio( | |
| label="Reference Audio", | |
| type="numpy", | |
| visible=False, | |
| ) | |
| def update_reference_visibility(mode): | |
| if mode == "preset": | |
| return gr.update(visible=True), gr.update(visible=False) | |
| else: | |
| return gr.update(visible=False), gr.update(visible=True) | |
| reference_mode.change( | |
| fn=update_reference_visibility, | |
| inputs=[reference_mode], | |
| outputs=[preset_id, reference_audio], | |
| ) | |
| target_text = gr.Textbox( | |
| label="Text to Synthesize", | |
| value="", | |
| placeholder="Enter text to synthesize", | |
| lines=3, | |
| ) | |
| with gr.Row(): | |
| seed_box = gr.Textbox( | |
| label="Seed (optional)", | |
| value="", | |
| placeholder="Leave blank for random", | |
| ) | |
| num_samples = gr.Slider( | |
| label="Number of Samples", | |
| minimum=1, | |
| maximum=MAX_NUM_SAMPLES, | |
| step=1, | |
| value=1, | |
| ) | |
| with gr.Row(): | |
| temperature = gr.Slider( | |
| label="Temperature", minimum=0.1, maximum=1.5, step=0.05, value=0.8 | |
| ) | |
| top_p = gr.Slider( | |
| label="Top-p", minimum=0.1, maximum=1.0, step=0.05, value=1.0 | |
| ) | |
| top_k = gr.Slider( | |
| label="Top-k", minimum=0, maximum=100, step=1, value=50 | |
| ) | |
| with gr.Row(): | |
| repetition_penalty = gr.Slider( | |
| label="Repetition Penalty", | |
| minimum=1.0, | |
| maximum=1.5, | |
| step=0.05, | |
| value=1.0, | |
| ) | |
| max_tokens = gr.Slider( | |
| label="Max Tokens", | |
| minimum=100, | |
| maximum=1000, | |
| step=50, | |
| value=700, | |
| ) | |
| generate_button = gr.Button("Generate", variant="primary") | |
| # Output audio components | |
| output_audios = [] | |
| cols_per_row = 4 | |
| num_rows = (MAX_NUM_SAMPLES + cols_per_row - 1) // cols_per_row | |
| with gr.Column(): | |
| for row_idx in range(num_rows): | |
| with gr.Row(): | |
| for col_idx in range(cols_per_row): | |
| i = row_idx * cols_per_row + col_idx | |
| if i >= MAX_NUM_SAMPLES: | |
| break | |
| audio = gr.Audio( | |
| label=f"Sample #{i+1}", | |
| type="numpy", | |
| interactive=False, | |
| visible=(i == 0), | |
| ) | |
| output_audios.append(audio) | |
| generate_button.click( | |
| fn=gradio_inference, | |
| inputs=[ | |
| target_text, | |
| reference_mode, | |
| reference_audio, | |
| preset_id, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty, | |
| max_tokens, | |
| seed_box, | |
| num_samples, | |
| ], | |
| outputs=output_audios, | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_demo() | |
| demo.launch() | |