canary_aed_streaming / old_app.py
Archime's picture
impl walkthrough
6f523af
raw
history blame
15.7 kB
from app.logger_config import logger as logging
import numpy as np
import gradio as gr
import asyncio
from fastrtc.webrtc import WebRTC
from fastrtc.utils import AdditionalOutputs
from pydub import AudioSegment
import time
import os
import json
import spaces
from app.utils import generate_coturn_config,raise_function
from app.old_session_utils import (
TMP_DIR,
generate_session_id,
register_session,
unregister_session,
get_active_sessions,
stop_file_path,
create_stop_flag,
clear_stop_flag,
reset_all_active_sessions,
on_load,
on_unload
)
# Reset sessions at startup
reset_all_active_sessions()
EXAMPLE_FILES = ["data/bonjour.wav", "data/bonjour2.wav"]
DEFAULT_FILE = EXAMPLE_FILES[0]
# --------------------------------------------------------
# STREAMING
# --------------------------------------------------------
def read_and_stream_audio(filepath_to_stream: str, session_id: str, chunk_seconds: float):
"""Stream audio chunks and save .npz files only when transcription is active."""
stop_file = os.path.join(TMP_DIR, f"stream_stop_flag_{session_id}.txt")
transcribe_flag = os.path.join(TMP_DIR, f"transcribe_active_{session_id}.txt")
logging.debug(f"[{session_id}] read_and_stream_audio() started with file: {filepath_to_stream}")
try:
if not filepath_to_stream or not os.path.exists(filepath_to_stream):
logging.error(f"[{session_id}] Audio file not found: {filepath_to_stream}")
raise f"Audio file not found: {filepath_to_stream}"
clear_stop_flag(session_id)
register_session(session_id, filepath_to_stream)
progress_path = os.path.join(TMP_DIR, f"progress_{session_id}.json")
segment = AudioSegment.from_file(filepath_to_stream)
chunk_ms = int(chunk_seconds * 1000)
total_chunks = len(segment) // chunk_ms + 1
logging.info(f"[{session_id}] Streaming {total_chunks} chunks ({chunk_seconds:.2f}s each)...")
for i, chunk in enumerate(segment[::chunk_ms], start=1):
if os.path.exists(stop_file):
logging.info(f"[{session_id}] Stop flag detected at chunk {i}. Ending stream.")
clear_stop_flag(session_id)
break
logging.info(f"[{session_id}] Streaming chunk {i}.")
iter_start = time.perf_counter()
elapsed_s = i * chunk_seconds
hours, remainder = divmod(int(elapsed_s), 3600)
minutes, seconds = divmod(remainder, 60)
elapsed_str = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
percent = round((i / total_chunks) * 100, 2)
progress_data = {"value": percent, "elapsed": elapsed_str}
with open(progress_path, "w") as f:
json.dump(progress_data, f)
chunk_array = np.array(chunk.get_array_of_samples(), dtype=np.int16)
rate = chunk.frame_rate
# Save only if transcription is active
if os.path.exists(transcribe_flag) :
chunk_dir = os.path.join(TMP_DIR, f"chunks_{session_id}")
if not os.path.exists(chunk_dir) :
os.makedirs(chunk_dir, exist_ok=True)
npz_path = os.path.join(chunk_dir, f"chunk_{i:05d}.npz")
np.savez_compressed(npz_path, data=chunk_array, rate=rate)
logging.debug(f"[{session_id}] Saved chunk {i}/{total_chunks} (transcribe active)")
# Stream audio to client
# yield (rate, chunk_array.reshape(1, -1))
msg = f"Chunk {i}/{total_chunks}"
yield ( (rate, chunk_array.reshape(1, -1)), AdditionalOutputs(msg) )
process_ms = (time.perf_counter() - iter_start) * 1000
# time.sleep(max(chunk_seconds - (process_ms / 1000.0) - 0.1, 0.01))
time.sleep(chunk_seconds)
raise_function()
logging.info(f"[{session_id}] Streaming completed successfully.")
except Exception as e:
logging.error(f"[{session_id}] Stream error: {e}", exc_info=True)
finally:
unregister_session(session_id)
clear_stop_flag(session_id)
if os.path.exists(progress_path):
os.remove(progress_path)
yield (None, AdditionalOutputs("STREAM_DONE"))
# --------------------------------------------------------
# TRANSCRIPTION
# --------------------------------------------------------
@spaces.GPU
def transcribe(session_id: str):
"""Continuously read and delete .npz chunks while transcription is active."""
active_flag = os.path.join(TMP_DIR, f"transcribe_active_{session_id}.txt")
with open(active_flag, "w") as f:
f.write("1")
logging.info(f"[{session_id}] Transcription started.")
chunk_dir = os.path.join(TMP_DIR, f"chunks_{session_id}")
try:
logging.info(f"[{session_id}] Transcription loop started.")
while os.path.exists(active_flag):
if not os.path.exists(chunk_dir):
logging.warning(f"[{session_id}] No chunk directory found for transcription.")
time.sleep(0.25)
continue
files = sorted(f for f in os.listdir(chunk_dir) if f.endswith(".npz"))
if not files:
time.sleep(0.25)
continue
for fname in files:
fpath = os.path.join(chunk_dir, fname)
try:
npz = np.load(fpath)
samples = npz["data"]
rate = int(npz["rate"])
text = f"Transcribed {fname}: {len(samples)} samples @ {rate}Hz"
logging.debug(f"[{session_id}] {text}")
os.remove(fpath)
logging.debug(f"[{session_id}] Deleted processed chunk: {fname}")
except Exception as e:
logging.error(f"[{session_id}] Error processing {fname}: {e}")
continue
time.sleep(0.25)
raise_function()
logging.info(f"[{session_id}] Transcription loop ended (flag removed).")
except Exception as e:
logging.error(f"[{session_id}] Transcription error: {e}", exc_info=True)
finally:
transcribe_active = os.path.join(TMP_DIR, f"transcribe_active_{session_id}.txt")
if os.path.exists(transcribe_active):
os.remove(transcribe_active)
logging.info(f"[{session_id}] Transcription stopped.")
try:
if os.path.exists(chunk_dir) and not os.listdir(chunk_dir):
os.rmdir(chunk_dir)
logging.debug(f"[{session_id}] Cleaned up empty chunk dir.")
except Exception as e:
logging.error(f"[{session_id}] Cleanup error: {e}")
logging.info(f"[{session_id}] Exiting transcription loop.")
return {
start_transcribe: gr.update(interactive=True),
stop_transcribe: gr.update(interactive=False),
progress_text: gr.update(value="πŸ›‘ Transcription stopped."),
}
# --------------------------------------------------------
# STOP STREAMING
# --------------------------------------------------------
# def stop_streaming(session_id: str):
# create_stop_flag(session_id)
# logging.info(f"[{session_id}] Stop button clicked β†’ stop flag created.")
# return None
def get_session_progress(session_id: str):
"""Read streaming progress and return slider position + elapsed time."""
progress_path = os.path.join(TMP_DIR, f"progress_{session_id}.json")
if not os.path.exists(progress_path):
return 0.0, "00:00:00"
try:
with open(progress_path, "r") as f:
data = json.load(f)
value = data.get("value", 0.0)
elapsed = data.get("elapsed", "00:00:00")
return value, elapsed
except Exception:
return 0.0, "00:00:00"
def handle_additional_outputs(message):
"""Called each time a new AdditionalOutputs is received."""
logging.debug(f"πŸ“‘ Additional output received: {message}")
if message == "STREAM_DONE":
return "βœ… Streaming finished"
elif message:
return f"πŸ“‘ {message}"
else:
return ""
# --------------------------------------------------------
# UI
# --------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"## 🎧 WebRTC Audio Streamer (Multi-user)\n"
"Each user controls their own stream. Transcription runs only during streaming."
)
session_id = gr.State()
sid_box = gr.Textbox(label="Session ID", interactive=False)
demo.load(fn=on_load, inputs=None, outputs=[session_id, sid_box])
demo.unload(on_unload)
active_filepath = gr.State(value=DEFAULT_FILE)
with gr.Row(equal_height=True):
with gr.Column(elem_id="column_source", scale=1):
with gr.Group(elem_id="centered_content"):
main_audio = gr.Audio(
label="Audio Source",
sources=["upload", "microphone"],
type="filepath",
value=DEFAULT_FILE,
)
chunk_slider = gr.Slider(
label="Chunk Duration (seconds)",
minimum=0.5,
maximum=5.0,
value=1.0,
step=0.5,
interactive=True,
)
progress_bar = gr.Slider(
label="Streaming Progress (%)",
minimum=0,
maximum=100,
value=0,
step=0.1,
interactive=False,
visible=False,
)
progress_text = gr.Textbox(
label="Elapsed Time (hh:mm:ss)",
interactive=False,
visible=False,
)
with gr.Row():
start_button = gr.Button("▢️ Start Streaming", variant="primary")
stop_button = gr.Button("⏹️ Stop Streaming", variant="stop", interactive=False)
with gr.Column():
status_box = gr.Textbox(label="Status", interactive=False)
webrtc_stream = WebRTC(
label="Audio Stream",
mode="receive",
modality="audio",
rtc_configuration=generate_coturn_config(),
visible=True,
)
# --- Transcription Controls ---
with gr.Row(equal_height=True):
with gr.Column():
start_transcribe = gr.Button("πŸŽ™οΈ Start Transcribe", interactive=False)
stop_transcribe = gr.Button("πŸ›‘ Stop Transcribe", interactive=False)
# --- UI Logic ---
def start_streaming(session_id):
return {
start_button: gr.update(interactive=False),
stop_button: gr.update(interactive=True),
start_transcribe: gr.update(interactive=True),
stop_transcribe: gr.update(interactive=False),
chunk_slider: gr.update(interactive=False),
main_audio: gr.update(visible=False),
progress_bar: gr.update(value=0, visible=True),
progress_text: gr.update(value="00:00:00", visible=True),
}
def stop_streaming(session_id):
logging.debug(f"[{session_id}] UI: Stop clicked β†’ restoring controls.")
create_stop_flag(session_id)
return {
webrtc_stream : None,
start_button: gr.update(interactive=True),
stop_button: gr.update(interactive=False),
start_transcribe: gr.update(interactive=False),
stop_transcribe: gr.update(interactive=False),
chunk_slider: gr.update(interactive=True),
main_audio: gr.update(visible=True),
progress_bar: gr.update(value=0, visible=False),
progress_text: gr.update(value="00:00:00", visible=False),
}
ui_components = [
start_button, stop_button, start_transcribe, stop_transcribe,
chunk_slider, main_audio, progress_bar, progress_text,
]
# --- Streaming event ---
webrtc_stream.stream(
fn=read_and_stream_audio,
inputs=[active_filepath, session_id, chunk_slider],
outputs=[webrtc_stream ],
trigger=start_button.click,
concurrency_limit=20,
concurrency_id="receive",
)
webrtc_stream.on_additional_outputs(
fn=handle_additional_outputs,
outputs=[status_box],
)
# status_box.change(
# fn=update_status,
# inputs=[status_box],
# outputs=[status_box],
# )
# .then(
# fn=stop_streaming,
# inputs=[session_id],
# outputs=ui_components
# )
start_button.click(fn=start_streaming, inputs=[session_id], outputs=ui_components)
# .then(fn=stop_streaming, inputs=[session_id], outputs=[webrtc_stream] + ui_components)
stop_button.click(fn=stop_streaming, inputs=[session_id], outputs=[webrtc_stream] + ui_components)
# --- Transcription control logic ---
def start_transcribe_ui(session_id: str):
"""Create transcription flag and update UI."""
return {
start_transcribe: gr.update(interactive=False),
stop_transcribe: gr.update(interactive=True),
progress_text: gr.update(value="πŸŽ™οΈ Transcription started..."),
}
def stop_transcribe_ui(session_id: str):
"""Stop transcription by removing flag and update UI."""
transcribe_active = os.path.join(TMP_DIR, f"transcribe_active_{session_id}.txt")
if os.path.exists(transcribe_active):
os.remove(transcribe_active)
return {
start_transcribe: gr.update(interactive=True),
stop_transcribe: gr.update(interactive=False),
progress_text: gr.update(value="πŸ›‘ Transcription stopped."),
}
start_transcribe.click(
fn=start_transcribe_ui,
inputs=[session_id],
outputs=[start_transcribe, stop_transcribe, progress_text],
# --- then chain the transcription process ---
).then(
fn=transcribe,
inputs=[session_id],
outputs=[start_transcribe, stop_transcribe, progress_text],
)
stop_transcribe.click(
fn=stop_transcribe_ui,
inputs=[session_id],
outputs=[start_transcribe, stop_transcribe, progress_text],
)
# --- Active sessions ---
with gr.Accordion("πŸ“Š Active Sessions", open=False):
sessions_table = gr.DataFrame(
headers=["session_id", "file", "start_time", "status"],
interactive=False,
wrap=True,
max_height=200,
)
gr.Timer(3.0).tick(fn=get_active_sessions, outputs=sessions_table)
gr.Timer(1.0).tick(fn=get_session_progress, inputs=[session_id], outputs=[progress_bar, progress_text])
# --------------------------------------------------------
# CSS
# --------------------------------------------------------
custom_css = """
#column_source {
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
gap: 1rem;
margin-top: auto;
margin-bottom: auto;
}
#column_source .gr-row {
padding-top: 12px;
padding-bottom: 12px;
}
"""
demo.css = custom_css
# --------------------------------------------------------
# MAIN
# --------------------------------------------------------
if __name__ == "__main__":
demo.queue(max_size=20, api_open=False).launch(show_api=False, debug=True)