Spaces:
Running
on
Zero
Running
on
Zero
| from app.logger_config import logger as logging | |
| import numpy as np | |
| import gradio as gr | |
| import asyncio | |
| from fastrtc.webrtc import WebRTC | |
| from fastrtc.utils import AdditionalOutputs | |
| from pydub import AudioSegment | |
| import time | |
| import os | |
| import json | |
| import spaces | |
| from app.utils import generate_coturn_config,raise_function | |
| from app.old_session_utils import ( | |
| TMP_DIR, | |
| generate_session_id, | |
| register_session, | |
| unregister_session, | |
| get_active_sessions, | |
| stop_file_path, | |
| create_stop_flag, | |
| clear_stop_flag, | |
| reset_all_active_sessions, | |
| on_load, | |
| on_unload | |
| ) | |
| # Reset sessions at startup | |
| reset_all_active_sessions() | |
| EXAMPLE_FILES = ["data/bonjour.wav", "data/bonjour2.wav"] | |
| DEFAULT_FILE = EXAMPLE_FILES[0] | |
| # -------------------------------------------------------- | |
| # STREAMING | |
| # -------------------------------------------------------- | |
| def read_and_stream_audio(filepath_to_stream: str, session_id: str, chunk_seconds: float): | |
| """Stream audio chunks and save .npz files only when transcription is active.""" | |
| stop_file = os.path.join(TMP_DIR, f"stream_stop_flag_{session_id}.txt") | |
| transcribe_flag = os.path.join(TMP_DIR, f"transcribe_active_{session_id}.txt") | |
| logging.debug(f"[{session_id}] read_and_stream_audio() started with file: {filepath_to_stream}") | |
| try: | |
| if not filepath_to_stream or not os.path.exists(filepath_to_stream): | |
| logging.error(f"[{session_id}] Audio file not found: {filepath_to_stream}") | |
| raise f"Audio file not found: {filepath_to_stream}" | |
| clear_stop_flag(session_id) | |
| register_session(session_id, filepath_to_stream) | |
| progress_path = os.path.join(TMP_DIR, f"progress_{session_id}.json") | |
| segment = AudioSegment.from_file(filepath_to_stream) | |
| chunk_ms = int(chunk_seconds * 1000) | |
| total_chunks = len(segment) // chunk_ms + 1 | |
| logging.info(f"[{session_id}] Streaming {total_chunks} chunks ({chunk_seconds:.2f}s each)...") | |
| for i, chunk in enumerate(segment[::chunk_ms], start=1): | |
| if os.path.exists(stop_file): | |
| logging.info(f"[{session_id}] Stop flag detected at chunk {i}. Ending stream.") | |
| clear_stop_flag(session_id) | |
| break | |
| logging.info(f"[{session_id}] Streaming chunk {i}.") | |
| iter_start = time.perf_counter() | |
| elapsed_s = i * chunk_seconds | |
| hours, remainder = divmod(int(elapsed_s), 3600) | |
| minutes, seconds = divmod(remainder, 60) | |
| elapsed_str = f"{hours:02d}:{minutes:02d}:{seconds:02d}" | |
| percent = round((i / total_chunks) * 100, 2) | |
| progress_data = {"value": percent, "elapsed": elapsed_str} | |
| with open(progress_path, "w") as f: | |
| json.dump(progress_data, f) | |
| chunk_array = np.array(chunk.get_array_of_samples(), dtype=np.int16) | |
| rate = chunk.frame_rate | |
| # Save only if transcription is active | |
| if os.path.exists(transcribe_flag) : | |
| chunk_dir = os.path.join(TMP_DIR, f"chunks_{session_id}") | |
| if not os.path.exists(chunk_dir) : | |
| os.makedirs(chunk_dir, exist_ok=True) | |
| npz_path = os.path.join(chunk_dir, f"chunk_{i:05d}.npz") | |
| np.savez_compressed(npz_path, data=chunk_array, rate=rate) | |
| logging.debug(f"[{session_id}] Saved chunk {i}/{total_chunks} (transcribe active)") | |
| # Stream audio to client | |
| # yield (rate, chunk_array.reshape(1, -1)) | |
| msg = f"Chunk {i}/{total_chunks}" | |
| yield ( (rate, chunk_array.reshape(1, -1)), AdditionalOutputs(msg) ) | |
| process_ms = (time.perf_counter() - iter_start) * 1000 | |
| # time.sleep(max(chunk_seconds - (process_ms / 1000.0) - 0.1, 0.01)) | |
| time.sleep(chunk_seconds) | |
| raise_function() | |
| logging.info(f"[{session_id}] Streaming completed successfully.") | |
| except Exception as e: | |
| logging.error(f"[{session_id}] Stream error: {e}", exc_info=True) | |
| finally: | |
| unregister_session(session_id) | |
| clear_stop_flag(session_id) | |
| if os.path.exists(progress_path): | |
| os.remove(progress_path) | |
| yield (None, AdditionalOutputs("STREAM_DONE")) | |
| # -------------------------------------------------------- | |
| # TRANSCRIPTION | |
| # -------------------------------------------------------- | |
| def transcribe(session_id: str): | |
| """Continuously read and delete .npz chunks while transcription is active.""" | |
| active_flag = os.path.join(TMP_DIR, f"transcribe_active_{session_id}.txt") | |
| with open(active_flag, "w") as f: | |
| f.write("1") | |
| logging.info(f"[{session_id}] Transcription started.") | |
| chunk_dir = os.path.join(TMP_DIR, f"chunks_{session_id}") | |
| try: | |
| logging.info(f"[{session_id}] Transcription loop started.") | |
| while os.path.exists(active_flag): | |
| if not os.path.exists(chunk_dir): | |
| logging.warning(f"[{session_id}] No chunk directory found for transcription.") | |
| time.sleep(0.25) | |
| continue | |
| files = sorted(f for f in os.listdir(chunk_dir) if f.endswith(".npz")) | |
| if not files: | |
| time.sleep(0.25) | |
| continue | |
| for fname in files: | |
| fpath = os.path.join(chunk_dir, fname) | |
| try: | |
| npz = np.load(fpath) | |
| samples = npz["data"] | |
| rate = int(npz["rate"]) | |
| text = f"Transcribed {fname}: {len(samples)} samples @ {rate}Hz" | |
| logging.debug(f"[{session_id}] {text}") | |
| os.remove(fpath) | |
| logging.debug(f"[{session_id}] Deleted processed chunk: {fname}") | |
| except Exception as e: | |
| logging.error(f"[{session_id}] Error processing {fname}: {e}") | |
| continue | |
| time.sleep(0.25) | |
| raise_function() | |
| logging.info(f"[{session_id}] Transcription loop ended (flag removed).") | |
| except Exception as e: | |
| logging.error(f"[{session_id}] Transcription error: {e}", exc_info=True) | |
| finally: | |
| transcribe_active = os.path.join(TMP_DIR, f"transcribe_active_{session_id}.txt") | |
| if os.path.exists(transcribe_active): | |
| os.remove(transcribe_active) | |
| logging.info(f"[{session_id}] Transcription stopped.") | |
| try: | |
| if os.path.exists(chunk_dir) and not os.listdir(chunk_dir): | |
| os.rmdir(chunk_dir) | |
| logging.debug(f"[{session_id}] Cleaned up empty chunk dir.") | |
| except Exception as e: | |
| logging.error(f"[{session_id}] Cleanup error: {e}") | |
| logging.info(f"[{session_id}] Exiting transcription loop.") | |
| return { | |
| start_transcribe: gr.update(interactive=True), | |
| stop_transcribe: gr.update(interactive=False), | |
| progress_text: gr.update(value="π Transcription stopped."), | |
| } | |
| # -------------------------------------------------------- | |
| # STOP STREAMING | |
| # -------------------------------------------------------- | |
| # def stop_streaming(session_id: str): | |
| # create_stop_flag(session_id) | |
| # logging.info(f"[{session_id}] Stop button clicked β stop flag created.") | |
| # return None | |
| def get_session_progress(session_id: str): | |
| """Read streaming progress and return slider position + elapsed time.""" | |
| progress_path = os.path.join(TMP_DIR, f"progress_{session_id}.json") | |
| if not os.path.exists(progress_path): | |
| return 0.0, "00:00:00" | |
| try: | |
| with open(progress_path, "r") as f: | |
| data = json.load(f) | |
| value = data.get("value", 0.0) | |
| elapsed = data.get("elapsed", "00:00:00") | |
| return value, elapsed | |
| except Exception: | |
| return 0.0, "00:00:00" | |
| def handle_additional_outputs(message): | |
| """Called each time a new AdditionalOutputs is received.""" | |
| logging.debug(f"π‘ Additional output received: {message}") | |
| if message == "STREAM_DONE": | |
| return "β Streaming finished" | |
| elif message: | |
| return f"π‘ {message}" | |
| else: | |
| return "" | |
| # -------------------------------------------------------- | |
| # UI | |
| # -------------------------------------------------------- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| "## π§ WebRTC Audio Streamer (Multi-user)\n" | |
| "Each user controls their own stream. Transcription runs only during streaming." | |
| ) | |
| session_id = gr.State() | |
| sid_box = gr.Textbox(label="Session ID", interactive=False) | |
| demo.load(fn=on_load, inputs=None, outputs=[session_id, sid_box]) | |
| demo.unload(on_unload) | |
| active_filepath = gr.State(value=DEFAULT_FILE) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(elem_id="column_source", scale=1): | |
| with gr.Group(elem_id="centered_content"): | |
| main_audio = gr.Audio( | |
| label="Audio Source", | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| value=DEFAULT_FILE, | |
| ) | |
| chunk_slider = gr.Slider( | |
| label="Chunk Duration (seconds)", | |
| minimum=0.5, | |
| maximum=5.0, | |
| value=1.0, | |
| step=0.5, | |
| interactive=True, | |
| ) | |
| progress_bar = gr.Slider( | |
| label="Streaming Progress (%)", | |
| minimum=0, | |
| maximum=100, | |
| value=0, | |
| step=0.1, | |
| interactive=False, | |
| visible=False, | |
| ) | |
| progress_text = gr.Textbox( | |
| label="Elapsed Time (hh:mm:ss)", | |
| interactive=False, | |
| visible=False, | |
| ) | |
| with gr.Row(): | |
| start_button = gr.Button("βΆοΈ Start Streaming", variant="primary") | |
| stop_button = gr.Button("βΉοΈ Stop Streaming", variant="stop", interactive=False) | |
| with gr.Column(): | |
| status_box = gr.Textbox(label="Status", interactive=False) | |
| webrtc_stream = WebRTC( | |
| label="Audio Stream", | |
| mode="receive", | |
| modality="audio", | |
| rtc_configuration=generate_coturn_config(), | |
| visible=True, | |
| ) | |
| # --- Transcription Controls --- | |
| with gr.Row(equal_height=True): | |
| with gr.Column(): | |
| start_transcribe = gr.Button("ποΈ Start Transcribe", interactive=False) | |
| stop_transcribe = gr.Button("π Stop Transcribe", interactive=False) | |
| # --- UI Logic --- | |
| def start_streaming(session_id): | |
| return { | |
| start_button: gr.update(interactive=False), | |
| stop_button: gr.update(interactive=True), | |
| start_transcribe: gr.update(interactive=True), | |
| stop_transcribe: gr.update(interactive=False), | |
| chunk_slider: gr.update(interactive=False), | |
| main_audio: gr.update(visible=False), | |
| progress_bar: gr.update(value=0, visible=True), | |
| progress_text: gr.update(value="00:00:00", visible=True), | |
| } | |
| def stop_streaming(session_id): | |
| logging.debug(f"[{session_id}] UI: Stop clicked β restoring controls.") | |
| create_stop_flag(session_id) | |
| return { | |
| webrtc_stream : None, | |
| start_button: gr.update(interactive=True), | |
| stop_button: gr.update(interactive=False), | |
| start_transcribe: gr.update(interactive=False), | |
| stop_transcribe: gr.update(interactive=False), | |
| chunk_slider: gr.update(interactive=True), | |
| main_audio: gr.update(visible=True), | |
| progress_bar: gr.update(value=0, visible=False), | |
| progress_text: gr.update(value="00:00:00", visible=False), | |
| } | |
| ui_components = [ | |
| start_button, stop_button, start_transcribe, stop_transcribe, | |
| chunk_slider, main_audio, progress_bar, progress_text, | |
| ] | |
| # --- Streaming event --- | |
| webrtc_stream.stream( | |
| fn=read_and_stream_audio, | |
| inputs=[active_filepath, session_id, chunk_slider], | |
| outputs=[webrtc_stream ], | |
| trigger=start_button.click, | |
| concurrency_limit=20, | |
| concurrency_id="receive", | |
| ) | |
| webrtc_stream.on_additional_outputs( | |
| fn=handle_additional_outputs, | |
| outputs=[status_box], | |
| ) | |
| # status_box.change( | |
| # fn=update_status, | |
| # inputs=[status_box], | |
| # outputs=[status_box], | |
| # ) | |
| # .then( | |
| # fn=stop_streaming, | |
| # inputs=[session_id], | |
| # outputs=ui_components | |
| # ) | |
| start_button.click(fn=start_streaming, inputs=[session_id], outputs=ui_components) | |
| # .then(fn=stop_streaming, inputs=[session_id], outputs=[webrtc_stream] + ui_components) | |
| stop_button.click(fn=stop_streaming, inputs=[session_id], outputs=[webrtc_stream] + ui_components) | |
| # --- Transcription control logic --- | |
| def start_transcribe_ui(session_id: str): | |
| """Create transcription flag and update UI.""" | |
| return { | |
| start_transcribe: gr.update(interactive=False), | |
| stop_transcribe: gr.update(interactive=True), | |
| progress_text: gr.update(value="ποΈ Transcription started..."), | |
| } | |
| def stop_transcribe_ui(session_id: str): | |
| """Stop transcription by removing flag and update UI.""" | |
| transcribe_active = os.path.join(TMP_DIR, f"transcribe_active_{session_id}.txt") | |
| if os.path.exists(transcribe_active): | |
| os.remove(transcribe_active) | |
| return { | |
| start_transcribe: gr.update(interactive=True), | |
| stop_transcribe: gr.update(interactive=False), | |
| progress_text: gr.update(value="π Transcription stopped."), | |
| } | |
| start_transcribe.click( | |
| fn=start_transcribe_ui, | |
| inputs=[session_id], | |
| outputs=[start_transcribe, stop_transcribe, progress_text], | |
| # --- then chain the transcription process --- | |
| ).then( | |
| fn=transcribe, | |
| inputs=[session_id], | |
| outputs=[start_transcribe, stop_transcribe, progress_text], | |
| ) | |
| stop_transcribe.click( | |
| fn=stop_transcribe_ui, | |
| inputs=[session_id], | |
| outputs=[start_transcribe, stop_transcribe, progress_text], | |
| ) | |
| # --- Active sessions --- | |
| with gr.Accordion("π Active Sessions", open=False): | |
| sessions_table = gr.DataFrame( | |
| headers=["session_id", "file", "start_time", "status"], | |
| interactive=False, | |
| wrap=True, | |
| max_height=200, | |
| ) | |
| gr.Timer(3.0).tick(fn=get_active_sessions, outputs=sessions_table) | |
| gr.Timer(1.0).tick(fn=get_session_progress, inputs=[session_id], outputs=[progress_bar, progress_text]) | |
| # -------------------------------------------------------- | |
| # CSS | |
| # -------------------------------------------------------- | |
| custom_css = """ | |
| #column_source { | |
| display: flex; | |
| flex-direction: column; | |
| justify-content: center; | |
| align-items: center; | |
| gap: 1rem; | |
| margin-top: auto; | |
| margin-bottom: auto; | |
| } | |
| #column_source .gr-row { | |
| padding-top: 12px; | |
| padding-bottom: 12px; | |
| } | |
| """ | |
| demo.css = custom_css | |
| # -------------------------------------------------------- | |
| # MAIN | |
| # -------------------------------------------------------- | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20, api_open=False).launch(show_api=False, debug=True) | |