Spaces:
Running
on
Zero
Running
on
Zero
fix
Browse files
app.py
CHANGED
|
@@ -26,6 +26,7 @@ except Exception as e:
|
|
| 26 |
def run_inference(
|
| 27 |
text_input: str,
|
| 28 |
audio_prompt_input: Optional[Tuple[int, np.ndarray]],
|
|
|
|
| 29 |
max_new_tokens: int,
|
| 30 |
cfg_scale: float,
|
| 31 |
temperature: float,
|
|
@@ -50,6 +51,10 @@ def run_inference(
|
|
| 50 |
prompt_path_for_generate = None
|
| 51 |
if audio_prompt_input is not None:
|
| 52 |
sr, audio_data = audio_prompt_input
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
# Check if audio_data is valid
|
| 54 |
if (
|
| 55 |
audio_data is None or audio_data.size == 0 or audio_data.max() == 0
|
|
@@ -117,8 +122,15 @@ def run_inference(
|
|
| 117 |
|
| 118 |
# Use torch.inference_mode() context manager for the generation call
|
| 119 |
with torch.inference_mode():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
output_audio_np = model.generate(
|
| 121 |
-
|
| 122 |
max_tokens=max_new_tokens,
|
| 123 |
cfg_scale=cfg_scale,
|
| 124 |
temperature=temperature,
|
|
@@ -242,11 +254,16 @@ with gr.Blocks(css=css) as demo:
|
|
| 242 |
lines=5, # Increased lines
|
| 243 |
)
|
| 244 |
audio_prompt_input = gr.Audio(
|
| 245 |
-
label="Audio Prompt (Optional)",
|
| 246 |
show_label=True,
|
| 247 |
sources=["upload", "microphone"],
|
| 248 |
type="numpy",
|
| 249 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
with gr.Accordion("Generation Parameters", open=False):
|
| 251 |
max_new_tokens = gr.Slider(
|
| 252 |
label="Max New Tokens (Audio Length)",
|
|
@@ -312,6 +329,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 312 |
inputs=[
|
| 313 |
text_input,
|
| 314 |
audio_prompt_input,
|
|
|
|
| 315 |
max_new_tokens,
|
| 316 |
cfg_scale,
|
| 317 |
temperature,
|
|
@@ -350,10 +368,19 @@ with gr.Blocks(css=css) as demo:
|
|
| 350 |
|
| 351 |
if examples_list:
|
| 352 |
gr.Examples(
|
| 353 |
-
examples=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
inputs=[
|
| 355 |
text_input,
|
| 356 |
audio_prompt_input,
|
|
|
|
| 357 |
max_new_tokens,
|
| 358 |
cfg_scale,
|
| 359 |
temperature,
|
|
|
|
| 26 |
def run_inference(
|
| 27 |
text_input: str,
|
| 28 |
audio_prompt_input: Optional[Tuple[int, np.ndarray]],
|
| 29 |
+
transcription_input: Optional[str],
|
| 30 |
max_new_tokens: int,
|
| 31 |
cfg_scale: float,
|
| 32 |
temperature: float,
|
|
|
|
| 51 |
prompt_path_for_generate = None
|
| 52 |
if audio_prompt_input is not None:
|
| 53 |
sr, audio_data = audio_prompt_input
|
| 54 |
+
# Enforce maximum duration of 10 seconds for the audio prompt
|
| 55 |
+
duration_sec = len(audio_data) / float(sr) if sr else 0
|
| 56 |
+
if duration_sec > 10.0:
|
| 57 |
+
raise gr.Error("Audio prompt must be 10 seconds or shorter.")
|
| 58 |
# Check if audio_data is valid
|
| 59 |
if (
|
| 60 |
audio_data is None or audio_data.size == 0 or audio_data.max() == 0
|
|
|
|
| 122 |
|
| 123 |
# Use torch.inference_mode() context manager for the generation call
|
| 124 |
with torch.inference_mode():
|
| 125 |
+
# Concatenate transcription (if provided) to the main text
|
| 126 |
+
combined_text = (
|
| 127 |
+
text_input.strip() + "\n" + transcription_input.strip()
|
| 128 |
+
if transcription_input and not transcription_input.isspace()
|
| 129 |
+
else text_input
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
output_audio_np = model.generate(
|
| 133 |
+
combined_text,
|
| 134 |
max_tokens=max_new_tokens,
|
| 135 |
cfg_scale=cfg_scale,
|
| 136 |
temperature=temperature,
|
|
|
|
| 254 |
lines=5, # Increased lines
|
| 255 |
)
|
| 256 |
audio_prompt_input = gr.Audio(
|
| 257 |
+
label="Audio Prompt (≤ 10 s, Optional)",
|
| 258 |
show_label=True,
|
| 259 |
sources=["upload", "microphone"],
|
| 260 |
type="numpy",
|
| 261 |
)
|
| 262 |
+
transcription_input = gr.Textbox(
|
| 263 |
+
label="Audio Prompt Transcription (Optional)",
|
| 264 |
+
placeholder="Enter transcription of your audio prompt here...",
|
| 265 |
+
lines=3,
|
| 266 |
+
)
|
| 267 |
with gr.Accordion("Generation Parameters", open=False):
|
| 268 |
max_new_tokens = gr.Slider(
|
| 269 |
label="Max New Tokens (Audio Length)",
|
|
|
|
| 329 |
inputs=[
|
| 330 |
text_input,
|
| 331 |
audio_prompt_input,
|
| 332 |
+
transcription_input,
|
| 333 |
max_new_tokens,
|
| 334 |
cfg_scale,
|
| 335 |
temperature,
|
|
|
|
| 368 |
|
| 369 |
if examples_list:
|
| 370 |
gr.Examples(
|
| 371 |
+
examples=[
|
| 372 |
+
[
|
| 373 |
+
ex[0], # text
|
| 374 |
+
ex[1], # audio prompt path
|
| 375 |
+
"", # transcription placeholder
|
| 376 |
+
*ex[2:],
|
| 377 |
+
]
|
| 378 |
+
for ex in examples_list
|
| 379 |
+
],
|
| 380 |
inputs=[
|
| 381 |
text_input,
|
| 382 |
audio_prompt_input,
|
| 383 |
+
transcription_input,
|
| 384 |
max_new_tokens,
|
| 385 |
cfg_scale,
|
| 386 |
temperature,
|