canary_aed_streaming / app /ui_utils.py
Archime's picture
Updates UI state error
d296c7b
raw
history blame
18.3 kB
from app.logger_config import logger as logging
import gradio as gr
from pathlib import Path
import os
from app.utils import (
remove_active_task_flag_file,
task_fake,
is_active_task,
is_active_stream,
task
)
# from app.utils import (
# raise_error,
# READ_SIZE,
# generate_coturn_config,
# read_and_stream_audio,
# stop_streaming,
# task,
# task_fake
# )
DEFAULT_CONFIG = {
"task_type": "Transcription",
"lang_source": "French",
"lang_target": "English",
"chunk_secs": 1.0,
"left_context_secs": 20.0,
"right_context_secs": 0.5,
"streaming_policy": "alignatt",
"alignatt_thr": 8,
"waitk_lagging": 2,
"exclude_sink_frames": 8,
"xatt_scores_layer": -2,
"hallucinations_detector": True,
}
EXAMPLE_CONFIGS = {
"data/english_meeting.wav": {
"task_type": "Transcription", "lang_source": "English", "lang_target": "English",
"chunk_secs": 1.0, "left_context_secs": 20.0, "right_context_secs": 0.5,
"streaming_policy": "waitk", "alignatt_thr": 8, "waitk_lagging": 2,
"exclude_sink_frames": 8, "xatt_scores_layer": -2, "hallucinations_detector": True
},
"data/french_news.wav": {
"task_type": "Transcription", "lang_source": "French", "lang_target": "French",
"chunk_secs": 1.0, "left_context_secs": 15.0, "right_context_secs": 0.5,
"streaming_policy": "alignatt", "alignatt_thr": 8.0, "waitk_lagging": 3,
"exclude_sink_frames": 8, "xatt_scores_layer": -2, "hallucinations_detector": True
},
"data/spanish_podcast.wav": {
"task_type": "Translation", "lang_source": "Spanish", "lang_target": "English",
"chunk_secs": 1.5, "left_context_secs": 25.0, "right_context_secs": 0.4,
"streaming_policy": "waitk", "alignatt_thr": 7, "waitk_lagging": 1,
"exclude_sink_frames": 8, "xatt_scores_layer": -2, "hallucinations_detector": False
}
}
# ========== FONCTIONS UTILITAIRES ==========
def to_updates(cfg):
"""Map dict -> gr.update list dans l'ordre des sorties."""
return [
gr.update(value=cfg["task_type"]),
gr.update(value=cfg["lang_source"]),
gr.update(
value=cfg["lang_target"],
visible=(cfg["task_type"] == "Translation")
),
gr.update(value=cfg["chunk_secs"]),
gr.update(value=cfg["left_context_secs"]),
gr.update(value=cfg["right_context_secs"]),
gr.update(value=cfg["streaming_policy"]),
gr.update(value=cfg["alignatt_thr"]),
gr.update(value=cfg["waitk_lagging"]),
gr.update(value=cfg["exclude_sink_frames"]),
gr.update(value=cfg["xatt_scores_layer"]),
gr.update(value=cfg["hallucinations_detector"]),
]
def apply_preset_if_example(filepath, auto_apply):
"""Si fichier = exemple ET auto_apply=True -> applique preset. Sinon, ne rien changer."""
if not filepath or not auto_apply:
updates = [gr.update() for _ in range(12)]
updates.append(gr.update())
return tuple(updates)
# On compare uniquement le nom de fichier, pas le chemin complet
file_name = Path(filepath).name
# Recherche dans EXAMPLE_CONFIGS par nom de fichier
cfg = next(
(config for path, config in EXAMPLE_CONFIGS.items() if Path(path).name == file_name),
None
)
if not cfg:
updates = [gr.update() for _ in range(12)]
updates.append(gr.update())
return tuple(updates)
updates = to_updates(cfg)
updates.append(gr.update(value=f"Preset applied for: {file_name}"))
return tuple(updates)
def reset_to_defaults():
"""Réinitialise tous les champs aux valeurs par défaut."""
updates = to_updates(DEFAULT_CONFIG) # 12 champs
# Ajout du résumé (13e sortie)
updates.append(gr.update(value="Defaults restored."))
return tuple(updates)
def summarize_config(
task, src, tgt,
chunk, left, right,
policy, thr, lag, sink, xatt, halluc
):
txt = f"🧠 **Task:** {task}\n🌐 **Source language:** {src}"
if task == "Translation":
txt += f"\n🎯 **Target language:** {tgt}"
txt += (
f"\n\n### ⚙️ Advanced Parameters:\n"
f"- chunk_secs = {chunk}\n"
f"- left_context_secs = {left}\n"
f"- right_context_secs = {right}\n"
f"- decoding.streaming_policy = {policy}\n"
f"- decoding.alignatt_thr = {thr}\n"
f"- decoding.waitk_lagging = {lag}\n"
f"- decoding.exclude_sink_frames = {sink}\n"
f"- decoding.xatt_scores_layer = {xatt}\n"
f"- decoding.hallucinations_detector = {halluc}"
)
return txt
def handle_additional_outputs(webrtc_stream, msg):
"""
Updates UI elements based on streaming state.
Improvements:
- Uses centralized state logic to avoid code duplication.
- Handles default values to reduce 'if/else' complexity.
- Secures reading of dictionary keys.
"""
# 1. Default state initialization (Neutral or State Conservation Mode)
# By default, return gr.update() which means "do nothing"
# This avoids specifying the state of every button each time
start_btn = gr.update()
stop_btn = gr.update()
start_task_btn = gr.update()
go_to_task_btn = gr.update()
audio_step = gr.update()
slider = gr.update()
walkthrough = gr.update()
status_msg = gr.update(visible=False, value="")
# Safety: if msg is not a valid dictionary
if not isinstance(msg, dict):
return (start_btn, stop_btn, start_task_btn, go_to_task_btn, audio_step, slider, walkthrough, status_msg)
session_hash = msg.get("session_hash_code", "")
# --- CASE 1: ERROR ---
if msg.get("errored"):
error_val = msg.get("value", "Unknown error")
logging.error(f"[stream_ui] Client-side error: {error_val}")
start_btn = gr.update(visible=True)
stop_btn = gr.update(visible=False)
start_task_btn = gr.update(visible=False)
go_to_task_btn = gr.update(visible=False)
audio_step = gr.update(interactive=True)
slider = gr.update(visible=False, value=0)
status_msg = gr.update(value=f"⚠️ **Error:** {error_val}", visible=True)
# --- CASE 2: MANUAL STOP ---
# Note: Kept key "stoped" (with one p), but added "stopped" just in case backend is fixed
elif msg.get("stoped") or msg.get("stopped"):
start_btn = gr.update(visible=True)
stop_btn = gr.update(visible=False)
start_task_btn = gr.update(visible=False)
go_to_task_btn = gr.update(visible=False)
audio_step = gr.update(interactive=True)
slider = gr.update(visible=True, value=0)
status_msg = gr.update(value="ℹ️ Stream stopped by user.", visible=True)
# --- CASE 3: PROGRESS ---
elif msg.get("progressed"):
progress = float(msg.get("value", 0))
# Common logic for progress (active or finished)
start_btn = gr.update(visible=False) # Hide Start during stream
stop_btn = gr.update(visible=True) # Show Stop during stream
audio_step = gr.update(interactive=False) # Lock input
slider = gr.update(visible=True, value=progress)
# Sub-case: Streaming finished (100%)
if progress >= 100.0:
start_btn = gr.update(visible=True)
stop_btn = gr.update(visible=False)
start_task_btn = gr.update(visible=False)
go_to_task_btn = gr.update(visible=True)
audio_step = gr.update(interactive=True)
# Status message remains hidden and empty (default values)
# Sub-case: Streaming in progress (<100%)
else:
go_to_task_btn = gr.update(visible=True)
# Your specific logic for start_task_button
# If task is active, do not touch (empty gr.update), otherwise show
if (not is_active_task(session_hash)) and is_active_stream(session_hash):
start_task_btn = gr.update(visible=True)
# No status message during normal progress
# --- SINGLE RETURN ---
# Order must match EXACTLY your outputs=[...] list in Gradio
return (
start_btn, # 1. start_stream_button
stop_btn, # 2. stop_stream_button
start_task_btn, # 3. start_task_button
go_to_task_btn, # 4. go_to_task
audio_step, # 5. audio_source_step
slider, # 6. status_slider
walkthrough, # 7. walkthrough
status_msg # 8. status_message (Markdown/HTML)
)
# def handle_additional_outputs(webrtc_stream, msg):
# """
# Update UI elements based on streaming progress or errors.
# Controls button states, audio visibility, and progress slider.
# """
# # ui_components = [start_stream_button, stop_stream_button,start_task_button,go_to_task, audio_source_step, status_slider,walkthrough]
# progress = float(0)
# # Handle structured error message
# if isinstance(msg, dict) and msg.get("errored"):
# value = msg.get("value", "Unknown error.")
# logging.error(f"[stream_ui] Client-side error: {value}")
# return (
# gr.update(visible=True), # start_stream_button enabled
# gr.update(visible=False), # stop_stream_button disabled
# gr.update(visible=False), #start_task_button
# gr.update(visible=False), # go_to_task disabled
# gr.update(interactive=True), # audio_source_step re-shown
# gr.update(visible=False, value=0), # slider hidden
# gr.update(), #walkthrough
# gr.update(value=f"**Error:** {value}", visible=True)
# )
# elif msg.get("progressed") :
# value = msg.get("value", 0)
# progress = float(value)
# if progress == 100.00 :
# return (
# gr.update(visible=True), # start_stream_button disabled
# gr.update(visible=False), # stop_stream_button enabled
# gr.update(visible=False), #start_task_button
# gr.update(visible=True), # go_to_task enabled
# gr.update(interactive=True), # hide audio_source_step
# gr.update(visible=True, value=progress), # show progress
# gr.update(), #walkthrough
# gr.update(value="", visible=False)
# )
# else :
# return (
# gr.update(visible=False), # start_stream_button disabled
# gr.update(visible=True), # stop_stream_button enabled
# gr.update() if is_active_task(msg.get("session_hash_code")) else gr.update(visible=True), #start_task_button
# gr.update(visible=True), # go_to_task enabled
# gr.update(interactive=False), # hide audio_source_step
# gr.update(visible=True, value=progress), # show progress
# gr.update(), #walkthrough
# gr.update(value="", visible=False)
# )
# elif msg.get("stoped") :
# return (
# gr.update(visible=True), # start_stream_button disabled
# gr.update(visible=False), # stop_stream_button enabled
# gr.update(visible=False), #start_task_button
# gr.update(visible=False), # go_to_task enabled
# gr.update(interactive=True), # hide audio_source_step
# gr.update(visible=True, value=0), # show progress
# gr.update(), #walkthrough
# gr.update(value="ℹStream stopped by user.", visible=True)
# )
def on_file_load(filepath):
"""
Update active audio path or reset".
"""
# Si un fichier est chargé (upload, micro, ou exemple),
# audio_path ne sera pas None.
is_visible = filepath is not None
return filepath, gr.update(visible=is_visible)
def get_custom_theme() :
# === Thème personnalisé (studio néon) ===
theme = gr.themes.Base(
primary_hue="blue",
secondary_hue="indigo",
).set(
body_background_fill="#F7F8FA",
body_text_color="#222222",
block_border_color="#D0D3D9",
button_primary_background_fill="#3B82F6",
button_primary_background_fill_hover="#2563EB",
button_primary_text_color="#FFFFFF",
)
css_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets", "custom_style.css")
with open(css_path, encoding="utf-8") as f:
css_style = f.read()
return theme, css_style
########## task
# def start_task_asr_ast(
# session_hash_code,
# task_type, lang_source, lang_target,
# chunk_secs, left_context_secs, right_context_secs,
# streaming_policy, alignatt_thr, waitk_lagging,
# exclude_sink_frames, xatt_scores_layer, hallucinations_detector
# ):
# """Stream transcription or translation results in real time."""
# accumulated = ""
# # Boucle sur le générateur de `task2()`
# # outputs=[task_output,status_message_task,start_task_button,stop_task_button,config_step]
# for result, status, current_chunk in task_fake(
# session_hash_code,
# task_type, lang_source, lang_target,
# chunk_secs, left_context_secs, right_context_secs,
# streaming_policy, alignatt_thr, waitk_lagging,
# exclude_sink_frames, xatt_scores_layer, hallucinations_detector
# ):
# if status == "success":
# yield (accumulated + result, #task_output
# gr.update(visible=True,value=current_chunk,elem_classes=[status]),#status_message_task
# gr.update(visible=False),#start_task_button
# gr.update(visible=True), #stop_task_button
# gr.update(interactive=False) # config_step
# )
# accumulated += result
# elif status in ["warning","info" ]:
# yield (accumulated, #task_output
# gr.update(visible=True,value=result , elem_classes=[status]),#status_message_task
# gr.update(visible=False),#start_task_button
# gr.update(visible=True),#stop_task_button
# gr.update(interactive=False) # config_step
# )
# elif status in [ "done"]:
# yield (accumulated, #task_output
# gr.update(visible=True,value=result , elem_classes=[status]),#status_message_task
# gr.update(visible=True) if is_active_stream(session_hash_code) else gr.update(visible=False),#start_task_button
# gr.update(visible=False),#stop_task_button
# gr.update(interactive=True) # config_step
# )
def start_task_asr_ast(
session_hash_code,
task_type, lang_source, lang_target,
chunk_secs, left_context_secs, right_context_secs,
streaming_policy, alignatt_thr, waitk_lagging,
exclude_sink_frames, xatt_scores_layer, hallucinations_detector
):
"""
Manages streaming of transcription (ASR) or translation (AST) results.
Orchestrates real-time UI updates (text, status, buttons).
"""
accumulated_text = ""
# Call task generator (backend)
task_generator = task_fake(
session_hash_code,
task_type, lang_source, lang_target,
chunk_secs, left_context_secs, right_context_secs,
streaming_policy, alignatt_thr, waitk_lagging,
exclude_sink_frames, xatt_scores_layer, hallucinations_detector
)
# Loop over partial results
# result_data: can be transcribed text OR an info message depending on status
for result_data, status, debug_info in task_generator:
# 1. Default states for this iteration ('In Progress' mode)
# By default, lock config and allow stopping
start_btn = gr.update(visible=False)
stop_btn = gr.update(visible=True)
config_step = gr.update(interactive=False)
# Status message and main text depend on return type
status_msg = gr.update(visible=True)
main_output = accumulated_text
# --- CASE 1: SUCCESS (New text segment) ---
if status == "success":
# result_data is the new text chunk here
partial_text = result_data
# Update accumulator
accumulated_text += partial_text
main_output = accumulated_text
# Status message displays chunk info (e.g., timestamps)
status_msg = gr.update(visible=True, value=debug_info, elem_classes=[status])
# --- CASE 2: WARNING / INFO (System message) ---
elif status in ["warning", "info"]:
# result_data is the error or info message here
# Do not touch accumulated_text
status_msg = gr.update(visible=True, value=result_data, elem_classes=[status])
# --- CASE 3: DONE / ERROR---
elif status in ["done", "error"]:
logging.error(f"[ui] error ")
# Re-enable controls
is_streaming = is_active_stream(session_hash_code)
start_btn = gr.update(visible=is_streaming) # Show Start only if audio stream is active
stop_btn = gr.update(visible=False)
config_step = gr.update(interactive=True)
# result_data is the completion message
status_msg = gr.update(visible=True, value=result_data, elem_classes=[status])
# 2. Single dispatch to UI
# Expected order: [task_output, status_message_task, start_task_button, stop_task_button, config_step]
yield (
main_output,
status_msg,
start_btn,
stop_btn,
config_step
)
def stop_task_fn(session_hash_code):
remove_active_task_flag_file(session_hash_code)
yield "Task stopped by user."
# # --------------------------------------------------------
def raise_error(message="Une erreur est survenue."):
raise gr.Error(message)