import os import random import re from typing import Optional import gradio as gr import numpy as np import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer from miocodec import MioCodecModel from text import normalize_text # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- MODEL_REPO = os.environ.get("MODEL_REPO", "Aratako/MioTTS-1.7B") CODEC_REPO = os.environ.get("CODEC_REPO", "Aratako/MioCodec-25Hz-24kHz") # Global variables for lazy loading _model = None _tokenizer = None _codec = None # Presets directory PRESETS_DIR = "presets" # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- TOKEN_PATTERN = re.compile(r"<\|s_(\d+)\|>") def seed_everything(seed: Optional[int]) -> int: if seed is None: seed = random.SystemRandom().randint(0, 2**31 - 1) print(f"[Info] No seed provided; using random seed {seed}") os.environ["PYTHONHASHSEED"] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True return seed def parse_speech_tokens(text: str) -> list[int]: tokens = [int(value) for value in TOKEN_PATTERN.findall(text)] if not tokens: raise ValueError("No speech tokens found in LLM output.") return tokens # --------------------------------------------------------------------------- # Model Loading # --------------------------------------------------------------------------- def load_models(): global _model, _tokenizer, _codec if _model is not None: return print(f"[Info] Loading LLM from {MODEL_REPO}...") device = "cuda" if torch.cuda.is_available() else "cpu" _tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) _model = AutoModelForCausalLM.from_pretrained( MODEL_REPO, torch_dtype=torch.bfloat16, ).to(device) _model.eval() print(f"[Info] Loading codec from {CODEC_REPO}...") _codec = MioCodecModel.from_pretrained(CODEC_REPO) _codec = _codec.eval().to(device) print("[Info] Models loaded successfully.") def get_preset_list() -> list[str]: if not os.path.exists(PRESETS_DIR): return [] presets = [] for f in os.listdir(PRESETS_DIR): if f.endswith(".pt"): presets.append(f[:-3]) return sorted(presets) def load_preset_embedding(preset_id: str) -> torch.Tensor: path = os.path.join(PRESETS_DIR, f"{preset_id}.pt") if not os.path.exists(path): raise FileNotFoundError(f"Preset '{preset_id}' not found.") embedding = torch.load(path, map_location="cpu", weights_only=True) if isinstance(embedding, dict): embedding = embedding.get("global_embedding", embedding) return embedding.squeeze() # --------------------------------------------------------------------------- # GPU-decorated Inference Functions # --------------------------------------------------------------------------- @spaces.GPU(duration=120) def run_inference_gpu( target_text: str, reference_mode: str, reference_audio: Optional[tuple[int, np.ndarray]], preset_id: Optional[str], temperature: float, top_p: float, top_k: int, repetition_penalty: float, max_tokens: int, seed: Optional[int], num_samples: int = 1, ) -> list[tuple[int, np.ndarray]]: load_models() used_seed = seed_everything(None if seed is None else int(seed)) device = "cuda" if torch.cuda.is_available() else "cpu" # Normalize text normalized_text = normalize_text(target_text) print(f"[Info] Normalized text: {normalized_text}") # Prepare reference reference_waveform = None global_embedding = None if reference_mode == "upload" and reference_audio is not None: sr, audio = reference_audio # Convert to tensor if audio.ndim == 1: audio_tensor = torch.from_numpy(audio).float() else: audio_tensor = torch.from_numpy(audio.mean(axis=1)).float() # Resample if needed codec_sr = _codec.config.sample_rate if sr != codec_sr: import torchaudio audio_tensor = audio_tensor.unsqueeze(0) resampler = torchaudio.transforms.Resample(sr, codec_sr) audio_tensor = resampler(audio_tensor).squeeze(0) # Trim to max 20 seconds max_samples = int(codec_sr * 20) if audio_tensor.shape[0] > max_samples: audio_tensor = audio_tensor[:max_samples] print(f"[Info] Reference audio trimmed to 20 seconds") reference_waveform = audio_tensor.to(device) elif reference_mode == "preset" and preset_id: global_embedding = load_preset_embedding(preset_id).to(device) else: raise ValueError("Either reference audio or preset must be provided.") # Tokenize input messages = [{"role": "user", "content": normalized_text}] input_text = _tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = _tokenizer(input_text, return_tensors="pt").to(device) # Remove token_type_ids if present (not used by this model) inputs.pop("token_type_ids", None) # Generate (batch) with torch.no_grad(): outputs = _model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, do_sample=True, pad_token_id=_tokenizer.eos_token_id, num_return_sequences=num_samples, ) # Parse all generated sequences tokens_list = [] for i in range(outputs.shape[0]): generated_text = _tokenizer.decode(outputs[i], skip_special_tokens=False) generated_part = generated_text[len(input_text):] try: speech_tokens = parse_speech_tokens(generated_part) tokens_list.append(speech_tokens) except ValueError as e: print(f"[Warning] Sample {i + 1}: {e}") if not tokens_list: raise ValueError("No valid speech tokens generated.") # Decode audio (batch) results = [] sample_rate = _codec.config.sample_rate with torch.no_grad(): # Prepare batch tokens max_len = max(len(t) for t in tokens_list) batch_tokens = torch.zeros((len(tokens_list), max_len), dtype=torch.long, device=device) content_lengths = [] for i, tokens in enumerate(tokens_list): batch_tokens[i, :len(tokens)] = torch.tensor(tokens, dtype=torch.long) content_lengths.append(len(tokens)) # Get global embeddings if reference_waveform is not None: # Extract global embedding from reference waveform ref_features = _codec.encode(reference_waveform, return_content=False, return_global=True) global_embeddings = ref_features.global_embedding.unsqueeze(0).expand(len(tokens_list), -1) else: global_embeddings = global_embedding.unsqueeze(0).expand(len(tokens_list), -1) # Batch decode audio_batch, audio_lengths = _codec.decode_batch( global_embeddings=global_embeddings, content_token_indices=batch_tokens, content_lengths=content_lengths, ) for i in range(len(tokens_list)): audio_len = int(audio_lengths[i]) audio_np = audio_batch[i, :audio_len].cpu().numpy() results.append((sample_rate, audio_np)) print(f"[Info] Seed used: {used_seed}") return results # Load models at startup load_models() # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- MAX_NUM_SAMPLES = 32 def gradio_inference( target_text: str, reference_mode: str, reference_audio: Optional[tuple[int, np.ndarray]], preset_id: Optional[str], temperature: float, top_p: float, top_k: int, repetition_penalty: float, max_tokens: int, seed: str, num_samples: int, ): if not target_text.strip(): outputs = [gr.update(value=None, visible=False) for _ in range(MAX_NUM_SAMPLES)] return outputs seed_val = None if seed.strip() not in {"", "None", "none"}: seed_val = int(float(seed)) try: results = run_inference_gpu( target_text=target_text, reference_mode=reference_mode, reference_audio=reference_audio, preset_id=preset_id, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, max_tokens=max_tokens, seed=seed_val, num_samples=int(num_samples), ) except Exception as e: print(f"[Error] {e}") raise gr.Error(str(e)) outputs = [] for i in range(MAX_NUM_SAMPLES): if i < len(results): outputs.append(gr.update(value=results[i], visible=True)) else: outputs.append(gr.update(value=None, visible=False)) return outputs def build_demo(): presets = get_preset_list() MODEL_LINK = f"https://huggingface.co/{MODEL_REPO}" GITHUB_REPO = "https://github.com/Aratako/MioTTS-Inference" title = "# MioTTS-0.1B Demo" description = f""" - **Model**: [{MODEL_REPO}]({MODEL_LINK}) - For faster and more efficient inference, see [MioTTS-Inference]({GITHUB_REPO}) **Usage:** - Select a preset voice OR upload your own reference audio (max 20 seconds) - Enter text to synthesize - Adjust generation parameters as needed """ with gr.Blocks() as demo: gr.Markdown(title) gr.Markdown(description) with gr.Row(): with gr.Column(scale=1): reference_mode = gr.Radio( choices=["preset", "upload"], value="preset", label="Reference Mode", ) preset_id = gr.Dropdown( choices=presets, value=presets[0] if presets else None, label="Preset Voice", allow_custom_value=False, visible=True, ) reference_audio = gr.Audio( label="Reference Audio", type="numpy", visible=False, ) def update_reference_visibility(mode): if mode == "preset": return gr.update(visible=True), gr.update(visible=False) else: return gr.update(visible=False), gr.update(visible=True) reference_mode.change( fn=update_reference_visibility, inputs=[reference_mode], outputs=[preset_id, reference_audio], ) target_text = gr.Textbox( label="Text to Synthesize", value="", placeholder="Enter text to synthesize", lines=3, ) with gr.Row(): seed_box = gr.Textbox( label="Seed (optional)", value="", placeholder="Leave blank for random", ) num_samples = gr.Slider( label="Number of Samples", minimum=1, maximum=MAX_NUM_SAMPLES, step=1, value=1, ) with gr.Row(): temperature = gr.Slider( label="Temperature", minimum=0.1, maximum=1.5, step=0.05, value=0.8 ) top_p = gr.Slider( label="Top-p", minimum=0.1, maximum=1.0, step=0.05, value=1.0 ) top_k = gr.Slider( label="Top-k", minimum=0, maximum=100, step=1, value=50 ) with gr.Row(): repetition_penalty = gr.Slider( label="Repetition Penalty", minimum=1.0, maximum=1.5, step=0.05, value=1.0, ) max_tokens = gr.Slider( label="Max Tokens", minimum=100, maximum=1000, step=50, value=700, ) generate_button = gr.Button("Generate", variant="primary") # Output audio components output_audios = [] cols_per_row = 4 num_rows = (MAX_NUM_SAMPLES + cols_per_row - 1) // cols_per_row with gr.Column(): for row_idx in range(num_rows): with gr.Row(): for col_idx in range(cols_per_row): i = row_idx * cols_per_row + col_idx if i >= MAX_NUM_SAMPLES: break audio = gr.Audio( label=f"Sample #{i+1}", type="numpy", interactive=False, visible=(i == 0), ) output_audios.append(audio) generate_button.click( fn=gradio_inference, inputs=[ target_text, reference_mode, reference_audio, preset_id, temperature, top_p, top_k, repetition_penalty, max_tokens, seed_box, num_samples, ], outputs=output_audios, ) return demo if __name__ == "__main__": demo = build_demo() demo.launch()