Testing / components /transcriber.py
Sidak Singh
question boundary works
7b7db64
import numpy as np
import threading
import time
from faster_whisper import WhisperModel
import scipy.signal as signal
from typing import List
from punctuators.models import SBDModelONNX
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import config
class AudioProcessor:
def __init__(self, model_size="tiny.en", device=None, compute_type=None):
"""Initialize the audio processor with configurable parameters"""
self.audio_buffer = np.array([]) # Stores raw audio for playback
self.processed_length = 0 # Length of audio already processed
self.sample_rate = 16000 # Default sample rate for whisper
self.lock = threading.Lock() # Thread safety for buffer access
self.min_process_length = 1 * self.sample_rate # Process at least 1 second
self.max_buffer_size = 30 * self.sample_rate # Maximum buffer size (30 seconds)
self.overlap_size = 3 * self.sample_rate # Keep 3 seconds of overlap when trimming
self.last_process_time = time.time()
self.process_interval = 0.5 # Process every 1 second
self.is_processing = False # Flag to prevent concurrent processing
self.full_transcription = "" # Complete history of transcription
self.last_segment_text = "" # Last segment that was transcribed
self.confirmed_transcription = "" # Transcription that won't change (beyond overlap zone)
# Use config for device and compute type if not specified
if device is None or compute_type is None:
whisper_config = config.get_whisper_config()
device = device or whisper_config["device"]
compute_type = compute_type or whisper_config["compute_type"]
# Initialize the whisper model
self.audio_model = WhisperModel(model_size, device=device, compute_type=compute_type)
print(f"Initialized {model_size} model on {device} with {compute_type}")
# Initialize sentence boundary detection with device config
self.sentence_end_detect = SBDModelONNX.from_pretrained("sbd_multi_lang")
if config.device == "cuda":
print("SBD model initialized with CUDA support")
def _trim_buffer_intelligently(self):
"""
Trim the buffer while preserving transcription continuity
Keep some overlap to maintain context for the next processing
"""
if len(self.audio_buffer) <= self.max_buffer_size:
return
# Calculate how much to trim (keep overlap_size for context)
trim_amount = len(self.audio_buffer) - self.max_buffer_size + self.overlap_size
# Make sure we don't trim more than we have
trim_amount = min(trim_amount, len(self.audio_buffer) - self.overlap_size)
if trim_amount > 0:
# Before trimming, finalize the transcription for the part we're removing
# This ensures we don't lose confirmed text
if self.processed_length > trim_amount:
# We're removing audio that was already processed
# The transcription for this part should be considered final
pass # The full_transcription already contains this
# Trim the buffer
self.audio_buffer = self.audio_buffer[trim_amount:]
# Adjust processed_length to account for trimmed audio
self.processed_length = max(0, self.processed_length - trim_amount)
# Reset last_segment_text since our context has changed
# This forces the next processing to start fresh with overlap handling
self.last_segment_text = ""
def _process_audio_chunk(self):
"""Process the current audio buffer and return new transcription"""
try:
with self.lock:
# Check if there's enough new content to process
unprocessed_length = len(self.audio_buffer) - self.processed_length
if unprocessed_length < self.min_process_length:
self.is_processing = False
return None
# Determine what portion to process
# Include some overlap from already processed audio for context
overlap_samples = min(self.overlap_size, self.processed_length)
start_pos = max(0, self.processed_length - overlap_samples)
# Process from start_pos to end of buffer
audio_to_process = self.audio_buffer[start_pos:].copy()
end_pos = len(self.audio_buffer)
# Normalize for transcription
audio_norm = audio_to_process.astype(np.float32)
if np.max(np.abs(audio_norm)) > 0:
audio_norm = audio_norm / np.max(np.abs(audio_norm))
# Transcribe with faster settings for real-time processing
segments, info = self.audio_model.transcribe(
audio_norm,
beam_size=1,
word_timestamps=False,
vad_filter=True,
vad_parameters=dict(min_silence_duration_ms=500)
)
result = list(segments)
if result:
# Get the new text from all segments
current_segment_text = " ".join([seg.text.strip() for seg in result if seg.text.strip()])
if not current_segment_text:
self.is_processing = False
return None
# Handle overlap and merge with existing transcription
new_text = self._merge_transcription_intelligently(current_segment_text)
if new_text:
# Append new text to full transcription
if self.full_transcription and not self.full_transcription.endswith(' '):
self.full_transcription += " "
self.full_transcription += new_text
# Update state
self.last_segment_text = current_segment_text
self.processed_length = end_pos
return self.full_transcription
return None
except Exception as e:
print(f"Transcription error: {e}")
return None
finally:
self.is_processing = False
def _merge_transcription_intelligently(self, new_segment_text):
"""
Intelligently merge new transcription with existing text
Handles overlap detection and prevents duplication
"""
if not new_segment_text or not new_segment_text.strip():
return ""
# If this is the first transcription or we reset context, use it directly
if not self.last_segment_text:
return new_segment_text
# Normalize text for comparison
import re
def normalize_for_comparison(text):
# Convert to lowercase and remove punctuation for comparison
text = text.lower()
text = re.sub(r'[^\w\s]', '', text)
return text.strip()
norm_prev = normalize_for_comparison(self.last_segment_text)
norm_new = normalize_for_comparison(new_segment_text)
if not norm_prev or not norm_new:
return new_segment_text
# Split into words for overlap detection
prev_words = norm_prev.split()
new_words = norm_new.split()
# Find the longest overlap between end of previous and start of new
max_overlap = min(len(prev_words), len(new_words), 15) # Check up to 15 words
overlap_found = 0
for i in range(max_overlap, 2, -1): # Minimum 3 words to consider overlap
if prev_words[-i:] == new_words[:i]:
overlap_found = i
break
# Handle special cases for numbers (counting sequences)
if overlap_found == 0:
# Check if we have a counting sequence
prev_numbers = [int(x) for x in re.findall(r'\b\d+\b', norm_prev)]
new_numbers = [int(x) for x in re.findall(r'\b\d+\b', norm_new)]
if prev_numbers and new_numbers:
max_prev = max(prev_numbers)
min_new = min(new_numbers)
# If there's a logical continuation, find where it starts
if min_new <= max_prev + 5: # Allow some gap in counting
new_text_words = new_segment_text.split()
for i, word in enumerate(new_text_words):
if re.search(r'\b\d+\b', word):
num = int(re.search(r'\d+', word).group())
if num > max_prev:
return " ".join(new_text_words[i:])
# Apply overlap removal if found
if overlap_found > 0:
new_text_words = new_segment_text.split()
return " ".join(new_text_words[overlap_found:])
else:
# Check if new text is completely contained in previous (avoid duplication)
if norm_new in norm_prev:
return ""
# No overlap found, return the full new text
return new_segment_text
def add_audio(self, audio_data, sr):
"""
Add audio to the buffer and process if needed
Args:
audio_data (numpy.ndarray): Audio data to add
sr (int): Sample rate of the audio data
Returns:
int: Current buffer size in samples
"""
with self.lock:
# Convert to mono if stereo
if audio_data.ndim > 1:
audio_data = audio_data.mean(axis=1)
# Convert to float32
audio_data = audio_data.astype(np.float32)
# Resample if needed
if sr != self.sample_rate:
try:
# Use scipy for proper resampling
number_of_samples = int(len(audio_data) * self.sample_rate / sr)
audio_data = signal.resample(audio_data, number_of_samples)
except Exception as e:
print(f"Resampling error: {e}")
# Fallback resampling
ratio = self.sample_rate / sr
audio_data = np.interp(
np.arange(0, len(audio_data) * ratio, ratio),
np.arange(0, len(audio_data)),
audio_data
)
# Apply fade-in to prevent clicks (5ms fade)
fade_samples = min(int(0.005 * self.sample_rate), len(audio_data))
if fade_samples > 0:
fade_in = np.linspace(0, 1, fade_samples)
audio_data[:fade_samples] *= fade_in
# Add to buffer
if len(self.audio_buffer) == 0:
self.audio_buffer = audio_data
else:
self.audio_buffer = np.concatenate([self.audio_buffer, audio_data])
# Intelligently trim buffer if it gets too large
self._trim_buffer_intelligently()
# Check if we should process now
should_process = (
len(self.audio_buffer) >= self.min_process_length and
time.time() - self.last_process_time >= self.process_interval and
not self.is_processing
)
if should_process:
self.last_process_time = time.time()
self.is_processing = True
# Process in a separate thread
threading.Thread(target=self._process_audio_chunk, daemon=False).start()
return len(self.audio_buffer)
def wait_for_processing_complete(self, timeout=5.0):
"""Wait for any current processing to complete"""
start_time = time.time()
while self.is_processing and (time.time() - start_time) < timeout:
time.sleep(0.05)
return not self.is_processing
def force_complete_processing(self):
"""Force completion of any pending processing - ensures sequential execution"""
# Wait for any current processing to complete
self.wait_for_processing_complete(10.0)
# Process any remaining audio in buffer
with self.lock:
if len(self.audio_buffer) > self.processed_length:
# Force process remaining audio
self.is_processing = True
self._process_audio_chunk()
# Final wait to ensure everything is complete
self.wait_for_processing_complete(2.0)
return self.get_transcription()
def clear_buffer(self):
"""Clear the audio buffer and transcription"""
with self.lock:
self.audio_buffer = np.array([])
self.processed_length = 0
self.full_transcription = ""
self.last_segment_text = ""
self.confirmed_transcription = ""
self.is_processing = False
return "Buffers cleared"
def get_transcription(self):
"""Get the current transcription text"""
with self.lock:
results: List[List[str]] = self.sentence_end_detect.infer([self.full_transcription])
return results[0]
def get_playback_audio(self):
"""Get properly formatted audio for Gradio playback"""
with self.lock:
if len(self.audio_buffer) == 0:
return None
# Make a copy and ensure proper format for Gradio
audio = self.audio_buffer.copy()
# Ensure audio is in the correct range for playback (-1 to 1)
if np.max(np.abs(audio)) > 0:
audio = audio / max(1.0, np.max(np.abs(audio)))
return (self.sample_rate, audio)
def get_buffer_info(self):
"""Get information about the current buffer state"""
with self.lock:
return {
"buffer_length_seconds": len(self.audio_buffer) / self.sample_rate,
"processed_length_seconds": self.processed_length / self.sample_rate,
"unprocessed_length_seconds": (len(self.audio_buffer) - self.processed_length) / self.sample_rate,
"is_processing": self.is_processing,
"transcription_length": len(self.full_transcription)
}