File size: 14,356 Bytes
8b3bbb3
 
 
 
 
66a7fab
 
7b7db64
 
 
 
 
8b3bbb3
 
7b7db64
8b3bbb3
 
 
 
 
 
 
7b7174c
8b3bbb3
7b7db64
8b3bbb3
7b7174c
 
 
 
 
7b7db64
 
 
 
 
 
8b3bbb3
 
7b7db64
7b7174c
7b7db64
66a7fab
7b7db64
 
66a7fab
7b7174c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b3bbb3
 
 
7b7174c
8b3bbb3
 
 
7b7174c
8b3bbb3
 
 
 
 
 
 
7b7174c
 
8b3bbb3
7b7174c
 
8b3bbb3
 
 
 
 
 
 
7b7174c
8b3bbb3
 
 
 
 
 
7b7174c
 
8b3bbb3
 
 
7b7174c
 
8b3bbb3
 
 
 
 
7b7174c
 
 
 
8b3bbb3
 
 
 
 
 
7b7174c
8b3bbb3
 
 
7b7174c
7b7db64
7b7174c
8b3bbb3
7b7174c
7b7db64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b3bbb3
7b7174c
8b3bbb3
 
 
7b7174c
 
 
8b3bbb3
 
7b7174c
 
 
 
66a7fab
 
7b7174c
8b3bbb3
 
 
 
 
7b7174c
8b3bbb3
 
7b7174c
8b3bbb3
 
 
7b7174c
8b3bbb3
7b7174c
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import numpy as np
import threading
import time
from faster_whisper import WhisperModel
import scipy.signal as signal
from typing import List
from punctuators.models import SBDModelONNX
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import config


class AudioProcessor:
    def __init__(self, model_size="tiny.en", device=None, compute_type=None):
        """Initialize the audio processor with configurable parameters"""
        self.audio_buffer = np.array([])  # Stores raw audio for playback
        self.processed_length = 0         # Length of audio already processed
        self.sample_rate = 16000          # Default sample rate for whisper
        self.lock = threading.Lock()      # Thread safety for buffer access
        self.min_process_length = 1 * self.sample_rate  # Process at least 1 second
        self.max_buffer_size = 30 * self.sample_rate  # Maximum buffer size (30 seconds)
        self.overlap_size = 3 * self.sample_rate  # Keep 3 seconds of overlap when trimming
        self.last_process_time = time.time()
        self.process_interval = 0.5       # Process every 1 second
        self.is_processing = False        # Flag to prevent concurrent processing

        self.full_transcription = ""      # Complete history of transcription
        self.last_segment_text = ""       # Last segment that was transcribed
        self.confirmed_transcription = "" # Transcription that won't change (beyond overlap zone)

        # Use config for device and compute type if not specified
        if device is None or compute_type is None:
            whisper_config = config.get_whisper_config()
            device = device or whisper_config["device"]
            compute_type = compute_type or whisper_config["compute_type"]

        # Initialize the whisper model
        self.audio_model = WhisperModel(model_size, device=device, compute_type=compute_type)
        print(f"Initialized {model_size} model on {device} with {compute_type}")

        # Initialize sentence boundary detection with device config
        self.sentence_end_detect = SBDModelONNX.from_pretrained("sbd_multi_lang")
        if config.device == "cuda":
            print("SBD model initialized with CUDA support")

    def _trim_buffer_intelligently(self):
        """
        Trim the buffer while preserving transcription continuity
        Keep some overlap to maintain context for the next processing
        """
        if len(self.audio_buffer) <= self.max_buffer_size:
            return

        # Calculate how much to trim (keep overlap_size for context)
        trim_amount = len(self.audio_buffer) - self.max_buffer_size + self.overlap_size

        # Make sure we don't trim more than we have
        trim_amount = min(trim_amount, len(self.audio_buffer) - self.overlap_size)

        if trim_amount > 0:
            # Before trimming, finalize the transcription for the part we're removing
            # This ensures we don't lose confirmed text
            if self.processed_length > trim_amount:
                # We're removing audio that was already processed
                # The transcription for this part should be considered final
                pass  # The full_transcription already contains this

            # Trim the buffer
            self.audio_buffer = self.audio_buffer[trim_amount:]

            # Adjust processed_length to account for trimmed audio
            self.processed_length = max(0, self.processed_length - trim_amount)

            # Reset last_segment_text since our context has changed
            # This forces the next processing to start fresh with overlap handling
            self.last_segment_text = ""

    def _process_audio_chunk(self):
        """Process the current audio buffer and return new transcription"""
        try:
            with self.lock:
                # Check if there's enough new content to process
                unprocessed_length = len(self.audio_buffer) - self.processed_length
                if unprocessed_length < self.min_process_length:
                    self.is_processing = False
                    return None

                # Determine what portion to process
                # Include some overlap from already processed audio for context
                overlap_samples = min(self.overlap_size, self.processed_length)
                start_pos = max(0, self.processed_length - overlap_samples)

                # Process from start_pos to end of buffer
                audio_to_process = self.audio_buffer[start_pos:].copy()
                end_pos = len(self.audio_buffer)

            # Normalize for transcription
            audio_norm = audio_to_process.astype(np.float32)
            if np.max(np.abs(audio_norm)) > 0:
                audio_norm = audio_norm / np.max(np.abs(audio_norm))

            # Transcribe with faster settings for real-time processing
            segments, info = self.audio_model.transcribe(
                audio_norm,
                beam_size=1,
                word_timestamps=False,
                vad_filter=True,
                vad_parameters=dict(min_silence_duration_ms=500)
            )

            result = list(segments)

            if result:
                # Get the new text from all segments
                current_segment_text = " ".join([seg.text.strip() for seg in result if seg.text.strip()])

                if not current_segment_text:
                    self.is_processing = False
                    return None

                # Handle overlap and merge with existing transcription
                new_text = self._merge_transcription_intelligently(current_segment_text)

                if new_text:
                    # Append new text to full transcription
                    if self.full_transcription and not self.full_transcription.endswith(' '):
                        self.full_transcription += " "
                    self.full_transcription += new_text

                # Update state
                self.last_segment_text = current_segment_text
                self.processed_length = end_pos

                return self.full_transcription

            return None

        except Exception as e:
            print(f"Transcription error: {e}")
            return None
        finally:
            self.is_processing = False

    def _merge_transcription_intelligently(self, new_segment_text):
        """
        Intelligently merge new transcription with existing text
        Handles overlap detection and prevents duplication
        """
        if not new_segment_text or not new_segment_text.strip():
            return ""

        # If this is the first transcription or we reset context, use it directly
        if not self.last_segment_text:
            return new_segment_text

        # Normalize text for comparison
        import re

        def normalize_for_comparison(text):
            # Convert to lowercase and remove punctuation for comparison
            text = text.lower()
            text = re.sub(r'[^\w\s]', '', text)
            return text.strip()

        norm_prev = normalize_for_comparison(self.last_segment_text)
        norm_new = normalize_for_comparison(new_segment_text)

        if not norm_prev or not norm_new:
            return new_segment_text

        # Split into words for overlap detection
        prev_words = norm_prev.split()
        new_words = norm_new.split()

        # Find the longest overlap between end of previous and start of new
        max_overlap = min(len(prev_words), len(new_words), 15)  # Check up to 15 words
        overlap_found = 0

        for i in range(max_overlap, 2, -1):  # Minimum 3 words to consider overlap
            if prev_words[-i:] == new_words[:i]:
                overlap_found = i
                break

        # Handle special cases for numbers (counting sequences)
        if overlap_found == 0:
            # Check if we have a counting sequence
            prev_numbers = [int(x) for x in re.findall(r'\b\d+\b', norm_prev)]
            new_numbers = [int(x) for x in re.findall(r'\b\d+\b', norm_new)]

            if prev_numbers and new_numbers:
                max_prev = max(prev_numbers)
                min_new = min(new_numbers)

                # If there's a logical continuation, find where it starts
                if min_new <= max_prev + 5:  # Allow some gap in counting
                    new_text_words = new_segment_text.split()
                    for i, word in enumerate(new_text_words):
                        if re.search(r'\b\d+\b', word):
                            num = int(re.search(r'\d+', word).group())
                            if num > max_prev:
                                return " ".join(new_text_words[i:])

        # Apply overlap removal if found
        if overlap_found > 0:
            new_text_words = new_segment_text.split()
            return " ".join(new_text_words[overlap_found:])
        else:
            # Check if new text is completely contained in previous (avoid duplication)
            if norm_new in norm_prev:
                return ""
            # No overlap found, return the full new text
            return new_segment_text

    def add_audio(self, audio_data, sr):
        """
        Add audio to the buffer and process if needed

        Args:
            audio_data (numpy.ndarray): Audio data to add
            sr (int): Sample rate of the audio data

        Returns:
            int: Current buffer size in samples
        """
        with self.lock:
            # Convert to mono if stereo
            if audio_data.ndim > 1:
                audio_data = audio_data.mean(axis=1)

            # Convert to float32
            audio_data = audio_data.astype(np.float32)

            # Resample if needed
            if sr != self.sample_rate:
                try:
                    # Use scipy for proper resampling
                    number_of_samples = int(len(audio_data) * self.sample_rate / sr)
                    audio_data = signal.resample(audio_data, number_of_samples)
                except Exception as e:
                    print(f"Resampling error: {e}")
                    # Fallback resampling
                    ratio = self.sample_rate / sr
                    audio_data = np.interp(
                        np.arange(0, len(audio_data) * ratio, ratio),
                        np.arange(0, len(audio_data)),
                        audio_data
                    )

            # Apply fade-in to prevent clicks (5ms fade)
            fade_samples = min(int(0.005 * self.sample_rate), len(audio_data))
            if fade_samples > 0:
                fade_in = np.linspace(0, 1, fade_samples)
                audio_data[:fade_samples] *= fade_in

            # Add to buffer
            if len(self.audio_buffer) == 0:
                self.audio_buffer = audio_data
            else:
                self.audio_buffer = np.concatenate([self.audio_buffer, audio_data])

            # Intelligently trim buffer if it gets too large
            self._trim_buffer_intelligently()

            # Check if we should process now
            should_process = (
                len(self.audio_buffer) >= self.min_process_length and
                time.time() - self.last_process_time >= self.process_interval and
                not self.is_processing
            )

            if should_process:
                self.last_process_time = time.time()
                self.is_processing = True
                # Process in a separate thread
                threading.Thread(target=self._process_audio_chunk, daemon=False).start()

            return len(self.audio_buffer)

    def wait_for_processing_complete(self, timeout=5.0):
        """Wait for any current processing to complete"""
        start_time = time.time()
        while self.is_processing and (time.time() - start_time) < timeout:
            time.sleep(0.05)
        return not self.is_processing

    def force_complete_processing(self):
        """Force completion of any pending processing - ensures sequential execution"""
        # Wait for any current processing to complete
        self.wait_for_processing_complete(10.0)

        # Process any remaining audio in buffer
        with self.lock:
            if len(self.audio_buffer) > self.processed_length:
                # Force process remaining audio
                self.is_processing = True
                self._process_audio_chunk()

        # Final wait to ensure everything is complete
        self.wait_for_processing_complete(2.0)

        return self.get_transcription()

    def clear_buffer(self):
        """Clear the audio buffer and transcription"""
        with self.lock:
            self.audio_buffer = np.array([])
            self.processed_length = 0
            self.full_transcription = ""
            self.last_segment_text = ""
            self.confirmed_transcription = ""
            self.is_processing = False
            return "Buffers cleared"

    def get_transcription(self):
        """Get the current transcription text"""
        with self.lock:
            results: List[List[str]] = self.sentence_end_detect.infer([self.full_transcription])
            return results[0]

    def get_playback_audio(self):
        """Get properly formatted audio for Gradio playback"""
        with self.lock:
            if len(self.audio_buffer) == 0:
                return None

            # Make a copy and ensure proper format for Gradio
            audio = self.audio_buffer.copy()

            # Ensure audio is in the correct range for playback (-1 to 1)
            if np.max(np.abs(audio)) > 0:
                audio = audio / max(1.0, np.max(np.abs(audio)))

            return (self.sample_rate, audio)

    def get_buffer_info(self):
        """Get information about the current buffer state"""
        with self.lock:
            return {
                "buffer_length_seconds": len(self.audio_buffer) / self.sample_rate,
                "processed_length_seconds": self.processed_length / self.sample_rate,
                "unprocessed_length_seconds": (len(self.audio_buffer) - self.processed_length) / self.sample_rate,
                "is_processing": self.is_processing,
                "transcription_length": len(self.full_transcription)
            }