Archime commited on
Commit
11c4a5a
·
1 Parent(s): 4ff8afc

add nemo_asr and silero_vad Engine

Browse files
.gitignore CHANGED
@@ -4,3 +4,4 @@ app/__pycache__/
4
  __pycache__/
5
  logs/
6
  .continue/
 
 
4
  __pycache__/
5
  logs/
6
  .continue/
7
+ tmp/
app.py CHANGED
@@ -39,7 +39,7 @@ from app.ui_utils import (
39
  get_custom_theme,
40
  on_file_load
41
  )
42
-
43
  # --------------------------------------------------------
44
  # Initialization
45
  # --------------------------------------------------------
@@ -47,6 +47,24 @@ reset_all_active_session_hash_code()
47
 
48
  theme,css_style = get_custom_theme()
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  with gr.Blocks(theme=theme, css=css_style) as demo:
51
  session_hash_code = gr.State()
52
  session_hash_code_box = gr.Textbox(label="Session ID", interactive=False, visible=DEBUG)
@@ -125,7 +143,7 @@ with gr.Blocks(theme=theme, css=css_style) as demo:
125
 
126
  webrtc_stream.stream(
127
  fn=read_and_stream_audio,
128
- inputs=[active_filepath, session_hash_code, stop_streaming_flags],
129
  outputs=[webrtc_stream],
130
  trigger=start_stream_button.click,
131
  concurrency_id="audio_stream",
@@ -273,7 +291,7 @@ with gr.Blocks(theme=theme, css=css_style) as demo:
273
  yield f"Starting {task_type.lower()}...\n\n",gr.update(visible=False),gr.update(visible=True)
274
 
275
  # Boucle sur le générateur de `task()`
276
- for msg in task(session_hash_code):
277
  accumulated += msg
278
  yield accumulated,gr.update(visible=False),gr.update(visible=True)
279
 
 
39
  get_custom_theme,
40
  on_file_load
41
  )
42
+ import nemo.collections.asr as nemo_asr
43
  # --------------------------------------------------------
44
  # Initialization
45
  # --------------------------------------------------------
 
47
 
48
  theme,css_style = get_custom_theme()
49
 
50
+ from omegaconf import OmegaConf
51
+ cfg = OmegaConf.load('app/config.yaml')
52
+ # logger.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
53
+ from app.canary_speech_engine import CanarySpeechEngine
54
+ from app.silero_vad_engine import Silero_Vad_Engine
55
+ from app.streaming_audio_processor import StreamingAudioProcessor,StreamingAudioProcessorConfig
56
+
57
+
58
+ asr_model = nemo_asr.models.ASRModel.from_pretrained(cfg.pretrained_name)
59
+ canary_speech_engine = CanarySpeechEngine(asr_model,cfg)
60
+ silero_vad_engine = Silero_Vad_Engine()
61
+ streaming_audio_processor_config = StreamingAudioProcessorConfig(
62
+ read_size=4000,
63
+ silence_threshold_chunks=1
64
+ )
65
+ streamer = StreamingAudioProcessor(speech_engine=canary_speech_engine,vad_engine=silero_vad_engine,cfg=streaming_audio_processor_config)
66
+
67
+
68
  with gr.Blocks(theme=theme, css=css_style) as demo:
69
  session_hash_code = gr.State()
70
  session_hash_code_box = gr.Textbox(label="Session ID", interactive=False, visible=DEBUG)
 
143
 
144
  webrtc_stream.stream(
145
  fn=read_and_stream_audio,
146
+ inputs=[active_filepath, session_hash_code, stop_streaming_flags,gr.State(streaming_audio_processor_config.read_size)],
147
  outputs=[webrtc_stream],
148
  trigger=start_stream_button.click,
149
  concurrency_id="audio_stream",
 
291
  yield f"Starting {task_type.lower()}...\n\n",gr.update(visible=False),gr.update(visible=True)
292
 
293
  # Boucle sur le générateur de `task()`
294
+ for msg in task(session_hash_code,streamer=streamer):
295
  accumulated += msg
296
  yield accumulated,gr.update(visible=False),gr.update(visible=True)
297
 
app/canary_speech_engine.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Optional, Tuple
3
+ from app.interfaces import IStreamingSpeechEngine
4
+ import numpy as np
5
+ import torch
6
+ import gc
7
+ from omegaconf import OmegaConf
8
+
9
+
10
+ from nemo.collections.asr.models.aed_multitask_models import lens_to_mask
11
+ from nemo.collections.asr.parts.submodules.aed_decoding import (
12
+ GreedyBatchedStreamingAEDComputer,
13
+ return_decoder_input_ids,
14
+ )
15
+ from nemo.collections.asr.parts.submodules.multitask_decoding import (
16
+ AEDStreamingDecodingConfig,
17
+ MultiTaskDecodingConfig,
18
+ )
19
+ # from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer # Not used
20
+ from nemo.collections.asr.parts.utils.streaming_utils import (
21
+ ContextSize,
22
+ StreamingBatchedAudioBuffer,
23
+ )
24
+ import nemo.collections.asr as nemo_asr
25
+ from nemo.collections.asr.parts.utils.transcribe_utils import (
26
+ get_inference_device,
27
+ get_inference_dtype,
28
+ )
29
+ from app.logger_config import (
30
+ logger as logging,
31
+ DEBUG
32
+ )
33
+
34
+
35
+ def make_divisible_by(num: int, factor: int) -> int:
36
+ """Make num divisible by factor"""
37
+ return (num // factor) * factor
38
+
39
+
40
+ class CanarySpeechEngine(IStreamingSpeechEngine):
41
+ """
42
+ Encapsulates the state and logic for streaming audio transcription
43
+ using an internally loaded Canary model.
44
+ """
45
+ def __init__(self,asr_model, cfg: OmegaConf):
46
+ """
47
+ Initializes the speech engine and loads the ASR model.
48
+
49
+ Args:
50
+ cfg: An OmegaConf object containing 'model' and 'streaming' configs.
51
+ """
52
+ self.cfg = cfg # Store the full config
53
+
54
+ # Setup device and dtype from config
55
+ self.map_location = get_inference_device(cuda=self.cfg.cuda, allow_mps=self.cfg.allow_mps)
56
+ self.compute_dtype = get_inference_dtype(self.cfg.compute_dtype, device=self.map_location)
57
+ logging.info(f"Inference will be on device: {self.map_location} with dtype: {self.compute_dtype}")
58
+
59
+ # Load the model internally
60
+ asr_model, _ = self._setup_model(asr_model,self.cfg, self.map_location)
61
+ self.asr_model = asr_model
62
+
63
+ self.full_transcription = [] # Stores finalized segments
64
+ self._setup_streaming_params()
65
+
66
+ # The initial full reset (buffer + decoder)
67
+ self.reset()
68
+
69
+ logging.info("CanarySpeechEngine initialized and ready.")
70
+ logging.info(f"Model-adjusted chunk size: {self.context_samples.chunk} samples.")
71
+
72
+ def _setup_model(self,asr_model, model_cfg: OmegaConf, map_location: str):
73
+ """Loads the pretrained ASR model and configures it for inference."""
74
+ logging.info(f"Loading model {model_cfg.pretrained_name}...")
75
+ start_time = time.time()
76
+ try:
77
+ asr_model = asr_model.to(map_location)
78
+ asr_model.eval()
79
+
80
+ # Change decoding strategy to greedy for streaming
81
+ if hasattr(asr_model, 'change_decoding_strategy'):
82
+ multitask_decoding = MultiTaskDecodingConfig()
83
+ multitask_decoding.strategy = "greedy"
84
+ asr_model.change_decoding_strategy(multitask_decoding)
85
+ logging.info("Model decoding strategy set to 'greedy'.")
86
+
87
+ if map_location == "cuda":
88
+ torch.cuda.synchronize()
89
+
90
+ end_time = time.time()
91
+ logging.info("Model loaded successfully.")
92
+ load_time = end_time - start_time
93
+ logging.info("\n" + "="*30)
94
+ logging.info(f"Total model load time: {load_time:.2f} seconds")
95
+ logging.info("="*30)
96
+ return asr_model, None
97
+
98
+ except Exception as e:
99
+ logging.error(f"Error loading model: {e}")
100
+ logging.error("Ensure NeMo is installed (pip install nemo_toolkit['asr'])")
101
+ return None, None
102
+
103
+ def _setup_streaming_params(self):
104
+ """Helper to calculate model-specific streaming parameters."""
105
+ model_cfg = self.asr_model.cfg
106
+ audio_sample_rate = model_cfg.preprocessor['sample_rate']
107
+ feature_stride_sec = model_cfg.preprocessor['window_stride']
108
+ features_per_sec = 1.0 / feature_stride_sec
109
+ self.encoder_subsampling_factor = self.asr_model.encoder.subsampling_factor
110
+
111
+ self.features_frame2audio_samples = make_divisible_by(
112
+ int(audio_sample_rate * feature_stride_sec), factor=self.encoder_subsampling_factor
113
+ )
114
+ encoder_frame2audio_samples = self.features_frame2audio_samples * self.encoder_subsampling_factor
115
+
116
+ # Use self.cfg.streaming instead of self.streaming_cfg
117
+ streaming_cfg = self.cfg
118
+ self.context_encoder_frames = ContextSize(
119
+ left=int(streaming_cfg.left_context_secs * features_per_sec / self.encoder_subsampling_factor),
120
+ chunk=int(streaming_cfg.chunk_secs * features_per_sec / self.encoder_subsampling_factor),
121
+ right=int(streaming_cfg.right_context_secs * features_per_sec / self.encoder_subsampling_factor),
122
+ )
123
+ self.context_samples = ContextSize(
124
+ left=self.context_encoder_frames.left * encoder_frame2audio_samples,
125
+ chunk=self.context_encoder_frames.chunk * encoder_frame2audio_samples,
126
+ right=self.context_encoder_frames.right * encoder_frame2audio_samples,
127
+ )
128
+
129
+ def _reset_decoder_state(self):
130
+ """
131
+ Resets ONLY the decoder state, preserving the audio buffer.
132
+ This prevents slowdowns on long audio streams.
133
+ """
134
+ start_time = time.perf_counter()
135
+ logging.debug("--- Resetting decoder state (audio buffer preserved) ---")
136
+
137
+ # Reset tracking for this segment
138
+ self.last_transcription = ""
139
+ self.chunk_count = 0
140
+ batch_size = 1 # Hardcoded for this script
141
+
142
+ # Use self.cfg.streaming instead of self.streaming_cfg
143
+ streaming_cfg = self.cfg
144
+
145
+ # 1. Recreate the initial prompt for the decoder
146
+ self.decoder_input_ids = return_decoder_input_ids(streaming_cfg, self.asr_model)
147
+
148
+ # 2. Recreate the "computer" object that manages decoding
149
+ self.decoding_computer = GreedyBatchedStreamingAEDComputer(
150
+ self.asr_model,
151
+ frame_chunk_size=self.context_encoder_frames.chunk,
152
+ decoding_cfg=streaming_cfg.decoding,
153
+ )
154
+
155
+ # 3. Recreate an EMPTY STATE object (model_state)
156
+ self.model_state = GreedyBatchedStreamingAEDComputer.initialize_aed_model_state(
157
+ asr_model=self.asr_model,
158
+ decoder_input_ids=self.decoder_input_ids,
159
+ batch_size=batch_size,
160
+ context_encoder_frames=self.context_encoder_frames,
161
+ chunk_secs=streaming_cfg.chunk_secs,
162
+ right_context_secs=streaming_cfg.right_context_secs,
163
+ )
164
+
165
+ # Clear CUDA cache if possible
166
+ if torch.cuda.is_available():
167
+ gc.collect()
168
+ torch.cuda.empty_cache()
169
+ torch.cuda.synchronize()
170
+
171
+ end_time = time.perf_counter()
172
+ duration_ms = (end_time - start_time) * 1000 # Convert to milliseconds
173
+ logging.debug(f"--- Decoder reset finished in {duration_ms:.2f} ms ---")
174
+
175
+ def reset(self):
176
+ """
177
+ Resets the transcriber's state completely (audio buffer + decoder state).
178
+ Called only on initialization.
179
+ """
180
+ start_time = time.perf_counter()
181
+ logging.debug("--- FULL RESET (Audio Buffer + Decoder State) ---")
182
+
183
+ # Operation 1: Reset the decoder (this now includes GC)
184
+ self._reset_decoder_state()
185
+
186
+ # Operation 2: Reset the audio buffer
187
+ self.buffer = StreamingBatchedAudioBuffer(
188
+ batch_size=1, # Hardcoded for this script
189
+ context_samples=self.context_samples,
190
+ dtype=torch.float32,
191
+ device=self.map_location,
192
+ )
193
+
194
+ end_time = time.perf_counter()
195
+ duration_ms = (end_time * 1000)
196
+ logging.debug(f"--- RESET Complete: took {duration_ms:.2f} ms ---")
197
+
198
+ def transcribe_chunk(self, chunk: np.ndarray, is_last_chunk: bool = False) -> Tuple[str, str]:
199
+ """
200
+ Processes a single audio chunk and returns the newly predicted text.
201
+
202
+ Returns:
203
+ Tuple[str, str]:
204
+ (current_transcription: The full transcription for the current segment,
205
+ new_text: The newly appended text since the last chunk)
206
+ """
207
+ start_time = time.perf_counter()
208
+ self.chunk_count += 1
209
+
210
+ # Preprocess audio
211
+ signal = torch.from_numpy(chunk.astype(np.float32) / 32768.0)
212
+ audio_batch = signal.unsqueeze(0).to(self.map_location)
213
+ audio_batch_lengths = torch.tensor([signal.shape[0]], device=self.map_location)
214
+
215
+ # 1. Add the chunk to the persistent buffer
216
+ self.buffer.add_audio_batch_(
217
+ audio_batch,
218
+ audio_lengths=audio_batch_lengths,
219
+ is_last_chunk=is_last_chunk,
220
+ is_last_chunk_batch=torch.tensor([is_last_chunk], device=self.map_location)
221
+ )
222
+
223
+ self.model_state.is_last_chunk_batch = torch.tensor([is_last_chunk], device=self.map_location)
224
+
225
+ # 2. Pass the buffer to the encoder
226
+ _, encoded_len, enc_states, _ = self.asr_model(
227
+ input_signal=self.buffer.samples, input_signal_length=self.buffer.context_size_batch.total()
228
+ )
229
+ encoder_context_batch = self.buffer.context_size_batch.subsample(factor=self.features_frame2audio_samples * self.encoder_subsampling_factor)
230
+ encoded_len_no_rc = encoder_context_batch.left + encoder_context_batch.chunk
231
+ encoded_length_corrected = torch.where(self.model_state.is_last_chunk_batch, encoded_len, encoded_len_no_rc)
232
+ encoder_input_mask = lens_to_mask(encoded_length_corrected, enc_states.shape[1]).to(enc_states.dtype)
233
+
234
+ # 3. Pass to the decoding computer
235
+ self.model_state = self.decoding_computer(
236
+ encoder_output=enc_states,
237
+ encoder_output_len=encoded_length_corrected,
238
+ encoder_input_mask=encoder_input_mask,
239
+ prev_batched_state=self.model_state,
240
+ )
241
+
242
+ # 4. Calculate the new text
243
+ current_tokens = self.model_state.pred_tokens_ids[0, self.decoder_input_ids.size(-1): self.model_state.current_context_lengths[0]]
244
+
245
+ # OPTIMIZATION: Move tokens to CPU before converting to list
246
+ current_transcription = self.asr_model.tokenizer.ids_to_text(current_tokens.cpu().tolist()).strip()
247
+
248
+ # Calculate the NEW text by "subtracting" the old history
249
+ new_text = ""
250
+ if current_transcription.startswith(self.last_transcription):
251
+ new_text = current_transcription[len(self.last_transcription):]
252
+ else:
253
+ # Model corrected itself, send the full new transcription
254
+ new_text = current_transcription
255
+
256
+ # Memorize the FULL current transcription as the new history
257
+ if new_text:
258
+ self.last_transcription = current_transcription
259
+
260
+ end_time = time.perf_counter()
261
+ duration_ms = (end_time - start_time) * 1000
262
+ # logging.info(f"--- transcribe_chunk: took {duration_ms:.2f} ms ---")
263
+
264
+ # Return both the full segment transcription and the new diff
265
+ return current_transcription, new_text
266
+
267
+ def finalize_segment(self):
268
+ """
269
+ Finalizes the current transcription segment (e.g., on silence)
270
+ and adds it to the full history.
271
+ """
272
+ if self.last_transcription:
273
+ self.full_transcription.append(self.last_transcription)
274
+ self.last_transcription = ""
275
+ # We must reset the decoder state to start a new segment
276
+ self._reset_decoder_state()
277
+
278
+ def get_full_transcription(self) -> str:
279
+ """
280
+ Returns the full accumulated transcription from all finalized segments.
281
+ Does NOT include the currently active (unfinalized) segment.
282
+ """
283
+ return " ".join(self.full_transcription)
284
+
285
+ def get_current_segment_text(self) -> str:
286
+ """Returns the text of the segment currently being transcribed."""
287
+ return self.last_transcription
288
+
app/config.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_name: 'nvidia/canary-1b-v2'
2
+ model_path: null
3
+ batch_size: 32
4
+
5
+ chunk_secs: 1.0
6
+ left_context_secs: 20.0
7
+ right_context_secs: 0.5
8
+ cuda: null
9
+ allow_mps: true
10
+ compute_dtype: null
11
+ matmul_precision: high
12
+ decoding:
13
+ streaming_policy: alignatt
14
+ alignatt_thr: 8.0
15
+ waitk_lagging: 2
16
+ exclude_sink_frames: 8
17
+ xatt_scores_layer: -2
18
+ max_tokens_per_alignatt_step: 30
19
+ max_generation_length: 512
20
+ use_avgpool_for_alignatt: false
21
+ hallucinations_detector: true
22
+ prompt:
23
+ pnc: 'no'
24
+ task: asr
25
+ source_lang: fr
26
+ target_lang: fr
27
+ timestamps: yes
28
+ debug_mode: false
app/interfaces.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Iterable
3
+
4
+ class IVoiceActivityEngine(ABC):
5
+ """Contract for a Voice Activity Detector (VAD)."""
6
+
7
+ @abstractmethod
8
+ def __call__(self, audio_chunk: bytes) -> bool:
9
+ """
10
+ Analyzes an audio chunk and returns True if speech is detected,
11
+ False otherwise.
12
+ """
13
+ pass
14
+
15
+ class IStreamingSpeechEngine(ABC):
16
+ """Contract for a streaming transcription service."""
17
+
18
+ @abstractmethod
19
+ def transcribe_chunk(self, audio_chunk: bytes) -> str:
20
+ """Processes an audio chunk and returns a transcription (partial or final)."""
21
+ pass
22
+
23
+ @abstractmethod
24
+ def finalize_segment(self) -> str:
25
+ """Called at the end of the stream to get the final transcription."""
26
+ pass
27
+
app/logger_config.py CHANGED
@@ -5,10 +5,10 @@ load_dotenv(find_dotenv())
5
  import logging
6
  from logging.handlers import RotatingFileHandler
7
  import os
8
- # from nemo.utils.nemo_logging import Logger
9
 
10
- # nemo_logger = Logger()
11
- # nemo_logger.remove_stream_handlers()
12
 
13
  DEBUG = os.getenv("DEBUG", "false").lower() == "true"
14
  # Create "logs" directory if it doesn't exist
 
5
  import logging
6
  from logging.handlers import RotatingFileHandler
7
  import os
8
+ from nemo.utils.nemo_logging import Logger
9
 
10
+ nemo_logger = Logger()
11
+ nemo_logger.remove_stream_handlers()
12
 
13
  DEBUG = os.getenv("DEBUG", "false").lower() == "true"
14
  # Create "logs" directory if it doesn't exist
app/silero_vad_engine.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import torch
4
+ import numpy as np
5
+ import onnxruntime
6
+ import warnings
7
+ from app.interfaces import IVoiceActivityEngine
8
+ from app.logger_config import (
9
+ logger as logging,
10
+ DEBUG
11
+ )
12
+
13
+ class VoiceActivityDetection():
14
+
15
+ def __init__(self, force_onnx_cpu=True):
16
+ logging.info("Initializing VoiceActivityDetection...")
17
+ path = self.download()
18
+
19
+ opts = onnxruntime.SessionOptions()
20
+ opts.log_severity_level = 3 # Suppress ONNX runtime's own logs
21
+
22
+ opts.inter_op_num_threads = 1
23
+ opts.intra_op_num_threads = 1
24
+
25
+ try:
26
+ if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
27
+ self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
28
+ logging.info("ONNX VAD session created with CPUExecutionProvider.")
29
+ else:
30
+ self.session = onnxruntime.InferenceSession(path, providers=['CUDAExecutionProvider'], sess_options=opts)
31
+ logging.info("ONNX VAD session created with CUDAExecutionProvider.")
32
+ except Exception as e:
33
+ logging.critical(f"Failed to create ONNX InferenceSession: {e}", exc_info=True)
34
+ raise
35
+
36
+ self.reset_states()
37
+ if '16k' in path:
38
+ logging.warning('This VAD model supports only 16000 sampling rate!')
39
+ self.sample_rates = [16000]
40
+ else:
41
+ logging.info("VAD model supports 8000Hz and 16000Hz.")
42
+ self.sample_rates = [8000, 16000]
43
+
44
+ def _validate_input(self, x, sr: int):
45
+ if x.dim() == 1:
46
+ x = x.unsqueeze(0)
47
+ if x.dim() > 2:
48
+ logging.error(f"Too many dimensions for input audio chunk: {x.dim()}")
49
+ raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
50
+
51
+ if sr != 16000 and (sr % 16000 == 0):
52
+ step = sr // 16000
53
+ x = x[:,::step]
54
+ sr = 16000
55
+ logging.debug(f"Downsampled input audio to 16000Hz from {sr}Hz.")
56
+
57
+ if sr not in self.sample_rates:
58
+ logging.error(f"Unsupported sampling rate: {sr}. Supported: {self.sample_rates}")
59
+ raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
60
+
61
+ return x, sr
62
+
63
+ def reset_states(self, batch_size=1):
64
+ logging.debug(f"Resetting VAD states for batch_size: {batch_size}")
65
+ self._state = torch.zeros((2, batch_size, 128)).float()
66
+ self._context = torch.zeros(0)
67
+ self._last_sr = 0
68
+ self._last_batch_size = 0
69
+
70
+ def __call__(self, x, sr: int):
71
+
72
+ x, sr = self._validate_input(x, sr)
73
+ num_samples = 512 if sr == 16000 else 256
74
+
75
+ if x.shape[-1] != num_samples:
76
+ logging.error(f"Invalid audio chunk size: {x.shape[-1]}. Expected {num_samples} for {sr}Hz.")
77
+ raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
78
+
79
+ batch_size = x.shape[0]
80
+ context_size = 64 if sr == 16000 else 32
81
+
82
+ if not self._last_batch_size:
83
+ logging.debug("First call, resetting states.")
84
+ self.reset_states(batch_size)
85
+ if (self._last_sr) and (self._last_sr != sr):
86
+ logging.warning(f"Sample rate changed ({self._last_sr}Hz -> {sr}Hz). Resetting states.")
87
+ self.reset_states(batch_size)
88
+ if (self._last_batch_size) and (self._last_batch_size != batch_size):
89
+ logging.warning(f"Batch size changed ({self._last_batch_size} -> {batch_size}). Resetting states.")
90
+ self.reset_states(batch_size)
91
+
92
+ if not len(self._context):
93
+ self._context = torch.zeros(batch_size, context_size)
94
+
95
+ x = torch.cat([self._context, x], dim=1)
96
+ if sr in [8000, 16000]:
97
+ ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
98
+ ort_outs = self.session.run(None, ort_inputs)
99
+ out, state = ort_outs
100
+ self._state = torch.from_numpy(state)
101
+ else:
102
+ # This should be caught by _validate_input, but as a safeguard:
103
+ logging.critical(f"Unexpected sample rate in VAD __call__: {sr}")
104
+ raise ValueError()
105
+
106
+ self._context = x[..., -context_size:]
107
+ self._last_sr = sr
108
+ self._last_batch_size = batch_size
109
+
110
+ out = torch.from_numpy(out)
111
+ return out
112
+
113
+ def audio_forward(self, x, sr: int):
114
+ outs = []
115
+ x, sr = self._validate_input(x, sr)
116
+ self.reset_states()
117
+ num_samples = 512 if sr == 16000 else 256
118
+
119
+ if x.shape[1] % num_samples:
120
+ pad_num = num_samples - (x.shape[1] % num_samples)
121
+ logging.debug(f"Padding audio input with {pad_num} samples.")
122
+ x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
123
+
124
+ for i in range(0, x.shape[1], num_samples):
125
+ wavs_batch = x[:, i:i+num_samples]
126
+ out_chunk = self.__call__(wavs_batch, sr)
127
+ outs.append(out_chunk)
128
+
129
+ stacked = torch.cat(outs, dim=1)
130
+ return stacked.cpu()
131
+
132
+ @staticmethod
133
+ def download(model_url="https://github.com/snakers4/silero-vad/raw/v5.0/files/silero_vad.onnx"):
134
+ target_dir = os.path.expanduser("~/.cache/silero_vad/")
135
+ os.makedirs(target_dir, exist_ok=True)
136
+ model_filename = os.path.join(target_dir, "silero_vad.onnx")
137
+
138
+ if not os.path.exists(model_filename):
139
+ logging.info(f"Downloading VAD model to {model_filename}...")
140
+ try:
141
+ subprocess.run(["wget", "-O", model_filename, model_url], check=True)
142
+ logging.info("VAD model downloaded successfully.")
143
+ except subprocess.CalledProcessError as e:
144
+ logging.critical(f"Failed to download the model using wget: {e}")
145
+ raise
146
+ else:
147
+ logging.info(f"VAD model already exists at {model_filename}.")
148
+ return model_filename
149
+
150
+
151
+ class Silero_Vad_Engine(IVoiceActivityEngine):
152
+ def __init__(self, threshold :float =0.5, frame_rate: int =16000):
153
+ """
154
+ Initializes the Silero_Vad_Engine with a voice activity detection model and a threshold.
155
+
156
+ Args:
157
+ threshold (float, optional): The probability threshold for detecting voice activity. Defaults to 0.5.
158
+ """
159
+ logging.info(f"Initializing Silero_Vad_Engine with threshold: {threshold} and frame_rate: {frame_rate}Hz.")
160
+ self.model = VoiceActivityDetection()
161
+ self.threshold = threshold
162
+ self.frame_rate = frame_rate
163
+
164
+ def __call__(self, audio_frame):
165
+ """
166
+ Determines if the given audio frame contains speech by comparing the detected speech probability against
167
+ the threshold.
168
+
169
+ Args:
170
+ audio_frame (np.ndarray): The audio frame to be analyzed for voice activity. It is expected to be a
171
+ NumPy array of audio samples.
172
+
173
+ Returns:
174
+ bool: True if the speech probability exceeds the threshold, indicating the presence of voice activity;
175
+ False otherwise.
176
+ """
177
+ # Convert frame to tensor
178
+ audio_tensor = torch.from_numpy(audio_frame.copy())
179
+
180
+ # Get speech probabilities
181
+ speech_probs = self.model.audio_forward(audio_tensor, self.frame_rate)[0]
182
+
183
+ # Check against threshold
184
+ is_speech = torch.any(speech_probs > self.threshold).item()
185
+
186
+ logging.debug(f"VAD check result: {is_speech} (Max prob: {torch.max(speech_probs).item():.4f})")
187
+
188
+ return is_speech
189
+
app/streaming_audio_processor.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from app.logger_config import (
3
+ logger as logging,
4
+ DEBUG
5
+ )
6
+ from app.interfaces import IStreamingSpeechEngine,IVoiceActivityEngine
7
+
8
+ from dataclasses import dataclass
9
+
10
+ @dataclass
11
+ class StreamingAudioProcessorConfig:
12
+ """Configuration settings for the StreamingAudioProcessor."""
13
+ read_size: int = 8000
14
+ silence_threshold_chunks: int = 2
15
+ sample_rate: int = 16000
16
+ # Add other streaming-related parameters here
17
+ # e.g., vad_padding_chunks: int = 0
18
+
19
+
20
+ class StreamingAudioProcessor:
21
+ """
22
+ Manages streaming transcription by combining a speech engine
23
+ and a voice activity detector (VAD).
24
+
25
+ This class handles internal audio buffering and VAD state.
26
+ """
27
+
28
+ def __init__(self, speech_engine: IStreamingSpeechEngine, vad_engine :IVoiceActivityEngine, cfg : StreamingAudioProcessorConfig):
29
+ """
30
+ Initializes the streaming processor.
31
+
32
+ Args:
33
+ speech_engine: The ASR speech engine (must have .transcribe_chunk() and .reset()).
34
+ vad_engine: The VAD engine (returns True/False for a chunk).
35
+ cfg: The configuration object for this processor.
36
+ """
37
+ logging.info("Initializing StreamingAudioProcessor...")
38
+ self.speech_engine = speech_engine
39
+ self.vad_engine = vad_engine
40
+
41
+ # Store config
42
+ self.VAD_SAMPLE_RATE = cfg.sample_rate
43
+ self.read_size = cfg.read_size
44
+ self.SILENCE_THRESHOLD_CHUNKS = cfg.silence_threshold_chunks
45
+
46
+ # Internal buffer state (Optimized: using numpy array)
47
+ self.internal_buffer = np.array([], dtype='int16')
48
+
49
+ # Internal logic state
50
+ self.is_first_logical_chunk = True
51
+ self.logical_chunk_size = self.speech_engine.context_samples.chunk
52
+ self.initial_logical_chunk_size = self.speech_engine.context_samples.chunk + self.speech_engine.context_samples.right
53
+
54
+ # Internal VAD state
55
+ self.silent_chunks_count = 0
56
+ self.chunks_count = 0
57
+
58
+ logging.info(f" Config: VAD Sample Rate={self.VAD_SAMPLE_RATE}Hz")
59
+ logging.info(f" Config: Physical Read Size={self.read_size} samples")
60
+ logging.info(f" Config: Silence Threshold={self.SILENCE_THRESHOLD_CHUNKS} chunks")
61
+ logging.info(f" Config: Initial Logical Chunk={self.initial_logical_chunk_size} samples")
62
+ logging.info(f" Config: Subsequent Logical Chunk={self.logical_chunk_size} samples")
63
+
64
+
65
+ def _append_to_buffer(self, chunk_np, asr_chunk_len):
66
+ """
67
+ Appends a NumPy chunk to the internal buffer and returns a logical chunk if ready.
68
+ (Optimized to use numpy concatenation).
69
+ """
70
+ logging.debug(f"Appending {len(chunk_np)} samples to internal buffer (current size: {len(self.internal_buffer)}).")
71
+ self.internal_buffer = np.concatenate((self.internal_buffer, chunk_np))
72
+
73
+ if len(self.internal_buffer) >= asr_chunk_len:
74
+ asr_signal_chunk = self.internal_buffer[:asr_chunk_len]
75
+ self.internal_buffer = self.internal_buffer[asr_chunk_len:]
76
+ logging.debug(f"Extracted logical chunk of {len(asr_signal_chunk)} samples. Buffer remaining: {len(self.internal_buffer)}.")
77
+ return asr_signal_chunk
78
+ else:
79
+ logging.debug(f"Buffer size ({len(self.internal_buffer)}) < target ({asr_chunk_len}). Holding.")
80
+ return None
81
+
82
+ def _flush_and_reset(self):
83
+ """
84
+ Flushes the remaining buffer to the transcriber, resets the state,
85
+ and returns the last transcribed text.
86
+ """
87
+ new_text = ""
88
+ if len(self.internal_buffer) > 0:
89
+ # Buffer is already a numpy array
90
+ final_segment_chunk = self.internal_buffer
91
+ logging.info(f"Flushing segment remainder of {len(final_segment_chunk)} samples.")
92
+ seg, new_text = self.speech_engine.transcribe_chunk(final_segment_chunk, is_last_chunk=True)
93
+ else:
94
+ # Buffer is empty, but send a silent "flush"
95
+ # to force the transcriber to finalize its internal state.
96
+ logging.info("Buffer empty, sending silent flush to finalize segment.")
97
+ flush_chunk = np.zeros(self.logical_chunk_size, dtype='int16')
98
+ seg, new_text = self.speech_engine.transcribe_chunk(flush_chunk, is_last_chunk=True)
99
+
100
+ # Full state reset
101
+ logging.debug("Resetting speech engine state...")
102
+ self.speech_engine.reset() # Resets the speech engine (decoder state)
103
+
104
+ logging.debug("Resetting internal buffer and VAD state.")
105
+ self.internal_buffer = np.array([], dtype='int16') # Reset buffer
106
+ self.is_first_logical_chunk = True
107
+ self.silent_chunks_count = 0
108
+
109
+ return new_text
110
+
111
+ def process_chunk(self, chunk: np.ndarray):
112
+ """
113
+ Processes a single physical chunk (e.g., 8000 samples).
114
+ Manages VAD, buffering, and transcription.
115
+
116
+ Args:
117
+ chunk (np.ndarray): The audio chunk (int16).
118
+
119
+ Returns:
120
+ list: A list of new transcribed text segments.
121
+ (Often empty, may contain one or more segments).
122
+ """
123
+ new_text_segments = []
124
+ self.chunks_count += 1
125
+ logging.debug(f"--- Processing Physical Chunk {self.chunks_count} ---")
126
+
127
+ # --- 1. VAD Logic ---
128
+ has_speech = self.vad_engine(chunk)
129
+ logging.debug(f"VAD result: {'SPEECH' if has_speech else 'SILENCE'}")
130
+
131
+ if has_speech:
132
+ self.silent_chunks_count = 0
133
+ else:
134
+ self.silent_chunks_count += 1
135
+ logging.debug(f"Silent chunks count: {self.silent_chunks_count}/{self.SILENCE_THRESHOLD_CHUNKS}")
136
+
137
+ silence_reset = self.silent_chunks_count >= self.SILENCE_THRESHOLD_CHUNKS
138
+
139
+ # --- 2. Buffering & Transcription Logic ---
140
+ target_size = self.initial_logical_chunk_size if self.is_first_logical_chunk else self.logical_chunk_size
141
+ asr_chunk_np = self._append_to_buffer(chunk, target_size) # Now returns np.ndarray or None
142
+
143
+ if asr_chunk_np is not None:
144
+ logging.debug(f"Sending logical chunk (size: {len(asr_chunk_np)}) to speech engine...")
145
+ seg, new_text = self.speech_engine.transcribe_chunk(asr_chunk_np, is_last_chunk=False)
146
+ if new_text:
147
+ logging.info(f"Received new text segment: '{new_text}'")
148
+ new_text_segments.append(new_text)
149
+ self.is_first_logical_chunk = False
150
+
151
+ # --- 3. VAD Reset Logic ---
152
+ if silence_reset and not self.is_first_logical_chunk:
153
+ logging.info(f"\n[VAD RESET: SILENCE detected ({self.silent_chunks_count} empty chunks) at {(self.chunks_count * (self.read_size/self.VAD_SAMPLE_RATE)):.2f}s]")
154
+
155
+ # Flush the buffer, reset state, and get final text
156
+ reset_text = self._flush_and_reset()
157
+ if reset_text:
158
+ logging.info(f"Received final reset text: '{reset_text}'")
159
+ new_text_segments.append(reset_text)
160
+
161
+ return new_text_segments
162
+
163
+ def finalize_stream(self):
164
+ """
165
+ Must be called at the very end of the stream (after the loop breaks).
166
+ Flushes anything remaining in the buffer.
167
+ """
168
+ logging.info("Finalizing stream. Flushing final buffer...")
169
+ final_text = self._flush_and_reset()
170
+ if final_text:
171
+ logging.info(f"Received final flushed text: '{final_text}'")
172
+ return final_text
173
+
app/utils.py CHANGED
@@ -12,7 +12,8 @@ import base64
12
  import os
13
  import time
14
  import random
15
-
 
16
  from app.session_utils import (
17
  get_active_task_flag_file,
18
  get_folder_chunks
@@ -57,7 +58,7 @@ def generate_coturn_config():
57
 
58
 
59
 
60
- def read_and_stream_audio(filepath_to_stream: str, session_id: str, stop_streaming_flags: dict):
61
  """
62
  Read an audio file and stream it chunk by chunk (1s per chunk).
63
  Handles errors safely and reports structured messages to the client.
@@ -72,7 +73,7 @@ def read_and_stream_audio(filepath_to_stream: str, session_id: str, stop_streami
72
  transcribe_flag = get_active_task_flag_file(session_id)
73
  try:
74
  segment = AudioSegment.from_file(filepath_to_stream)
75
- chunk_duration_ms = 1000
76
  total_chunks = len(segment) // chunk_duration_ms + 1
77
  logging.info(f"[{session_id}] Starting audio streaming {filepath_to_stream} ({total_chunks} chunks).")
78
 
@@ -88,9 +89,9 @@ def read_and_stream_audio(filepath_to_stream: str, session_id: str, stop_streami
88
  break
89
 
90
  yield ((frame_rate, samples), AdditionalOutputs({"progressed": True, "value": progress} ))
91
- logging.debug(f"[{session_id}] Sent chunk {i+1}/{total_chunks} ({progress}%).")
92
 
93
- time.sleep(1)
94
  # Save only if transcription is active
95
  if os.path.exists(transcribe_flag) :
96
  chunk_dir = get_folder_chunks(session_id)
@@ -99,7 +100,7 @@ def read_and_stream_audio(filepath_to_stream: str, session_id: str, stop_streami
99
  npz_path = os.path.join(chunk_dir, f"chunk_{i:05d}.npz")
100
  chunk_array = np.array(chunk.get_array_of_samples(), dtype=np.int16)
101
  np.savez_compressed(npz_path, data=chunk_array, rate=frame_rate)
102
- logging.debug(f"[{session_id}] Saved chunk {i}/{total_chunks} (transcribe active)")
103
 
104
  # raise_function() # Optional injected test exception
105
 
@@ -148,14 +149,15 @@ def _is_stop_requested(stop_streaming_flags: dict) -> bool:
148
 
149
  # --- Decorator compatibility layer ---
150
  if os.environ.get("SPACE_ID", "").startswith("zero-gpu"):
151
- logging.warning("Running on ZeroGPU — disabling @spaces.GPU")
152
- def gpu_decorator(f): return f
153
  else:
154
  gpu_decorator = spaces.GPU
155
 
 
156
  # --- Audio Stream Function ---
157
- @gpu_decorator
158
- def task(session_id: str):
159
  """Continuously read and delete .npz chunks while task is active."""
160
  active_flag = get_active_task_flag_file(session_id)
161
  with open(active_flag, "w") as f:
@@ -171,11 +173,11 @@ def task(session_id: str):
171
  if not os.path.exists(chunk_dir):
172
  logging.warning(f"[{session_id}] No chunk directory found for task.")
173
  yield "No audio chunks yet... waiting for stream.\n"
174
- time.sleep(0.25)
175
  continue
176
  files = sorted(f for f in os.listdir(chunk_dir) if f.endswith(".npz"))
177
  if not files:
178
- time.sleep(0.25)
179
  continue
180
 
181
  for fname in files:
@@ -186,19 +188,26 @@ def task(session_id: str):
186
  rate = int(npz["rate"])
187
 
188
  text = f"Transcribed {fname}: {len(samples)} samples @ {rate}Hz"
189
- yield f"{text}\n"
190
- logging.debug(f"[{session_id}] {text}")
191
-
 
 
 
192
  os.remove(fpath)
193
  logging.debug(f"[{session_id}] Deleted processed chunk: {fname}")
194
  except Exception as e:
195
  logging.error(f"[{session_id}] Error processing {fname}: {e}")
196
  yield f"Error processing {fname}: {e}\n"
197
  continue
198
-
199
- time.sleep(0.25)
200
  # raise_function()
201
- yield "\nTask stopped by user or stream ended.\n"
 
 
 
 
202
  logging.info(f"[{session_id}] task loop ended (flag removed).")
203
 
204
  except Exception as e:
 
12
  import os
13
  import time
14
  import random
15
+ import torch
16
+ from app.streaming_audio_processor import StreamingAudioProcessor
17
  from app.session_utils import (
18
  get_active_task_flag_file,
19
  get_folder_chunks
 
58
 
59
 
60
 
61
+ def read_and_stream_audio(filepath_to_stream: str, session_id: str, stop_streaming_flags: dict,read_size:int =8000, sample_rate:int =16000):
62
  """
63
  Read an audio file and stream it chunk by chunk (1s per chunk).
64
  Handles errors safely and reports structured messages to the client.
 
73
  transcribe_flag = get_active_task_flag_file(session_id)
74
  try:
75
  segment = AudioSegment.from_file(filepath_to_stream)
76
+ chunk_duration_ms = int((read_size/sample_rate)*1000)
77
  total_chunks = len(segment) // chunk_duration_ms + 1
78
  logging.info(f"[{session_id}] Starting audio streaming {filepath_to_stream} ({total_chunks} chunks).")
79
 
 
89
  break
90
 
91
  yield ((frame_rate, samples), AdditionalOutputs({"progressed": True, "value": progress} ))
92
+ # logging.debug(f"[{session_id}] Sent chunk {i+1}/{total_chunks} ({progress}%).")
93
 
94
+ time.sleep(chunk_duration_ms/1000)
95
  # Save only if transcription is active
96
  if os.path.exists(transcribe_flag) :
97
  chunk_dir = get_folder_chunks(session_id)
 
100
  npz_path = os.path.join(chunk_dir, f"chunk_{i:05d}.npz")
101
  chunk_array = np.array(chunk.get_array_of_samples(), dtype=np.int16)
102
  np.savez_compressed(npz_path, data=chunk_array, rate=frame_rate)
103
+ logging.debug(f"[{session_id}] Saved chunk {i}/{total_chunks} (transcribe active) ({progress}%) ({npz_path}).")
104
 
105
  # raise_function() # Optional injected test exception
106
 
 
149
 
150
  # --- Decorator compatibility layer ---
151
  if os.environ.get("SPACE_ID", "").startswith("zero-gpu"):
152
+ logging.warning("Running on ZeroGPU — gpu_fork_decorator @spaces.GPU")
153
+ def gpu_fork_decorator(f): return f
154
  else:
155
  gpu_decorator = spaces.GPU
156
 
157
+
158
  # --- Audio Stream Function ---
159
+ @spaces.GPU
160
+ def task(session_id: str, streamer: StreamingAudioProcessor):
161
  """Continuously read and delete .npz chunks while task is active."""
162
  active_flag = get_active_task_flag_file(session_id)
163
  with open(active_flag, "w") as f:
 
173
  if not os.path.exists(chunk_dir):
174
  logging.warning(f"[{session_id}] No chunk directory found for task.")
175
  yield "No audio chunks yet... waiting for stream.\n"
176
+ time.sleep(0.1)
177
  continue
178
  files = sorted(f for f in os.listdir(chunk_dir) if f.endswith(".npz"))
179
  if not files:
180
+ time.sleep(0.1)
181
  continue
182
 
183
  for fname in files:
 
188
  rate = int(npz["rate"])
189
 
190
  text = f"Transcribed {fname}: {len(samples)} samples @ {rate}Hz"
191
+ new_texts = streamer.process_chunk(samples)
192
+ for text in new_texts:
193
+ print(text, end='', flush=True)
194
+ yield f"{text}"
195
+ logging.debug(f"[{session_id}] {new_texts}")
196
+ # yield f"{text}\n"
197
  os.remove(fpath)
198
  logging.debug(f"[{session_id}] Deleted processed chunk: {fname}")
199
  except Exception as e:
200
  logging.error(f"[{session_id}] Error processing {fname}: {e}")
201
  yield f"Error processing {fname}: {e}\n"
202
  continue
203
+
204
+ time.sleep(0.1)
205
  # raise_function()
206
+ final_text = streamer.finalize_stream()
207
+ if final_text:
208
+ print(final_text, end='', flush=True)
209
+ yield f"\n{final_text}"
210
+ # yield f"\n"
211
  logging.info(f"[{session_id}] task loop ended (flag removed).")
212
 
213
  except Exception as e:
requirements.txt CHANGED
@@ -3,3 +3,6 @@ spaces
3
  torch
4
  python-dotenv
5
  fastrtc==0.0.33
 
 
 
 
3
  torch
4
  python-dotenv
5
  fastrtc==0.0.33
6
+ Cython
7
+ nemo_toolkit[asr,nlp] @ git+https://github.com/NVIDIA/NeMo.git@237e2c08ed8e5b6bad66b124d75b02f6510b9b56
8
+ onnxruntime