Spaces:
Runtime error
Runtime error
Commit
·
ecd3224
1
Parent(s):
bd8dcd1
perf: model lazy load
Browse files
app.py
CHANGED
|
@@ -27,38 +27,48 @@ from huggingface_hub import HfApi
|
|
| 27 |
|
| 28 |
# will use api to restart space on a unrecoverable error
|
| 29 |
api = HfApi(token=HF_TOKEN)
|
| 30 |
-
repo_id = "
|
| 31 |
|
| 32 |
-
|
|
|
|
| 33 |
|
| 34 |
-
model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
|
| 35 |
-
model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
|
|
|
| 39 |
|
| 40 |
-
model
|
| 41 |
-
model.load_checkpoint(
|
| 42 |
-
config,
|
| 43 |
-
checkpoint_path=os.path.join(model_path, "model.pth"),
|
| 44 |
-
vocab_path=os.path.join(model_path, "vocab.json"),
|
| 45 |
-
eval=True,
|
| 46 |
-
use_deepspeed=False,
|
| 47 |
-
)
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
print("Model loaded")
|
| 55 |
|
| 56 |
# This is for debugging purposes only
|
| 57 |
DEVICE_ASSERT_DETECTED = 0
|
| 58 |
DEVICE_ASSERT_PROMPT = None
|
| 59 |
DEVICE_ASSERT_LANG = None
|
| 60 |
|
| 61 |
-
supported_languages = config.languages
|
| 62 |
|
| 63 |
def predict(
|
| 64 |
prompt,
|
|
@@ -68,6 +78,9 @@ def predict(
|
|
| 68 |
no_lang_auto_detect,
|
| 69 |
agree,
|
| 70 |
):
|
|
|
|
|
|
|
|
|
|
| 71 |
if agree == True:
|
| 72 |
if language not in supported_languages:
|
| 73 |
gr.Warning(
|
|
@@ -184,7 +197,7 @@ def predict(
|
|
| 184 |
|
| 185 |
# HF Space specific.. This error is unrecoverable need to restart space
|
| 186 |
space = api.get_space_runtime(repo_id=repo_id)
|
| 187 |
-
if space.stage!="BUILDING":
|
| 188 |
api.restart_space(repo_id=repo_id)
|
| 189 |
else:
|
| 190 |
print("TRIED TO RESTART but space is building")
|
|
@@ -198,7 +211,9 @@ def predict(
|
|
| 198 |
(
|
| 199 |
gpt_cond_latent,
|
| 200 |
speaker_embedding,
|
| 201 |
-
) = model.get_conditioning_latents(
|
|
|
|
|
|
|
| 202 |
except Exception as e:
|
| 203 |
print("Speaker encoding error", str(e))
|
| 204 |
gr.Warning(
|
|
@@ -215,7 +230,7 @@ def predict(
|
|
| 215 |
# metrics_text=f"Embedding calculation time: {latent_calculation_time:.2f} seconds\n"
|
| 216 |
|
| 217 |
# temporary comma fix
|
| 218 |
-
prompt= re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)",r"\1 \2\2",prompt)
|
| 219 |
|
| 220 |
wav_chunks = []
|
| 221 |
## Direct mode
|
|
@@ -260,9 +275,9 @@ def predict(
|
|
| 260 |
print(
|
| 261 |
f"I: Time to generate audio: {round(inference_time*1000)} milliseconds"
|
| 262 |
)
|
| 263 |
-
#metrics_text += (
|
| 264 |
# f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
|
| 265 |
-
#)
|
| 266 |
|
| 267 |
wav = torch.cat(wav_chunks, dim=0)
|
| 268 |
print(wav.shape)
|
|
@@ -330,11 +345,11 @@ def predict(
|
|
| 330 |
|
| 331 |
# HF Space specific.. This error is unrecoverable need to restart space
|
| 332 |
space = api.get_space_runtime(repo_id=repo_id)
|
| 333 |
-
if space.stage!="BUILDING":
|
| 334 |
api.restart_space(repo_id=repo_id)
|
| 335 |
else:
|
| 336 |
print("TRIED TO RESTART but space is building")
|
| 337 |
-
|
| 338 |
else:
|
| 339 |
if "Failed to decode" in str(e):
|
| 340 |
print("Speaker encoding error", str(e))
|
|
@@ -459,7 +474,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
|
|
| 459 |
"zh-cn",
|
| 460 |
"ja",
|
| 461 |
"ko",
|
| 462 |
-
"hu"
|
| 463 |
],
|
| 464 |
value="en",
|
| 465 |
)
|
|
@@ -487,14 +502,17 @@ with gr.Blocks(analytics_enabled=False) as demo:
|
|
| 487 |
|
| 488 |
tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
|
| 489 |
|
| 490 |
-
|
| 491 |
with gr.Column():
|
| 492 |
video_gr = gr.Video(label="Waveform Visual")
|
| 493 |
audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
|
| 494 |
out_text_gr = gr.Text(label="Metrics")
|
| 495 |
ref_audio_gr = gr.Audio(label="Reference Audio Used")
|
| 496 |
|
| 497 |
-
tts_button.click(
|
| 498 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
print("Starting server")
|
| 500 |
demo.queue().launch(debug=True, show_api=True)
|
|
|
|
| 27 |
|
| 28 |
# will use api to restart space on a unrecoverable error
|
| 29 |
api = HfApi(token=HF_TOKEN)
|
| 30 |
+
repo_id = "JacobLinCool/xtts-v2"
|
| 31 |
|
| 32 |
+
model = None
|
| 33 |
+
supported_languages = None
|
| 34 |
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
def load_model():
|
| 37 |
+
global model
|
| 38 |
+
global supported_languages
|
| 39 |
|
| 40 |
+
print("loading model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
|
| 43 |
+
model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
|
| 44 |
+
|
| 45 |
+
config = XttsConfig()
|
| 46 |
+
config.load_json(os.path.join(model_path, "config.json"))
|
| 47 |
+
|
| 48 |
+
model = Xtts.init_from_config(config)
|
| 49 |
+
model.load_checkpoint(
|
| 50 |
+
config,
|
| 51 |
+
checkpoint_path=os.path.join(model_path, "model.pth"),
|
| 52 |
+
vocab_path=os.path.join(model_path, "vocab.json"),
|
| 53 |
+
eval=True,
|
| 54 |
+
use_deepspeed=False,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if torch.cuda.is_available():
|
| 58 |
+
model.cuda()
|
| 59 |
+
else:
|
| 60 |
+
model.cpu()
|
| 61 |
+
|
| 62 |
+
supported_languages = config.languages
|
| 63 |
+
|
| 64 |
+
print("Model loaded")
|
| 65 |
|
|
|
|
| 66 |
|
| 67 |
# This is for debugging purposes only
|
| 68 |
DEVICE_ASSERT_DETECTED = 0
|
| 69 |
DEVICE_ASSERT_PROMPT = None
|
| 70 |
DEVICE_ASSERT_LANG = None
|
| 71 |
|
|
|
|
| 72 |
|
| 73 |
def predict(
|
| 74 |
prompt,
|
|
|
|
| 78 |
no_lang_auto_detect,
|
| 79 |
agree,
|
| 80 |
):
|
| 81 |
+
if model is None:
|
| 82 |
+
load_model()
|
| 83 |
+
|
| 84 |
if agree == True:
|
| 85 |
if language not in supported_languages:
|
| 86 |
gr.Warning(
|
|
|
|
| 197 |
|
| 198 |
# HF Space specific.. This error is unrecoverable need to restart space
|
| 199 |
space = api.get_space_runtime(repo_id=repo_id)
|
| 200 |
+
if space.stage != "BUILDING":
|
| 201 |
api.restart_space(repo_id=repo_id)
|
| 202 |
else:
|
| 203 |
print("TRIED TO RESTART but space is building")
|
|
|
|
| 211 |
(
|
| 212 |
gpt_cond_latent,
|
| 213 |
speaker_embedding,
|
| 214 |
+
) = model.get_conditioning_latents(
|
| 215 |
+
audio_path=speaker_wav, gpt_cond_len=30, max_ref_length=60
|
| 216 |
+
)
|
| 217 |
except Exception as e:
|
| 218 |
print("Speaker encoding error", str(e))
|
| 219 |
gr.Warning(
|
|
|
|
| 230 |
# metrics_text=f"Embedding calculation time: {latent_calculation_time:.2f} seconds\n"
|
| 231 |
|
| 232 |
# temporary comma fix
|
| 233 |
+
prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
|
| 234 |
|
| 235 |
wav_chunks = []
|
| 236 |
## Direct mode
|
|
|
|
| 275 |
print(
|
| 276 |
f"I: Time to generate audio: {round(inference_time*1000)} milliseconds"
|
| 277 |
)
|
| 278 |
+
# metrics_text += (
|
| 279 |
# f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
|
| 280 |
+
# )
|
| 281 |
|
| 282 |
wav = torch.cat(wav_chunks, dim=0)
|
| 283 |
print(wav.shape)
|
|
|
|
| 345 |
|
| 346 |
# HF Space specific.. This error is unrecoverable need to restart space
|
| 347 |
space = api.get_space_runtime(repo_id=repo_id)
|
| 348 |
+
if space.stage != "BUILDING":
|
| 349 |
api.restart_space(repo_id=repo_id)
|
| 350 |
else:
|
| 351 |
print("TRIED TO RESTART but space is building")
|
| 352 |
+
|
| 353 |
else:
|
| 354 |
if "Failed to decode" in str(e):
|
| 355 |
print("Speaker encoding error", str(e))
|
|
|
|
| 474 |
"zh-cn",
|
| 475 |
"ja",
|
| 476 |
"ko",
|
| 477 |
+
"hu",
|
| 478 |
],
|
| 479 |
value="en",
|
| 480 |
)
|
|
|
|
| 502 |
|
| 503 |
tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
|
| 504 |
|
|
|
|
| 505 |
with gr.Column():
|
| 506 |
video_gr = gr.Video(label="Waveform Visual")
|
| 507 |
audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
|
| 508 |
out_text_gr = gr.Text(label="Metrics")
|
| 509 |
ref_audio_gr = gr.Audio(label="Reference Audio Used")
|
| 510 |
|
| 511 |
+
tts_button.click(
|
| 512 |
+
predict,
|
| 513 |
+
[input_text_gr, language_gr, ref_gr, clean_ref_gr, auto_det_lang_gr, tos_gr],
|
| 514 |
+
outputs=[video_gr, audio_gr, out_text_gr, ref_audio_gr],
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
print("Starting server")
|
| 518 |
demo.queue().launch(debug=True, show_api=True)
|