File size: 2,899 Bytes
8e25e9b
5e4386d
 
 
 
8e25e9b
c23ecfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a5c091
c23ecfc
8e25e9b
5e4386d
367613d
c23ecfc
 
 
367613d
8e25e9b
c23ecfc
 
 
 
 
 
367613d
c23ecfc
 
367613d
c23ecfc
 
367613d
c23ecfc
 
 
367613d
c23ecfc
 
 
 
 
5e4386d
 
 
 
367613d
c23ecfc
8e25e9b
c23ecfc
5e4386d
8e25e9b
c23ecfc
8e25e9b
cd931dd
5e4386d
82105a8
c23ecfc
5e4386d
367613d
8e25e9b
5e4386d
 
cd931dd
82105a8
c23ecfc
8e25e9b
cd931dd
5e4386d
 
82105a8
c23ecfc
5e4386d
367613d
5e4386d
 
 
 
 
8e25e9b
5e4386d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import numpy as np
import soundfile as sf
import os
import uuid

import torch
from transformers import VitsModel, VitsTokenizer, set_seed

# 1. Load MMS-TTS English model (lighter than Bark)
MODEL_ID = "facebook/mms-tts-eng"

tokenizer = VitsTokenizer.from_pretrained(MODEL_ID)
model = VitsModel.from_pretrained(MODEL_ID)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Optional: make outputs deterministic
set_seed(555)


MAX_CHARS = 300  # keep text short for speed and stability


def generate_speech(text: str) -> str:
    """
    Take text, synthesize speech with MMS-TTS,
    save to a WAV file, and return the filepath
    (for gr.Audio(type="filepath")).
    """
    if not text or text.strip() == "":
        raise gr.Error("Please enter some text 🙂")

    text = text.strip()
    if len(text) > MAX_CHARS:
        text = text[:MAX_CHARS]
        # You could also show a warning text if you like.

    # MMS-TTS is trained on lowercased, unpunctuated text → simple normalization
    normalized_text = text.lower()

    # 1) Tokenize
    inputs = tokenizer(text=normalized_text, return_tensors="pt").to(device)

    # 2) Forward pass
    with torch.no_grad():
        outputs = model(**inputs)

    # 3) Get waveform and sampling rate
    waveform = outputs.waveform[0].cpu().numpy().astype(np.float32)
    sr = model.config.sampling_rate  # typically 16000

    # 4) Save to /tmp as WAV
    tmp_dir = "/tmp"
    os.makedirs(tmp_dir, exist_ok=True)
    filename = f"tts_{uuid.uuid4().hex}.wav"
    filepath = os.path.join(tmp_dir, filename)

    sf.write(filepath, waveform, sr)

    # 5) Return file path for gr.Audio(type="filepath")
    return filepath


with gr.Blocks() as demo:
    gr.Markdown("# 🗣️ Англи текстийг яриа болгох \n\n --- Simple TTS with facebook/mms-tts-eng")
    gr.Markdown(
        "Энд англи дээр өгүүлбэрээ бичээд **Яриаг үүсгэ** товчийг дарж англи яриаг сонсоорой. \n\n"
        "Model: `facebook/mms-tts-eng` (MMS-TTS, VITS-based)."
    )

    with gr.Row():
        with gr.Column(scale=2):
            text_input = gr.Textbox(
                label="Яриа болгох англи өгүүлбэр",
                placeholder="Жишээ: Hello, this is my text-to-speech demo",
                lines=3,
            )
            generate_button = gr.Button("Яриаг үүсгэнэ үү", variant="primary")
        with gr.Column(scale=1):
            audio_output = gr.Audio(
                label="Үүссэн бичлэг",
                type="filepath",  # we return a path string
            )

    generate_button.click(
        fn=generate_speech,
        inputs=text_input,
        outputs=audio_output,
    )

if __name__ == "__main__":
    demo.launch(ssr_mode=False)