someonecooool commited on
Commit
ac72b53
·
verified ·
1 Parent(s): f505439

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -28
app.py CHANGED
@@ -1,47 +1,88 @@
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
 
 
 
3
  from dia.model import Dia
4
 
5
- MODEL_ID = "nari-labs/Dia-1.6B"
6
  SAMPLE_RATE = 44100
7
 
8
- model = Dia.from_pretrained(MODEL_ID)
 
 
9
 
10
 
11
- def synthesize(text: str):
12
- if not text or not text.strip():
13
- raise gr.Error("Text input is empty.")
14
- audio = model.generate(text)
15
- return SAMPLE_RATE, np.asarray(audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  with gr.Blocks() as demo:
19
- gr.Markdown("# Dia 1.6B Text-to-Speech")
20
 
21
  with gr.Row():
22
- with gr.Column():
23
- text_input = gr.Textbox(
24
- label="Script",
25
- lines=8,
26
- placeholder="[S1] Hello. [S2] This is a reply.",
27
  )
28
- generate_button = gr.Button("Generate")
29
-
30
- gr.Examples(
31
- inputs=text_input,
32
- examples=[
33
- "[S1] Hi, this is Dia speaking. [S2] I can turn your script into a dialogue.",
34
- "[S1] The system is ready. [S2] Please provide a longer script for more natural audio.",
35
- ],
36
  )
37
 
38
- with gr.Column():
39
- audio_output = gr.Audio(label="Output audio", type="numpy")
 
 
 
 
 
 
40
 
41
- generate_button.click(
42
- fn=synthesize,
43
- inputs=text_input,
44
- outputs=audio_output,
 
45
  )
46
 
47
- demo.queue().launch()
 
 
1
+ import tempfile
2
+
3
+ from pathlib import Path
4
+
5
  import gradio as gr
6
  import numpy as np
7
+ import soundfile as sf
8
+ import torch
9
+
10
  from dia.model import Dia
11
 
12
+ MODEL_ID = "nari-labs/Dia-1.6B-0626"
13
  SAMPLE_RATE = 44100
14
 
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ model = Dia.from_pretrained(MODEL_ID, device=device)
18
 
19
 
20
+ def generate_audio(text, audio_prompt):
21
+ if not text or text.isspace():
22
+ raise gr.Error("Text input cannot be empty.")
23
+
24
+ prompt_path = None
25
+
26
+ if audio_prompt is not None:
27
+ sr, audio = audio_prompt
28
+
29
+ if audio is not None and audio.size > 0:
30
+ with tempfile.NamedTemporaryFile(mode="wb", suffix=".wav", delete=False) as f:
31
+ sf.write(f.name, audio, sr)
32
+ prompt_path = f.name
33
+
34
+ audio_out = model.generate(
35
+ text=text,
36
+ audio_prompt=prompt_path,
37
+ max_tokens=None,
38
+ cfg_scale=3.0,
39
+ temperature=1.8,
40
+ top_p=0.95,
41
+ )
42
+
43
+ if prompt_path is not None and Path(prompt_path).exists():
44
+ try:
45
+ Path(prompt_path).unlink()
46
+ except OSError:
47
+ pass
48
+
49
+ audio_np = np.asarray(audio_out, dtype=np.float32)
50
+
51
+ return SAMPLE_RATE, audio_np
52
 
53
 
54
  with gr.Blocks() as demo:
55
+ gr.Markdown("# Dia 1.6B-0626 Text-to-Speech")
56
 
57
  with gr.Row():
58
+ with gr.Column(scale=1):
59
+ text_in = gr.Textbox(
60
+ label="Input text",
61
+ lines=6,
62
+ placeholder="Start with [S1] / [S2] tags, e.g.:\n[S1] Hello. [S2] Hi there.",
63
  )
64
+
65
+ audio_prompt_in = gr.Audio(
66
+ label="Audio prompt (optional, voice cloning)",
67
+ sources=["upload", "microphone"],
68
+ type="numpy",
 
 
 
69
  )
70
 
71
+ btn = gr.Button("Generate", variant="primary")
72
+
73
+ with gr.Column(scale=1):
74
+ audio_out = gr.Audio(
75
+ label="Generated audio",
76
+ type="numpy",
77
+ autoplay=False,
78
+ )
79
 
80
+ btn.click(
81
+ fn=generate_audio,
82
+ inputs=[text_in, audio_prompt_in],
83
+ outputs=[audio_out],
84
+ api_name="generate",
85
  )
86
 
87
+ if __name__ == "__main__":
88
+ demo.launch()