Spaces:
Sleeping
Sleeping
Commit
·
f4275bf
1
Parent(s):
bd39e10
Transcript
Browse files- inference.py +19 -0
- shared.py +57 -8
- test_websocket.py +1 -0
- ui.py +10 -2
inference.py
CHANGED
|
@@ -10,6 +10,16 @@ import time
|
|
| 10 |
from typing import Set, Dict, Any
|
| 11 |
import traceback
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
# Set up logging
|
| 14 |
logging.basicConfig(
|
| 15 |
level=logging.INFO,
|
|
@@ -185,6 +195,15 @@ async def shutdown_event():
|
|
| 185 |
try:
|
| 186 |
diart.stop_recording()
|
| 187 |
logger.info("Recording stopped")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
except Exception as e:
|
| 189 |
logger.error(f"Error stopping recording: {e}")
|
| 190 |
|
|
|
|
| 10 |
from typing import Set, Dict, Any
|
| 11 |
import traceback
|
| 12 |
|
| 13 |
+
# Check for RealtimeSTT and install if needed
|
| 14 |
+
try:
|
| 15 |
+
from RealtimeSTT import AudioToTextRecorder
|
| 16 |
+
except ImportError:
|
| 17 |
+
import subprocess
|
| 18 |
+
import sys
|
| 19 |
+
print("Installing RealtimeSTT dependency...")
|
| 20 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "RealtimeSTT"])
|
| 21 |
+
from RealtimeSTT import AudioToTextRecorder
|
| 22 |
+
|
| 23 |
# Set up logging
|
| 24 |
logging.basicConfig(
|
| 25 |
level=logging.INFO,
|
|
|
|
| 195 |
try:
|
| 196 |
diart.stop_recording()
|
| 197 |
logger.info("Recording stopped")
|
| 198 |
+
|
| 199 |
+
# Shutdown RealtimeSTT properly if available
|
| 200 |
+
if hasattr(diart, 'recorder') and diart.recorder:
|
| 201 |
+
try:
|
| 202 |
+
diart.recorder.shutdown()
|
| 203 |
+
logger.info("Transcription model shut down")
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.error(f"Error shutting down transcription model: {e}")
|
| 206 |
+
|
| 207 |
except Exception as e:
|
| 208 |
logger.error(f"Error stopping recording: {e}")
|
| 209 |
|
shared.py
CHANGED
|
@@ -8,6 +8,9 @@ import torchaudio
|
|
| 8 |
from scipy.spatial.distance import cosine
|
| 9 |
from scipy.signal import resample
|
| 10 |
import logging
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# Set up logging
|
| 13 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -64,12 +67,26 @@ class SpeechBrainEncoder:
|
|
| 64 |
self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
|
| 65 |
os.makedirs(self.cache_dir, exist_ok=True)
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def load_model(self):
|
| 68 |
"""Load the ECAPA-TDNN model"""
|
| 69 |
try:
|
| 70 |
# Import SpeechBrain
|
| 71 |
from speechbrain.pretrained import EncoderClassifier
|
| 72 |
|
|
|
|
|
|
|
|
|
|
| 73 |
# Load the pre-trained model
|
| 74 |
self.model = EncoderClassifier.from_hparams(
|
| 75 |
source="speechbrain/spkrec-ecapa-voxceleb",
|
|
@@ -286,7 +303,7 @@ class RealtimeSpeakerDiarization:
|
|
| 286 |
self.encoder = None
|
| 287 |
self.audio_processor = None
|
| 288 |
self.speaker_detector = None
|
| 289 |
-
self.recorder = None
|
| 290 |
self.sentence_queue = queue.Queue()
|
| 291 |
self.full_sentences = []
|
| 292 |
self.sentence_speakers = []
|
|
@@ -314,6 +331,25 @@ class RealtimeSpeakerDiarization:
|
|
| 314 |
change_threshold=self.change_threshold,
|
| 315 |
max_speakers=self.max_speakers
|
| 316 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
logger.info("Models initialized successfully!")
|
| 318 |
return True
|
| 319 |
else:
|
|
@@ -416,6 +452,11 @@ class RealtimeSpeakerDiarization:
|
|
| 416 |
self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
|
| 417 |
self.sentence_thread.start()
|
| 418 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
return "Recording started successfully!"
|
| 420 |
|
| 421 |
except Exception as e:
|
|
@@ -425,6 +466,15 @@ class RealtimeSpeakerDiarization:
|
|
| 425 |
def stop_recording(self):
|
| 426 |
"""Stop the recording process"""
|
| 427 |
self.is_running = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
return "Recording stopped!"
|
| 429 |
|
| 430 |
def clear_conversation(self):
|
|
@@ -573,6 +623,12 @@ class RealtimeSpeakerDiarization:
|
|
| 573 |
# Add to audio processor buffer for speaker detection
|
| 574 |
self.audio_processor.add_audio_chunk(audio_data)
|
| 575 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 576 |
# Periodically extract embeddings for speaker detection
|
| 577 |
embedding = None
|
| 578 |
speaker_id = self.speaker_detector.current_speaker
|
|
@@ -582,12 +638,6 @@ class RealtimeSpeakerDiarization:
|
|
| 582 |
embedding = self.audio_processor.extract_embedding_from_buffer()
|
| 583 |
if embedding is not None:
|
| 584 |
speaker_id, similarity = self.speaker_detector.add_embedding(embedding)
|
| 585 |
-
|
| 586 |
-
# Add a simulated sentence for demo purposes
|
| 587 |
-
if similarity < 0.5:
|
| 588 |
-
with self.transcription_lock:
|
| 589 |
-
self.full_sentences.append((f"[Audio segment {self.speaker_detector.segment_counter}]", speaker_id))
|
| 590 |
-
self.update_conversation_display()
|
| 591 |
|
| 592 |
# Return processing result
|
| 593 |
return {
|
|
@@ -595,7 +645,6 @@ class RealtimeSpeakerDiarization:
|
|
| 595 |
"buffer_size": len(self.audio_processor.audio_buffer),
|
| 596 |
"speaker_id": int(speaker_id) if not isinstance(speaker_id, int) else speaker_id,
|
| 597 |
"similarity": float(similarity) if embedding is not None and not isinstance(similarity, float) else similarity,
|
| 598 |
-
"latest_sentence": f"[Audio segment {self.speaker_detector.segment_counter}]" if similarity < 0.5 else None,
|
| 599 |
"conversation_html": self.current_conversation
|
| 600 |
}
|
| 601 |
|
|
|
|
| 8 |
from scipy.spatial.distance import cosine
|
| 9 |
from scipy.signal import resample
|
| 10 |
import logging
|
| 11 |
+
import urllib.request
|
| 12 |
+
# Import RealtimeSTT for transcription
|
| 13 |
+
from RealtimeSTT import AudioToTextRecorder
|
| 14 |
|
| 15 |
# Set up logging
|
| 16 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 67 |
self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain")
|
| 68 |
os.makedirs(self.cache_dir, exist_ok=True)
|
| 69 |
|
| 70 |
+
def _download_model(self):
|
| 71 |
+
"""Download pre-trained SpeechBrain ECAPA-TDNN model if not present"""
|
| 72 |
+
model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt"
|
| 73 |
+
model_path = os.path.join(self.cache_dir, "embedding_model.ckpt")
|
| 74 |
+
|
| 75 |
+
if not os.path.exists(model_path):
|
| 76 |
+
print(f"Downloading ECAPA-TDNN model to {model_path}...")
|
| 77 |
+
urllib.request.urlretrieve(model_url, model_path)
|
| 78 |
+
|
| 79 |
+
return model_path
|
| 80 |
+
|
| 81 |
def load_model(self):
|
| 82 |
"""Load the ECAPA-TDNN model"""
|
| 83 |
try:
|
| 84 |
# Import SpeechBrain
|
| 85 |
from speechbrain.pretrained import EncoderClassifier
|
| 86 |
|
| 87 |
+
# Get model path
|
| 88 |
+
model_path = self._download_model()
|
| 89 |
+
|
| 90 |
# Load the pre-trained model
|
| 91 |
self.model = EncoderClassifier.from_hparams(
|
| 92 |
source="speechbrain/spkrec-ecapa-voxceleb",
|
|
|
|
| 303 |
self.encoder = None
|
| 304 |
self.audio_processor = None
|
| 305 |
self.speaker_detector = None
|
| 306 |
+
self.recorder = None # RealtimeSTT recorder
|
| 307 |
self.sentence_queue = queue.Queue()
|
| 308 |
self.full_sentences = []
|
| 309 |
self.sentence_speakers = []
|
|
|
|
| 331 |
change_threshold=self.change_threshold,
|
| 332 |
max_speakers=self.max_speakers
|
| 333 |
)
|
| 334 |
+
|
| 335 |
+
# Initialize RealtimeSTT transcription model
|
| 336 |
+
self.recorder = AudioToTextRecorder(
|
| 337 |
+
spinner=False,
|
| 338 |
+
use_microphone=False,
|
| 339 |
+
model=FINAL_TRANSCRIPTION_MODEL,
|
| 340 |
+
language=TRANSCRIPTION_LANGUAGE,
|
| 341 |
+
silero_sensitivity=SILERO_SENSITIVITY,
|
| 342 |
+
webrtc_sensitivity=WEBRTC_SENSITIVITY,
|
| 343 |
+
post_speech_silence_duration=0.7,
|
| 344 |
+
min_length_of_recording=MIN_LENGTH_OF_RECORDING,
|
| 345 |
+
pre_recording_buffer_duration=PRE_RECORDING_BUFFER_DURATION,
|
| 346 |
+
enable_realtime_transcription=True,
|
| 347 |
+
realtime_processing_pause=0,
|
| 348 |
+
realtime_model_type=REALTIME_TRANSCRIPTION_MODEL,
|
| 349 |
+
on_realtime_transcription_stabilized=self.live_text_detected,
|
| 350 |
+
on_recording_complete=self.process_final_text
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
logger.info("Models initialized successfully!")
|
| 354 |
return True
|
| 355 |
else:
|
|
|
|
| 452 |
self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True)
|
| 453 |
self.sentence_thread.start()
|
| 454 |
|
| 455 |
+
# Start the RealtimeSTT recorder if not already started
|
| 456 |
+
if self.recorder and not getattr(self.recorder, '_is_running', False):
|
| 457 |
+
self.recorder.start()
|
| 458 |
+
logger.info("RealtimeSTT recorder started")
|
| 459 |
+
|
| 460 |
return "Recording started successfully!"
|
| 461 |
|
| 462 |
except Exception as e:
|
|
|
|
| 466 |
def stop_recording(self):
|
| 467 |
"""Stop the recording process"""
|
| 468 |
self.is_running = False
|
| 469 |
+
|
| 470 |
+
# Stop the RealtimeSTT recorder
|
| 471 |
+
if self.recorder:
|
| 472 |
+
try:
|
| 473 |
+
self.recorder.stop()
|
| 474 |
+
logger.info("RealtimeSTT recorder stopped")
|
| 475 |
+
except Exception as e:
|
| 476 |
+
logger.error(f"Error stopping recorder: {e}")
|
| 477 |
+
|
| 478 |
return "Recording stopped!"
|
| 479 |
|
| 480 |
def clear_conversation(self):
|
|
|
|
| 623 |
# Add to audio processor buffer for speaker detection
|
| 624 |
self.audio_processor.add_audio_chunk(audio_data)
|
| 625 |
|
| 626 |
+
# Feed to RealtimeSTT for transcription
|
| 627 |
+
if self.recorder:
|
| 628 |
+
# Convert to int16 for RealtimeSTT
|
| 629 |
+
audio_int16 = (audio_data * 32768).astype(np.int16)
|
| 630 |
+
self.recorder.feed_audio(audio_int16.tobytes())
|
| 631 |
+
|
| 632 |
# Periodically extract embeddings for speaker detection
|
| 633 |
embedding = None
|
| 634 |
speaker_id = self.speaker_detector.current_speaker
|
|
|
|
| 638 |
embedding = self.audio_processor.extract_embedding_from_buffer()
|
| 639 |
if embedding is not None:
|
| 640 |
speaker_id, similarity = self.speaker_detector.add_embedding(embedding)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 641 |
|
| 642 |
# Return processing result
|
| 643 |
return {
|
|
|
|
| 645 |
"buffer_size": len(self.audio_processor.audio_buffer),
|
| 646 |
"speaker_id": int(speaker_id) if not isinstance(speaker_id, int) else speaker_id,
|
| 647 |
"similarity": float(similarity) if embedding is not None and not isinstance(similarity, float) else similarity,
|
|
|
|
| 648 |
"conversation_html": self.current_conversation
|
| 649 |
}
|
| 650 |
|
test_websocket.py
CHANGED
|
@@ -15,6 +15,7 @@ async def test_ws():
|
|
| 15 |
audio = (np.random.randn(3200) * 3000).astype(np.int16)
|
| 16 |
await websocket.send(audio.tobytes())
|
| 17 |
print(f"Sent audio chunk {i+1}/20")
|
|
|
|
| 18 |
|
| 19 |
try:
|
| 20 |
while True:
|
|
|
|
| 15 |
audio = (np.random.randn(3200) * 3000).astype(np.int16)
|
| 16 |
await websocket.send(audio.tobytes())
|
| 17 |
print(f"Sent audio chunk {i+1}/20")
|
| 18 |
+
await asyncio.sleep(0.05)
|
| 19 |
|
| 20 |
try:
|
| 21 |
while True:
|
ui.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from fastapi import FastAPI
|
| 3 |
-
from shared import DEFAULT_CHANGE_THRESHOLD, DEFAULT_MAX_SPEAKERS, ABSOLUTE_MAX_SPEAKERS
|
| 4 |
print(gr.__version__)
|
| 5 |
# Connection configuration (separate signaling server from model server)
|
| 6 |
# These will be replaced at deployment time with the correct URLs
|
|
@@ -23,7 +23,10 @@ def build_ui():
|
|
| 23 |
|
| 24 |
# Header and description
|
| 25 |
gr.Markdown("# 🎤 Live Speaker Diarization")
|
| 26 |
-
gr.Markdown("Real-time speech recognition with automatic speaker identification")
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# Status indicator
|
| 29 |
connection_status = gr.HTML(
|
|
@@ -459,6 +462,11 @@ def build_ui():
|
|
| 459 |
<li>Threshold: ${threshold}</li>
|
| 460 |
<li>Max Speakers: ${maxSpeakers}</li>
|
| 461 |
</ul>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
`;
|
| 463 |
}
|
| 464 |
});
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from fastapi import FastAPI
|
| 3 |
+
from shared import DEFAULT_CHANGE_THRESHOLD, DEFAULT_MAX_SPEAKERS, ABSOLUTE_MAX_SPEAKERS, FINAL_TRANSCRIPTION_MODEL, REALTIME_TRANSCRIPTION_MODEL
|
| 4 |
print(gr.__version__)
|
| 5 |
# Connection configuration (separate signaling server from model server)
|
| 6 |
# These will be replaced at deployment time with the correct URLs
|
|
|
|
| 23 |
|
| 24 |
# Header and description
|
| 25 |
gr.Markdown("# 🎤 Live Speaker Diarization")
|
| 26 |
+
gr.Markdown(f"Real-time speech recognition with automatic speaker identification")
|
| 27 |
+
|
| 28 |
+
# Add transcription model info
|
| 29 |
+
gr.Markdown(f"**Using Models:** Final: {FINAL_TRANSCRIPTION_MODEL}, Realtime: {REALTIME_TRANSCRIPTION_MODEL}")
|
| 30 |
|
| 31 |
# Status indicator
|
| 32 |
connection_status = gr.HTML(
|
|
|
|
| 462 |
<li>Threshold: ${threshold}</li>
|
| 463 |
<li>Max Speakers: ${maxSpeakers}</li>
|
| 464 |
</ul>
|
| 465 |
+
<p>Transcription Models:</p>
|
| 466 |
+
<ul>
|
| 467 |
+
<li>Final: ${window.FINAL_TRANSCRIPTION_MODEL || "distil-large-v3"}</li>
|
| 468 |
+
<li>Realtime: ${window.REALTIME_TRANSCRIPTION_MODEL || "distil-small.en"}</li>
|
| 469 |
+
</ul>
|
| 470 |
`;
|
| 471 |
}
|
| 472 |
});
|