ekwek commited on
Commit
7e66c78
·
verified ·
1 Parent(s): 63d4ab6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -0
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from soprano import SopranoTTS
5
+ from scipy.io.wavfile import write as wav_write
6
+ import tempfile
7
+ import os
8
+
9
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+ print(DEVICE)
11
+
12
+ # Load model once
13
+ model = SopranoTTS(
14
+ backend="auto",
15
+ device=DEVICE,
16
+ cache_size_mb=100,
17
+ decoder_batch_size=1,
18
+ )
19
+
20
+ SAMPLE_RATE = 32000
21
+
22
+
23
+ def tts_stream(text, temperature, top_p, repetition_penalty, state):
24
+ if not text.strip():
25
+ yield None, state
26
+ return
27
+
28
+ chunks = []
29
+ stream = model.infer_stream(
30
+ text,
31
+ chunk_size=1,
32
+ temperature=temperature,
33
+ top_p=top_p,
34
+ repetition_penalty=repetition_penalty,
35
+ )
36
+
37
+ for chunk in stream:
38
+ if isinstance(chunk, torch.Tensor):
39
+ audio_np = chunk.detach().cpu().numpy().astype(np.float32)
40
+ chunks.append(audio_np)
41
+ # stream partial audio
42
+ yield (SAMPLE_RATE, audio_np), np.concatenate(chunks)
43
+
44
+ if chunks:
45
+ final_audio = np.concatenate(chunks)
46
+ yield (SAMPLE_RATE, final_audio), final_audio
47
+
48
+
49
+ def save_audio(state):
50
+ if state is None or len(state) == 0:
51
+ return None
52
+ fd, path = tempfile.mkstemp(suffix=".wav")
53
+ os.close(fd)
54
+ wav_write(path, SAMPLE_RATE, state)
55
+ return path
56
+
57
+
58
+ with gr.Blocks() as demo:
59
+ state_audio = gr.State(None)
60
+
61
+ with gr.Row():
62
+ with gr.Column():
63
+ gr.Markdown("## Soprano Demo")
64
+
65
+ text_in = gr.Textbox(
66
+ label="Input Text",
67
+ placeholder="Enter text to synthesize...",
68
+ lines=4,
69
+ )
70
+
71
+ with gr.Accordion("Advanced options", open=False):
72
+ temperature = gr.Slider(
73
+ 0.0, 1.0, value=0.3, step=0.05, label="Temperature"
74
+ )
75
+ top_p = gr.Slider(
76
+ 0.0, 1.0, value=0.95, step=0.01, label="Top-p"
77
+ )
78
+ repetition_penalty = gr.Slider(
79
+ 0.5, 2.0, value=1.2, step=0.05, label="Repetition penalty"
80
+ )
81
+
82
+ gen_btn = gr.Button("Generate")
83
+
84
+ with gr.Column():
85
+ audio_out = gr.Audio(
86
+ label="Output Audio",
87
+ autoplay=True,
88
+ streaming=True,
89
+ )
90
+ download_btn = gr.Button("Download")
91
+ file_out = gr.File(label="Download file")
92
+ gr.Markdown(
93
+ "Usage tips: (placeholder)\n\n"
94
+ "- Tip 1\n"
95
+ "- Tip 2\n"
96
+ "- Tip 3"
97
+ )
98
+
99
+ gen_btn.click(
100
+ fn=tts_stream,
101
+ inputs=[text_in, temperature, top_p, repetition_penalty, state_audio],
102
+ outputs=[audio_out, state_audio],
103
+ )
104
+
105
+ download_btn.click(
106
+ fn=save_audio,
107
+ inputs=[state_audio],
108
+ outputs=[file_out],
109
+ )
110
+
111
+ demo.queue()
112
+ demo.launch()