Archime commited on
Commit
70d2ece
·
1 Parent(s): 9d3beef

add new_app with better stop_streaming_flags

Browse files
Files changed (2) hide show
  1. app/new_session_utils.py +237 -0
  2. new_app.py +270 -0
app/new_session_utils.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import uuid
4
+ import shutil
5
+ from datetime import datetime
6
+ from app.logger_config import logger as logging
7
+ import gradio as gr
8
+ # TMP_DIR = "/tmp/canary_aed_streaming"
9
+
10
+ TMP_DIR = os.getenv("TMP_DIR", "/tmp/canary_aed_streaming")
11
+ ACTIVE_SESSIONS_FILE = os.path.join(TMP_DIR, "active_sessions.json")
12
+
13
+
14
+ # ---------------------------
15
+ # Helper to manage the JSON
16
+ # ---------------------------
17
+ def _read_sessions():
18
+ if not os.path.exists(ACTIVE_SESSIONS_FILE):
19
+ return {}
20
+ try:
21
+ with open(ACTIVE_SESSIONS_FILE, "r") as f:
22
+ return json.load(f)
23
+ except Exception:
24
+ return {}
25
+
26
+
27
+ def _write_sessions(data):
28
+ os.makedirs(os.path.dirname(ACTIVE_SESSIONS_FILE), exist_ok=True)
29
+ with open(ACTIVE_SESSIONS_FILE, "w") as f:
30
+ json.dump(data, f, indent=2)
31
+
32
+
33
+ # ---------------------------
34
+ # LOAD
35
+ # ---------------------------
36
+ def on_load(request: gr.Request):
37
+ """Called when a new visitor opens the app."""
38
+ session_hash = request.session_hash # ✅ Directly use session_hash as unique ID
39
+ sessions = _read_sessions()
40
+
41
+ sessions[session_hash] = {
42
+ "session_hash": session_hash,
43
+ "file": "",
44
+ "start_time": datetime.utcnow().strftime("%H:%M:%S"),
45
+ "status": "active",
46
+ }
47
+
48
+ _write_sessions(sessions)
49
+ logging.info(f"[{session_hash}] Session registered (on_load).")
50
+
51
+ return session_hash, session_hash # can be used as gr.State + display
52
+
53
+
54
+ # ---------------------------
55
+ # UNLOAD
56
+ # ---------------------------
57
+ def on_unload(request: gr.Request):
58
+ """Called when the visitor closes or refreshes the app."""
59
+ sid = request.session_hash
60
+ sessions = _read_sessions()
61
+
62
+ if sid in sessions:
63
+ create_stop_flag(sid)
64
+ sessions.pop(sid)
65
+ _write_sessions(sessions)
66
+ remove_session_data(sid)
67
+ unregister_session(sid)
68
+ logging.info(f"[{sid}] Session removed (on_unload).")
69
+ else:
70
+ logging.info(f"[{sid}] No active session found to remove.")
71
+
72
+ def ensure_tmp_dir():
73
+ """Ensures the base temporary directory exists."""
74
+ try:
75
+ os.makedirs(TMP_DIR, exist_ok=True)
76
+ except Exception as e:
77
+ logging.error(f"Failed to create tmp directory {TMP_DIR}: {e}")
78
+
79
+
80
+ def reset_all_active_sessions():
81
+ """Removes all temporary session files and folders at startup."""
82
+ ensure_tmp_dir()
83
+
84
+ try:
85
+ # --- Remove active sessions file ---
86
+ if os.path.exists(ACTIVE_SESSIONS_FILE):
87
+ os.remove(ACTIVE_SESSIONS_FILE)
88
+ logging.info("Active sessions file reset at startup.")
89
+ else:
90
+ logging.debug("No active sessions file found to reset.")
91
+
92
+ except Exception as e:
93
+ logging.error(f"Error resetting active sessions: {e}")
94
+
95
+ def remove_session_data(session_id: str):
96
+ """Removes all temporary files and data related to a specific session."""
97
+ if not session_id:
98
+ logging.warning("reset_session() called without a valid session_id.")
99
+ return
100
+
101
+ try:
102
+ # --- Remove session from active_sessions.json ---
103
+ if os.path.exists(ACTIVE_SESSIONS_FILE):
104
+ try:
105
+ with open(ACTIVE_SESSIONS_FILE, "r") as f:
106
+ data = json.load(f)
107
+ if session_id in data:
108
+ data.pop(session_id)
109
+ with open(ACTIVE_SESSIONS_FILE, "w") as f:
110
+ json.dump(data, f, indent=2)
111
+ logging.debug(f"[{session_id}] Removed from active_sessions.json.")
112
+ except Exception as e:
113
+ logging.warning(f"[{session_id}] Failed to update active_sessions.json: {e}")
114
+
115
+ # --- Define all possible session file patterns ---
116
+ files_to_remove = [
117
+ f"progress_{session_id}.json",
118
+ # f"stream_stop_flag_{session_id}.txt",
119
+ f"transcribe_stop_flag_{session_id}.txt",
120
+ f"transcribe_active_{session_id}.txt",
121
+ ]
122
+
123
+ # --- Remove all temporary files ---
124
+ for fname in files_to_remove:
125
+ path = os.path.join(TMP_DIR, fname)
126
+ if os.path.exists(path):
127
+ try:
128
+ os.remove(path)
129
+ logging.debug(f"[{session_id}] Removed file: {fname}")
130
+ except Exception as e:
131
+ logging.warning(f"[{session_id}] Failed to remove file {fname}: {e}")
132
+
133
+ # --- Remove chunk folder if exists ---
134
+ chunk_dir = os.path.join(TMP_DIR, f"chunks_{session_id}")
135
+ if os.path.isdir(chunk_dir):
136
+ try:
137
+ shutil.rmtree(chunk_dir)
138
+ logging.debug(f"[{session_id}] Removed chunk folder: chunks_{session_id}")
139
+ except Exception as e:
140
+ logging.warning(f"[{session_id}] Failed to remove chunk folder: {e}")
141
+
142
+ logging.info(f"[{session_id}] Session fully reset.")
143
+
144
+ except Exception as e:
145
+ logging.error(f"[{session_id}] Error during reset_session: {e}")
146
+ def generate_session_id() -> str:
147
+ """Generates a unique session ID."""
148
+ sid = str(uuid.uuid4())
149
+ logging.debug(f"[{sid}] New session created.")
150
+ return sid
151
+
152
+
153
+ def register_session(session_id: str, filepath: str):
154
+ """Registers a new session."""
155
+ ensure_tmp_dir()
156
+ data = {}
157
+ if os.path.exists(ACTIVE_SESSIONS_FILE):
158
+ with open(ACTIVE_SESSIONS_FILE, "r") as f:
159
+ try:
160
+ data = json.load(f)
161
+ except Exception:
162
+ data = {}
163
+
164
+ data[session_id] = {
165
+ "session_id": session_id,
166
+ "file": filepath,
167
+ "start_time": datetime.utcnow().strftime("%H:%M:%S"),
168
+ "status": "active",
169
+ }
170
+
171
+ with open(ACTIVE_SESSIONS_FILE, "w") as f:
172
+ json.dump(data, f)
173
+
174
+ logging.debug(f"[{session_id}] Session registered in active_sessions.json.")
175
+
176
+
177
+ def unregister_session(session_id: str):
178
+ """Removes a session from the registry."""
179
+ if not os.path.exists(ACTIVE_SESSIONS_FILE):
180
+ return
181
+
182
+ try:
183
+ with open(ACTIVE_SESSIONS_FILE, "r") as f:
184
+ data = json.load(f)
185
+ if session_id in data:
186
+ data.pop(session_id)
187
+ with open(ACTIVE_SESSIONS_FILE, "w") as f:
188
+ json.dump(data, f)
189
+ logging.debug(f"[{session_id}] Session unregistered.")
190
+ except Exception as e:
191
+ logging.error(f"[{session_id}] Error unregistering session: {e}")
192
+
193
+
194
+ def get_active_sessions():
195
+ """Returns active sessions as a list of rows for the DataFrame."""
196
+ if not os.path.exists(ACTIVE_SESSIONS_FILE):
197
+ return []
198
+
199
+ try:
200
+ with open(ACTIVE_SESSIONS_FILE, "r") as f:
201
+ data = json.load(f)
202
+
203
+ rows = [
204
+ [
205
+ s.get("session_id", ""),
206
+ s.get("file", ""),
207
+ s.get("start_time", ""),
208
+ s.get("status", ""),
209
+ ]
210
+ for s in data.values()
211
+ ]
212
+ return rows
213
+ except Exception as e:
214
+ logging.error(f"Error reading active sessions: {e}")
215
+ return []
216
+
217
+
218
+ def stop_file_path(session_id: str) -> str:
219
+ """Returns the stop-flag file path for a given session."""
220
+ ensure_tmp_dir()
221
+ return os.path.join(TMP_DIR, f"stream_stop_flag_{session_id}.txt")
222
+
223
+
224
+ def create_stop_flag(session_id: str):
225
+ """Creates a stop-flag file for this session."""
226
+ path = stop_file_path(session_id)
227
+ with open(path, "w") as f:
228
+ f.write("1")
229
+ logging.info(f"[{session_id}] Stop flag file created at {path}.")
230
+
231
+
232
+ def clear_stop_flag(session_id: str):
233
+ """Deletes the stop-flag file if it exists."""
234
+ path = stop_file_path(session_id)
235
+ if os.path.exists(path):
236
+ os.remove(path)
237
+ logging.debug(f"[{session_id}] Stop flag cleared.")
new_app.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.logger_config import logger as logging
2
+ import numpy as np
3
+ import gradio as gr
4
+ import asyncio
5
+ from fastrtc.webrtc import WebRTC
6
+ from pydub import AudioSegment
7
+ import time
8
+ import os
9
+ from gradio.utils import get_space
10
+
11
+ from app.logger_config import logger as logging
12
+ from app.utils import (
13
+ generate_coturn_config
14
+ )
15
+ from app.session_utils import (
16
+ on_load,
17
+ on_unload,
18
+ get_active_sessions,
19
+ reset_all_active_sessions,
20
+ )
21
+ reset_all_active_sessions()
22
+ EXAMPLE_FILES = ["data/bonjour.wav", "data/bonjour2.wav"]
23
+ DEFAULT_FILE = EXAMPLE_FILES[0]
24
+
25
+ def _is_stop_requested(stop_streaming_flags: dict) -> bool:
26
+ if not isinstance(stop_streaming_flags, dict):
27
+ return False
28
+ return bool(stop_streaming_flags.get("stop", False))
29
+
30
+ def read_and_stream_audio(filepath_to_stream: str, session_id: str, stop_streaming_flags: dict):
31
+ """
32
+ Un générateur synchrone qui lit un fichier audio (via filepath_to_stream)
33
+ et le streame chunk par chunk d'1 seconde.
34
+ """
35
+
36
+ if not session_id:
37
+ logging.warning( "Aucun session_id fourni, arrêt du stream par sécurité.")
38
+ return
39
+
40
+ if isinstance(stop_streaming_flags, dict):
41
+ stop_streaming_flags["stop"] = False
42
+ else:
43
+ logging.warning(f" [{session_id}] Stop stop_streaming_flags non initialisés, le stream continuera sans contrôle d'arrêt.")
44
+
45
+ if not filepath_to_stream or not os.path.exists(filepath_to_stream):
46
+ logging.error(f"[{session_id}] Fichier audio non trouvé ou non spécifié : {filepath_to_stream}")
47
+ # Tenter d'utiliser le fichier par défaut en cas de problème
48
+ if os.path.exists(DEFAULT_FILE):
49
+ logging.warning(f"[{session_id}] Utilisation du fichier par défaut : {DEFAULT_FILE}")
50
+ filepath_to_stream = DEFAULT_FILE
51
+ else:
52
+ logging.error(f"[{session_id}] Fichier par défaut non trouvé. Arrêt du stream.")
53
+ return
54
+
55
+ logging.info(f"[{session_id}] Préparation du segment audio depuis : {filepath_to_stream}")
56
+
57
+ try:
58
+ segment = AudioSegment.from_file(filepath_to_stream)
59
+ chunk_duree_ms = 1000
60
+ logging.info(f"[{session_id}] Début du streaming en chunks de {chunk_duree_ms}ms...")
61
+
62
+ for i, chunk in enumerate(segment[::chunk_duree_ms]):
63
+ iter_start_time = time.perf_counter()
64
+ logging.info(f"Envoi du chunk {i+1}...")
65
+
66
+ if _is_stop_requested(stop_streaming_flags):
67
+ logging.info(f"[{session_id}]Signal d'arrêt reçu, arrêt de la boucle.")
68
+ break
69
+
70
+ output_chunk = (
71
+ chunk.frame_rate,
72
+ np.array(chunk.get_array_of_samples()).reshape(1, -1),
73
+ )
74
+
75
+ yield output_chunk
76
+
77
+ iter_end_time = time.perf_counter()
78
+ processing_duration_ms = (iter_end_time - iter_start_time) * 1000
79
+
80
+ sleep_duration = (chunk_duree_ms / 1000.0) - (processing_duration_ms / 1000.0) - 0.1
81
+ if sleep_duration < 0:
82
+ sleep_duration = 0.01 # Éviter un temps de sommeil négatif
83
+
84
+ logging.debug(f"[{session_id}]Temps de traitement: {processing_duration_ms:.2f}ms, Sommeil: {sleep_duration:.2f}s")
85
+
86
+ elapsed = 0.0
87
+ interval = 0.05
88
+ while elapsed < sleep_duration:
89
+ if _is_stop_requested(stop_streaming_flags):
90
+ logging.info(f"[{session_id}]Signal d'arrêt reçu pendant l'attente.")
91
+ break
92
+ wait_chunk = min(interval, sleep_duration - elapsed)
93
+ time.sleep(wait_chunk)
94
+ elapsed += wait_chunk
95
+ if _is_stop_requested(stop_streaming_flags):
96
+ break
97
+
98
+ logging.info(f"[{session_id}]Streaming terminé.")
99
+
100
+ except asyncio.CancelledError:
101
+ logging.info(f"[{session_id}]Stream arrêté par l'utilisateur (CancelledError).")
102
+ raise
103
+ except FileNotFoundError:
104
+ logging.error(f"[{session_id}] Erreur critique : Fichier non trouvé : {filepath_to_stream}")
105
+ except Exception as e:
106
+ logging.error(f"[{session_id}] Erreur pendant le stream: {e}", exc_info=True)
107
+ raise
108
+ finally:
109
+ if isinstance(stop_streaming_flags, dict):
110
+ stop_streaming_flags["stop"] = False
111
+ logging.info(f"[{session_id}]Signal d'arrêt nettoyé.")
112
+
113
+
114
+ def stop_streaming(session_id: str, stop_streaming_flags: dict):
115
+ """Active le signal d'arrêt pour le générateur."""
116
+ logging.info("Bouton Stop cliqué: envoi du signal d'arrêt.")
117
+ if not isinstance(stop_streaming_flags, dict):
118
+ stop_streaming_flags = {"stop": True}
119
+ else:
120
+ stop_streaming_flags["stop"] = True
121
+ return stop_streaming_flags
122
+
123
+ # --- Interface Gradio ---
124
+
125
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
126
+
127
+ session_hash = gr.State()
128
+ session_hash_box = gr.Textbox(label="Session ID", interactive=False)
129
+ demo.load(fn=on_load, inputs=None, outputs=[session_hash, session_hash_box])
130
+ demo.unload(on_unload)
131
+
132
+ stop_streaming_flags = gr.State(value={"stop": False})
133
+
134
+
135
+
136
+
137
+
138
+ gr.Markdown(
139
+ "## Application 'Streamer' WebRTC (Serveur -> Client)\n"
140
+ "Utilisez l'exemple fourni, uploadez un fichier ou enregistrez depuis votre micro, "
141
+ "puis cliquez sur 'Start' pour écouter le stream."
142
+ )
143
+
144
+ # 1. État pour stocker le chemin du fichier à lire
145
+ active_filepath = gr.State(value=DEFAULT_FILE)
146
+
147
+ with gr.Row():
148
+ with gr.Column():
149
+ main_audio = gr.Audio(
150
+ label="Source Audio",
151
+ sources=["upload", "microphone"], # Combine les deux sources
152
+ type="filepath",
153
+ value=DEFAULT_FILE, # Défaut au premier exemple
154
+ )
155
+ with gr.Column():
156
+ webrtc_stream = WebRTC(
157
+ label="Stream Audio",
158
+ mode="receive",
159
+ modality="audio",
160
+ rtc_configuration=generate_coturn_config(),
161
+ visible=True, # Caché par défaut
162
+ height = 200,
163
+ )
164
+ # 4. Boutons de contrôle
165
+ with gr.Row():
166
+ with gr.Column():
167
+ start_button = gr.Button("Start Streaming", variant="primary")
168
+ stop_button = gr.Button("Stop Streaming", variant="stop", interactive=False)
169
+ with gr.Column():
170
+ gr.Text()
171
+
172
+ def set_new_file(filepath):
173
+ """Met à jour l'état avec le nouveau chemin, ou revient au défaut si None."""
174
+ if filepath is None:
175
+ logging.info("Audio effacé, retour au fichier d'exemple par défaut.")
176
+ new_path = DEFAULT_FILE
177
+ else:
178
+ logging.info(f"Nouvelle source audio sélectionnée : {filepath}")
179
+ new_path = filepath
180
+ # Retourne la valeur à mettre dans le gr.State
181
+ return new_path
182
+
183
+ # Mettre à jour le chemin si l'utilisateur upload, efface, ou change le fichier
184
+ main_audio.change(
185
+ fn=set_new_file,
186
+ inputs=[main_audio],
187
+ outputs=[active_filepath]
188
+ )
189
+
190
+ # Mettre à jour le chemin si l'utilisateur termine un enregistrement
191
+ main_audio.stop_recording(
192
+ fn=set_new_file,
193
+ inputs=[main_audio],
194
+ outputs=[active_filepath]
195
+ )
196
+
197
+
198
+ # Fonctions pour mettre à jour l'état de l'interface
199
+ def start_streaming_ui(session_id: str, flags: dict):
200
+ logging.info("UI : Démarrage du streaming. Désactivation des contrôles.")
201
+ if not isinstance(flags, dict):
202
+ flags = {"stop": False}
203
+ else:
204
+ flags["stop"] = False
205
+ return (
206
+ gr.Button(interactive=False),
207
+ gr.Button(interactive=True),
208
+ gr.Audio(visible=False),
209
+ flags,
210
+ )
211
+
212
+ def stop_streaming_ui(flags: dict):
213
+ logging.info("UI : Arrêt du streaming. Réactivation des contrôles.")
214
+ return (
215
+ gr.Button(interactive=True),
216
+ gr.Button(interactive=False),
217
+ gr.Audio(
218
+ label="Source Audio",
219
+ sources=["upload", "microphone"],
220
+ type="filepath",
221
+ value=active_filepath.value,
222
+ visible=True,
223
+ ),
224
+ )
225
+
226
+
227
+ ui_components = [
228
+ start_button, stop_button,
229
+ main_audio,
230
+ ]
231
+
232
+ stream_event = webrtc_stream.stream(
233
+ fn=read_and_stream_audio,
234
+ inputs=[active_filepath, session_hash, stop_streaming_flags],
235
+ outputs=[webrtc_stream],
236
+ trigger=start_button.click,
237
+ concurrency_id="audio_stream", # ID de concurrence
238
+ concurrency_limit=10
239
+ )
240
+
241
+ # Mettre à jour l'interface au clic sur START
242
+ start_button.click(
243
+ fn=start_streaming_ui,
244
+ inputs=[session_hash, stop_streaming_flags],
245
+ outputs=ui_components + [stop_streaming_flags]
246
+ )
247
+
248
+ # Correction : S'assurer que le stream est bien annulé
249
+ stop_button.click(
250
+ fn=stop_streaming,
251
+ inputs=[session_hash, stop_streaming_flags],
252
+ outputs=[stop_streaming_flags],
253
+ ).then(
254
+ fn=stop_streaming_ui, # ENSUITE, mettre à jour l'interface
255
+ inputs=[stop_streaming_flags],
256
+ outputs=ui_components
257
+ )
258
+ # --- Active sessions ---
259
+ with gr.Accordion("📊 Active Sessions", open=False):
260
+ sessions_table = gr.DataFrame(
261
+ headers=["session_hash", "file", "start_time", "status"],
262
+ interactive=False,
263
+ wrap=True,
264
+ max_height=200,
265
+ )
266
+
267
+ gr.Timer(3.0).tick(fn=get_active_sessions, outputs=sessions_table)
268
+
269
+ if __name__ == "__main__":
270
+ demo.queue(max_size=10, api_open=False).launch(show_api=False, debug=True)