ASureevaA
commited on
Commit
·
c6a3c71
1
Parent(s):
fb68e9f
fix mms
Browse files
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import tempfile
|
|
| 2 |
from typing import List, Tuple, Any
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
-
import soundfile as
|
| 6 |
import torch
|
| 7 |
import torch.nn.functional as torch_functional
|
| 8 |
from gtts import gTTS
|
|
@@ -202,7 +202,7 @@ def get_mms_tts_components():
|
|
| 202 |
if "mms_tts_pipeline" not in MODEL_STORE:
|
| 203 |
tts_pipeline = pipeline(
|
| 204 |
task="text-to-speech",
|
| 205 |
-
model="
|
| 206 |
)
|
| 207 |
MODEL_STORE["mms_tts_pipeline"] = tts_pipeline
|
| 208 |
|
|
@@ -279,22 +279,17 @@ def synthesize_speech(text_value: str, model_key: str):
|
|
| 279 |
text_to_speech_engine = gTTS(text=text_value, lang="ru")
|
| 280 |
text_to_speech_engine.save(file_object.name)
|
| 281 |
return file_object.name
|
|
|
|
|
|
|
|
|
|
| 282 |
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
tts_output = tts_pipeline(text_value)
|
| 287 |
-
|
| 288 |
-
audio_array = tts_output["audio"]
|
| 289 |
-
sampling_rate_value = tts_output["sampling_rate"]
|
| 290 |
|
| 291 |
-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
audio_array,
|
| 295 |
-
sampling_rate_value,
|
| 296 |
-
)
|
| 297 |
-
return file_object.name
|
| 298 |
|
| 299 |
raise ValueError(f"Неизвестная модель: {model_key}")
|
| 300 |
|
|
@@ -692,11 +687,11 @@ def build_interface():
|
|
| 692 |
lines=3,
|
| 693 |
)
|
| 694 |
tts_model_selector = gr.Dropdown(
|
| 695 |
-
choices=["
|
| 696 |
label="Выберите модель",
|
| 697 |
-
value="
|
| 698 |
info=(
|
| 699 |
-
"
|
| 700 |
"Google TTS"
|
| 701 |
),
|
| 702 |
)
|
|
|
|
| 2 |
from typing import List, Tuple, Any
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
+
import soundfile as sf
|
| 6 |
import torch
|
| 7 |
import torch.nn.functional as torch_functional
|
| 8 |
from gtts import gTTS
|
|
|
|
| 202 |
if "mms_tts_pipeline" not in MODEL_STORE:
|
| 203 |
tts_pipeline = pipeline(
|
| 204 |
task="text-to-speech",
|
| 205 |
+
model="facebook/mms-tts-rus",
|
| 206 |
)
|
| 207 |
MODEL_STORE["mms_tts_pipeline"] = tts_pipeline
|
| 208 |
|
|
|
|
| 279 |
text_to_speech_engine = gTTS(text=text_value, lang="ru")
|
| 280 |
text_to_speech_engine.save(file_object.name)
|
| 281 |
return file_object.name
|
| 282 |
+
elif model_key == "mms":
|
| 283 |
+
model = VitsModel.from_pretrained("facebook/mms-tts-rus")
|
| 284 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
|
| 285 |
|
| 286 |
+
inputs = tokenizer(text_value, return_tensors="pt")
|
| 287 |
+
with torch.no_grad():
|
| 288 |
+
output = model(**inputs).waveform
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
| 291 |
+
sf.write(f.name, output.numpy().squeeze(), model.config.sampling_rate)
|
| 292 |
+
return f.name
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
raise ValueError(f"Неизвестная модель: {model_key}")
|
| 295 |
|
|
|
|
| 687 |
lines=3,
|
| 688 |
)
|
| 689 |
tts_model_selector = gr.Dropdown(
|
| 690 |
+
choices=["mms", "Google TTS"],
|
| 691 |
label="Выберите модель",
|
| 692 |
+
value="mms",
|
| 693 |
info=(
|
| 694 |
+
"facebook/mms-tts-rus\n"
|
| 695 |
"Google TTS"
|
| 696 |
),
|
| 697 |
)
|