Archime commited on
Commit
3532025
·
1 Parent(s): 703ca2c

add on_load and on_unload

Browse files
Files changed (2) hide show
  1. app.py +15 -6
  2. app/session_utils.py +138 -29
app.py CHANGED
@@ -19,11 +19,14 @@ from app.session_utils import (
19
  stop_file_path,
20
  create_stop_flag,
21
  clear_stop_flag,
22
- reset_active_sessions,
 
 
 
23
  )
24
 
25
  # Reset sessions at startup
26
- reset_active_sessions()
27
 
28
  EXAMPLE_FILES = ["data/bonjour.wav", "data/bonjour2.wav"]
29
  DEFAULT_FILE = EXAMPLE_FILES[0]
@@ -60,7 +63,7 @@ def read_and_stream_audio(filepath_to_stream: str, session_id: str, chunk_second
60
  logging.info(f"[{session_id}] Stop flag detected at chunk {i}. Ending stream.")
61
  clear_stop_flag(session_id)
62
  break
63
-
64
  iter_start = time.perf_counter()
65
 
66
  elapsed_s = i * chunk_seconds
@@ -75,7 +78,7 @@ def read_and_stream_audio(filepath_to_stream: str, session_id: str, chunk_second
75
  chunk_array = np.array(chunk.get_array_of_samples(), dtype=np.int16)
76
  rate = chunk.frame_rate
77
 
78
- #Save only if transcription is active
79
  if os.path.exists(transcribe_flag):
80
  npz_path = os.path.join(chunk_dir, f"chunk_{i:05d}.npz")
81
  np.savez_compressed(npz_path, data=chunk_array, rate=rate)
@@ -184,7 +187,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
184
  "Each user controls their own stream. Transcription runs only during streaming."
185
  )
186
 
187
- session_id = gr.State(value=generate_session_id())
 
 
 
188
  active_filepath = gr.State(value=DEFAULT_FILE)
189
 
190
  with gr.Row(equal_height=True):
@@ -274,6 +280,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
274
  inputs=[active_filepath, session_id, chunk_slider],
275
  outputs=[webrtc_stream],
276
  trigger=start_button.click,
 
 
277
  )
278
 
279
  start_button.click(fn=start_streaming_ui, inputs=[session_id], outputs=[
@@ -372,4 +380,5 @@ demo.css = custom_css
372
  # MAIN
373
  # --------------------------------------------------------
374
  if __name__ == "__main__":
375
- demo.queue(max_size=50, api_open=False).launch(show_api=False, debug=True)
 
 
19
  stop_file_path,
20
  create_stop_flag,
21
  clear_stop_flag,
22
+ reset_all_active_sessions,
23
+ on_load,
24
+ on_unload
25
+
26
  )
27
 
28
  # Reset sessions at startup
29
+ reset_all_active_sessions()
30
 
31
  EXAMPLE_FILES = ["data/bonjour.wav", "data/bonjour2.wav"]
32
  DEFAULT_FILE = EXAMPLE_FILES[0]
 
63
  logging.info(f"[{session_id}] Stop flag detected at chunk {i}. Ending stream.")
64
  clear_stop_flag(session_id)
65
  break
66
+ logging.info(f"[{session_id}] Streaming chunk {i}.")
67
  iter_start = time.perf_counter()
68
 
69
  elapsed_s = i * chunk_seconds
 
78
  chunk_array = np.array(chunk.get_array_of_samples(), dtype=np.int16)
79
  rate = chunk.frame_rate
80
 
81
+ # Save only if transcription is active
82
  if os.path.exists(transcribe_flag):
83
  npz_path = os.path.join(chunk_dir, f"chunk_{i:05d}.npz")
84
  np.savez_compressed(npz_path, data=chunk_array, rate=rate)
 
187
  "Each user controls their own stream. Transcription runs only during streaming."
188
  )
189
 
190
+ session_id = gr.State()
191
+ sid_box = gr.Textbox(label="Session ID", interactive=False)
192
+ demo.load(fn=on_load, inputs=None, outputs=[session_id, sid_box])
193
+ demo.unload(on_unload)
194
  active_filepath = gr.State(value=DEFAULT_FILE)
195
 
196
  with gr.Row(equal_height=True):
 
280
  inputs=[active_filepath, session_id, chunk_slider],
281
  outputs=[webrtc_stream],
282
  trigger=start_button.click,
283
+ concurrency_limit=20,
284
+ concurrency_id="receive"
285
  )
286
 
287
  start_button.click(fn=start_streaming_ui, inputs=[session_id], outputs=[
 
380
  # MAIN
381
  # --------------------------------------------------------
382
  if __name__ == "__main__":
383
+
384
+ demo.queue(max_size=20, api_open=False).launch(show_api=False, debug=True)
app/session_utils.py CHANGED
@@ -1,15 +1,74 @@
1
  import os
2
  import json
3
  import uuid
 
4
  from datetime import datetime
5
  from app.logger_config import logger as logging
 
 
6
 
7
- TMP_DIR = "/tmp/canary_aed_streaming"
8
- # TMP_DIR = "/home/sifar-dev/workspace/canary_aed_streaming/tmp/canary_aed_streaming"
9
-
10
  ACTIVE_SESSIONS_FILE = os.path.join(TMP_DIR, "active_sessions.json")
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def ensure_tmp_dir():
14
  """Ensures the base temporary directory exists."""
15
  try:
@@ -18,58 +77,108 @@ def ensure_tmp_dir():
18
  logging.error(f"Failed to create tmp directory {TMP_DIR}: {e}")
19
 
20
 
21
- def reset_active_sessions():
22
- """Removes all temporary session files at startup."""
23
  ensure_tmp_dir()
24
 
25
  try:
26
- # Remove active sessions file
27
  if os.path.exists(ACTIVE_SESSIONS_FILE):
28
  os.remove(ACTIVE_SESSIONS_FILE)
29
  logging.info("Active sessions file reset at startup.")
30
  else:
31
- logging.debug("No active sessions found to reset.")
32
 
33
- # Clean up old progress files
34
  for f in os.listdir(TMP_DIR):
35
  if f.startswith("progress_") and f.endswith(".json"):
 
36
  try:
37
- os.remove(os.path.join(TMP_DIR, f))
38
  logging.debug(f"Removed leftover progress file: {f}")
39
  except Exception as e:
40
  logging.warning(f"Failed to remove progress file {f}: {e}")
41
 
42
- # Clean up old stop flag files
43
  for f in os.listdir(TMP_DIR):
44
- if f.startswith("stream_stop_flag_") and f.endswith(".txt"):
 
 
 
 
 
45
  try:
46
- os.remove(os.path.join(TMP_DIR, f))
47
- logging.debug(f"Removed leftover stop flag file: {f}")
48
  except Exception as e:
49
- logging.warning(f"Failed to remove stop flag file {f}: {e}")
50
- # Clean up old transcribe_stop_flag
51
- for f in os.listdir(TMP_DIR):
52
- if f.startswith("transcribe_stop_flag_") and f.endswith(".txt"):
53
- try:
54
- os.remove(os.path.join(TMP_DIR, f))
55
- logging.debug(f"Removed leftover transcribe_stop_flag flag file: {f}")
56
- except Exception as e:
57
- logging.warning(f"Failed to remove transcribe_stop_flag file {f}: {e}")
58
- # Clean up old transcribe_active_flag
59
- for f in os.listdir(TMP_DIR):
60
- if f.startswith("transcribe_active_") and f.endswith(".txt"):
61
  try:
62
- os.remove(os.path.join(TMP_DIR, f))
63
- logging.debug(f"Removed leftover transcribe active flag file: {f}")
64
  except Exception as e:
65
- logging.warning(f"Failed to remove transcribe active file {f}: {e}")
66
-
67
 
 
68
 
69
  except Exception as e:
70
  logging.error(f"Error resetting active sessions: {e}")
71
 
 
 
 
 
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def generate_session_id() -> str:
74
  """Generates a unique session ID."""
75
  sid = str(uuid.uuid4())
 
1
  import os
2
  import json
3
  import uuid
4
+ import shutil
5
  from datetime import datetime
6
  from app.logger_config import logger as logging
7
+ import gradio as gr
8
+ # TMP_DIR = "/tmp/canary_aed_streaming"
9
 
10
+ TMP_DIR = os.getenv("TMP_DIR", "/tmp/canary_aed_streaming")
 
 
11
  ACTIVE_SESSIONS_FILE = os.path.join(TMP_DIR, "active_sessions.json")
12
 
13
 
14
+ # ---------------------------
15
+ # Helper to manage the JSON
16
+ # ---------------------------
17
+ def _read_sessions():
18
+ if not os.path.exists(ACTIVE_SESSIONS_FILE):
19
+ return {}
20
+ try:
21
+ with open(ACTIVE_SESSIONS_FILE, "r") as f:
22
+ return json.load(f)
23
+ except Exception:
24
+ return {}
25
+
26
+
27
+ def _write_sessions(data):
28
+ os.makedirs(os.path.dirname(ACTIVE_SESSIONS_FILE), exist_ok=True)
29
+ with open(ACTIVE_SESSIONS_FILE, "w") as f:
30
+ json.dump(data, f, indent=2)
31
+
32
+
33
+ # ---------------------------
34
+ # LOAD
35
+ # ---------------------------
36
+ def on_load(request: gr.Request):
37
+ """Called when a new visitor opens the app."""
38
+ sid = request.session_hash # ✅ Directly use session_hash as unique ID
39
+ sessions = _read_sessions()
40
+
41
+ sessions[sid] = {
42
+ "session_id": sid,
43
+ "file": "",
44
+ "start_time": datetime.utcnow().strftime("%H:%M:%S"),
45
+ "status": "active",
46
+ }
47
+
48
+ _write_sessions(sessions)
49
+ logging.info(f"[{sid}] Session registered (on_load).")
50
+
51
+ return sid, sid # can be used as gr.State + display
52
+
53
+
54
+ # ---------------------------
55
+ # UNLOAD
56
+ # ---------------------------
57
+ def on_unload(request: gr.Request):
58
+ """Called when the visitor closes or refreshes the app."""
59
+ sid = request.session_hash
60
+ sessions = _read_sessions()
61
+
62
+ if sid in sessions:
63
+ create_stop_flag(sid)
64
+ sessions.pop(sid)
65
+ _write_sessions(sessions)
66
+ remove_session_data(sid)
67
+ unregister_session(sid)
68
+ logging.info(f"[{sid}] Session removed (on_unload).")
69
+ else:
70
+ logging.info(f"[{sid}] No active session found to remove.")
71
+
72
  def ensure_tmp_dir():
73
  """Ensures the base temporary directory exists."""
74
  try:
 
77
  logging.error(f"Failed to create tmp directory {TMP_DIR}: {e}")
78
 
79
 
80
+ def reset_all_active_sessions():
81
+ """Removes all temporary session files and folders at startup."""
82
  ensure_tmp_dir()
83
 
84
  try:
85
+ # --- Remove active sessions file ---
86
  if os.path.exists(ACTIVE_SESSIONS_FILE):
87
  os.remove(ACTIVE_SESSIONS_FILE)
88
  logging.info("Active sessions file reset at startup.")
89
  else:
90
+ logging.debug("No active sessions file found to reset.")
91
 
92
+ # --- Clean progress files ---
93
  for f in os.listdir(TMP_DIR):
94
  if f.startswith("progress_") and f.endswith(".json"):
95
+ path = os.path.join(TMP_DIR, f)
96
  try:
97
+ os.remove(path)
98
  logging.debug(f"Removed leftover progress file: {f}")
99
  except Exception as e:
100
  logging.warning(f"Failed to remove progress file {f}: {e}")
101
 
102
+ # --- Clean all flag files (stream + transcribe) ---
103
  for f in os.listdir(TMP_DIR):
104
+ if (
105
+ f.startswith("stream_stop_flag_")
106
+ or f.startswith("transcribe_stop_flag_")
107
+ or f.startswith("transcribe_active_")
108
+ ) and f.endswith(".txt"):
109
+ path = os.path.join(TMP_DIR, f)
110
  try:
111
+ os.remove(path)
112
+ logging.debug(f"Removed leftover flag file: {f}")
113
  except Exception as e:
114
+ logging.warning(f"Failed to remove flag file {f}: {e}")
115
+
116
+ # --- Clean chunk directories ---
117
+ for name in os.listdir(TMP_DIR):
118
+ path = os.path.join(TMP_DIR, name)
119
+ if os.path.isdir(path) and name.startswith("chunks_"):
 
 
 
 
 
 
120
  try:
121
+ shutil.rmtree(path)
122
+ logging.debug(f"Removed leftover chunk folder: {name}")
123
  except Exception as e:
124
+ logging.warning(f"Failed to remove chunk folder {name}: {e}")
 
125
 
126
+ logging.info("Temporary session cleanup completed successfully.")
127
 
128
  except Exception as e:
129
  logging.error(f"Error resetting active sessions: {e}")
130
 
131
+ def remove_session_data(session_id: str):
132
+ """Removes all temporary files and data related to a specific session."""
133
+ if not session_id:
134
+ logging.warning("reset_session() called without a valid session_id.")
135
+ return
136
 
137
+ try:
138
+ # --- Remove session from active_sessions.json ---
139
+ if os.path.exists(ACTIVE_SESSIONS_FILE):
140
+ try:
141
+ with open(ACTIVE_SESSIONS_FILE, "r") as f:
142
+ data = json.load(f)
143
+ if session_id in data:
144
+ data.pop(session_id)
145
+ with open(ACTIVE_SESSIONS_FILE, "w") as f:
146
+ json.dump(data, f, indent=2)
147
+ logging.debug(f"[{session_id}] Removed from active_sessions.json.")
148
+ except Exception as e:
149
+ logging.warning(f"[{session_id}] Failed to update active_sessions.json: {e}")
150
+
151
+ # --- Define all possible session file patterns ---
152
+ files_to_remove = [
153
+ f"progress_{session_id}.json",
154
+ # f"stream_stop_flag_{session_id}.txt",
155
+ f"transcribe_stop_flag_{session_id}.txt",
156
+ f"transcribe_active_{session_id}.txt",
157
+ ]
158
+
159
+ # --- Remove all temporary files ---
160
+ for fname in files_to_remove:
161
+ path = os.path.join(TMP_DIR, fname)
162
+ if os.path.exists(path):
163
+ try:
164
+ os.remove(path)
165
+ logging.debug(f"[{session_id}] Removed file: {fname}")
166
+ except Exception as e:
167
+ logging.warning(f"[{session_id}] Failed to remove file {fname}: {e}")
168
+
169
+ # --- Remove chunk folder if exists ---
170
+ chunk_dir = os.path.join(TMP_DIR, f"chunks_{session_id}")
171
+ if os.path.isdir(chunk_dir):
172
+ try:
173
+ shutil.rmtree(chunk_dir)
174
+ logging.debug(f"[{session_id}] Removed chunk folder: chunks_{session_id}")
175
+ except Exception as e:
176
+ logging.warning(f"[{session_id}] Failed to remove chunk folder: {e}")
177
+
178
+ logging.info(f"[{session_id}] Session fully reset.")
179
+
180
+ except Exception as e:
181
+ logging.error(f"[{session_id}] Error during reset_session: {e}")
182
  def generate_session_id() -> str:
183
  """Generates a unique session ID."""
184
  sid = str(uuid.uuid4())