Spaces:
Running
on
Zero
Running
on
Zero
| from app.logger_config import logger as logging | |
| import numpy as np | |
| import gradio as gr | |
| import asyncio | |
| from fastrtc.webrtc import WebRTC | |
| from fastrtc.utils import AdditionalOutputs | |
| from pydub import AudioSegment | |
| import time | |
| import os | |
| from gradio.utils import get_space | |
| from app.utils import ( | |
| generate_coturn_config, | |
| raise_function | |
| ) | |
| from app.new_session_utils import ( | |
| on_load, | |
| on_unload, | |
| get_active_sessions, | |
| reset_all_active_sessions, | |
| ) | |
| # -------------------------------------------------------- | |
| # Initialization | |
| # -------------------------------------------------------- | |
| reset_all_active_sessions() | |
| EXAMPLE_FILES = ["data/bonjour.wav", "data/bonjour2.wav"] | |
| DEFAULT_FILE = EXAMPLE_FILES[0] | |
| # -------------------------------------------------------- | |
| # Utility functions | |
| # -------------------------------------------------------- | |
| def _is_stop_requested(stop_streaming_flags: dict) -> bool: | |
| """Check if the stop signal was requested.""" | |
| if not isinstance(stop_streaming_flags, dict): | |
| return False | |
| return bool(stop_streaming_flags.get("stop", False)) | |
| def handle_stream_error(session_id: str, error: Exception | str, stop_streaming_flags: dict | None = None): | |
| """ | |
| Handle streaming errors: | |
| - Log the error | |
| - Send structured info to client | |
| - Reset stop flag | |
| """ | |
| if isinstance(error, Exception): | |
| msg = f"{type(error).__name__}: {str(error)}" | |
| else: | |
| msg = str(error) | |
| logging.error(f"[{session_id}] Streaming error: {msg}", exc_info=isinstance(error, Exception)) | |
| if isinstance(stop_streaming_flags, dict): | |
| stop_streaming_flags["stop"] = False | |
| yield (None, AdditionalOutputs({"error": True, "message": msg})) | |
| yield (None, AdditionalOutputs("STREAM_DONE")) | |
| def read_and_stream_audio(filepath_to_stream: str, session_id: str, stop_streaming_flags: dict): | |
| """ | |
| Read an audio file and stream it chunk by chunk (1s per chunk). | |
| Handles errors safely and reports structured messages to the client. | |
| """ | |
| if not session_id: | |
| yield from handle_stream_error("unknown", "No session_id provided.", stop_streaming_flags) | |
| return | |
| if not filepath_to_stream or not os.path.exists(filepath_to_stream): | |
| yield from handle_stream_error(session_id, f"Audio file not found: {filepath_to_stream}", stop_streaming_flags) | |
| return | |
| try: | |
| segment = AudioSegment.from_file(filepath_to_stream) | |
| chunk_duration_ms = 1000 | |
| total_chunks = len(segment) // chunk_duration_ms + 1 | |
| logging.info(f"[{session_id}] Starting audio streaming ({total_chunks} chunks).") | |
| for i, chunk in enumerate(segment[::chunk_duration_ms]): | |
| if _is_stop_requested(stop_streaming_flags): | |
| logging.info(f"[{session_id}] Stop signal received. Terminating stream.") | |
| break | |
| frame_rate = chunk.frame_rate | |
| samples = np.array(chunk.get_array_of_samples()).reshape(1, -1) | |
| progress = round(((i + 1) / total_chunks) * 100, 2) | |
| yield ((frame_rate, samples), AdditionalOutputs(progress)) | |
| logging.debug(f"[{session_id}] Sent chunk {i+1}/{total_chunks} ({progress}%).") | |
| time.sleep(0.9) | |
| raise_function() # Optional injected test exception | |
| logging.info(f"[{session_id}] Audio streaming completed successfully.") | |
| except asyncio.CancelledError: | |
| yield from handle_stream_error(session_id, "Streaming cancelled by user.", stop_streaming_flags) | |
| except FileNotFoundError as e: | |
| yield from handle_stream_error(session_id, e, stop_streaming_flags) | |
| except Exception as e: | |
| yield from handle_stream_error(session_id, e, stop_streaming_flags) | |
| finally: | |
| if isinstance(stop_streaming_flags, dict): | |
| stop_streaming_flags["stop"] = False | |
| logging.info(f"[{session_id}] Stop flag reset.") | |
| yield (None, AdditionalOutputs("STREAM_DONE")) | |
| def stop_streaming(session_id: str, stop_streaming_flags: dict): | |
| """Trigger the stop flag for active streaming.""" | |
| logging.info(f"[{session_id}] Stop button clicked β sending stop signal.") | |
| if not isinstance(stop_streaming_flags, dict): | |
| stop_streaming_flags = {"stop": True} | |
| else: | |
| stop_streaming_flags["stop"] = True | |
| return stop_streaming_flags | |
| def handle_additional_outputs(start_button, stop_button, main_audio, status_slider, progress_value): | |
| """ | |
| Update UI elements based on streaming progress or errors. | |
| Controls button states, audio visibility, and progress slider. | |
| """ | |
| logging.debug(f"Additional output received: {progress_value}") | |
| # Handle structured error message | |
| if isinstance(progress_value, dict) and progress_value.get("error"): | |
| msg = progress_value.get("message", "Unknown error.") | |
| logging.error(f"[stream_ui] Client-side error: {msg}") | |
| return ( | |
| gr.update(interactive=True), # start_button enabled | |
| gr.update(interactive=False), # stop_button disabled | |
| gr.update(visible=True), # audio re-shown | |
| gr.update(visible=False, value=0), # slider hidden | |
| ) | |
| try: | |
| progress = float(progress_value) | |
| except (ValueError, TypeError): | |
| progress = 0 | |
| # --- Stream not started --- | |
| if progress <= 0: | |
| return ( | |
| gr.update(interactive=True), # start_button enabled | |
| gr.update(interactive=False), # stop_button disabled | |
| gr.update(visible=True), # audio visible | |
| gr.update(visible=False, value=0), # slider hidden | |
| ) | |
| # --- Stream finished --- | |
| if progress >= 100: | |
| return ( | |
| gr.update(interactive=True), # start_button re-enabled | |
| gr.update(interactive=False), # stop_button disabled | |
| gr.update(visible=True), # audio visible | |
| gr.update(visible=False, value=100), # slider hidden | |
| ) | |
| # --- Stream in progress --- | |
| return ( | |
| gr.update(interactive=False), # start_button disabled | |
| gr.update(interactive=True), # stop_button enabled | |
| gr.update(visible=False), # hide audio | |
| gr.update(visible=True, value=progress), # show progress | |
| ) | |
| # -------------------------------------------------------- | |
| # Gradio Interface | |
| # -------------------------------------------------------- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| session_hash = gr.State() | |
| session_hash_box = gr.Textbox(label="Session ID", interactive=False) | |
| demo.load(fn=on_load, inputs=None, outputs=[session_hash, session_hash_box]) | |
| demo.unload(on_unload) | |
| stop_streaming_flags = gr.State(value={"stop": False}) | |
| gr.Markdown( | |
| "## WebRTC Audio Streamer (Server β Client)\n" | |
| "Upload or record an audio file, then click **Start** to listen to the streamed audio." | |
| ) | |
| active_filepath = gr.State(value=DEFAULT_FILE) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(elem_id="column_source", scale=1): | |
| with gr.Group(elem_id="centered_content"): | |
| main_audio = gr.Audio( | |
| label="Audio File", | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| value=DEFAULT_FILE, | |
| ) | |
| status_slider = gr.Slider( | |
| 0, 100, value=0, label="Streaming Progress", interactive=False, visible=False | |
| ) | |
| with gr.Column(): | |
| webrtc_stream = WebRTC( | |
| label="Live", | |
| mode="receive", | |
| modality="audio", | |
| rtc_configuration=generate_coturn_config(), | |
| visible=True, | |
| height=200, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| start_button = gr.Button("Start Streaming", variant="primary") | |
| stop_button = gr.Button("Stop Streaming", variant="stop", interactive=False) | |
| def set_new_file(filepath): | |
| """Update active audio path or reset to default if empty.""" | |
| if filepath is None: | |
| logging.info("[ui] Audio cleared β reverting to default example file.") | |
| new_path = DEFAULT_FILE | |
| else: | |
| logging.info(f"[ui] New audio source selected: {filepath}") | |
| new_path = filepath | |
| return new_path | |
| main_audio.change(fn=set_new_file, inputs=[main_audio], outputs=[active_filepath]) | |
| main_audio.stop_recording(fn=set_new_file, inputs=[main_audio], outputs=[active_filepath]) | |
| ui_components = [start_button, stop_button, main_audio, status_slider] | |
| stream_event = webrtc_stream.stream( | |
| fn=read_and_stream_audio, | |
| inputs=[active_filepath, session_hash, stop_streaming_flags], | |
| outputs=[webrtc_stream], | |
| trigger=start_button.click, | |
| concurrency_id="audio_stream", | |
| concurrency_limit=10, | |
| ) | |
| webrtc_stream.on_additional_outputs( | |
| fn=handle_additional_outputs, | |
| inputs=ui_components, | |
| outputs=ui_components, | |
| concurrency_id="additional_outputs_audio_stream", | |
| concurrency_limit=10, | |
| ) | |
| start_button.click(fn=None, inputs=None, outputs=None) | |
| stop_button.click( | |
| fn=stop_streaming, | |
| inputs=[session_hash, stop_streaming_flags], | |
| outputs=[stop_streaming_flags], | |
| ) | |
| with gr.Accordion("π Active Sessions", open=False): | |
| sessions_table = gr.DataFrame( | |
| headers=["session_hash", "file", "start_time", "status"], | |
| interactive=False, | |
| wrap=True, | |
| max_height=200, | |
| ) | |
| gr.Timer(3.0).tick(fn=get_active_sessions, outputs=sessions_table) | |
| # -------------------------------------------------------- | |
| # Custom CSS | |
| # -------------------------------------------------------- | |
| custom_css = """ | |
| #column_source { | |
| display: flex; | |
| flex-direction: column; | |
| justify-content: center; | |
| align-items: center; | |
| gap: 1rem; | |
| margin-top: auto; | |
| margin-bottom: auto; | |
| } | |
| #column_source .gr-row { | |
| padding-top: 12px; | |
| padding-bottom: 12px; | |
| } | |
| """ | |
| demo.css = custom_css | |
| # -------------------------------------------------------- | |
| # Main | |
| # -------------------------------------------------------- | |
| if __name__ == "__main__": | |
| demo.queue(max_size=10, api_open=False).launch(show_api=False, debug=True) | |