File size: 10,518 Bytes
5dcbcd9
b8a79bd
 
2bf87e7
b8a79bd
 
 
 
 
2bf87e7
b8a79bd
 
 
 
 
 
 
2bf87e7
 
 
 
 
 
 
 
 
 
 
 
 
 
b8a79bd
2bf87e7
 
 
a3fd3c7
1dce2dd
b8a79bd
 
 
 
 
 
 
 
a3fd3c7
 
 
 
 
 
 
b8a79bd
 
 
 
 
 
 
 
 
 
 
 
 
 
2bf87e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8a79bd
 
 
 
 
 
 
 
2bf87e7
 
1dce2dd
 
a3fd3c7
 
 
 
 
 
 
 
 
1dce2dd
a3fd3c7
 
 
 
b8a79bd
 
a3fd3c7
 
b8a79bd
 
 
 
 
2bf87e7
b8a79bd
 
 
 
 
 
 
 
2bf87e7
b8a79bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bf87e7
b8a79bd
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235

import spaces
import os
import codecs
from huggingface_hub import login
import gradio as gr
from cached_path import cached_path
import tempfile
from vinorm import TTSnorm
from importlib.resources import files
from f5_tts.model import DiT
from f5_tts.infer.utils_infer import (
    preprocess_ref_audio_text,
    load_vocoder,
    load_model,
    infer_process,
    save_spectrogram,
    target_sample_rate as default_target_sample_rate,
    n_mel_channels as default_n_mel_channels,
    hop_length as default_hop_length,
    win_length as default_win_length,
    n_fft as default_n_fft,
    mel_spec_type as default_mel_spec_type,
    target_rms as default_target_rms,
    cross_fade_duration as default_cross_fade_duration,
    ode_method as default_ode_method,
    nfe_step as default_nfe_step,  # 16, 32
    cfg_strength as default_cfg_strength,
    sway_sampling_coef as default_sway_sampling_coef,
    speed as default_speed,
    fix_duration as default_fix_duration
)
from pathlib import Path
from omegaconf import OmegaConf
from datetime import datetime
import hashlib
import unicodedata
# Retrieve token from secrets
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")


# Log in to Hugging Face
if hf_token:
    login(token=hf_token)

# Hàm lấy đường dẫn file cache dựa trên text, ref_audio, model
def get_audio_cache_path(text, ref_audio_path, model, cache_dir="tts_cache"):
    os.makedirs(cache_dir, exist_ok=True)
    hash_input = f"{text}|{ref_audio_path}|{model}"
    hash_val = hashlib.sha256(hash_input.encode("utf-8")).hexdigest()
    return os.path.join(cache_dir, f"{hash_val}.wav")

def post_process(text):
    text = " " + text + " "
    text = text.replace(" . . ", " . ")
    text = " " + text + " "
    text = text.replace(" .. ", " . ")
    text = " " + text + " "
    text = text.replace(" , , ", " , ")
    text = " " + text + " "
    text = text.replace(" ,, ", " , ")
    text = " " + text + " "
    text = text.replace('"', "")
    return " ".join(text.split())
# Load models
@spaces.GPU
def infer_tts(ref_audio_orig: str, ref_text_input: str, gen_text: str, speed: float = 1.0, request: gr.Request = None):
    
    args = {
        "model": "F5TTS_Base",
        "ckpt_file": str(cached_path("hf://hynt/F5-TTS-Vietnamese-ViVoice/model_last.pt")),
        "vocab_file": str(cached_path("hf://hynt/F5-TTS-Vietnamese-ViVoice/config.json")),
        "ref_audio": ref_audio_orig,
        "ref_text": ref_text_input,
        "gen_text": gen_text,
        "speed": speed
    }
    config = {} # tomli.load(open(args.config, "rb"))
    # command-line interface parameters

    model = args["model"] or config.get("model", "F5TTS_Base")
    ckpt_file = args["ckpt_file"] or config.get("ckpt_file", "")
    vocab_file = args["vocab_file"] or config.get("vocab_file", "")

    ref_audio = args["ref_audio"] or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav")
    ref_text = args["ref_text"] if args["ref_text"] is not None else config.get("ref_text", "Some call me nature, others call me mother nature.")
    gen_text = args["gen_text"] or config.get("gen_text", "Here we generate something just for test.")
    gen_file = args.get("gen_file", "") or config.get("gen_file", "")
    output_dir = args.get("output_dir", "") or config.get("output_dir", "tests")
    output_file = args.get("output_file", "") or config.get("output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav")
    save_chunk = args.get("save_chunk", False) or config.get("save_chunk", False)
    remove_silence = args.get("remove_silence", False) or config.get("remove_silence", False)
    load_vocoder_from_local = args.get("load_vocoder_from_local", False) or config.get("load_vocoder_from_local", False)
    vocoder_name = args.get("vocoder_name", "") or config.get("vocoder_name", default_mel_spec_type)
    target_rms = args.get("target_rms", None) or config.get("target_rms", default_target_rms)
    cross_fade_duration = args.get("cross_fade_duration", None) or config.get("cross_fade_duration", default_cross_fade_duration)
    nfe_step = args.get("nfe_step", None) or config.get("nfe_step", default_nfe_step)
    cfg_strength = args.get("cfg_strength", None) or config.get("cfg_strength", default_cfg_strength)
    sway_sampling_coef = args.get("sway_sampling_coef", None) or config.get("sway_sampling_coef", default_sway_sampling_coef)
    speed = args.get("speed", None) or config.get("speed", default_speed)
    fix_duration = args.get("fix_duration", None) or config.get("fix_duration", default_fix_duration)

    if "infer/examples/" in ref_audio:
        ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}"))
    if "infer/examples/" in gen_file:
        gen_file = str(files("f5_tts").joinpath(f"{gen_file}"))
    if "voices" in config:
        for voice in config["voices"]:
            voice_ref_audio = config["voices"][voice]["ref_audio"]
            if "infer/examples/" in voice_ref_audio:
                config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))


    # ignore gen_text if gen_file provided

    if gen_file:
        gen_text = codecs.open(gen_file, "r", "utf-8").read()


    # output path

    wave_path = Path(output_dir) / output_file
    # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
    if save_chunk:
        output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks")
        if not os.path.exists(output_chunk_dir):
            os.makedirs(output_chunk_dir)
        
    # load vocoder

    if vocoder_name == "vocos":
        vocoder_local_path = "../checkpoints/vocos-mel-24khz"
    elif vocoder_name == "bigvgan":
        vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"

    vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path)

        
    # load TTS model

    model_cfg = OmegaConf.load(
        config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
    ).model
    model_cls = globals()[model_cfg.backbone]

    repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"

    if model != "F5TTS_Base":
        assert vocoder_name == model_cfg.mel_spec.mel_spec_type

    # override for previous models
    if model == "F5TTS_Base":
        if vocoder_name == "vocos":
            ckpt_step = 1200000
        elif vocoder_name == "bigvgan":
            model = "F5TTS_Base_bigvgan"
            ckpt_type = "pt"
    elif model == "E2TTS_Base":
        repo_name = "E2-TTS"
        ckpt_step = 1200000

    if not ckpt_file:
        ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))

    print(f"Using {model}...")
    ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)

    if not ref_audio_orig:
        raise gr.Error("Please upload a sample audio file.")
    if not gen_text.strip():
        raise gr.Error("Please enter the text content to generate voice.")
    if len(gen_text.split()) > 1000:
        raise gr.Error("Please enter text content with less than 1000 words.")
    try:
        # Nếu người dùng nhập ref_text thì dùng, không thì để rỗng để tự động nhận diện
        ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text_input or "")
        ref_text = unicodedata.normalize("NFC", ref_text.strip())
        gen_text_ = unicodedata.normalize("NFC", gen_text.strip())
        # --- BẮT ĐẦU: Thêm logic cache ---
        cache_path = get_audio_cache_path(gen_text_, ref_audio_orig, model)
        import soundfile as sf
        if os.path.exists(cache_path):
            print(f"Using cached audio: {cache_path}")
            final_wave, final_sample_rate = sf.read(cache_path)
            spectrogram = None
        else:
            final_wave, final_sample_rate, spectrogram = infer_process(
                ref_audio, ref_text, gen_text_, ema_model, vocoder, speed=speed
            )
            print(f"[CACHE] Saved new audio to: {cache_path}")
            sf.write(cache_path, final_wave, final_sample_rate)
        # --- KẾT THÚC: Thêm logic cache ---
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
            spectrogram_path = tmp_spectrogram.name
            if spectrogram is not None:
                save_spectrogram(spectrogram, spectrogram_path)
        return (final_sample_rate, final_wave), spectrogram_path
    except Exception as e:
        raise gr.Error(f"Error generating voice: {e}")

# Gradio UI

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # 🎤 F5-TTS: Vietnamese Text-to-Speech Synthesis.
    # The model was trained with approximately 1000 hours of data on a RTX 3090 GPU. 
    Enter text and upload a sample voice to generate natural speech.
    """)
    with gr.Row():
        ref_audio = gr.Audio(label="🔊 Sample Voice", type="filepath")
        ref_text = gr.Textbox(label="📝 Reference Transcript (optional)", placeholder="Nhập transcript tiếng Việt cho sample voice nếu có...", lines=2)
        gen_text = gr.Textbox(label="📝 Text", placeholder="Enter the text to generate voice...", lines=3)
    speed = gr.Slider(0.3, 2.0, value=1.0, step=0.1, label="⚡ Speed")
    btn_synthesize = gr.Button("🔥 Generate Voice")
    with gr.Row():
        output_audio = gr.Audio(label="🎧 Generated Audio", type="numpy")
        output_spectrogram = gr.Image(label="📊 Spectrogram")
    model_limitations = gr.Textbox(
        value="""1. This model may not perform well with numerical characters, dates, special characters, etc. => A text normalization module is needed.
2. The rhythm of some generated audios may be inconsistent or choppy => It is recommended to select clearly pronounced sample audios with minimal pauses for better synthesis quality.
3. Default, reference audio text uses the pho-whisper-medium model, which may not always accurately recognize Vietnamese, resulting in poor voice synthesis quality.
4. Inference with overly long paragraphs may produce poor results.""", 
        label="❗ Model Limitations",
        lines=4,
        interactive=False
    )
    btn_synthesize.click(infer_tts, inputs=[ref_audio, ref_text, gen_text, speed], outputs=[output_audio, output_spectrogram])

# Run Gradio with share=True to get a gradio.live link
# demo.queue().launch()

if __name__ == "__main__":

    demo.queue().launch()