sbv2_chupa_demo / app.py
litagin's picture
init
70c3683
import datetime
from pathlib import Path
import gradio as gr
import random
from style_bert_vits2.constants import (
DEFAULT_LENGTH,
DEFAULT_LINE_SPLIT,
DEFAULT_NOISE,
DEFAULT_NOISEW,
DEFAULT_SPLIT_INTERVAL,
)
from style_bert_vits2.logging import logger
from style_bert_vits2.models.infer import InvalidToneError
from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk
from style_bert_vits2.tts_model import TTSModelHolder
pyopenjtalk.initialize_worker()
example_file = "chupa_examples.txt"
initial_text = (
"ใกใ‚…ใฑใ€ใกใ‚…ใ‚‹ใ‚‹ใ€ใขใ‚…ใ€ใ‚“ใใ€ใ‚Œใƒผใ‚Œใ‚…ใ‚Œใ‚ใ‚Œใ‚ใ‚Œใ‚ใ€ใ˜ใ‚…ใฝใฝใฝใฝใฝโ€ฆโ€ฆใกใ‚…ใ†ใ†ใ†๏ผ"
)
with open(example_file, "r", encoding="utf-8") as f:
examples = f.read().splitlines()
def get_random_text() -> str:
return random.choice(examples)
initial_md = """
# ใƒใƒฅใƒ‘้Ÿณๅˆๆˆใƒ‡ใƒข
2024-07-07: initial ver
"""
def make_interactive():
return gr.update(interactive=True, value="้Ÿณๅฃฐๅˆๆˆ")
def make_non_interactive():
return gr.update(interactive=False, value="้Ÿณๅฃฐๅˆๆˆ๏ผˆใƒขใƒ‡ใƒซใ‚’ใƒญใƒผใƒ‰ใ—ใฆใใ ใ•ใ„๏ผ‰")
def gr_util(item):
if item == "ใƒ—ใƒชใ‚ปใƒƒใƒˆใ‹ใ‚‰้ธใถ":
return (gr.update(visible=True), gr.Audio(visible=False, value=None))
else:
return (gr.update(visible=False), gr.update(visible=True))
def create_inference_app(model_holder: TTSModelHolder) -> gr.Blocks:
def tts_fn(
model_name,
model_path,
text,
language,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
line_split,
split_interval,
speaker,
):
model_holder.get_model(model_name, model_path)
assert model_holder.current_model is not None
speaker_id = model_holder.current_model.spk2id[speaker]
start_time = datetime.datetime.now()
try:
sr, audio = model_holder.current_model.infer(
text=text,
language=language,
sdp_ratio=sdp_ratio,
noise=noise_scale,
noise_w=noise_scale_w,
length=length_scale,
line_split=line_split,
split_interval=split_interval,
speaker_id=speaker_id,
)
except InvalidToneError as e:
logger.error(f"Tone error: {e}")
return f"Error: ใ‚ขใ‚ฏใ‚ปใƒณใƒˆๆŒ‡ๅฎšใŒไธๆญฃใงใ™:\n{e}", None
except ValueError as e:
logger.error(f"Value error: {e}")
return f"Error: {e}", None
end_time = datetime.datetime.now()
duration = (end_time - start_time).total_seconds()
message = f"Success, time: {duration} seconds."
return message, (sr, audio)
def get_model_files(model_name: str):
return [str(f) for f in model_holder.model_files_dict[model_name]]
model_names = model_holder.model_names
if len(model_names) == 0:
logger.error(
f"ใƒขใƒ‡ใƒซใŒ่ฆ‹ใคใ‹ใ‚Šใพใ›ใ‚“ใงใ—ใŸใ€‚{model_holder.root_dir}ใซใƒขใƒ‡ใƒซใ‚’็ฝฎใ„ใฆใใ ใ•ใ„ใ€‚"
)
with gr.Blocks() as app:
gr.Markdown(
f"Error: ใƒขใƒ‡ใƒซใŒ่ฆ‹ใคใ‹ใ‚Šใพใ›ใ‚“ใงใ—ใŸใ€‚{model_holder.root_dir}ใซใƒขใƒ‡ใƒซใ‚’็ฝฎใ„ใฆใใ ใ•ใ„ใ€‚"
)
return app
initial_pth_files = get_model_files(model_names[0])
model = model_holder.get_model(model_names[0], initial_pth_files[0])
speakers = list(model.spk2id.keys())
with gr.Blocks(theme="ParityError/Anime") as app:
gr.Markdown(initial_md)
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column(scale=3):
model_name = gr.Dropdown(
label="ใƒขใƒ‡ใƒซไธ€่ฆง",
choices=model_names,
value=model_names[0],
)
model_path = gr.Dropdown(
label="ใƒขใƒ‡ใƒซใƒ•ใ‚กใ‚คใƒซ",
choices=initial_pth_files,
value=initial_pth_files[0],
)
refresh_button = gr.Button("ๆ›ดๆ–ฐ", scale=1, visible=False)
load_button = gr.Button("ใƒญใƒผใƒ‰", scale=1, variant="primary")
with gr.Row():
text_input = gr.TextArea(
label="ใƒ†ใ‚ญใ‚นใƒˆ", value=initial_text, scale=3
)
random_button = gr.Button("ไพ‹ใ‹ใ‚‰้ธใถ ๐ŸŽฒ", scale=1)
random_button.click(get_random_text, outputs=[text_input])
with gr.Row():
length_scale = gr.Slider(
minimum=0.1,
maximum=2,
value=DEFAULT_LENGTH,
step=0.1,
label="็”Ÿๆˆ้Ÿณๅฃฐใฎ้•ทใ•๏ผˆLength๏ผ‰",
)
sdp_ratio = gr.Slider(
minimum=0,
maximum=1,
value=1,
step=0.1,
label="SDP Ratio",
)
line_split = gr.Checkbox(
label="ๆ”น่กŒใงๅˆ†ใ‘ใฆ็”Ÿๆˆ๏ผˆๅˆ†ใ‘ใŸใปใ†ใŒๆ„Ÿๆƒ…ใŒไน—ใ‚Šใพใ™๏ผ‰",
value=DEFAULT_LINE_SPLIT,
visible=False,
)
split_interval = gr.Slider(
minimum=0.0,
maximum=2,
value=DEFAULT_SPLIT_INTERVAL,
step=0.1,
label="ๆ”น่กŒใ”ใจใซๆŒŸใ‚€็„ก้Ÿณใฎ้•ทใ•๏ผˆ็ง’๏ผ‰",
)
line_split.change(
lambda x: (gr.Slider(visible=x)),
inputs=[line_split],
outputs=[split_interval],
)
language = gr.Dropdown(
choices=["JP"], value="JP", label="Language", visible=False
)
speaker = gr.Dropdown(label="่ฉฑ่€…", choices=speakers, value=speakers[0])
with gr.Accordion(label="่ฉณ็ดฐ่จญๅฎš", open=True):
noise_scale = gr.Slider(
minimum=0.1,
maximum=2,
value=DEFAULT_NOISE,
step=0.1,
label="Noise",
)
noise_scale_w = gr.Slider(
minimum=0.1,
maximum=2,
value=DEFAULT_NOISEW,
step=0.1,
label="Noise_W",
)
with gr.Column():
tts_button = gr.Button("้Ÿณๅฃฐๅˆๆˆ", variant="primary")
text_output = gr.Textbox(label="ๆƒ…ๅ ฑ")
audio_output = gr.Audio(label="็ตๆžœ")
tts_button.click(
tts_fn,
inputs=[
model_name,
model_path,
text_input,
language,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
line_split,
split_interval,
speaker,
],
outputs=[text_output, audio_output],
)
model_name.change(
model_holder.update_model_files_for_gradio,
inputs=[model_name],
outputs=[model_path],
)
model_path.change(make_non_interactive, outputs=[tts_button])
refresh_button.click(
model_holder.update_model_names_for_gradio,
outputs=[model_name, model_path, tts_button],
)
style = gr.Dropdown(label="ใ‚นใ‚ฟใ‚คใƒซ", choices=[], visible=False)
load_button.click(
model_holder.get_model_for_gradio,
inputs=[model_name, model_path],
outputs=[style, tts_button, speaker],
)
return app
if __name__ == "__main__":
import torch
from style_bert_vits2.constants import Languages
from style_bert_vits2.nlp import bert_models
bert_models.load_model(Languages.JP)
bert_models.load_tokenizer(Languages.JP)
device = "cuda" if torch.cuda.is_available() else "cpu"
model_holder = TTSModelHolder(Path("model_assets"), device)
app = create_inference_app(model_holder)
app.launch(inbrowser=True)