Archime's picture
add nemo_asr and silero_vad Engine
11c4a5a
raw
history blame
14.7 kB
from app.logger_config import (
logger as logging,
DEBUG
)
import numpy as np
import gradio as gr
import asyncio
from fastrtc.webrtc import WebRTC
from fastrtc.utils import AdditionalOutputs
from pydub import AudioSegment
import time
import os
from gradio.utils import get_space
from app.utils import (
raise_function,
generate_coturn_config,
read_and_stream_audio,
stop_streaming,
task
)
from app.session_utils import (
on_load,
on_unload,
get_active_session_hash_code,
register_session_hash_code,
reset_all_active_session_hash_code,
get_active_task_flag_file,
)
from app.ui_utils import (
SUPPORTED_LANGS_MAP,
EXAMPLE_CONFIGS,
apply_preset_if_example,
reset_to_defaults,
summarize_config,
handle_additional_outputs,
get_custom_theme,
on_file_load
)
import nemo.collections.asr as nemo_asr
# --------------------------------------------------------
# Initialization
# --------------------------------------------------------
reset_all_active_session_hash_code()
theme,css_style = get_custom_theme()
from omegaconf import OmegaConf
cfg = OmegaConf.load('app/config.yaml')
# logger.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
from app.canary_speech_engine import CanarySpeechEngine
from app.silero_vad_engine import Silero_Vad_Engine
from app.streaming_audio_processor import StreamingAudioProcessor,StreamingAudioProcessorConfig
asr_model = nemo_asr.models.ASRModel.from_pretrained(cfg.pretrained_name)
canary_speech_engine = CanarySpeechEngine(asr_model,cfg)
silero_vad_engine = Silero_Vad_Engine()
streaming_audio_processor_config = StreamingAudioProcessorConfig(
read_size=4000,
silence_threshold_chunks=1
)
streamer = StreamingAudioProcessor(speech_engine=canary_speech_engine,vad_engine=silero_vad_engine,cfg=streaming_audio_processor_config)
with gr.Blocks(theme=theme, css=css_style) as demo:
session_hash_code = gr.State()
session_hash_code_box = gr.Textbox(label="Session ID", interactive=False, visible=DEBUG)
with gr.Accordion("📊 Active Sessions Hash", open=True ,visible=DEBUG):
sessions_table = gr.DataFrame(
headers=["session_hash_code", "file", "start_time", "status"],
interactive=False,
wrap=True,
max_height=200,
)
gr.Timer(3.0).tick(fn=get_active_session_hash_code, outputs=sessions_table)
demo.load(fn=on_load, inputs=None, outputs=[session_hash_code, session_hash_code_box])
demo.unload(on_unload)
stop_streaming_flags = gr.State(value={"stop": False})
active_filepath = gr.State(value=next(iter(EXAMPLE_CONFIGS)))
with gr.Walkthrough(selected=0) as walkthrough:
# === STEP 1 ===
with gr.Step("Audio", id=0) as audio_source_step:
gr.Markdown(
"""
### Step 1: Upload or Record an Audio File
You can upload an existing file or record directly from your microphone.
Accepted formats: **.wav**, **.mp3**, **.flac**
Maximum length recommended: **60 seconds**
"""
)
with gr.Group():
with gr.Column():
main_audio = gr.Audio(
label="Audio Input",
sources=["upload", "microphone"],
type="filepath",
interactive=True
)
with gr.Accordion("Need a quick test? Try one of the sample audios below", open=True):
examples = gr.Examples(
examples=list(EXAMPLE_CONFIGS.keys()),
inputs=main_audio,
label=None,
examples_per_page=3
)
gr.Markdown(
"""
🔹 **english_meeting.wav** – Short business meeting in English
🔹 **french_news.wav** – Excerpt from a French radio broadcast
🔹 **spanish_podcast.wav** – Segment from a Spanish-language podcast
"""
)
btn_proceed_streaming = gr.Button("Proceed to Streaming", visible=False)
ui_components_oload_audio = [active_filepath, btn_proceed_streaming]
main_audio.change(fn=on_file_load, inputs=[main_audio], outputs=ui_components_oload_audio)
# main_audio.stop_recording(fn=on_file_load, inputs=[main_audio], outputs=ui_components_one)
# main_audio.clear(fn=on_file_load, inputs=[main_audio], outputs=ui_components_one)
btn_proceed_streaming.click(lambda: gr.Walkthrough(selected=1), outputs=walkthrough)
# === STEP 2 ===
with gr.Step("Stream", id=1) as audio_stream:
gr.Markdown("### Step 2: Start audio streaming")
with gr.Group():
with gr.Column():
webrtc_stream = WebRTC(
label="Live Stream",
mode="receive",
modality="audio",
rtc_configuration=generate_coturn_config(),
visible=True,
inputs=main_audio
)
start_stream_button = gr.Button("Start Streaming")
webrtc_stream.stream(
fn=read_and_stream_audio,
inputs=[active_filepath, session_hash_code, stop_streaming_flags,gr.State(streaming_audio_processor_config.read_size)],
outputs=[webrtc_stream],
trigger=start_stream_button.click,
concurrency_id="audio_stream",
concurrency_limit=10,
)
status_message_stream = gr.Markdown("", elem_id="status-message-stream", visible=False)
go_to_config = gr.Button("Go to Configuration", visible=False)
go_to_config.click(lambda: gr.Walkthrough(selected=2), outputs=walkthrough)
# === STEP 3 ===
with gr.Step("Configuration", id=2):
gr.Markdown("## Step 3: Configure the Task")
task_type = gr.Radio(["Transcription", "Translation"], value="Transcription", label="Task Type")
lang_source = gr.Dropdown(list(SUPPORTED_LANGS_MAP.keys()), value="French", label="Source Language")
lang_target = gr.Dropdown(list(SUPPORTED_LANGS_MAP.keys()), value="English", label="Target Language", visible=False)
with gr.Accordion("Advanced Configuration", open=False):
chunk_secs = gr.Number(value=1.0, label="chunk_secs", precision=1)
left_context_secs = gr.Number(value=20.0, label="left_context_secs", precision=1)
right_context_secs = gr.Number(value=0.5, label="right_context_secs", precision=1)
streaming_policy = gr.Dropdown(["waitk", "alignatt"], value="waitk", label="decoding.streaming_policy")
alignatt_thr = gr.Number(value=8, label="alignatt_thr", precision=0)
waitk_lagging = gr.Number(value=2, label="waitk_lagging", precision=0)
exclude_sink_frames = gr.Number(value=8, label="exclude_sink_frames", precision=0)
xatt_scores_layer = gr.Number(value=-2, label="xatt_scores_layer", precision=0)
hallucinations_detector = gr.Checkbox(value=True, label="hallucinations_detector")
with gr.Row():
auto_apply_presets = gr.Checkbox(value=True, label="Auto-apply presets for sample audios")
reset_btn = gr.Button("Reset to defaults")
summary_box = gr.Textbox(label="Configuration Summary", lines=10, interactive=False)
# --- Events ---
task_type.change(
fn=lambda t: gr.update(visible=(t == "Translation")),
inputs=task_type,
outputs=lang_target,
queue=False
)
inputs_list = [
task_type, lang_source, lang_target,
chunk_secs, left_context_secs, right_context_secs,
streaming_policy, alignatt_thr, waitk_lagging,
exclude_sink_frames, xatt_scores_layer, hallucinations_detector
]
for inp in inputs_list:
inp.change(
fn=summarize_config,
inputs=inputs_list,
outputs=summary_box,
queue=False
)
# Apply preset or not
main_audio.change(
fn=apply_preset_if_example,
inputs=[main_audio, auto_apply_presets],
outputs=[
task_type, lang_source, lang_target,
chunk_secs, left_context_secs, right_context_secs,
streaming_policy, alignatt_thr, waitk_lagging,
exclude_sink_frames, xatt_scores_layer, hallucinations_detector,
summary_box
],
queue=False
)
# Reset defaults
reset_btn.click(
fn=reset_to_defaults,
inputs=None,
outputs=[
task_type, lang_source, lang_target,
chunk_secs, left_context_secs, right_context_secs,
streaming_policy, alignatt_thr, waitk_lagging,
exclude_sink_frames, xatt_scores_layer, hallucinations_detector,
summary_box
],
queue=False
)
go_to_task = gr.Button("Go to Task")
go_to_task.click(lambda: gr.Walkthrough(selected=3), outputs=walkthrough)
# === STEP 4 ===
with gr.Step("Task", id=3) as task_step:
gr.Markdown("## Step 4: Start the Task")
with gr.Group():
with gr.Column():
status_slider = gr.Slider(
0, 100,
value=0,
label="Streaming Progress",
interactive=False,
visible=False
)
stop_stream_button = gr.Button("Stop Streaming", visible=False)
transcription_output = gr.Textbox(
label="Transcription / Translation Result",
placeholder="Waiting for output...",
lines=10,
max_lines= 10,
interactive=False,
visible=True,
autoscroll=True
)
start_task_button = gr.Button("Start Task", visible=True)
stop_task_button = gr.Button("Stop Task", visible=False)
stop_stream_button.click(
fn=stop_streaming,
inputs=[session_hash_code, stop_streaming_flags],
outputs=[stop_streaming_flags],
)
def stop_task_fn(session_hash_code):
transcribe_active = get_active_task_flag_file(session_hash_code)
if os.path.exists(transcribe_active):
os.remove(transcribe_active)
yield "Task stopped by user."
stop_task_button.click(
fn=stop_task_fn,
inputs=session_hash_code,
outputs=transcription_output
)
# task(session_hash_code)
def start_transcription(
session_hash_code, stop_streaming_flags,
task_type, lang_source, lang_target,
chunk_secs, left_context_secs, right_context_secs,
streaming_policy, alignatt_thr, waitk_lagging,
exclude_sink_frames, xatt_scores_layer, hallucinations_detector
):
"""Stream transcription or translation results in real time."""
accumulated = ""
yield f"Starting {task_type.lower()}...\n\n",gr.update(visible=False),gr.update(visible=True)
# Boucle sur le générateur de `task()`
for msg in task(session_hash_code,streamer=streamer):
accumulated += msg
yield accumulated,gr.update(visible=False),gr.update(visible=True)
yield accumulated + "\nDone.",gr.update(visible=True),gr.update(visible=False)
start_task_button.click(
fn=start_transcription,
inputs=[
session_hash_code, stop_streaming_flags,
task_type, lang_source, lang_target,
chunk_secs, left_context_secs, right_context_secs,
streaming_policy, alignatt_thr, waitk_lagging,
exclude_sink_frames, xatt_scores_layer, hallucinations_detector
],
outputs=[transcription_output,start_task_button,stop_task_button]
)
ui_components = [
start_stream_button, stop_stream_button,
go_to_config, audio_source_step, status_slider,walkthrough,status_message_stream
]
webrtc_stream.on_additional_outputs(
fn=handle_additional_outputs,
inputs=[webrtc_stream],
outputs=ui_components,
concurrency_id="additional_outputs_audio_stream",
concurrency_limit=10,
)
# def toggle_task_buttons():
# return (
# gr.update(visible=False),
# gr.update(visible=True),
# gr.update(visible=True)
# )
# start_task_button.click(
# fn=toggle_task_buttons,
# inputs=None,
# outputs=[start_task_button, stop_task_button, stop_stream_button],
# queue=False
# )
if __name__ == "__main__":
demo.queue(max_size=10, api_open=False).launch(show_api=False,show_error=True, debug=DEBUG)