import numpy as np import threading import time from faster_whisper import WhisperModel import scipy.signal as signal from typing import List from punctuators.models import SBDModelONNX import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from config import config class AudioProcessor: def __init__(self, model_size="tiny.en", device=None, compute_type=None): """Initialize the audio processor with configurable parameters""" self.audio_buffer = np.array([]) # Stores raw audio for playback self.processed_length = 0 # Length of audio already processed self.sample_rate = 16000 # Default sample rate for whisper self.lock = threading.Lock() # Thread safety for buffer access self.min_process_length = 1 * self.sample_rate # Process at least 1 second self.max_buffer_size = 30 * self.sample_rate # Maximum buffer size (30 seconds) self.overlap_size = 3 * self.sample_rate # Keep 3 seconds of overlap when trimming self.last_process_time = time.time() self.process_interval = 0.5 # Process every 1 second self.is_processing = False # Flag to prevent concurrent processing self.full_transcription = "" # Complete history of transcription self.last_segment_text = "" # Last segment that was transcribed self.confirmed_transcription = "" # Transcription that won't change (beyond overlap zone) # Use config for device and compute type if not specified if device is None or compute_type is None: whisper_config = config.get_whisper_config() device = device or whisper_config["device"] compute_type = compute_type or whisper_config["compute_type"] # Initialize the whisper model self.audio_model = WhisperModel(model_size, device=device, compute_type=compute_type) print(f"Initialized {model_size} model on {device} with {compute_type}") # Initialize sentence boundary detection with device config self.sentence_end_detect = SBDModelONNX.from_pretrained("sbd_multi_lang") if config.device == "cuda": print("SBD model initialized with CUDA support") def _trim_buffer_intelligently(self): """ Trim the buffer while preserving transcription continuity Keep some overlap to maintain context for the next processing """ if len(self.audio_buffer) <= self.max_buffer_size: return # Calculate how much to trim (keep overlap_size for context) trim_amount = len(self.audio_buffer) - self.max_buffer_size + self.overlap_size # Make sure we don't trim more than we have trim_amount = min(trim_amount, len(self.audio_buffer) - self.overlap_size) if trim_amount > 0: # Before trimming, finalize the transcription for the part we're removing # This ensures we don't lose confirmed text if self.processed_length > trim_amount: # We're removing audio that was already processed # The transcription for this part should be considered final pass # The full_transcription already contains this # Trim the buffer self.audio_buffer = self.audio_buffer[trim_amount:] # Adjust processed_length to account for trimmed audio self.processed_length = max(0, self.processed_length - trim_amount) # Reset last_segment_text since our context has changed # This forces the next processing to start fresh with overlap handling self.last_segment_text = "" def _process_audio_chunk(self): """Process the current audio buffer and return new transcription""" try: with self.lock: # Check if there's enough new content to process unprocessed_length = len(self.audio_buffer) - self.processed_length if unprocessed_length < self.min_process_length: self.is_processing = False return None # Determine what portion to process # Include some overlap from already processed audio for context overlap_samples = min(self.overlap_size, self.processed_length) start_pos = max(0, self.processed_length - overlap_samples) # Process from start_pos to end of buffer audio_to_process = self.audio_buffer[start_pos:].copy() end_pos = len(self.audio_buffer) # Normalize for transcription audio_norm = audio_to_process.astype(np.float32) if np.max(np.abs(audio_norm)) > 0: audio_norm = audio_norm / np.max(np.abs(audio_norm)) # Transcribe with faster settings for real-time processing segments, info = self.audio_model.transcribe( audio_norm, beam_size=1, word_timestamps=False, vad_filter=True, vad_parameters=dict(min_silence_duration_ms=500) ) result = list(segments) if result: # Get the new text from all segments current_segment_text = " ".join([seg.text.strip() for seg in result if seg.text.strip()]) if not current_segment_text: self.is_processing = False return None # Handle overlap and merge with existing transcription new_text = self._merge_transcription_intelligently(current_segment_text) if new_text: # Append new text to full transcription if self.full_transcription and not self.full_transcription.endswith(' '): self.full_transcription += " " self.full_transcription += new_text # Update state self.last_segment_text = current_segment_text self.processed_length = end_pos return self.full_transcription return None except Exception as e: print(f"Transcription error: {e}") return None finally: self.is_processing = False def _merge_transcription_intelligently(self, new_segment_text): """ Intelligently merge new transcription with existing text Handles overlap detection and prevents duplication """ if not new_segment_text or not new_segment_text.strip(): return "" # If this is the first transcription or we reset context, use it directly if not self.last_segment_text: return new_segment_text # Normalize text for comparison import re def normalize_for_comparison(text): # Convert to lowercase and remove punctuation for comparison text = text.lower() text = re.sub(r'[^\w\s]', '', text) return text.strip() norm_prev = normalize_for_comparison(self.last_segment_text) norm_new = normalize_for_comparison(new_segment_text) if not norm_prev or not norm_new: return new_segment_text # Split into words for overlap detection prev_words = norm_prev.split() new_words = norm_new.split() # Find the longest overlap between end of previous and start of new max_overlap = min(len(prev_words), len(new_words), 15) # Check up to 15 words overlap_found = 0 for i in range(max_overlap, 2, -1): # Minimum 3 words to consider overlap if prev_words[-i:] == new_words[:i]: overlap_found = i break # Handle special cases for numbers (counting sequences) if overlap_found == 0: # Check if we have a counting sequence prev_numbers = [int(x) for x in re.findall(r'\b\d+\b', norm_prev)] new_numbers = [int(x) for x in re.findall(r'\b\d+\b', norm_new)] if prev_numbers and new_numbers: max_prev = max(prev_numbers) min_new = min(new_numbers) # If there's a logical continuation, find where it starts if min_new <= max_prev + 5: # Allow some gap in counting new_text_words = new_segment_text.split() for i, word in enumerate(new_text_words): if re.search(r'\b\d+\b', word): num = int(re.search(r'\d+', word).group()) if num > max_prev: return " ".join(new_text_words[i:]) # Apply overlap removal if found if overlap_found > 0: new_text_words = new_segment_text.split() return " ".join(new_text_words[overlap_found:]) else: # Check if new text is completely contained in previous (avoid duplication) if norm_new in norm_prev: return "" # No overlap found, return the full new text return new_segment_text def add_audio(self, audio_data, sr): """ Add audio to the buffer and process if needed Args: audio_data (numpy.ndarray): Audio data to add sr (int): Sample rate of the audio data Returns: int: Current buffer size in samples """ with self.lock: # Convert to mono if stereo if audio_data.ndim > 1: audio_data = audio_data.mean(axis=1) # Convert to float32 audio_data = audio_data.astype(np.float32) # Resample if needed if sr != self.sample_rate: try: # Use scipy for proper resampling number_of_samples = int(len(audio_data) * self.sample_rate / sr) audio_data = signal.resample(audio_data, number_of_samples) except Exception as e: print(f"Resampling error: {e}") # Fallback resampling ratio = self.sample_rate / sr audio_data = np.interp( np.arange(0, len(audio_data) * ratio, ratio), np.arange(0, len(audio_data)), audio_data ) # Apply fade-in to prevent clicks (5ms fade) fade_samples = min(int(0.005 * self.sample_rate), len(audio_data)) if fade_samples > 0: fade_in = np.linspace(0, 1, fade_samples) audio_data[:fade_samples] *= fade_in # Add to buffer if len(self.audio_buffer) == 0: self.audio_buffer = audio_data else: self.audio_buffer = np.concatenate([self.audio_buffer, audio_data]) # Intelligently trim buffer if it gets too large self._trim_buffer_intelligently() # Check if we should process now should_process = ( len(self.audio_buffer) >= self.min_process_length and time.time() - self.last_process_time >= self.process_interval and not self.is_processing ) if should_process: self.last_process_time = time.time() self.is_processing = True # Process in a separate thread threading.Thread(target=self._process_audio_chunk, daemon=False).start() return len(self.audio_buffer) def wait_for_processing_complete(self, timeout=5.0): """Wait for any current processing to complete""" start_time = time.time() while self.is_processing and (time.time() - start_time) < timeout: time.sleep(0.05) return not self.is_processing def force_complete_processing(self): """Force completion of any pending processing - ensures sequential execution""" # Wait for any current processing to complete self.wait_for_processing_complete(10.0) # Process any remaining audio in buffer with self.lock: if len(self.audio_buffer) > self.processed_length: # Force process remaining audio self.is_processing = True self._process_audio_chunk() # Final wait to ensure everything is complete self.wait_for_processing_complete(2.0) return self.get_transcription() def clear_buffer(self): """Clear the audio buffer and transcription""" with self.lock: self.audio_buffer = np.array([]) self.processed_length = 0 self.full_transcription = "" self.last_segment_text = "" self.confirmed_transcription = "" self.is_processing = False return "Buffers cleared" def get_transcription(self): """Get the current transcription text""" with self.lock: results: List[List[str]] = self.sentence_end_detect.infer([self.full_transcription]) return results[0] def get_playback_audio(self): """Get properly formatted audio for Gradio playback""" with self.lock: if len(self.audio_buffer) == 0: return None # Make a copy and ensure proper format for Gradio audio = self.audio_buffer.copy() # Ensure audio is in the correct range for playback (-1 to 1) if np.max(np.abs(audio)) > 0: audio = audio / max(1.0, np.max(np.abs(audio))) return (self.sample_rate, audio) def get_buffer_info(self): """Get information about the current buffer state""" with self.lock: return { "buffer_length_seconds": len(self.audio_buffer) / self.sample_rate, "processed_length_seconds": self.processed_length / self.sample_rate, "unprocessed_length_seconds": (len(self.audio_buffer) - self.processed_length) / self.sample_rate, "is_processing": self.is_processing, "transcription_length": len(self.full_transcription) }