TTS-Talker / app_tts.py
Quang Long
update ui
52ff743
raw
history blame
6.64 kB
import os, sys
import hashlib
import soundfile as sf
import gradio as gr
import re
import numpy as np
import spaces
from pathlib import Path
from datetime import datetime
from huggingface_hub import login
from cached_path import cached_path
sys.path.append(os.path.join(os.path.dirname(__file__), "src"))
# Import hàm infer gốc của f5_tts
from f5_tts.infer.utils_infer import (
preprocess_ref_audio_text,
load_vocoder,
load_model,
infer_process,
speed,
mel_spec_type as default_mel_spec_type,
target_sample_rate as default_target_sample_rate,
)
from f5_tts.model import DiT
from omegaconf import OmegaConf
from importlib.resources import files
import unicodedata
QUOTES_MAP = {
"“": '"',
"”": '"',
"‘": "'",
"’": "'",
"«": '"',
"»": '"',
"\u00a0": " ", # NBSP -> space
}
def normalize_for_tts(s: str, to_lower=True) -> str:
# Unicode NFC (đúng cho tiếng Việt có dấu)
s = unicodedata.normalize("NFC", s or "")
# Đổi ngoặc cong → thẳng, xoá các space lạ
for src, dst in QUOTES_MAP.items():
s = s.replace(src, dst)
# Gọn khoảng trắng
s = re.sub(r"\s+", " ", s).strip()
# Vinorm (nếu có) để chuẩn số/viết tắt
try:
from vinorm import TTSnorm
s = TTSnorm(s)
except Exception:
pass
if to_lower:
s = s.lower()
# Không cần đọc ngoặc → bỏ ngoặc trước khi đẩy vào model
s = s.replace('"', "").replace("'", "")
return s
# Đăng nhập HuggingFace
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
if hf_token:
login(token=hf_token)
# =============================
# 🔹 Hàm cache output
# =============================
def get_audio_cache_path(
text, ref_audio_path, model="F5TTS_Base", cache_dir="tts_cache"
):
os.makedirs(cache_dir, exist_ok=True)
hash_input = f"{text}|{ref_audio_path}|{model}"
hash_val = hashlib.sha256(hash_input.encode("utf-8")).hexdigest()
return os.path.join(cache_dir, f"{hash_val}.wav")
# =============================
# 🔹 Wrapper cho F5-TTS
# =============================
@spaces.GPU
def infer_tts(
ref_audio_orig: str,
ref_text_input: str,
gen_text: str,
speed: float = 1.0,
request: gr.Request = None,
):
args = {
"model": "F5TTS_Base",
"ckpt_file": str(
cached_path("hf://hynt/F5-TTS-Vietnamese-ViVoice/model_last.pt")
),
"vocab_file": str(
cached_path("hf://hynt/F5-TTS-Vietnamese-ViVoice/config.json")
),
"ref_audio": ref_audio_orig,
"ref_text": ref_text_input,
"gen_text": gen_text,
"speed": speed,
}
model = args["model"]
ckpt_file = args["ckpt_file"]
vocab_file = args["vocab_file"]
# Load model
vocoder = load_vocoder(
vocoder_name=default_mel_spec_type, is_local=False, local_path=None
)
model_cfg = OmegaConf.load(
str(files("f5_tts").joinpath(f"configs/{model}.yaml"))
).model
model_cls = globals()[model_cfg.backbone]
ema_model = load_model(
model_cls,
model_cfg.arch,
ckpt_file,
mel_spec_type=default_mel_spec_type,
vocab_file=vocab_file,
)
if not ref_audio_orig:
raise gr.Error("Please upload a sample audio file.")
if not gen_text.strip():
raise gr.Error("Please enter the text content to generate voice.")
# Chuẩn hóa ref_text
ref_audio, ref_text = preprocess_ref_audio_text(
ref_audio_orig, ref_text_input or ""
)
ref_text = ref_text.strip()
gen_text = unicodedata.normalize("NFC", gen_text or "")
for src, dst in QUOTES_MAP.items():
gen_text = gen_text.replace(src, dst)
parts = re.split(r"([\.!\?…,,]+)\s+", gen_text) # bắt luôn dấu . ! ? … , ,
chunks = []
for i in range(0, len(parts), 2):
s = parts[i] or ""
sep = parts[i + 1] if i + 1 < len(parts) else ""
sent = (s + sep).strip()
if not sent:
continue
# 3) bỏ ngoặc/tham chiếu rìa để tránh chunk bắt đầu bằng dấu ngoặc
sent = sent.strip("\"'" + "\u201c\u201d\u00ab\u00bb()[]{}")
if sent:
chunks.append(sent)
final_audio_segments = []
sample_rate = default_target_sample_rate
silence = np.zeros(int(0.2 * sample_rate))
for i, text_chunk in enumerate(chunks, 1):
text_chunk = text_chunk.strip()
if not text_chunk:
continue
# Log trạng thái cho UI
yield None, f"🔄 Processing chunk {i}/{len(chunks)}: {text_chunk[:50]}..."
cache_path = get_audio_cache_path(text_chunk, ref_audio_orig, model)
if os.path.exists(cache_path):
print(f"Using cached audio: {cache_path}")
wave, sample_rate = sf.read(cache_path)
else:
clean_chunk = normalize_for_tts(text_chunk) # <- thêm dòng này
wave, sample_rate, _ = infer_process(
ref_audio,
ref_text,
clean_chunk,
ema_model,
vocoder,
speed=speed,
nfe_step=16, # giảm tải
)
print(f"[CACHE] Saved new audio to: {cache_path}")
sf.write(cache_path, wave, sample_rate)
final_audio_segments.append(wave)
final_audio_segments.append(silence)
# Ghép lại audio cuối cùng
final_wave = (
np.concatenate(final_audio_segments) if final_audio_segments else np.array([])
)
yield (sample_rate, final_wave), "✅ Done synthesizing!"
# =============================
# 🔹 Gradio UI
# =============================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎤 F5-TTS: Vietnamese Text-to-Speech")
with gr.Row():
ref_audio = gr.Audio(label="🔊 Sample Voice", type="filepath")
ref_text = gr.Textbox(label="📝 Reference Transcript (optional)")
gen_text = gr.Textbox(label="📝 Text", lines=3)
speed = gr.Slider(0.3, 2.0, value=1.0, step=0.1, label="⚡ Speed")
btn_synthesize = gr.Button("🔥 Generate Voice")
with gr.Row():
output_audio = gr.Audio(label="🎧 Generated Audio", type="numpy")
log_box = gr.Textbox(label="📜 Logs", lines=6)
# Chạy infer_tts theo dạng generator -> stream log
btn_synthesize.click(
infer_tts,
inputs=[ref_audio, ref_text, gen_text, speed],
outputs=[output_audio, log_box],
)
if __name__ == "__main__":
demo.queue().launch()