Respair commited on
Commit
6c7edb1
·
verified ·
1 Parent(s): a6e3096

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -8
app.py CHANGED
@@ -29,8 +29,50 @@ code {
29
  color: #b45309 !important;
30
  font-weight: 600;
31
  }
 
 
 
 
 
 
 
 
 
 
 
 
32
  """
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def load_examples(csv_path):
35
  examples = []
36
 
@@ -95,9 +137,17 @@ def run_generation_pipeline_client(
95
  top_k,
96
  temperature,
97
  use_chained_longform,
98
- seed # Add seed parameter
 
99
  ):
100
  try:
 
 
 
 
 
 
 
101
  # Handle audio prompt - save to temporary file if provided
102
  audio_prompt_for_api = None
103
 
@@ -173,9 +223,17 @@ def run_duration_generation_pipeline_client(
173
  add_steps,
174
  use_duration_aware,
175
  chars_per_second,
176
- seed # Add seed parameter
 
177
  ):
178
  try:
 
 
 
 
 
 
 
179
  # Handle audio prompt - save to temporary file if provided
180
  audio_prompt_for_api = None
181
 
@@ -242,6 +300,15 @@ def run_duration_generation_pipeline_client(
242
  except Exception as e:
243
  return None, f"Status: Connection error: {str(e)}"
244
 
 
 
 
 
 
 
 
 
 
245
  # Load examples
246
  examples_csv_path = "./samples.csv" # Adjust path as needed for client side
247
  example_list = load_examples(examples_csv_path)
@@ -307,10 +374,19 @@ with gr.Blocks(theme="Respair/Shiki@9.1.0", css=css) as demo:
307
  value=False
308
  )
309
  audio_prompt_input = gr.Audio(
310
- label="Audio Prompt (Optional - オプション)",
311
  sources=["upload", "microphone"],
312
  type="numpy"
313
  )
 
 
 
 
 
 
 
 
 
314
 
315
  # Turbo mode event handler
316
  def toggle_turbo(turbo_enabled):
@@ -332,7 +408,7 @@ with gr.Blocks(theme="Respair/Shiki@9.1.0", css=css) as demo:
332
  status_output = gr.Textbox(label="Status", interactive=False)
333
  audio_output = gr.Audio(label="Generated Speech", interactive=False, show_download_button=True)
334
 
335
- # Event handler
336
  generate_button.click(
337
  fn=run_generation_pipeline_client,
338
  inputs=[
@@ -343,7 +419,8 @@ with gr.Blocks(theme="Respair/Shiki@9.1.0", css=css) as demo:
343
  top_k_slider,
344
  temperature_slider,
345
  chained_longform_checkbox,
346
- seed_slider # Add seed slider to inputs
 
347
  ],
348
  outputs=[audio_output, status_output],
349
  concurrency_limit=4 # Limit concurrent requests
@@ -434,10 +511,19 @@ with gr.Blocks(theme="Respair/Shiki@9.1.0", css=css) as demo:
434
  value=False
435
  )
436
  audio_prompt_input_dur = gr.Audio(
437
- label="Audio Prompt (Optional - オプション)",
438
  sources=["upload", "microphone"],
439
  type="numpy"
440
  )
 
 
 
 
 
 
 
 
 
441
 
442
  # Turbo mode event handler for duration tab
443
  def toggle_turbo_dur(turbo_enabled):
@@ -459,7 +545,7 @@ with gr.Blocks(theme="Respair/Shiki@9.1.0", css=css) as demo:
459
  status_output_dur = gr.Textbox(label="Status", interactive=False)
460
  audio_output_dur = gr.Audio(label="Generated Speech", interactive=False, show_download_button=True)
461
 
462
- # Event handler for duration tab
463
  generate_button_dur.click(
464
  fn=run_duration_generation_pipeline_client,
465
  inputs=[
@@ -473,7 +559,8 @@ with gr.Blocks(theme="Respair/Shiki@9.1.0", css=css) as demo:
473
  add_steps_slider,
474
  use_duration_aware_checkbox,
475
  chars_per_second_slider,
476
- seed_slider_dur # Add seed slider to inputs
 
477
  ],
478
  outputs=[audio_output_dur, status_output_dur],
479
  concurrency_limit=4 # Limit concurrent requests
 
29
  color: #b45309 !important;
30
  font-weight: 600;
31
  }
32
+
33
+ .audio-warning {
34
+ color: #ff6b35 !important;
35
+ font-weight: 600;
36
+ margin: 10px 0;
37
+ }
38
+
39
+ .audio-error {
40
+ color: #dc2626 !important;
41
+ font-weight: 600;
42
+ margin: 10px 0;
43
+ }
44
  """
45
 
46
+ def validate_audio_duration(audio_data):
47
+ """
48
+ Validate audio duration and return appropriate message
49
+ Returns: (is_valid, warning_message)
50
+ """
51
+ if audio_data is None:
52
+ return True, ""
53
+
54
+ sample_rate, audio_array = audio_data
55
+ duration_seconds = len(audio_array) / sample_rate
56
+
57
+ if duration_seconds > 10:
58
+ error_msg = f"""
59
+ <div class="audio-error">
60
+ ❌ Error: Audio is {duration_seconds:.1f} seconds long. Maximum allowed is 10 seconds.<br>
61
+ ❌ エラー: 音声が{duration_seconds:.1f}秒です。最大10秒まで許可されています。
62
+ </div>
63
+ """
64
+ return False, error_msg
65
+ elif duration_seconds > 8:
66
+ warning_msg = f"""
67
+ <div class="audio-warning">
68
+ ⚠️ Warning: Your audio is {duration_seconds:.1f} seconds, it will eat up precious context and may result in poor generation.<br>
69
+ ⚠️ 警告: 音声が{duration_seconds:.1f}秒を超えています。貴重なコンテキストを消費し、生成品質が低下する可能性があります。
70
+ </div>
71
+ """
72
+ return True, warning_msg
73
+ else:
74
+ return True, ""
75
+
76
  def load_examples(csv_path):
77
  examples = []
78
 
 
137
  top_k,
138
  temperature,
139
  use_chained_longform,
140
+ seed, # Add seed parameter
141
+ audio_warning_display
142
  ):
143
  try:
144
+ # Validate audio duration first
145
+ is_valid, warning_msg = validate_audio_duration(audio_prompt)
146
+
147
+ if not is_valid:
148
+ # Return error without processing
149
+ return None, "Status: Audio too long. Please use audio under 10 seconds."
150
+
151
  # Handle audio prompt - save to temporary file if provided
152
  audio_prompt_for_api = None
153
 
 
223
  add_steps,
224
  use_duration_aware,
225
  chars_per_second,
226
+ seed, # Add seed parameter
227
+ audio_warning_display_dur
228
  ):
229
  try:
230
+ # Validate audio duration first
231
+ is_valid, warning_msg = validate_audio_duration(audio_prompt)
232
+
233
+ if not is_valid:
234
+ # Return error without processing
235
+ return None, "Status: Audio too long. Please use audio under 10 seconds."
236
+
237
  # Handle audio prompt - save to temporary file if provided
238
  audio_prompt_for_api = None
239
 
 
300
  except Exception as e:
301
  return None, f"Status: Connection error: {str(e)}"
302
 
303
+ # Audio validation callback
304
+ def on_audio_upload(audio_data):
305
+ """Validate audio when uploaded and return warning message"""
306
+ is_valid, warning_msg = validate_audio_duration(audio_data)
307
+ if not is_valid:
308
+ # Clear the audio input if it's too long
309
+ return None, warning_msg
310
+ return audio_data, warning_msg
311
+
312
  # Load examples
313
  examples_csv_path = "./samples.csv" # Adjust path as needed for client side
314
  example_list = load_examples(examples_csv_path)
 
374
  value=False
375
  )
376
  audio_prompt_input = gr.Audio(
377
+ label="Audio Prompt (Optional - オプション) [Max 10 seconds / 最大10秒]",
378
  sources=["upload", "microphone"],
379
  type="numpy"
380
  )
381
+ # Warning display for audio duration
382
+ audio_warning_display = gr.HTML(value="", visible=True)
383
+
384
+ # Audio validation on change
385
+ audio_prompt_input.change(
386
+ fn=on_audio_upload,
387
+ inputs=[audio_prompt_input],
388
+ outputs=[audio_prompt_input, audio_warning_display]
389
+ )
390
 
391
  # Turbo mode event handler
392
  def toggle_turbo(turbo_enabled):
 
408
  status_output = gr.Textbox(label="Status", interactive=False)
409
  audio_output = gr.Audio(label="Generated Speech", interactive=False, show_download_button=True)
410
 
411
+ # Event handler - pass the warning display as a dummy input
412
  generate_button.click(
413
  fn=run_generation_pipeline_client,
414
  inputs=[
 
419
  top_k_slider,
420
  temperature_slider,
421
  chained_longform_checkbox,
422
+ seed_slider, # Add seed slider to inputs
423
+ audio_warning_display # Pass as dummy input
424
  ],
425
  outputs=[audio_output, status_output],
426
  concurrency_limit=4 # Limit concurrent requests
 
511
  value=False
512
  )
513
  audio_prompt_input_dur = gr.Audio(
514
+ label="Audio Prompt (Optional - オプション) [Max 10 seconds / 最大10秒]",
515
  sources=["upload", "microphone"],
516
  type="numpy"
517
  )
518
+ # Warning display for audio duration
519
+ audio_warning_display_dur = gr.HTML(value="", visible=True)
520
+
521
+ # Audio validation on change
522
+ audio_prompt_input_dur.change(
523
+ fn=on_audio_upload,
524
+ inputs=[audio_prompt_input_dur],
525
+ outputs=[audio_prompt_input_dur, audio_warning_display_dur]
526
+ )
527
 
528
  # Turbo mode event handler for duration tab
529
  def toggle_turbo_dur(turbo_enabled):
 
545
  status_output_dur = gr.Textbox(label="Status", interactive=False)
546
  audio_output_dur = gr.Audio(label="Generated Speech", interactive=False, show_download_button=True)
547
 
548
+ # Event handler for duration tab - pass the warning display as a dummy input
549
  generate_button_dur.click(
550
  fn=run_duration_generation_pipeline_client,
551
  inputs=[
 
559
  add_steps_slider,
560
  use_duration_aware_checkbox,
561
  chars_per_second_slider,
562
+ seed_slider_dur, # Add seed slider to inputs
563
+ audio_warning_display_dur # Pass as dummy input
564
  ],
565
  outputs=[audio_output_dur, status_output_dur],
566
  concurrency_limit=4 # Limit concurrent requests