Spaces:
Sleeping
Sleeping
File size: 14,356 Bytes
8b3bbb3 66a7fab 7b7db64 8b3bbb3 7b7db64 8b3bbb3 7b7174c 8b3bbb3 7b7db64 8b3bbb3 7b7174c 7b7db64 8b3bbb3 7b7db64 7b7174c 7b7db64 66a7fab 7b7db64 66a7fab 7b7174c 8b3bbb3 7b7174c 8b3bbb3 7b7174c 8b3bbb3 7b7174c 8b3bbb3 7b7174c 8b3bbb3 7b7174c 8b3bbb3 7b7174c 8b3bbb3 7b7174c 8b3bbb3 7b7174c 8b3bbb3 7b7174c 8b3bbb3 7b7174c 7b7db64 7b7174c 8b3bbb3 7b7174c 7b7db64 8b3bbb3 7b7174c 8b3bbb3 7b7174c 8b3bbb3 7b7174c 66a7fab 7b7174c 8b3bbb3 7b7174c 8b3bbb3 7b7174c 8b3bbb3 7b7174c 8b3bbb3 7b7174c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 |
import numpy as np
import threading
import time
from faster_whisper import WhisperModel
import scipy.signal as signal
from typing import List
from punctuators.models import SBDModelONNX
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import config
class AudioProcessor:
def __init__(self, model_size="tiny.en", device=None, compute_type=None):
"""Initialize the audio processor with configurable parameters"""
self.audio_buffer = np.array([]) # Stores raw audio for playback
self.processed_length = 0 # Length of audio already processed
self.sample_rate = 16000 # Default sample rate for whisper
self.lock = threading.Lock() # Thread safety for buffer access
self.min_process_length = 1 * self.sample_rate # Process at least 1 second
self.max_buffer_size = 30 * self.sample_rate # Maximum buffer size (30 seconds)
self.overlap_size = 3 * self.sample_rate # Keep 3 seconds of overlap when trimming
self.last_process_time = time.time()
self.process_interval = 0.5 # Process every 1 second
self.is_processing = False # Flag to prevent concurrent processing
self.full_transcription = "" # Complete history of transcription
self.last_segment_text = "" # Last segment that was transcribed
self.confirmed_transcription = "" # Transcription that won't change (beyond overlap zone)
# Use config for device and compute type if not specified
if device is None or compute_type is None:
whisper_config = config.get_whisper_config()
device = device or whisper_config["device"]
compute_type = compute_type or whisper_config["compute_type"]
# Initialize the whisper model
self.audio_model = WhisperModel(model_size, device=device, compute_type=compute_type)
print(f"Initialized {model_size} model on {device} with {compute_type}")
# Initialize sentence boundary detection with device config
self.sentence_end_detect = SBDModelONNX.from_pretrained("sbd_multi_lang")
if config.device == "cuda":
print("SBD model initialized with CUDA support")
def _trim_buffer_intelligently(self):
"""
Trim the buffer while preserving transcription continuity
Keep some overlap to maintain context for the next processing
"""
if len(self.audio_buffer) <= self.max_buffer_size:
return
# Calculate how much to trim (keep overlap_size for context)
trim_amount = len(self.audio_buffer) - self.max_buffer_size + self.overlap_size
# Make sure we don't trim more than we have
trim_amount = min(trim_amount, len(self.audio_buffer) - self.overlap_size)
if trim_amount > 0:
# Before trimming, finalize the transcription for the part we're removing
# This ensures we don't lose confirmed text
if self.processed_length > trim_amount:
# We're removing audio that was already processed
# The transcription for this part should be considered final
pass # The full_transcription already contains this
# Trim the buffer
self.audio_buffer = self.audio_buffer[trim_amount:]
# Adjust processed_length to account for trimmed audio
self.processed_length = max(0, self.processed_length - trim_amount)
# Reset last_segment_text since our context has changed
# This forces the next processing to start fresh with overlap handling
self.last_segment_text = ""
def _process_audio_chunk(self):
"""Process the current audio buffer and return new transcription"""
try:
with self.lock:
# Check if there's enough new content to process
unprocessed_length = len(self.audio_buffer) - self.processed_length
if unprocessed_length < self.min_process_length:
self.is_processing = False
return None
# Determine what portion to process
# Include some overlap from already processed audio for context
overlap_samples = min(self.overlap_size, self.processed_length)
start_pos = max(0, self.processed_length - overlap_samples)
# Process from start_pos to end of buffer
audio_to_process = self.audio_buffer[start_pos:].copy()
end_pos = len(self.audio_buffer)
# Normalize for transcription
audio_norm = audio_to_process.astype(np.float32)
if np.max(np.abs(audio_norm)) > 0:
audio_norm = audio_norm / np.max(np.abs(audio_norm))
# Transcribe with faster settings for real-time processing
segments, info = self.audio_model.transcribe(
audio_norm,
beam_size=1,
word_timestamps=False,
vad_filter=True,
vad_parameters=dict(min_silence_duration_ms=500)
)
result = list(segments)
if result:
# Get the new text from all segments
current_segment_text = " ".join([seg.text.strip() for seg in result if seg.text.strip()])
if not current_segment_text:
self.is_processing = False
return None
# Handle overlap and merge with existing transcription
new_text = self._merge_transcription_intelligently(current_segment_text)
if new_text:
# Append new text to full transcription
if self.full_transcription and not self.full_transcription.endswith(' '):
self.full_transcription += " "
self.full_transcription += new_text
# Update state
self.last_segment_text = current_segment_text
self.processed_length = end_pos
return self.full_transcription
return None
except Exception as e:
print(f"Transcription error: {e}")
return None
finally:
self.is_processing = False
def _merge_transcription_intelligently(self, new_segment_text):
"""
Intelligently merge new transcription with existing text
Handles overlap detection and prevents duplication
"""
if not new_segment_text or not new_segment_text.strip():
return ""
# If this is the first transcription or we reset context, use it directly
if not self.last_segment_text:
return new_segment_text
# Normalize text for comparison
import re
def normalize_for_comparison(text):
# Convert to lowercase and remove punctuation for comparison
text = text.lower()
text = re.sub(r'[^\w\s]', '', text)
return text.strip()
norm_prev = normalize_for_comparison(self.last_segment_text)
norm_new = normalize_for_comparison(new_segment_text)
if not norm_prev or not norm_new:
return new_segment_text
# Split into words for overlap detection
prev_words = norm_prev.split()
new_words = norm_new.split()
# Find the longest overlap between end of previous and start of new
max_overlap = min(len(prev_words), len(new_words), 15) # Check up to 15 words
overlap_found = 0
for i in range(max_overlap, 2, -1): # Minimum 3 words to consider overlap
if prev_words[-i:] == new_words[:i]:
overlap_found = i
break
# Handle special cases for numbers (counting sequences)
if overlap_found == 0:
# Check if we have a counting sequence
prev_numbers = [int(x) for x in re.findall(r'\b\d+\b', norm_prev)]
new_numbers = [int(x) for x in re.findall(r'\b\d+\b', norm_new)]
if prev_numbers and new_numbers:
max_prev = max(prev_numbers)
min_new = min(new_numbers)
# If there's a logical continuation, find where it starts
if min_new <= max_prev + 5: # Allow some gap in counting
new_text_words = new_segment_text.split()
for i, word in enumerate(new_text_words):
if re.search(r'\b\d+\b', word):
num = int(re.search(r'\d+', word).group())
if num > max_prev:
return " ".join(new_text_words[i:])
# Apply overlap removal if found
if overlap_found > 0:
new_text_words = new_segment_text.split()
return " ".join(new_text_words[overlap_found:])
else:
# Check if new text is completely contained in previous (avoid duplication)
if norm_new in norm_prev:
return ""
# No overlap found, return the full new text
return new_segment_text
def add_audio(self, audio_data, sr):
"""
Add audio to the buffer and process if needed
Args:
audio_data (numpy.ndarray): Audio data to add
sr (int): Sample rate of the audio data
Returns:
int: Current buffer size in samples
"""
with self.lock:
# Convert to mono if stereo
if audio_data.ndim > 1:
audio_data = audio_data.mean(axis=1)
# Convert to float32
audio_data = audio_data.astype(np.float32)
# Resample if needed
if sr != self.sample_rate:
try:
# Use scipy for proper resampling
number_of_samples = int(len(audio_data) * self.sample_rate / sr)
audio_data = signal.resample(audio_data, number_of_samples)
except Exception as e:
print(f"Resampling error: {e}")
# Fallback resampling
ratio = self.sample_rate / sr
audio_data = np.interp(
np.arange(0, len(audio_data) * ratio, ratio),
np.arange(0, len(audio_data)),
audio_data
)
# Apply fade-in to prevent clicks (5ms fade)
fade_samples = min(int(0.005 * self.sample_rate), len(audio_data))
if fade_samples > 0:
fade_in = np.linspace(0, 1, fade_samples)
audio_data[:fade_samples] *= fade_in
# Add to buffer
if len(self.audio_buffer) == 0:
self.audio_buffer = audio_data
else:
self.audio_buffer = np.concatenate([self.audio_buffer, audio_data])
# Intelligently trim buffer if it gets too large
self._trim_buffer_intelligently()
# Check if we should process now
should_process = (
len(self.audio_buffer) >= self.min_process_length and
time.time() - self.last_process_time >= self.process_interval and
not self.is_processing
)
if should_process:
self.last_process_time = time.time()
self.is_processing = True
# Process in a separate thread
threading.Thread(target=self._process_audio_chunk, daemon=False).start()
return len(self.audio_buffer)
def wait_for_processing_complete(self, timeout=5.0):
"""Wait for any current processing to complete"""
start_time = time.time()
while self.is_processing and (time.time() - start_time) < timeout:
time.sleep(0.05)
return not self.is_processing
def force_complete_processing(self):
"""Force completion of any pending processing - ensures sequential execution"""
# Wait for any current processing to complete
self.wait_for_processing_complete(10.0)
# Process any remaining audio in buffer
with self.lock:
if len(self.audio_buffer) > self.processed_length:
# Force process remaining audio
self.is_processing = True
self._process_audio_chunk()
# Final wait to ensure everything is complete
self.wait_for_processing_complete(2.0)
return self.get_transcription()
def clear_buffer(self):
"""Clear the audio buffer and transcription"""
with self.lock:
self.audio_buffer = np.array([])
self.processed_length = 0
self.full_transcription = ""
self.last_segment_text = ""
self.confirmed_transcription = ""
self.is_processing = False
return "Buffers cleared"
def get_transcription(self):
"""Get the current transcription text"""
with self.lock:
results: List[List[str]] = self.sentence_end_detect.infer([self.full_transcription])
return results[0]
def get_playback_audio(self):
"""Get properly formatted audio for Gradio playback"""
with self.lock:
if len(self.audio_buffer) == 0:
return None
# Make a copy and ensure proper format for Gradio
audio = self.audio_buffer.copy()
# Ensure audio is in the correct range for playback (-1 to 1)
if np.max(np.abs(audio)) > 0:
audio = audio / max(1.0, np.max(np.abs(audio)))
return (self.sample_rate, audio)
def get_buffer_info(self):
"""Get information about the current buffer state"""
with self.lock:
return {
"buffer_length_seconds": len(self.audio_buffer) / self.sample_rate,
"processed_length_seconds": self.processed_length / self.sample_rate,
"unprocessed_length_seconds": (len(self.audio_buffer) - self.processed_length) / self.sample_rate,
"is_processing": self.is_processing,
"transcription_length": len(self.full_transcription)
}
|