ASureevaA commited on
Commit
c6a3c71
·
1 Parent(s): fb68e9f
Files changed (1) hide show
  1. app.py +14 -19
app.py CHANGED
@@ -2,7 +2,7 @@ import tempfile
2
  from typing import List, Tuple, Any
3
 
4
  import gradio as gr
5
- import soundfile as soundfile_module
6
  import torch
7
  import torch.nn.functional as torch_functional
8
  from gtts import gTTS
@@ -202,7 +202,7 @@ def get_mms_tts_components():
202
  if "mms_tts_pipeline" not in MODEL_STORE:
203
  tts_pipeline = pipeline(
204
  task="text-to-speech",
205
- model="kakao-enterprise/vits-ljs",
206
  )
207
  MODEL_STORE["mms_tts_pipeline"] = tts_pipeline
208
 
@@ -279,22 +279,17 @@ def synthesize_speech(text_value: str, model_key: str):
279
  text_to_speech_engine = gTTS(text=text_value, lang="ru")
280
  text_to_speech_engine.save(file_object.name)
281
  return file_object.name
 
 
 
282
 
283
- if model_key == "vits-ljs":
284
- tts_pipeline = get_mms_tts_components()
285
-
286
- tts_output = tts_pipeline(text_value)
287
-
288
- audio_array = tts_output["audio"]
289
- sampling_rate_value = tts_output["sampling_rate"]
290
 
291
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as file_object:
292
- soundfile_module.write(
293
- file_object.name,
294
- audio_array,
295
- sampling_rate_value,
296
- )
297
- return file_object.name
298
 
299
  raise ValueError(f"Неизвестная модель: {model_key}")
300
 
@@ -692,11 +687,11 @@ def build_interface():
692
  lines=3,
693
  )
694
  tts_model_selector = gr.Dropdown(
695
- choices=["vits-ljs", "Google TTS"],
696
  label="Выберите модель",
697
- value="vits-ljs",
698
  info=(
699
- "kakao-enterprise/vits-ljs\n"
700
  "Google TTS"
701
  ),
702
  )
 
2
  from typing import List, Tuple, Any
3
 
4
  import gradio as gr
5
+ import soundfile as sf
6
  import torch
7
  import torch.nn.functional as torch_functional
8
  from gtts import gTTS
 
202
  if "mms_tts_pipeline" not in MODEL_STORE:
203
  tts_pipeline = pipeline(
204
  task="text-to-speech",
205
+ model="facebook/mms-tts-rus",
206
  )
207
  MODEL_STORE["mms_tts_pipeline"] = tts_pipeline
208
 
 
279
  text_to_speech_engine = gTTS(text=text_value, lang="ru")
280
  text_to_speech_engine.save(file_object.name)
281
  return file_object.name
282
+ elif model_key == "mms":
283
+ model = VitsModel.from_pretrained("facebook/mms-tts-rus")
284
+ tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
285
 
286
+ inputs = tokenizer(text_value, return_tensors="pt")
287
+ with torch.no_grad():
288
+ output = model(**inputs).waveform
 
 
 
 
289
 
290
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
291
+ sf.write(f.name, output.numpy().squeeze(), model.config.sampling_rate)
292
+ return f.name
 
 
 
 
293
 
294
  raise ValueError(f"Неизвестная модель: {model_key}")
295
 
 
687
  lines=3,
688
  )
689
  tts_model_selector = gr.Dropdown(
690
+ choices=["mms", "Google TTS"],
691
  label="Выберите модель",
692
+ value="mms",
693
  info=(
694
+ "facebook/mms-tts-rus\n"
695
  "Google TTS"
696
  ),
697
  )