Spaces:
Runtime error
Runtime error
| import subprocess | |
| import threading | |
| import argparse | |
| import fcntl | |
| import select | |
| import whisper | |
| import ffmpeg | |
| import signal | |
| import numpy as np | |
| import queue | |
| import time | |
| import webrtcvad | |
| import collections | |
| import os | |
| from transformers import MarianMTModel, MarianTokenizer | |
| # Global variables | |
| rtmp_url = "" | |
| dash_output_path = "" | |
| segment_duration = 2 | |
| last_activity_time = 0.0 | |
| cleanup_threshold = 10 # seconds of inactivity before cleanup | |
| start_time = 0.0 | |
| # Languages for translation (ISO 639-1 codes) | |
| target_languages = ["es", "zh", "ru"] # Example: Spanish, Chinese, Russian | |
| # Initialize Whisper model | |
| whisper_model = {} | |
| # Define Frame class | |
| class Frame: | |
| def __init__(self, data, timestamp, duration): | |
| self.data = data | |
| self.timestamp = timestamp | |
| self.duration = duration | |
| # Audio buffer and caption queues | |
| audio_buffer = queue.Queue() | |
| caption_queues = {lang: queue.Queue() for lang in target_languages + ["original", "en"]} | |
| language_model_names = { | |
| "es": "Helsinki-NLP/opus-mt-en-es", | |
| "zh": "Helsinki-NLP/opus-mt-en-zh", | |
| "ru": "Helsinki-NLP/opus-mt-en-ru", | |
| } | |
| translation_models = {} | |
| tokenizers = {} | |
| # Initialize VAD | |
| vad = webrtcvad.Vad(3) # Aggressiveness mode 3 (most aggressive) | |
| # Event to signal threads to stop | |
| stop_event = threading.Event() | |
| def transcode_rtmp_to_dash(): | |
| ffmpeg_command = [ | |
| "ffmpeg", | |
| "-i", rtmp_url, | |
| "-map", "0:v:0", "-map", "0:a:0", | |
| "-c:v", "libx264", "-preset", "slow", | |
| "-c:a", "aac", "-b:a", "128k", | |
| "-f", "dash", | |
| "-seg_duration", str(segment_duration), | |
| "-use_timeline", "1", | |
| "-use_template", "1", | |
| "-init_seg_name", "init_$RepresentationID$.m4s", | |
| "-media_seg_name", "chunk_$RepresentationID$_$Number%05d$.m4s", | |
| "-adaptation_sets", "id=0,streams=v id=1,streams=a", | |
| f"{dash_output_path}/manifest.mpd" | |
| ] | |
| process = subprocess.Popen(ffmpeg_command) | |
| while not stop_event.is_set(): | |
| time.sleep(1) | |
| process.kill() | |
| def capture_audio(): | |
| global last_activity_time | |
| command = [ | |
| 'ffmpeg', | |
| '-i', rtmp_url, | |
| '-acodec', 'pcm_s16le', | |
| '-ar', '16000', | |
| '-ac', '1', | |
| '-f', 's16le', | |
| '-' | |
| ] | |
| sample_rate = 16000 | |
| frame_duration_ms = 30 | |
| sample_width = 2 # Only 16-bit audio supported | |
| process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) | |
| # Set stdout to non-blocking mode | |
| fd = process.stdout.fileno() | |
| fl = fcntl.fcntl(fd, fcntl.F_GETFL) | |
| fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) | |
| frame_size = int(sample_rate * frame_duration_ms / 1000) * sample_width | |
| frame_count = 0 | |
| while not stop_event.is_set(): | |
| ready, _, _ = select.select([process.stdout], [], [], 0.1) | |
| if ready: | |
| try: | |
| in_bytes = os.read(fd, frame_size) | |
| if not in_bytes: | |
| break | |
| if len(in_bytes) < frame_size: | |
| in_bytes += b'\x00' * (frame_size - len(in_bytes)) | |
| last_activity_time = time.time() | |
| timestamp = frame_count * frame_duration_ms * 0.85 | |
| frame = Frame(np.frombuffer(in_bytes, np.int16), timestamp, frame_duration_ms) | |
| audio_buffer.put(frame) | |
| frame_count += 1 | |
| except BlockingIOError: | |
| continue | |
| else: | |
| time.sleep(0.01) | |
| process.kill() | |
| def frames_to_numpy(frames): | |
| all_frames = np.concatenate([f.data for f in frames]) | |
| float_samples = all_frames.astype(np.float32) / np.iinfo(np.int16).max | |
| return float_samples | |
| def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames): | |
| num_padding_frames = int(padding_duration_ms / frame_duration_ms) | |
| ring_buffer = collections.deque(maxlen=num_padding_frames) | |
| triggered = False | |
| for frame in frames: | |
| if len(frame.data) != int(sample_rate * (frame_duration_ms / 1000.0)): | |
| print(f"Skipping frame with incorrect size: {len(frame.data)} samples", flush=True) | |
| continue | |
| is_speech = vad.is_speech(frame.data.tobytes(), sample_rate) | |
| if not triggered: | |
| ring_buffer.append((frame, is_speech)) | |
| num_voiced = len([f for f, speech in ring_buffer if speech]) | |
| if num_voiced > 0.8 * ring_buffer.maxlen: | |
| triggered = True | |
| for f, s in ring_buffer: | |
| yield f | |
| ring_buffer.clear() | |
| else: | |
| yield frame | |
| ring_buffer.append((frame, is_speech)) | |
| num_unvoiced = len([f for f, speech in ring_buffer if not speech]) | |
| if num_unvoiced > 0.8 * ring_buffer.maxlen: | |
| triggered = False | |
| yield None | |
| ring_buffer.clear() | |
| for f, s in ring_buffer: | |
| yield f | |
| ring_buffer.clear() | |
| def process_audio(): | |
| global last_activity_time | |
| frames = [] | |
| buffer_duration_ms = 1500 # About 1.5 seconds of audio | |
| while not stop_event.is_set(): | |
| while not audio_buffer.empty(): | |
| frame = audio_buffer.get(timeout=5.0) | |
| frames.append(frame) | |
| if frames and sum(f.duration for f in frames) >= buffer_duration_ms: | |
| vad_frames = list(vad_collector(16000, 30, 300, vad, frames)) | |
| if vad_frames: | |
| audio_segment = [f for f in vad_frames if f is not None] | |
| if audio_segment: | |
| # Transcribe the original audio | |
| result = whisper_model.transcribe(frames_to_numpy(audio_segment)) | |
| if result["text"]: | |
| timestamp = audio_segment[0].timestamp | |
| caption_queues["original"].put((timestamp, result["text"])) | |
| english_translation = whisper_model.transcribe(frames_to_numpy(audio_segment), task="translate") | |
| caption_queues["en"].put((timestamp, english_translation["text"])) | |
| # Translate to target languages | |
| for lang in target_languages: | |
| tokenizer = tokenizers[lang] | |
| translation_model = translation_models[lang] | |
| inputs = tokenizer.encode(english_translation["text"], return_tensors="pt", padding=True, truncation=True) | |
| translated_tokens = translation_model.generate(inputs) | |
| translated_text = tokenizer.decode(translated_tokens[0], skip_special_tokens=True) | |
| caption_queues[lang].put((timestamp, translated_text)) | |
| frames = [] | |
| time.sleep(0.01) | |
| def write_captions(lang): | |
| os.makedirs(dash_output_path, exist_ok=True) | |
| filename = f"{dash_output_path}/captions_{lang}.vtt" | |
| with open(filename, "w", encoding="utf-8") as f: | |
| f.write("WEBVTT\n\n") | |
| last_end_time = None | |
| while not stop_event.is_set(): | |
| if not caption_queues[lang].empty(): | |
| timestamp, text = caption_queues[lang].get() | |
| start_time = format_time(timestamp / 1000) # Convert ms to seconds | |
| end_time = format_time((timestamp + 5000) / 1000) # Assume 5-second duration for each caption | |
| # Adjust the previous caption's end time if necessary | |
| if last_end_time and start_time != last_end_time: | |
| adjust_previous_caption(filename, last_end_time, start_time) | |
| # Write the new caption | |
| with open(filename, "a", encoding="utf-8") as f: | |
| f.write(f"{start_time} --> {end_time}\n") | |
| f.write(f"{text}\n\n") | |
| f.flush() | |
| last_end_time = end_time | |
| time.sleep(0.1) | |
| def adjust_previous_caption(filename, old_end_time, new_end_time): | |
| with open(filename, "r", encoding="utf-8") as f: | |
| lines = f.readlines() | |
| for i in range(len(lines) - 1, -1, -1): | |
| if "-->" in lines[i]: | |
| parts = lines[i].split("-->") | |
| if parts[1].strip() == old_end_time: | |
| lines[i] = f"{parts[0].strip()} --> {new_end_time}\n" | |
| break | |
| with open(filename, "w", encoding="utf-8") as f: | |
| f.writelines(lines) | |
| def format_time(seconds): | |
| hours, remainder = divmod(seconds, 3600) | |
| minutes, seconds = divmod(remainder, 60) | |
| return f"{int(hours):02d}:{int(minutes):02d}:{seconds:06.3f}" | |
| def signal_handler(signum, frame): | |
| print(f"Received signal {signum}. Cleaning up and exiting...") | |
| # Signal all threads to stop | |
| stop_event.set() | |
| def cleanup(): | |
| global last_activity_time | |
| while not stop_event.is_set(): | |
| current_time = time.time() | |
| if last_activity_time != 0.0 and current_time - last_activity_time > cleanup_threshold: | |
| print("No activity detected for 10 seconds. Cleaning up...", flush=True) | |
| # Signal all threads to stop | |
| stop_event.set() | |
| break | |
| time.sleep(1) # Check for inactivity every second | |
| # Clear caption queues | |
| for lang in target_languages + ["original", "en"]: | |
| while not caption_queues[lang].empty(): | |
| caption_queues[lang].get() | |
| # Delete DASH output files | |
| for root, dirs, files in os.walk(dash_output_path, topdown=False): | |
| for name in files: | |
| os.remove(os.path.join(root, name)) | |
| for name in dirs: | |
| os.rmdir(os.path.join(root, name)) | |
| print("Cleanup completed.", flush=True) | |
| if __name__ == "__main__": | |
| # Get RTMP URL and DASH output path from user input | |
| signal.signal(signal.SIGTERM, signal_handler) | |
| parser = argparse.ArgumentParser(description="Process audio for translation.") | |
| parser.add_argument('--rtmp_url', help='rtmp url') | |
| parser.add_argument('--output_directory', help='Dash directory') | |
| parser.add_argument('--model', help='Whisper model size: base|small|medium|large|large-v2') | |
| start_time = time.time() | |
| args = parser.parse_args() | |
| rtmp_url = args.rtmp_url | |
| dash_output_path = args.output_directory | |
| model_size = args.model | |
| print(f"RTMP URL: {rtmp_url}") | |
| print(f"DASH output path: {dash_output_path}") | |
| print(f"Model: {dash_output_path}") | |
| print("Downloading models\n") | |
| print("Whisper\n") | |
| whisper_model = whisper.load_model(model_size, download_root="/tmp/model/") # Adjust model size as necessary | |
| for lang, model_name in language_model_names.items(): | |
| print(f"Lang: {lang}, model: {model_name}\n") | |
| tokenizers[lang] = MarianTokenizer.from_pretrained(model_name) | |
| translation_models[lang] = MarianMTModel.from_pretrained(model_name) | |
| # Start RTMP to DASH transcoding in a separate thread | |
| transcode_thread = threading.Thread(target=transcode_rtmp_to_dash) | |
| transcode_thread.start() | |
| # Start audio capture in a separate thread | |
| audio_capture_thread = threading.Thread(target=capture_audio) | |
| audio_capture_thread.start() | |
| # Start audio processing in a separate thread | |
| audio_processing_thread = threading.Thread(target=process_audio) | |
| audio_processing_thread.start() | |
| # Start caption writing threads for original and all target languages | |
| caption_threads = [] | |
| for lang in target_languages + ["original", "en"]: | |
| caption_thread = threading.Thread(target=write_captions, args=(lang,)) | |
| caption_threads.append(caption_thread) | |
| caption_thread.start() | |
| # Start the cleanup thread | |
| cleanup_thread = threading.Thread(target=cleanup) | |
| cleanup_thread.start() | |
| # Wait for all threads to complete | |
| print("Join transcode", flush=True) | |
| if transcode_thread.is_alive(): | |
| transcode_thread.join() | |
| print("Join sudio capture", flush=True) | |
| if audio_capture_thread.is_alive(): | |
| audio_capture_thread.join() | |
| print("Join audio processing", flush=True) | |
| if audio_processing_thread.is_alive(): | |
| audio_processing_thread.join() | |
| for thread in caption_threads: | |
| if thread.is_alive(): | |
| thread.join() | |
| print("Join clenaup", flush=True) | |
| if cleanup_thread.is_alive(): | |
| cleanup_thread.join() | |
| print("All threads have been stopped and cleaned up.") | |
| exit(0) | |