someonecooool's picture
Update app.py
ac72b53 verified
raw
history blame
2.18 kB
import tempfile
from pathlib import Path
import gradio as gr
import numpy as np
import soundfile as sf
import torch
from dia.model import Dia
MODEL_ID = "nari-labs/Dia-1.6B-0626"
SAMPLE_RATE = 44100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Dia.from_pretrained(MODEL_ID, device=device)
def generate_audio(text, audio_prompt):
if not text or text.isspace():
raise gr.Error("Text input cannot be empty.")
prompt_path = None
if audio_prompt is not None:
sr, audio = audio_prompt
if audio is not None and audio.size > 0:
with tempfile.NamedTemporaryFile(mode="wb", suffix=".wav", delete=False) as f:
sf.write(f.name, audio, sr)
prompt_path = f.name
audio_out = model.generate(
text=text,
audio_prompt=prompt_path,
max_tokens=None,
cfg_scale=3.0,
temperature=1.8,
top_p=0.95,
)
if prompt_path is not None and Path(prompt_path).exists():
try:
Path(prompt_path).unlink()
except OSError:
pass
audio_np = np.asarray(audio_out, dtype=np.float32)
return SAMPLE_RATE, audio_np
with gr.Blocks() as demo:
gr.Markdown("# Dia 1.6B-0626 Text-to-Speech")
with gr.Row():
with gr.Column(scale=1):
text_in = gr.Textbox(
label="Input text",
lines=6,
placeholder="Start with [S1] / [S2] tags, e.g.:\n[S1] Hello. [S2] Hi there.",
)
audio_prompt_in = gr.Audio(
label="Audio prompt (optional, voice cloning)",
sources=["upload", "microphone"],
type="numpy",
)
btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
audio_out = gr.Audio(
label="Generated audio",
type="numpy",
autoplay=False,
)
btn.click(
fn=generate_audio,
inputs=[text_in, audio_prompt_in],
outputs=[audio_out],
api_name="generate",
)
if __name__ == "__main__":
demo.launch()