Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import threading | |
| import time | |
| from faster_whisper import WhisperModel | |
| import scipy.signal as signal | |
| from typing import List | |
| from punctuators.models import SBDModelONNX | |
| import sys | |
| import os | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from config import config | |
| class AudioProcessor: | |
| def __init__(self, model_size="tiny.en", device=None, compute_type=None): | |
| """Initialize the audio processor with configurable parameters""" | |
| self.audio_buffer = np.array([]) # Stores raw audio for playback | |
| self.processed_length = 0 # Length of audio already processed | |
| self.sample_rate = 16000 # Default sample rate for whisper | |
| self.lock = threading.Lock() # Thread safety for buffer access | |
| self.min_process_length = 1 * self.sample_rate # Process at least 1 second | |
| self.max_buffer_size = 30 * self.sample_rate # Maximum buffer size (30 seconds) | |
| self.overlap_size = 3 * self.sample_rate # Keep 3 seconds of overlap when trimming | |
| self.last_process_time = time.time() | |
| self.process_interval = 0.5 # Process every 1 second | |
| self.is_processing = False # Flag to prevent concurrent processing | |
| self.full_transcription = "" # Complete history of transcription | |
| self.last_segment_text = "" # Last segment that was transcribed | |
| self.confirmed_transcription = "" # Transcription that won't change (beyond overlap zone) | |
| # Use config for device and compute type if not specified | |
| if device is None or compute_type is None: | |
| whisper_config = config.get_whisper_config() | |
| device = device or whisper_config["device"] | |
| compute_type = compute_type or whisper_config["compute_type"] | |
| # Initialize the whisper model | |
| self.audio_model = WhisperModel(model_size, device=device, compute_type=compute_type) | |
| print(f"Initialized {model_size} model on {device} with {compute_type}") | |
| # Initialize sentence boundary detection with device config | |
| self.sentence_end_detect = SBDModelONNX.from_pretrained("sbd_multi_lang") | |
| if config.device == "cuda": | |
| print("SBD model initialized with CUDA support") | |
| def _trim_buffer_intelligently(self): | |
| """ | |
| Trim the buffer while preserving transcription continuity | |
| Keep some overlap to maintain context for the next processing | |
| """ | |
| if len(self.audio_buffer) <= self.max_buffer_size: | |
| return | |
| # Calculate how much to trim (keep overlap_size for context) | |
| trim_amount = len(self.audio_buffer) - self.max_buffer_size + self.overlap_size | |
| # Make sure we don't trim more than we have | |
| trim_amount = min(trim_amount, len(self.audio_buffer) - self.overlap_size) | |
| if trim_amount > 0: | |
| # Before trimming, finalize the transcription for the part we're removing | |
| # This ensures we don't lose confirmed text | |
| if self.processed_length > trim_amount: | |
| # We're removing audio that was already processed | |
| # The transcription for this part should be considered final | |
| pass # The full_transcription already contains this | |
| # Trim the buffer | |
| self.audio_buffer = self.audio_buffer[trim_amount:] | |
| # Adjust processed_length to account for trimmed audio | |
| self.processed_length = max(0, self.processed_length - trim_amount) | |
| # Reset last_segment_text since our context has changed | |
| # This forces the next processing to start fresh with overlap handling | |
| self.last_segment_text = "" | |
| def _process_audio_chunk(self): | |
| """Process the current audio buffer and return new transcription""" | |
| try: | |
| with self.lock: | |
| # Check if there's enough new content to process | |
| unprocessed_length = len(self.audio_buffer) - self.processed_length | |
| if unprocessed_length < self.min_process_length: | |
| self.is_processing = False | |
| return None | |
| # Determine what portion to process | |
| # Include some overlap from already processed audio for context | |
| overlap_samples = min(self.overlap_size, self.processed_length) | |
| start_pos = max(0, self.processed_length - overlap_samples) | |
| # Process from start_pos to end of buffer | |
| audio_to_process = self.audio_buffer[start_pos:].copy() | |
| end_pos = len(self.audio_buffer) | |
| # Normalize for transcription | |
| audio_norm = audio_to_process.astype(np.float32) | |
| if np.max(np.abs(audio_norm)) > 0: | |
| audio_norm = audio_norm / np.max(np.abs(audio_norm)) | |
| # Transcribe with faster settings for real-time processing | |
| segments, info = self.audio_model.transcribe( | |
| audio_norm, | |
| beam_size=1, | |
| word_timestamps=False, | |
| vad_filter=True, | |
| vad_parameters=dict(min_silence_duration_ms=500) | |
| ) | |
| result = list(segments) | |
| if result: | |
| # Get the new text from all segments | |
| current_segment_text = " ".join([seg.text.strip() for seg in result if seg.text.strip()]) | |
| if not current_segment_text: | |
| self.is_processing = False | |
| return None | |
| # Handle overlap and merge with existing transcription | |
| new_text = self._merge_transcription_intelligently(current_segment_text) | |
| if new_text: | |
| # Append new text to full transcription | |
| if self.full_transcription and not self.full_transcription.endswith(' '): | |
| self.full_transcription += " " | |
| self.full_transcription += new_text | |
| # Update state | |
| self.last_segment_text = current_segment_text | |
| self.processed_length = end_pos | |
| return self.full_transcription | |
| return None | |
| except Exception as e: | |
| print(f"Transcription error: {e}") | |
| return None | |
| finally: | |
| self.is_processing = False | |
| def _merge_transcription_intelligently(self, new_segment_text): | |
| """ | |
| Intelligently merge new transcription with existing text | |
| Handles overlap detection and prevents duplication | |
| """ | |
| if not new_segment_text or not new_segment_text.strip(): | |
| return "" | |
| # If this is the first transcription or we reset context, use it directly | |
| if not self.last_segment_text: | |
| return new_segment_text | |
| # Normalize text for comparison | |
| import re | |
| def normalize_for_comparison(text): | |
| # Convert to lowercase and remove punctuation for comparison | |
| text = text.lower() | |
| text = re.sub(r'[^\w\s]', '', text) | |
| return text.strip() | |
| norm_prev = normalize_for_comparison(self.last_segment_text) | |
| norm_new = normalize_for_comparison(new_segment_text) | |
| if not norm_prev or not norm_new: | |
| return new_segment_text | |
| # Split into words for overlap detection | |
| prev_words = norm_prev.split() | |
| new_words = norm_new.split() | |
| # Find the longest overlap between end of previous and start of new | |
| max_overlap = min(len(prev_words), len(new_words), 15) # Check up to 15 words | |
| overlap_found = 0 | |
| for i in range(max_overlap, 2, -1): # Minimum 3 words to consider overlap | |
| if prev_words[-i:] == new_words[:i]: | |
| overlap_found = i | |
| break | |
| # Handle special cases for numbers (counting sequences) | |
| if overlap_found == 0: | |
| # Check if we have a counting sequence | |
| prev_numbers = [int(x) for x in re.findall(r'\b\d+\b', norm_prev)] | |
| new_numbers = [int(x) for x in re.findall(r'\b\d+\b', norm_new)] | |
| if prev_numbers and new_numbers: | |
| max_prev = max(prev_numbers) | |
| min_new = min(new_numbers) | |
| # If there's a logical continuation, find where it starts | |
| if min_new <= max_prev + 5: # Allow some gap in counting | |
| new_text_words = new_segment_text.split() | |
| for i, word in enumerate(new_text_words): | |
| if re.search(r'\b\d+\b', word): | |
| num = int(re.search(r'\d+', word).group()) | |
| if num > max_prev: | |
| return " ".join(new_text_words[i:]) | |
| # Apply overlap removal if found | |
| if overlap_found > 0: | |
| new_text_words = new_segment_text.split() | |
| return " ".join(new_text_words[overlap_found:]) | |
| else: | |
| # Check if new text is completely contained in previous (avoid duplication) | |
| if norm_new in norm_prev: | |
| return "" | |
| # No overlap found, return the full new text | |
| return new_segment_text | |
| def add_audio(self, audio_data, sr): | |
| """ | |
| Add audio to the buffer and process if needed | |
| Args: | |
| audio_data (numpy.ndarray): Audio data to add | |
| sr (int): Sample rate of the audio data | |
| Returns: | |
| int: Current buffer size in samples | |
| """ | |
| with self.lock: | |
| # Convert to mono if stereo | |
| if audio_data.ndim > 1: | |
| audio_data = audio_data.mean(axis=1) | |
| # Convert to float32 | |
| audio_data = audio_data.astype(np.float32) | |
| # Resample if needed | |
| if sr != self.sample_rate: | |
| try: | |
| # Use scipy for proper resampling | |
| number_of_samples = int(len(audio_data) * self.sample_rate / sr) | |
| audio_data = signal.resample(audio_data, number_of_samples) | |
| except Exception as e: | |
| print(f"Resampling error: {e}") | |
| # Fallback resampling | |
| ratio = self.sample_rate / sr | |
| audio_data = np.interp( | |
| np.arange(0, len(audio_data) * ratio, ratio), | |
| np.arange(0, len(audio_data)), | |
| audio_data | |
| ) | |
| # Apply fade-in to prevent clicks (5ms fade) | |
| fade_samples = min(int(0.005 * self.sample_rate), len(audio_data)) | |
| if fade_samples > 0: | |
| fade_in = np.linspace(0, 1, fade_samples) | |
| audio_data[:fade_samples] *= fade_in | |
| # Add to buffer | |
| if len(self.audio_buffer) == 0: | |
| self.audio_buffer = audio_data | |
| else: | |
| self.audio_buffer = np.concatenate([self.audio_buffer, audio_data]) | |
| # Intelligently trim buffer if it gets too large | |
| self._trim_buffer_intelligently() | |
| # Check if we should process now | |
| should_process = ( | |
| len(self.audio_buffer) >= self.min_process_length and | |
| time.time() - self.last_process_time >= self.process_interval and | |
| not self.is_processing | |
| ) | |
| if should_process: | |
| self.last_process_time = time.time() | |
| self.is_processing = True | |
| # Process in a separate thread | |
| threading.Thread(target=self._process_audio_chunk, daemon=False).start() | |
| return len(self.audio_buffer) | |
| def wait_for_processing_complete(self, timeout=5.0): | |
| """Wait for any current processing to complete""" | |
| start_time = time.time() | |
| while self.is_processing and (time.time() - start_time) < timeout: | |
| time.sleep(0.05) | |
| return not self.is_processing | |
| def force_complete_processing(self): | |
| """Force completion of any pending processing - ensures sequential execution""" | |
| # Wait for any current processing to complete | |
| self.wait_for_processing_complete(10.0) | |
| # Process any remaining audio in buffer | |
| with self.lock: | |
| if len(self.audio_buffer) > self.processed_length: | |
| # Force process remaining audio | |
| self.is_processing = True | |
| self._process_audio_chunk() | |
| # Final wait to ensure everything is complete | |
| self.wait_for_processing_complete(2.0) | |
| return self.get_transcription() | |
| def clear_buffer(self): | |
| """Clear the audio buffer and transcription""" | |
| with self.lock: | |
| self.audio_buffer = np.array([]) | |
| self.processed_length = 0 | |
| self.full_transcription = "" | |
| self.last_segment_text = "" | |
| self.confirmed_transcription = "" | |
| self.is_processing = False | |
| return "Buffers cleared" | |
| def get_transcription(self): | |
| """Get the current transcription text""" | |
| with self.lock: | |
| results: List[List[str]] = self.sentence_end_detect.infer([self.full_transcription]) | |
| return results[0] | |
| def get_playback_audio(self): | |
| """Get properly formatted audio for Gradio playback""" | |
| with self.lock: | |
| if len(self.audio_buffer) == 0: | |
| return None | |
| # Make a copy and ensure proper format for Gradio | |
| audio = self.audio_buffer.copy() | |
| # Ensure audio is in the correct range for playback (-1 to 1) | |
| if np.max(np.abs(audio)) > 0: | |
| audio = audio / max(1.0, np.max(np.abs(audio))) | |
| return (self.sample_rate, audio) | |
| def get_buffer_info(self): | |
| """Get information about the current buffer state""" | |
| with self.lock: | |
| return { | |
| "buffer_length_seconds": len(self.audio_buffer) / self.sample_rate, | |
| "processed_length_seconds": self.processed_length / self.sample_rate, | |
| "unprocessed_length_seconds": (len(self.audio_buffer) - self.processed_length) / self.sample_rate, | |
| "is_processing": self.is_processing, | |
| "transcription_length": len(self.full_transcription) | |
| } | |