Spaces:
Running
on
Zero
Running
on
Zero
| from app.logger_config import ( | |
| logger as logging, | |
| DEBUG | |
| ) | |
| import numpy as np | |
| import gradio as gr | |
| from fastrtc.webrtc import WebRTC | |
| from fastrtc.utils import AdditionalOutputs | |
| from pydub import AudioSegment | |
| from gradio.utils import get_space | |
| from app.supported_languages import ( | |
| SUPPORTED_LANGS_MAP, | |
| ) | |
| from app.ui_utils import ( | |
| EXAMPLE_CONFIGS, | |
| apply_preset_if_example, | |
| reset_to_defaults, | |
| summarize_config, | |
| handle_additional_outputs, | |
| get_custom_theme, | |
| on_file_load, | |
| start_task_asr_ast, | |
| stop_task_fn | |
| ) | |
| from app.utils import ( | |
| READ_SIZE, | |
| generate_coturn_config, | |
| read_and_stream_audio, | |
| stop_streaming, | |
| raise_error | |
| ) | |
| from app.session_utils import ( | |
| on_load, | |
| on_unload, | |
| get_active_session_hashes, | |
| reset_all_active_sessions, | |
| ) | |
| import spaces | |
| # -------------------------------------------------------- | |
| # Initialization | |
| # -------------------------------------------------------- | |
| reset_all_active_sessions() | |
| theme,css_style = get_custom_theme() | |
| with gr.Blocks(theme=theme, css=css_style) as demo: | |
| session_hash_code = gr.State() | |
| with gr.Accordion("DEGUG PANEL", open=False, visible=DEBUG): | |
| session_hash_code_box = gr.Textbox(label="Session ID", interactive=False, visible=DEBUG) | |
| with gr.Accordion("📊 Active Sessions Hash", open=True ,visible=DEBUG): | |
| sessions_table = gr.DataFrame( | |
| headers=["session_hash_code", "file", "start_time", "status"], | |
| interactive=False, | |
| wrap=True, | |
| max_height=200, | |
| ) | |
| gr.Timer(3.0).tick(fn=get_active_session_hashes, outputs=sessions_table) | |
| demo.load(fn=on_load, inputs=None, outputs=[session_hash_code, session_hash_code_box]) | |
| demo.unload(fn=on_unload) | |
| active_filepath = gr.State(value=next(iter(EXAMPLE_CONFIGS))) | |
| with gr.Walkthrough(selected=0) as walkthrough: | |
| # === STEP 1 === | |
| with gr.Step("Audio", id=0) as audio_source_step: | |
| gr.Markdown( | |
| """ | |
| ### Step 1: Upload or Record an Audio File | |
| You can upload an existing file or record directly from your microphone. | |
| Accepted formats: **.wav**, **.mp3**, **.flac** | |
| Maximum length recommended: **60 seconds** | |
| """ | |
| ) | |
| with gr.Group(): | |
| with gr.Column(): | |
| main_audio = gr.Audio( | |
| label="Audio Input", | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| interactive=True | |
| ) | |
| with gr.Accordion("Need a quick test? Try one of the sample audios below", open=True): | |
| examples = gr.Examples( | |
| examples=list(EXAMPLE_CONFIGS.keys()), | |
| inputs=main_audio, | |
| label=None, | |
| examples_per_page=3 | |
| ) | |
| gr.Markdown( | |
| """ | |
| 🔹 **english_meeting.wav** – Short business meeting in English | |
| 🔹 **french_news.wav** – Excerpt from a French radio broadcast | |
| 🔹 **spanish_podcast.wav** – Segment from a Spanish-language podcast | |
| """ | |
| ) | |
| go_to_config = gr.Button("Go to Configuration", visible=False) | |
| ui_components_oload_audio = [active_filepath, go_to_config] | |
| main_audio.change(fn=on_file_load, inputs=[main_audio], outputs=ui_components_oload_audio) | |
| go_to_config.click(lambda: gr.Walkthrough(selected=1), outputs=walkthrough) | |
| # === STEP 2 === | |
| with gr.Step("Configuration", id=1)as config_step: | |
| gr.Markdown("### Step 3: Configure the Task") | |
| with gr.Group(): | |
| with gr.Row(): | |
| task_type = gr.Radio(["Transcription", "Translation"], value="Transcription", label="Task Type") | |
| with gr.Row(): | |
| lang_source = gr.Dropdown(list(SUPPORTED_LANGS_MAP.keys()), value="French", label="Source Language") | |
| lang_target = gr.Dropdown(list(SUPPORTED_LANGS_MAP.keys()), value="English", label="Target Language", visible=False) | |
| with gr.Accordion("Advanced Configuration", open=False): | |
| with gr.Group(): | |
| with gr.Row(): | |
| gr.Markdown("##### Chunks ") | |
| with gr.Row(): | |
| left_context_secs = gr.Slider(value=20.0, label="left_context_secs",info="Streaming chunk duration in seconds (left context)", minimum=1.0, maximum=60.0, step=1.0, show_reset_button=False) | |
| chunk_secs = gr.Slider(value=1.0, label="chunk_secs", info="Streaming chunk duration in seconds (chunk)", minimum=0.1, maximum=5.0, step=0.1, show_reset_button=False) | |
| right_context_secs = gr.Slider(value=0.5, label="right_context_secs", info="Streaming chunk duration in seconds (right context)", minimum=0.1, maximum=10.0, step=0.1, show_reset_button=False) | |
| gr.Markdown("---") | |
| with gr.Group(): | |
| with gr.Row(): | |
| gr.Markdown("##### Decoding ") | |
| with gr.Row(): | |
| streaming_policy = gr.Dropdown(["waitk", "alignatt"], value="alignatt", label="streaming_policy", elem_classes="full-width", | |
| info="“Wait-k: Higher accuracy, requires larger left context, higher latency” \n”AlignAtt: Lower latency, suitable for production, predicts multiple tokens per chunk”") | |
| with gr.Row(): | |
| alignatt_thr = gr.Number(value=8, label="alignatt_thr", info="Cross-attention threshold for AlignAtt policy (default: 8), alignatt only", precision=0) | |
| waitk_lagging = gr.Number(value=2, label="waitk_lagging", info="Number of chunks to wait in the beginning (default: 2), works for both policies", precision=0) | |
| with gr.Row(): | |
| exclude_sink_frames = gr.Number(value=8, label="exclude_sink_frames", info="Number of frames to exclude from the xatt scores calculation (default: 8), alignatt only", precision=0) | |
| xatt_scores_layer = gr.Number(value=-2, label="xatt_scores_layer", info="Layer to get cross-attention (xatt) scores from (default: -2), alignatt only", precision=0) | |
| with gr.Row(): | |
| hallucinations_detector = gr.Checkbox(value=True, label="hallucinations_detector" , info="Detect hallucinations in the predicted tokens (default: True), works for both policies" ) | |
| with gr.Row(): | |
| auto_apply_presets = gr.Checkbox(value=True, label="Auto-apply presets for sample audios") | |
| reset_btn = gr.Button("Reset to defaults") | |
| with gr.Accordion("Configuration Summary", open=False): | |
| summary_box = gr.Textbox(lines=15, interactive=False,show_label=False) | |
| # --- Events --- | |
| task_type.change( | |
| fn=lambda t: gr.update(visible=(t == "Translation")), | |
| inputs=task_type, | |
| outputs=lang_target, | |
| queue=False | |
| ) | |
| inputs_list = [ | |
| task_type, lang_source, lang_target, | |
| chunk_secs, left_context_secs, right_context_secs, | |
| streaming_policy, alignatt_thr, waitk_lagging, | |
| exclude_sink_frames, xatt_scores_layer, hallucinations_detector | |
| ] | |
| for inp in inputs_list: | |
| inp.change( | |
| fn=summarize_config, | |
| inputs=inputs_list, | |
| outputs=summary_box, | |
| queue=False | |
| ) | |
| # Apply preset or not | |
| main_audio.change( | |
| fn=apply_preset_if_example, | |
| inputs=[main_audio, auto_apply_presets], | |
| outputs=[ | |
| task_type, lang_source, lang_target, | |
| chunk_secs, left_context_secs, right_context_secs, | |
| streaming_policy, alignatt_thr, waitk_lagging, | |
| exclude_sink_frames, xatt_scores_layer, hallucinations_detector, | |
| summary_box | |
| ], | |
| queue=False | |
| ) | |
| # Reset defaults | |
| reset_btn.click( | |
| fn=reset_to_defaults, | |
| inputs=None, | |
| outputs=[ | |
| task_type, lang_source, lang_target, | |
| chunk_secs, left_context_secs, right_context_secs, | |
| streaming_policy, alignatt_thr, waitk_lagging, | |
| exclude_sink_frames, xatt_scores_layer, hallucinations_detector, | |
| summary_box | |
| ], | |
| queue=False | |
| ) | |
| go_to_task = gr.Button("Go to Task") | |
| go_to_task.click(lambda: gr.Walkthrough(selected=2), outputs=walkthrough) | |
| # === STEP 3 === | |
| with gr.Step("Task", id=2) as task_step: | |
| with gr.Row(): | |
| gr.Markdown("## Step 4: Start Stream Task") | |
| with gr.Group(): | |
| with gr.Column(): | |
| webrtc_stream = WebRTC( | |
| label="Live Stream", | |
| mode="receive", | |
| modality="audio", | |
| rtc_configuration=generate_coturn_config(), | |
| visible=True, | |
| inputs=main_audio, | |
| icon= "https://cdn-icons-png.flaticon.com/128/18429/18429788.png", | |
| pulse_color= "#df7a7a", | |
| icon_radius= "10px", | |
| icon_button_color= "rgb(255, 255, 255)", | |
| height=150, | |
| show_label=False | |
| ) | |
| status_slider = gr.Slider( | |
| 0, 100, | |
| value=0, | |
| label="Streaming Progress %", | |
| show_label=True, | |
| interactive=False, | |
| visible=False, | |
| show_reset_button=False | |
| ) | |
| start_stream_button = gr.Button("▶️ Start Streaming", variant="primary") | |
| stop_stream_button = gr.Button("⏹️ Stop Streaming", visible=False,variant="stop") | |
| webrtc_stream.stream( | |
| fn=read_and_stream_audio, | |
| inputs=[active_filepath, session_hash_code,gr.State(READ_SIZE)], | |
| outputs=[webrtc_stream], | |
| trigger=start_stream_button.click, | |
| concurrency_id="audio_stream", | |
| concurrency_limit=10, | |
| ) | |
| status_message_stream = gr.Markdown("", elem_id="status-message-stream", visible=False) | |
| with gr.Row(): | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| gr.Markdown("##### Transcription / Translation Result") | |
| with gr.Row(): | |
| task_output = gr.Textbox( | |
| label="Transcription / Translation Result", | |
| show_label=False, | |
| lines=10, | |
| max_lines= 10, | |
| interactive=False, | |
| visible=True, | |
| autoscroll=True, | |
| elem_id="task-output-box" | |
| ) | |
| with gr.Row(): | |
| status_message_task = gr.Markdown("", elem_id="status-message-task",elem_classes=["info"], visible=False) | |
| with gr.Row(): | |
| start_task_button = gr.Button("▶️ Start Task", visible=False, variant="primary") | |
| stop_task_button = gr.Button("⏹️ Stop Task", visible=False,variant="stop") | |
| stop_stream_button.click( | |
| fn=stop_streaming, | |
| inputs=[session_hash_code], | |
| ) | |
| stop_task_button.click( | |
| fn=stop_task_fn, | |
| inputs=session_hash_code, | |
| outputs=task_output | |
| ) | |
| config_task_ui = [session_hash_code,task_type, lang_source, lang_target, | |
| chunk_secs, left_context_secs, right_context_secs, | |
| streaming_policy, alignatt_thr, waitk_lagging, | |
| exclude_sink_frames, xatt_scores_layer, hallucinations_detector] | |
| start_task_button.click( | |
| fn=start_task_asr_ast, | |
| inputs=[ | |
| session_hash_code, | |
| task_type, lang_source, lang_target, | |
| chunk_secs, left_context_secs, right_context_secs, | |
| streaming_policy, alignatt_thr, waitk_lagging, | |
| exclude_sink_frames, xatt_scores_layer, hallucinations_detector | |
| ], | |
| outputs=[task_output,status_message_task,start_task_button,stop_task_button,config_step] | |
| ) | |
| ui_components = [ | |
| start_stream_button, stop_stream_button,start_task_button, | |
| go_to_config, audio_source_step, status_slider,walkthrough,status_message_stream | |
| ] | |
| webrtc_stream.on_additional_outputs( | |
| fn=handle_additional_outputs, | |
| inputs=[webrtc_stream], | |
| outputs=ui_components, | |
| concurrency_id="additional_outputs_audio_stream", | |
| concurrency_limit=10, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=10, api_open=False).launch(show_api=False,show_error=True, debug=DEBUG) |