Commit
·
782d74b
1
Parent(s):
9e032ec
(wip)debug
Browse files- .gitmodules +3 -0
- CosyVoice2-0.5B +1 -0
- requirements.txt +2 -1
- tts.py +31 -23
.gitmodules
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "CosyVoice2-0.5B"]
|
| 2 |
+
path = CosyVoice2-0.5B
|
| 3 |
+
url = git@hf.co:spaces/FunAudioLLM/CosyVoice2-0.5B
|
CosyVoice2-0.5B
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit b1769de266d0f5f94ea6e8dfb1df1519a2407be2
|
requirements.txt
CHANGED
|
@@ -12,4 +12,5 @@ gunicorn
|
|
| 12 |
waitress
|
| 13 |
fal-client
|
| 14 |
gradio_client==1.7.0
|
| 15 |
-
git+https://github.com/playht/pyht
|
|
|
|
|
|
| 12 |
waitress
|
| 13 |
fal-client
|
| 14 |
gradio_client==1.7.0
|
| 15 |
+
git+https://github.com/playht/pyht
|
| 16 |
+
modelscope
|
tts.py
CHANGED
|
@@ -2,6 +2,8 @@
|
|
| 2 |
# Currently just use current TTS router.
|
| 3 |
import os
|
| 4 |
import json
|
|
|
|
|
|
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
import fal_client
|
| 7 |
import requests
|
|
@@ -232,35 +234,41 @@ def predict_spark_tts(text, reference_audio_path=None):
|
|
| 232 |
|
| 233 |
|
| 234 |
def predict_cosyvoice_tts(text, reference_audio_path=None):
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
if not reference_audio_path:
|
| 238 |
raise ValueError("cosyvoice-2.0 需要 reference_audio_path")
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
prompt_wav_upload=prompt_wav,
|
| 252 |
-
prompt_wav_record=prompt_wav,
|
| 253 |
-
instruct_text="",
|
| 254 |
-
seed=0,
|
| 255 |
-
api_name="/generate_audio"
|
| 256 |
-
)
|
| 257 |
-
print("cosyvoice-2.0 result:", result)
|
| 258 |
-
return result
|
| 259 |
|
| 260 |
|
| 261 |
def predict_maskgct(text, reference_audio_path=None):
|
| 262 |
from gradio_client import Client, handle_file
|
| 263 |
-
client = Client("
|
| 264 |
if not reference_audio_path:
|
| 265 |
raise ValueError("maskgct 需要 reference_audio_path")
|
| 266 |
prompt_wav = handle_file(reference_audio_path)
|
|
|
|
| 2 |
# Currently just use current TTS router.
|
| 3 |
import os
|
| 4 |
import json
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
import fal_client
|
| 9 |
import requests
|
|
|
|
| 234 |
|
| 235 |
|
| 236 |
def predict_cosyvoice_tts(text, reference_audio_path=None):
|
| 237 |
+
import tempfile
|
| 238 |
+
import soundfile as sf
|
| 239 |
+
from modelscope import snapshot_download
|
| 240 |
+
model_dir = os.path.join(os.path.dirname(__file__), "CosyVoice2-0.5B", "pretrained_models", "CosyVoice2-0.5B")
|
| 241 |
+
if not os.path.exists(model_dir) or not os.listdir(model_dir):
|
| 242 |
+
snapshot_download('iic/CosyVoice2-0.5B', local_dir=model_dir)
|
| 243 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "CosyVoice2-0.5B"))
|
| 244 |
+
from cosyvoice.cli.cosyvoice import CosyVoice2
|
| 245 |
+
from cosyvoice.utils.file_utils import load_wav
|
| 246 |
+
|
| 247 |
+
# 全局模型初始化
|
| 248 |
+
global _cosyvoice_model
|
| 249 |
+
if '_cosyvoice_model' not in globals() or _cosyvoice_model is None:
|
| 250 |
+
_cosyvoice_model = CosyVoice2(model_dir)
|
| 251 |
+
model = _cosyvoice_model
|
| 252 |
+
|
| 253 |
if not reference_audio_path:
|
| 254 |
raise ValueError("cosyvoice-2.0 需要 reference_audio_path")
|
| 255 |
+
# 读取参考音频
|
| 256 |
+
prompt_speech_16k = load_wav(reference_audio_path, 16000)
|
| 257 |
+
# 参考文本可选,这里不做ASR,直接传空字符串
|
| 258 |
+
prompt_text = ""
|
| 259 |
+
# 推理
|
| 260 |
+
result = None
|
| 261 |
+
for i in model.inference_zero_shot(text, prompt_text, prompt_speech_16k):
|
| 262 |
+
result = i['tts_speech'].numpy().flatten()
|
| 263 |
+
# 保存为临时wav
|
| 264 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
| 265 |
+
sf.write(temp_file.name, result, 24000)
|
| 266 |
+
return temp_file.name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
|
| 269 |
def predict_maskgct(text, reference_audio_path=None):
|
| 270 |
from gradio_client import Client, handle_file
|
| 271 |
+
client = Client("cocktailpeanut/maskgct")
|
| 272 |
if not reference_audio_path:
|
| 273 |
raise ValueError("maskgct 需要 reference_audio_path")
|
| 274 |
prompt_wav = handle_file(reference_audio_path)
|