from app.logger_config import logger as logging import numpy as np import gradio as gr import asyncio from fastrtc.webrtc import WebRTC from fastrtc.utils import AdditionalOutputs from pydub import AudioSegment import time import os import json import spaces from app.utils import generate_coturn_config,raise_function from app.old_session_utils import ( TMP_DIR, generate_session_id, register_session, unregister_session, get_active_sessions, stop_file_path, create_stop_flag, clear_stop_flag, reset_all_active_sessions, on_load, on_unload ) # Reset sessions at startup reset_all_active_sessions() EXAMPLE_FILES = ["data/bonjour.wav", "data/bonjour2.wav"] DEFAULT_FILE = EXAMPLE_FILES[0] # -------------------------------------------------------- # STREAMING # -------------------------------------------------------- def read_and_stream_audio(filepath_to_stream: str, session_id: str, chunk_seconds: float): """Stream audio chunks and save .npz files only when transcription is active.""" stop_file = os.path.join(TMP_DIR, f"stream_stop_flag_{session_id}.txt") transcribe_flag = os.path.join(TMP_DIR, f"transcribe_active_{session_id}.txt") logging.debug(f"[{session_id}] read_and_stream_audio() started with file: {filepath_to_stream}") try: if not filepath_to_stream or not os.path.exists(filepath_to_stream): logging.error(f"[{session_id}] Audio file not found: {filepath_to_stream}") raise f"Audio file not found: {filepath_to_stream}" clear_stop_flag(session_id) register_session(session_id, filepath_to_stream) progress_path = os.path.join(TMP_DIR, f"progress_{session_id}.json") segment = AudioSegment.from_file(filepath_to_stream) chunk_ms = int(chunk_seconds * 1000) total_chunks = len(segment) // chunk_ms + 1 logging.info(f"[{session_id}] Streaming {total_chunks} chunks ({chunk_seconds:.2f}s each)...") for i, chunk in enumerate(segment[::chunk_ms], start=1): if os.path.exists(stop_file): logging.info(f"[{session_id}] Stop flag detected at chunk {i}. Ending stream.") clear_stop_flag(session_id) break logging.info(f"[{session_id}] Streaming chunk {i}.") iter_start = time.perf_counter() elapsed_s = i * chunk_seconds hours, remainder = divmod(int(elapsed_s), 3600) minutes, seconds = divmod(remainder, 60) elapsed_str = f"{hours:02d}:{minutes:02d}:{seconds:02d}" percent = round((i / total_chunks) * 100, 2) progress_data = {"value": percent, "elapsed": elapsed_str} with open(progress_path, "w") as f: json.dump(progress_data, f) chunk_array = np.array(chunk.get_array_of_samples(), dtype=np.int16) rate = chunk.frame_rate # Save only if transcription is active if os.path.exists(transcribe_flag) : chunk_dir = os.path.join(TMP_DIR, f"chunks_{session_id}") if not os.path.exists(chunk_dir) : os.makedirs(chunk_dir, exist_ok=True) npz_path = os.path.join(chunk_dir, f"chunk_{i:05d}.npz") np.savez_compressed(npz_path, data=chunk_array, rate=rate) logging.debug(f"[{session_id}] Saved chunk {i}/{total_chunks} (transcribe active)") # Stream audio to client # yield (rate, chunk_array.reshape(1, -1)) msg = f"Chunk {i}/{total_chunks}" yield ( (rate, chunk_array.reshape(1, -1)), AdditionalOutputs(msg) ) process_ms = (time.perf_counter() - iter_start) * 1000 # time.sleep(max(chunk_seconds - (process_ms / 1000.0) - 0.1, 0.01)) time.sleep(chunk_seconds) raise_function() logging.info(f"[{session_id}] Streaming completed successfully.") except Exception as e: logging.error(f"[{session_id}] Stream error: {e}", exc_info=True) finally: unregister_session(session_id) clear_stop_flag(session_id) if os.path.exists(progress_path): os.remove(progress_path) yield (None, AdditionalOutputs("STREAM_DONE")) # -------------------------------------------------------- # TRANSCRIPTION # -------------------------------------------------------- @spaces.GPU def transcribe(session_id: str): """Continuously read and delete .npz chunks while transcription is active.""" active_flag = os.path.join(TMP_DIR, f"transcribe_active_{session_id}.txt") with open(active_flag, "w") as f: f.write("1") logging.info(f"[{session_id}] Transcription started.") chunk_dir = os.path.join(TMP_DIR, f"chunks_{session_id}") try: logging.info(f"[{session_id}] Transcription loop started.") while os.path.exists(active_flag): if not os.path.exists(chunk_dir): logging.warning(f"[{session_id}] No chunk directory found for transcription.") time.sleep(0.25) continue files = sorted(f for f in os.listdir(chunk_dir) if f.endswith(".npz")) if not files: time.sleep(0.25) continue for fname in files: fpath = os.path.join(chunk_dir, fname) try: npz = np.load(fpath) samples = npz["data"] rate = int(npz["rate"]) text = f"Transcribed {fname}: {len(samples)} samples @ {rate}Hz" logging.debug(f"[{session_id}] {text}") os.remove(fpath) logging.debug(f"[{session_id}] Deleted processed chunk: {fname}") except Exception as e: logging.error(f"[{session_id}] Error processing {fname}: {e}") continue time.sleep(0.25) raise_function() logging.info(f"[{session_id}] Transcription loop ended (flag removed).") except Exception as e: logging.error(f"[{session_id}] Transcription error: {e}", exc_info=True) finally: transcribe_active = os.path.join(TMP_DIR, f"transcribe_active_{session_id}.txt") if os.path.exists(transcribe_active): os.remove(transcribe_active) logging.info(f"[{session_id}] Transcription stopped.") try: if os.path.exists(chunk_dir) and not os.listdir(chunk_dir): os.rmdir(chunk_dir) logging.debug(f"[{session_id}] Cleaned up empty chunk dir.") except Exception as e: logging.error(f"[{session_id}] Cleanup error: {e}") logging.info(f"[{session_id}] Exiting transcription loop.") return { start_transcribe: gr.update(interactive=True), stop_transcribe: gr.update(interactive=False), progress_text: gr.update(value="🛑 Transcription stopped."), } # -------------------------------------------------------- # STOP STREAMING # -------------------------------------------------------- # def stop_streaming(session_id: str): # create_stop_flag(session_id) # logging.info(f"[{session_id}] Stop button clicked → stop flag created.") # return None def get_session_progress(session_id: str): """Read streaming progress and return slider position + elapsed time.""" progress_path = os.path.join(TMP_DIR, f"progress_{session_id}.json") if not os.path.exists(progress_path): return 0.0, "00:00:00" try: with open(progress_path, "r") as f: data = json.load(f) value = data.get("value", 0.0) elapsed = data.get("elapsed", "00:00:00") return value, elapsed except Exception: return 0.0, "00:00:00" def handle_additional_outputs(message): """Called each time a new AdditionalOutputs is received.""" logging.debug(f"📡 Additional output received: {message}") if message == "STREAM_DONE": return "✅ Streaming finished" elif message: return f"📡 {message}" else: return "" # -------------------------------------------------------- # UI # -------------------------------------------------------- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( "## 🎧 WebRTC Audio Streamer (Multi-user)\n" "Each user controls their own stream. Transcription runs only during streaming." ) session_id = gr.State() sid_box = gr.Textbox(label="Session ID", interactive=False) demo.load(fn=on_load, inputs=None, outputs=[session_id, sid_box]) demo.unload(on_unload) active_filepath = gr.State(value=DEFAULT_FILE) with gr.Row(equal_height=True): with gr.Column(elem_id="column_source", scale=1): with gr.Group(elem_id="centered_content"): main_audio = gr.Audio( label="Audio Source", sources=["upload", "microphone"], type="filepath", value=DEFAULT_FILE, ) chunk_slider = gr.Slider( label="Chunk Duration (seconds)", minimum=0.5, maximum=5.0, value=1.0, step=0.5, interactive=True, ) progress_bar = gr.Slider( label="Streaming Progress (%)", minimum=0, maximum=100, value=0, step=0.1, interactive=False, visible=False, ) progress_text = gr.Textbox( label="Elapsed Time (hh:mm:ss)", interactive=False, visible=False, ) with gr.Row(): start_button = gr.Button("â–ļī¸ Start Streaming", variant="primary") stop_button = gr.Button("âšī¸ Stop Streaming", variant="stop", interactive=False) with gr.Column(): status_box = gr.Textbox(label="Status", interactive=False) webrtc_stream = WebRTC( label="Audio Stream", mode="receive", modality="audio", rtc_configuration=generate_coturn_config(), visible=True, ) # --- Transcription Controls --- with gr.Row(equal_height=True): with gr.Column(): start_transcribe = gr.Button("đŸŽ™ī¸ Start Transcribe", interactive=False) stop_transcribe = gr.Button("🛑 Stop Transcribe", interactive=False) # --- UI Logic --- def start_streaming(session_id): return { start_button: gr.update(interactive=False), stop_button: gr.update(interactive=True), start_transcribe: gr.update(interactive=True), stop_transcribe: gr.update(interactive=False), chunk_slider: gr.update(interactive=False), main_audio: gr.update(visible=False), progress_bar: gr.update(value=0, visible=True), progress_text: gr.update(value="00:00:00", visible=True), } def stop_streaming(session_id): logging.debug(f"[{session_id}] UI: Stop clicked → restoring controls.") create_stop_flag(session_id) return { webrtc_stream : None, start_button: gr.update(interactive=True), stop_button: gr.update(interactive=False), start_transcribe: gr.update(interactive=False), stop_transcribe: gr.update(interactive=False), chunk_slider: gr.update(interactive=True), main_audio: gr.update(visible=True), progress_bar: gr.update(value=0, visible=False), progress_text: gr.update(value="00:00:00", visible=False), } ui_components = [ start_button, stop_button, start_transcribe, stop_transcribe, chunk_slider, main_audio, progress_bar, progress_text, ] # --- Streaming event --- webrtc_stream.stream( fn=read_and_stream_audio, inputs=[active_filepath, session_id, chunk_slider], outputs=[webrtc_stream ], trigger=start_button.click, concurrency_limit=20, concurrency_id="receive", ) webrtc_stream.on_additional_outputs( fn=handle_additional_outputs, outputs=[status_box], ) # status_box.change( # fn=update_status, # inputs=[status_box], # outputs=[status_box], # ) # .then( # fn=stop_streaming, # inputs=[session_id], # outputs=ui_components # ) start_button.click(fn=start_streaming, inputs=[session_id], outputs=ui_components) # .then(fn=stop_streaming, inputs=[session_id], outputs=[webrtc_stream] + ui_components) stop_button.click(fn=stop_streaming, inputs=[session_id], outputs=[webrtc_stream] + ui_components) # --- Transcription control logic --- def start_transcribe_ui(session_id: str): """Create transcription flag and update UI.""" return { start_transcribe: gr.update(interactive=False), stop_transcribe: gr.update(interactive=True), progress_text: gr.update(value="đŸŽ™ī¸ Transcription started..."), } def stop_transcribe_ui(session_id: str): """Stop transcription by removing flag and update UI.""" transcribe_active = os.path.join(TMP_DIR, f"transcribe_active_{session_id}.txt") if os.path.exists(transcribe_active): os.remove(transcribe_active) return { start_transcribe: gr.update(interactive=True), stop_transcribe: gr.update(interactive=False), progress_text: gr.update(value="🛑 Transcription stopped."), } start_transcribe.click( fn=start_transcribe_ui, inputs=[session_id], outputs=[start_transcribe, stop_transcribe, progress_text], # --- then chain the transcription process --- ).then( fn=transcribe, inputs=[session_id], outputs=[start_transcribe, stop_transcribe, progress_text], ) stop_transcribe.click( fn=stop_transcribe_ui, inputs=[session_id], outputs=[start_transcribe, stop_transcribe, progress_text], ) # --- Active sessions --- with gr.Accordion("📊 Active Sessions", open=False): sessions_table = gr.DataFrame( headers=["session_id", "file", "start_time", "status"], interactive=False, wrap=True, max_height=200, ) gr.Timer(3.0).tick(fn=get_active_sessions, outputs=sessions_table) gr.Timer(1.0).tick(fn=get_session_progress, inputs=[session_id], outputs=[progress_bar, progress_text]) # -------------------------------------------------------- # CSS # -------------------------------------------------------- custom_css = """ #column_source { display: flex; flex-direction: column; justify-content: center; align-items: center; gap: 1rem; margin-top: auto; margin-bottom: auto; } #column_source .gr-row { padding-top: 12px; padding-bottom: 12px; } """ demo.css = custom_css # -------------------------------------------------------- # MAIN # -------------------------------------------------------- if __name__ == "__main__": demo.queue(max_size=20, api_open=False).launch(show_api=False, debug=True)