Spaces:
Sleeping
Sleeping
| import os, sys | |
| import hashlib | |
| import soundfile as sf | |
| import gradio as gr | |
| import re | |
| import numpy as np | |
| import spaces | |
| from pathlib import Path | |
| from datetime import datetime | |
| from huggingface_hub import login | |
| from cached_path import cached_path | |
| sys.path.append(os.path.join(os.path.dirname(__file__), "src")) | |
| # Import hàm infer gốc của f5_tts | |
| from f5_tts.infer.utils_infer import ( | |
| preprocess_ref_audio_text, | |
| load_vocoder, | |
| load_model, | |
| infer_process, | |
| speed, | |
| mel_spec_type as default_mel_spec_type, | |
| target_sample_rate as default_target_sample_rate, | |
| ) | |
| from f5_tts.model import DiT | |
| from omegaconf import OmegaConf | |
| from importlib.resources import files | |
| import unicodedata | |
| QUOTES_MAP = { | |
| "“": '"', | |
| "”": '"', | |
| "‘": "'", | |
| "’": "'", | |
| "«": '"', | |
| "»": '"', | |
| "\u00a0": " ", # NBSP -> space | |
| } | |
| def normalize_for_tts(s: str, to_lower=True) -> str: | |
| # Unicode NFC (đúng cho tiếng Việt có dấu) | |
| s = unicodedata.normalize("NFC", s or "") | |
| # Đổi ngoặc cong → thẳng, xoá các space lạ | |
| for src, dst in QUOTES_MAP.items(): | |
| s = s.replace(src, dst) | |
| # Gọn khoảng trắng | |
| s = re.sub(r"\s+", " ", s).strip() | |
| # Vinorm (nếu có) để chuẩn số/viết tắt | |
| try: | |
| from vinorm import TTSnorm | |
| s = TTSnorm(s) | |
| except Exception: | |
| pass | |
| if to_lower: | |
| s = s.lower() | |
| # Không cần đọc ngoặc → bỏ ngoặc trước khi đẩy vào model | |
| s = s.replace('"', "").replace("'", "") | |
| return s | |
| # Đăng nhập HuggingFace | |
| hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| if hf_token: | |
| login(token=hf_token) | |
| # ============================= | |
| # 🔹 Hàm cache output | |
| # ============================= | |
| def get_audio_cache_path( | |
| text, ref_audio_path, model="F5TTS_Base", cache_dir="tts_cache" | |
| ): | |
| os.makedirs(cache_dir, exist_ok=True) | |
| hash_input = f"{text}|{ref_audio_path}|{model}" | |
| hash_val = hashlib.sha256(hash_input.encode("utf-8")).hexdigest() | |
| return os.path.join(cache_dir, f"{hash_val}.wav") | |
| # ============================= | |
| # 🔹 Wrapper cho F5-TTS | |
| # ============================= | |
| def infer_tts( | |
| ref_audio_orig: str, | |
| ref_text_input: str, | |
| gen_text: str, | |
| speed: float = 1.0, | |
| request: gr.Request = None, | |
| ): | |
| args = { | |
| "model": "F5TTS_Base", | |
| "ckpt_file": str( | |
| cached_path("hf://hynt/F5-TTS-Vietnamese-ViVoice/model_last.pt") | |
| ), | |
| "vocab_file": str( | |
| cached_path("hf://hynt/F5-TTS-Vietnamese-ViVoice/config.json") | |
| ), | |
| "ref_audio": ref_audio_orig, | |
| "ref_text": ref_text_input, | |
| "gen_text": gen_text, | |
| "speed": speed, | |
| } | |
| model = args["model"] | |
| ckpt_file = args["ckpt_file"] | |
| vocab_file = args["vocab_file"] | |
| # Load model | |
| vocoder = load_vocoder( | |
| vocoder_name=default_mel_spec_type, is_local=False, local_path=None | |
| ) | |
| model_cfg = OmegaConf.load( | |
| str(files("f5_tts").joinpath(f"configs/{model}.yaml")) | |
| ).model | |
| model_cls = globals()[model_cfg.backbone] | |
| ema_model = load_model( | |
| model_cls, | |
| model_cfg.arch, | |
| ckpt_file, | |
| mel_spec_type=default_mel_spec_type, | |
| vocab_file=vocab_file, | |
| ) | |
| if not ref_audio_orig: | |
| raise gr.Error("Please upload a sample audio file.") | |
| if not gen_text.strip(): | |
| raise gr.Error("Please enter the text content to generate voice.") | |
| # Chuẩn hóa ref_text | |
| ref_audio, ref_text = preprocess_ref_audio_text( | |
| ref_audio_orig, ref_text_input or "" | |
| ) | |
| ref_text = ref_text.strip() | |
| gen_text = unicodedata.normalize("NFC", gen_text or "") | |
| for src, dst in QUOTES_MAP.items(): | |
| gen_text = gen_text.replace(src, dst) | |
| parts = re.split(r"([\.!\?…,,]+)\s+", gen_text) # bắt luôn dấu . ! ? … , , | |
| chunks = [] | |
| for i in range(0, len(parts), 2): | |
| s = parts[i] or "" | |
| sep = parts[i + 1] if i + 1 < len(parts) else "" | |
| sent = (s + sep).strip() | |
| if not sent: | |
| continue | |
| # 3) bỏ ngoặc/tham chiếu rìa để tránh chunk bắt đầu bằng dấu ngoặc | |
| sent = sent.strip("\"'" + "\u201c\u201d\u00ab\u00bb()[]{}") | |
| if sent: | |
| chunks.append(sent) | |
| final_audio_segments = [] | |
| sample_rate = default_target_sample_rate | |
| silence = np.zeros(int(0.2 * sample_rate)) | |
| for i, text_chunk in enumerate(chunks, 1): | |
| text_chunk = text_chunk.strip() | |
| if not text_chunk: | |
| continue | |
| # Log trạng thái cho UI | |
| yield None, f"🔄 Processing chunk {i}/{len(chunks)}: {text_chunk[:50]}..." | |
| cache_path = get_audio_cache_path(text_chunk, ref_audio_orig, model) | |
| if os.path.exists(cache_path): | |
| print(f"Using cached audio: {cache_path}") | |
| wave, sample_rate = sf.read(cache_path) | |
| else: | |
| clean_chunk = normalize_for_tts(text_chunk) # <- thêm dòng này | |
| wave, sample_rate, _ = infer_process( | |
| ref_audio, | |
| ref_text, | |
| clean_chunk, | |
| ema_model, | |
| vocoder, | |
| speed=speed, | |
| nfe_step=16, # giảm tải | |
| ) | |
| print(f"[CACHE] Saved new audio to: {cache_path}") | |
| sf.write(cache_path, wave, sample_rate) | |
| final_audio_segments.append(wave) | |
| final_audio_segments.append(silence) | |
| # Ghép lại audio cuối cùng | |
| final_wave = ( | |
| np.concatenate(final_audio_segments) if final_audio_segments else np.array([]) | |
| ) | |
| yield (sample_rate, final_wave), "✅ Done synthesizing!" | |
| # ============================= | |
| # 🔹 Gradio UI | |
| # ============================= | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🎤 F5-TTS: Vietnamese Text-to-Speech") | |
| with gr.Row(): | |
| ref_audio = gr.Audio(label="🔊 Sample Voice", type="filepath") | |
| ref_text = gr.Textbox(label="📝 Reference Transcript (optional)") | |
| gen_text = gr.Textbox(label="📝 Text", lines=3) | |
| speed = gr.Slider(0.3, 2.0, value=1.0, step=0.1, label="⚡ Speed") | |
| btn_synthesize = gr.Button("🔥 Generate Voice") | |
| with gr.Row(): | |
| output_audio = gr.Audio(label="🎧 Generated Audio", type="numpy") | |
| log_box = gr.Textbox(label="📜 Logs", lines=6) | |
| # Chạy infer_tts theo dạng generator -> stream log | |
| btn_synthesize.click( | |
| infer_tts, | |
| inputs=[ref_audio, ref_text, gen_text, speed], | |
| outputs=[output_audio, log_box], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |