Spaces:
Runtime error
Runtime error
Commit
·
fcd2e2b
1
Parent(s):
17225b6
update
Browse files- cosyvoice/cli/cosyvoice.py +1 -1
- cosyvoice/cli/frontend.py +4 -1
- requirements.txt +2 -2
cosyvoice/cli/cosyvoice.py
CHANGED
|
@@ -148,7 +148,7 @@ class CosyVoice2(CosyVoice):
|
|
| 148 |
model_dir = snapshot_download(model_dir)
|
| 149 |
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
| 150 |
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
| 151 |
-
print(f"Loading configs:{configs}")
|
| 152 |
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
| 153 |
configs['feat_extractor'],
|
| 154 |
'{}/campplus.onnx'.format(model_dir),
|
|
|
|
| 148 |
model_dir = snapshot_download(model_dir)
|
| 149 |
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
| 150 |
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
| 151 |
+
# print(f"Loading configs:{configs}")
|
| 152 |
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
| 153 |
configs['feat_extractor'],
|
| 154 |
'{}/campplus.onnx'.format(model_dir),
|
cosyvoice/cli/frontend.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
| 14 |
from functools import partial
|
| 15 |
import json
|
| 16 |
import onnxruntime
|
|
|
|
| 17 |
import torch
|
| 18 |
import numpy as np
|
| 19 |
import whisper
|
|
@@ -35,7 +36,7 @@ from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, repl
|
|
| 35 |
|
| 36 |
|
| 37 |
class CosyVoiceFrontEnd:
|
| 38 |
-
|
| 39 |
def __init__(self,
|
| 40 |
get_tokenizer: Callable,
|
| 41 |
feat_extractor: Callable,
|
|
@@ -51,9 +52,11 @@ class CosyVoiceFrontEnd:
|
|
| 51 |
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 52 |
option.intra_op_num_threads = 1
|
| 53 |
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
|
|
|
| 54 |
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
| 55 |
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
| 56 |
"CPUExecutionProvider"])
|
|
|
|
| 57 |
if os.path.exists(spk2info):
|
| 58 |
self.spk2info = torch.load(spk2info, map_location=self.device)
|
| 59 |
else:
|
|
|
|
| 14 |
from functools import partial
|
| 15 |
import json
|
| 16 |
import onnxruntime
|
| 17 |
+
import spaces
|
| 18 |
import torch
|
| 19 |
import numpy as np
|
| 20 |
import whisper
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
class CosyVoiceFrontEnd:
|
| 39 |
+
@spaces.GPU
|
| 40 |
def __init__(self,
|
| 41 |
get_tokenizer: Callable,
|
| 42 |
feat_extractor: Callable,
|
|
|
|
| 52 |
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 53 |
option.intra_op_num_threads = 1
|
| 54 |
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
| 55 |
+
print("load campplus model from {}".format(campplus_model))
|
| 56 |
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
| 57 |
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
| 58 |
"CPUExecutionProvider"])
|
| 59 |
+
print("load speech-tokenizer model from {}".format(speech_tokenizer_model))
|
| 60 |
if os.path.exists(spk2info):
|
| 61 |
self.spk2info = torch.load(spk2info, map_location=self.device)
|
| 62 |
else:
|
requirements.txt
CHANGED
|
@@ -34,8 +34,8 @@ tensorboard==2.14.0
|
|
| 34 |
tensorrt-cu12==10.0.1; sys_platform == 'linux'
|
| 35 |
tensorrt-cu12-bindings==10.0.1; sys_platform == 'linux'
|
| 36 |
tensorrt-cu12-libs==10.0.1; sys_platform == 'linux'
|
| 37 |
-
torch
|
| 38 |
-
torchaudio
|
| 39 |
transformers==4.40.1
|
| 40 |
uvicorn==0.30.0
|
| 41 |
wget
|
|
|
|
| 34 |
tensorrt-cu12==10.0.1; sys_platform == 'linux'
|
| 35 |
tensorrt-cu12-bindings==10.0.1; sys_platform == 'linux'
|
| 36 |
tensorrt-cu12-libs==10.0.1; sys_platform == 'linux'
|
| 37 |
+
torch==2.3.1+cu121
|
| 38 |
+
torchaudio==2.3.1+cu121
|
| 39 |
transformers==4.40.1
|
| 40 |
uvicorn==0.30.0
|
| 41 |
wget
|