|
|
"""Full duplex streaming mode for MiniCPM-o 4.5 MLX. |
|
|
|
|
|
Captures screen video + system audio, processes through the model in real-time, |
|
|
and outputs text analysis with optional TTS playback. |
|
|
|
|
|
Architecture: |
|
|
[Screen 1fps] + [Audio 16kHz] -> ChunkSynchronizer -> DuplexGenerator -> TTSPlayback |
|
|
""" |
|
|
|
|
|
import queue |
|
|
import threading |
|
|
import time |
|
|
from typing import Optional |
|
|
|
|
|
import mlx.core as mx |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
class ScreenCapture: |
|
|
"""Capture screen region at 1fps using mss. |
|
|
|
|
|
Produces (H, W, C) float32 frames resized to 448x448. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
out_queue: queue.Queue, |
|
|
region: Optional[tuple] = None, |
|
|
fps: float = 1.0, |
|
|
target_size: int = 448, |
|
|
): |
|
|
self.out_queue = out_queue |
|
|
self.region = region |
|
|
self.fps = fps |
|
|
self.target_size = target_size |
|
|
self._stop = threading.Event() |
|
|
self._thread: Optional[threading.Thread] = None |
|
|
|
|
|
def start(self): |
|
|
self._stop.clear() |
|
|
self._thread = threading.Thread(target=self._run, daemon=True) |
|
|
self._thread.start() |
|
|
|
|
|
def stop(self): |
|
|
self._stop.set() |
|
|
if self._thread: |
|
|
self._thread.join(timeout=2) |
|
|
|
|
|
def _run(self): |
|
|
import mss |
|
|
from PIL import Image |
|
|
|
|
|
with mss.mss() as sct: |
|
|
if self.region: |
|
|
x, y, w, h = self.region |
|
|
monitor = {"left": x, "top": y, "width": w, "height": h} |
|
|
else: |
|
|
monitor = sct.monitors[1] |
|
|
|
|
|
while not self._stop.is_set(): |
|
|
t0 = time.time() |
|
|
screenshot = sct.grab(monitor) |
|
|
|
|
|
img = Image.frombytes("RGB", screenshot.size, screenshot.rgb) |
|
|
img = img.resize( |
|
|
(self.target_size, self.target_size), Image.BILINEAR |
|
|
) |
|
|
frame = np.array(img, dtype=np.float32) / 255.0 |
|
|
|
|
|
try: |
|
|
self.out_queue.put_nowait( |
|
|
{"type": "video", "frame": frame, "time": time.time()} |
|
|
) |
|
|
except queue.Full: |
|
|
pass |
|
|
|
|
|
elapsed = time.time() - t0 |
|
|
sleep_time = max(0, (1.0 / self.fps) - elapsed) |
|
|
if sleep_time > 0: |
|
|
self._stop.wait(sleep_time) |
|
|
|
|
|
|
|
|
class AudioCapture: |
|
|
"""Capture system audio at 16kHz using sounddevice. |
|
|
|
|
|
Uses BlackHole virtual audio device for system audio loopback on macOS. |
|
|
Produces 1-second mono float32 audio chunks. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
out_queue: queue.Queue, |
|
|
device: Optional[str] = None, |
|
|
sample_rate: int = 16000, |
|
|
chunk_seconds: float = 1.0, |
|
|
): |
|
|
self.out_queue = out_queue |
|
|
self.device = device |
|
|
self.sample_rate = sample_rate |
|
|
self.chunk_seconds = chunk_seconds |
|
|
self.chunk_samples = int(sample_rate * chunk_seconds) |
|
|
self._stop = threading.Event() |
|
|
self._thread: Optional[threading.Thread] = None |
|
|
|
|
|
def start(self): |
|
|
self._stop.clear() |
|
|
self._thread = threading.Thread(target=self._run, daemon=True) |
|
|
self._thread.start() |
|
|
|
|
|
def stop(self): |
|
|
self._stop.set() |
|
|
if self._thread: |
|
|
self._thread.join(timeout=2) |
|
|
|
|
|
def _find_device(self): |
|
|
"""Find audio device by name.""" |
|
|
import sounddevice as sd |
|
|
|
|
|
if self.device is None: |
|
|
return None |
|
|
|
|
|
if isinstance(self.device, int): |
|
|
return self.device |
|
|
|
|
|
devices = sd.query_devices() |
|
|
for i, d in enumerate(devices): |
|
|
if self.device.lower() in d["name"].lower() and d["max_input_channels"] > 0: |
|
|
return i |
|
|
|
|
|
print(f"Warning: Audio device '{self.device}' not found, using default.") |
|
|
return None |
|
|
|
|
|
def _run(self): |
|
|
import sounddevice as sd |
|
|
|
|
|
device_id = self._find_device() |
|
|
buffer = np.array([], dtype=np.float32) |
|
|
|
|
|
def callback(indata, frames, time_info, status): |
|
|
nonlocal buffer |
|
|
if status: |
|
|
pass |
|
|
mono = indata.mean(axis=1) if indata.ndim > 1 else indata.flatten() |
|
|
buffer = np.concatenate([buffer, mono]) |
|
|
|
|
|
try: |
|
|
with sd.InputStream( |
|
|
device=device_id, |
|
|
channels=1, |
|
|
samplerate=self.sample_rate, |
|
|
blocksize=1024, |
|
|
callback=callback, |
|
|
): |
|
|
while not self._stop.is_set(): |
|
|
if len(buffer) >= self.chunk_samples: |
|
|
chunk = buffer[: self.chunk_samples].copy() |
|
|
buffer = buffer[self.chunk_samples :] |
|
|
try: |
|
|
self.out_queue.put_nowait( |
|
|
{ |
|
|
"type": "audio", |
|
|
"data": chunk, |
|
|
"time": time.time(), |
|
|
} |
|
|
) |
|
|
except queue.Full: |
|
|
pass |
|
|
else: |
|
|
self._stop.wait(0.05) |
|
|
except Exception as e: |
|
|
print(f"Audio capture error: {e}") |
|
|
|
|
|
|
|
|
class ChunkSynchronizer: |
|
|
"""Synchronize video frames and audio into 1-second chunks. |
|
|
|
|
|
Pairs the latest video frame with each 1-second audio chunk. |
|
|
Runs mel processing on the audio. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
raw_queue: queue.Queue, |
|
|
sync_queue: queue.Queue, |
|
|
mel_processor, |
|
|
): |
|
|
self.raw_queue = raw_queue |
|
|
self.sync_queue = sync_queue |
|
|
self.mel_processor = mel_processor |
|
|
self._stop = threading.Event() |
|
|
self._thread: Optional[threading.Thread] = None |
|
|
self._latest_frame: Optional[np.ndarray] = None |
|
|
|
|
|
def start(self): |
|
|
self._stop.clear() |
|
|
self._thread = threading.Thread(target=self._run, daemon=True) |
|
|
self._thread.start() |
|
|
|
|
|
def stop(self): |
|
|
self._stop.set() |
|
|
if self._thread: |
|
|
self._thread.join(timeout=2) |
|
|
|
|
|
def _run(self): |
|
|
while not self._stop.is_set(): |
|
|
try: |
|
|
item = self.raw_queue.get(timeout=0.1) |
|
|
except queue.Empty: |
|
|
continue |
|
|
|
|
|
if item["type"] == "video": |
|
|
self._latest_frame = item["frame"] |
|
|
elif item["type"] == "audio": |
|
|
self.mel_processor.add_audio(item["data"]) |
|
|
mel_chunk = self.mel_processor.get_mel_chunk() |
|
|
if mel_chunk is not None: |
|
|
try: |
|
|
self.sync_queue.put_nowait( |
|
|
{ |
|
|
"video_frame": self._latest_frame, |
|
|
"mel_chunk": mel_chunk, |
|
|
"time": item["time"], |
|
|
} |
|
|
) |
|
|
except queue.Full: |
|
|
pass |
|
|
|
|
|
|
|
|
class DuplexGenerator: |
|
|
"""Main processing loop for full duplex streaming. |
|
|
|
|
|
Dequeues synchronized chunks, runs model inference, generates text responses, |
|
|
and optionally queues TTS audio for playback. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model, |
|
|
processor, |
|
|
sync_queue: queue.Queue, |
|
|
tts_queue: Optional[queue.Queue] = None, |
|
|
temperature: float = 0.0, |
|
|
max_tokens_per_chunk: int = 50, |
|
|
enable_tts: bool = False, |
|
|
): |
|
|
self.model = model |
|
|
self.processor = processor |
|
|
self.sync_queue = sync_queue |
|
|
self.tts_queue = tts_queue |
|
|
self.temperature = temperature |
|
|
self.max_tokens = max_tokens_per_chunk |
|
|
self.enable_tts = enable_tts |
|
|
self._stop = threading.Event() |
|
|
self._thread: Optional[threading.Thread] = None |
|
|
self.ctx = None |
|
|
self.chunk_count = 0 |
|
|
self.on_text = None |
|
|
self.on_status = None |
|
|
|
|
|
def start(self): |
|
|
self._stop.clear() |
|
|
self._thread = threading.Thread(target=self._run, daemon=True) |
|
|
self._thread.start() |
|
|
|
|
|
def stop(self): |
|
|
self._stop.set() |
|
|
if self._thread: |
|
|
self._thread.join(timeout=5) |
|
|
|
|
|
def _build_chunk_prompt(self, has_video: bool, has_audio: bool): |
|
|
"""Build prompt tokens for one streaming chunk. |
|
|
|
|
|
Returns: |
|
|
dict with input_ids, image_bound, audio_bound |
|
|
""" |
|
|
tokenizer = self.processor.tokenizer |
|
|
|
|
|
parts = [] |
|
|
parts.append("<|im_start|>user\n") |
|
|
|
|
|
image_bound = [] |
|
|
audio_bound = [] |
|
|
|
|
|
|
|
|
if has_video: |
|
|
|
|
|
n_img_tokens = self.model.config.query_num |
|
|
img_placeholder = "<image>" + "<unk>" * n_img_tokens + "</image>" |
|
|
parts.append(img_placeholder) |
|
|
|
|
|
|
|
|
if has_audio: |
|
|
|
|
|
n_audio_tokens = 10 |
|
|
audio_placeholder = ( |
|
|
"<|audio_start|>" + "<unk>" * n_audio_tokens + "<|audio_end|>" |
|
|
) |
|
|
parts.append(audio_placeholder) |
|
|
|
|
|
parts.append("\nDescribe what you see and hear.<|im_end|>\n") |
|
|
parts.append("<|im_start|>assistant\n") |
|
|
|
|
|
text = "".join(parts) |
|
|
tokenized = tokenizer(text, return_tensors="np") |
|
|
input_ids = mx.array(tokenized["input_ids"]) |
|
|
|
|
|
|
|
|
ids_list = tokenized["input_ids"][0].tolist() |
|
|
unk_id = tokenizer.convert_tokens_to_ids("<unk>") |
|
|
|
|
|
if has_video: |
|
|
img_start_id = tokenizer.convert_tokens_to_ids("<image>") |
|
|
img_end_id = tokenizer.convert_tokens_to_ids("</image>") |
|
|
in_img = False |
|
|
start_idx = None |
|
|
for i, tok in enumerate(ids_list): |
|
|
if tok == img_start_id: |
|
|
in_img = True |
|
|
start_idx = i + 1 |
|
|
elif tok == img_end_id and in_img: |
|
|
image_bound.append((start_idx, i)) |
|
|
in_img = False |
|
|
|
|
|
if has_audio: |
|
|
audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_start|>") |
|
|
audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_end|>") |
|
|
in_audio = False |
|
|
start_idx = None |
|
|
for i, tok in enumerate(ids_list): |
|
|
if tok == audio_start_id: |
|
|
in_audio = True |
|
|
start_idx = i + 1 |
|
|
elif tok == audio_end_id and in_audio: |
|
|
audio_bound.append((start_idx, i)) |
|
|
in_audio = False |
|
|
|
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"image_bound": image_bound if image_bound else None, |
|
|
"audio_bound": audio_bound if audio_bound else None, |
|
|
} |
|
|
|
|
|
def _prepare_video_frame(self, frame: np.ndarray): |
|
|
"""Prepare a video frame for model input. |
|
|
|
|
|
Args: |
|
|
frame: (H, W, 3) float32 frame |
|
|
|
|
|
Returns: |
|
|
(pixel_values, tgt_sizes, patch_attention_mask) |
|
|
""" |
|
|
|
|
|
|
|
|
pv = mx.array(frame[np.newaxis, ...]) |
|
|
|
|
|
|
|
|
h_patches = frame.shape[0] // 14 |
|
|
w_patches = frame.shape[1] // 14 |
|
|
tgt_sizes = mx.array([[h_patches, w_patches]], dtype=mx.int32) |
|
|
|
|
|
total_patches = h_patches * w_patches |
|
|
patch_attention_mask = mx.ones((1, total_patches), dtype=mx.bool_) |
|
|
|
|
|
return pv, tgt_sizes, patch_attention_mask |
|
|
|
|
|
def _run(self): |
|
|
|
|
|
self.ctx = self.model.init_streaming() |
|
|
self.chunk_count = 0 |
|
|
|
|
|
while not self._stop.is_set(): |
|
|
try: |
|
|
chunk = self.sync_queue.get(timeout=0.5) |
|
|
except queue.Empty: |
|
|
continue |
|
|
|
|
|
t0 = time.time() |
|
|
self.chunk_count += 1 |
|
|
|
|
|
video_frame = chunk.get("video_frame") |
|
|
mel_chunk = chunk.get("mel_chunk") |
|
|
|
|
|
has_video = video_frame is not None |
|
|
has_audio = mel_chunk is not None |
|
|
|
|
|
if not has_video and not has_audio: |
|
|
continue |
|
|
|
|
|
|
|
|
prompt = self._build_chunk_prompt(has_video, has_audio) |
|
|
|
|
|
|
|
|
pixel_values = None |
|
|
tgt_sizes = None |
|
|
patch_attention_mask = None |
|
|
if has_video: |
|
|
pixel_values, tgt_sizes, patch_attention_mask = ( |
|
|
self._prepare_video_frame(video_frame) |
|
|
) |
|
|
|
|
|
|
|
|
logits = self.model.process_streaming_chunk( |
|
|
ctx=self.ctx, |
|
|
video_frame=pixel_values, |
|
|
audio_chunk=mel_chunk, |
|
|
prompt_tokens=prompt["input_ids"], |
|
|
image_bound=prompt["image_bound"], |
|
|
audio_bound=prompt["audio_bound"], |
|
|
tgt_sizes=tgt_sizes, |
|
|
patch_attention_mask=patch_attention_mask, |
|
|
) |
|
|
|
|
|
|
|
|
tokens = self.model.streaming_generate( |
|
|
ctx=self.ctx, |
|
|
logits=logits, |
|
|
tokenizer=self.processor.tokenizer, |
|
|
max_tokens=self.max_tokens, |
|
|
temperature=self.temperature, |
|
|
) |
|
|
|
|
|
elapsed = time.time() - t0 |
|
|
|
|
|
if tokens: |
|
|
text = self.processor.tokenizer.decode( |
|
|
tokens, skip_special_tokens=True |
|
|
) |
|
|
if self.on_text and text.strip(): |
|
|
self.on_text(text.strip()) |
|
|
|
|
|
|
|
|
if self.enable_tts and self.tts_queue and tokens: |
|
|
self.tts_queue.put_nowait( |
|
|
{"tokens": tokens, "text": text} |
|
|
) |
|
|
|
|
|
if self.on_status: |
|
|
self.on_status( |
|
|
{ |
|
|
"chunk": self.chunk_count, |
|
|
"mode": self.ctx.mode, |
|
|
"cache_tokens": self.ctx.total_tokens, |
|
|
"latency_ms": int(elapsed * 1000), |
|
|
"mem_gb": mx.get_peak_memory() / 1e9, |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
class TTSPlayback: |
|
|
"""Dequeue TTS tokens, convert to audio, and play back. |
|
|
|
|
|
Uses Token2wav vocoder for audio synthesis and sounddevice for playback. |
|
|
""" |
|
|
|
|
|
def __init__(self, tts_queue: queue.Queue, sample_rate: int = 24000): |
|
|
self.tts_queue = tts_queue |
|
|
self.sample_rate = sample_rate |
|
|
self._stop = threading.Event() |
|
|
self._thread: Optional[threading.Thread] = None |
|
|
self._vocoder = None |
|
|
|
|
|
def start(self): |
|
|
self._stop.clear() |
|
|
self._thread = threading.Thread(target=self._run, daemon=True) |
|
|
self._thread.start() |
|
|
|
|
|
def stop(self): |
|
|
self._stop.set() |
|
|
if self._thread: |
|
|
self._thread.join(timeout=2) |
|
|
|
|
|
def _run(self): |
|
|
import sounddevice as sd |
|
|
|
|
|
|
|
|
try: |
|
|
from stepaudio2 import Token2wav |
|
|
self._vocoder = Token2wav() |
|
|
except ImportError: |
|
|
print("TTSPlayback: Token2wav not available, TTS disabled.") |
|
|
return |
|
|
|
|
|
while not self._stop.is_set(): |
|
|
try: |
|
|
item = self.tts_queue.get(timeout=0.5) |
|
|
except queue.Empty: |
|
|
continue |
|
|
|
|
|
tokens = item.get("tokens", []) |
|
|
if not tokens: |
|
|
continue |
|
|
|
|
|
try: |
|
|
import io |
|
|
import soundfile as sf |
|
|
|
|
|
wav_bytes = self._vocoder(tokens, None) |
|
|
waveform, sr = sf.read(io.BytesIO(wav_bytes)) |
|
|
sd.play(waveform, sr, blocking=False) |
|
|
except Exception as e: |
|
|
print(f"TTS playback error: {e}") |
|
|
|
|
|
|
|
|
def run_live_mode(model, processor, args): |
|
|
"""Run full duplex streaming mode. |
|
|
|
|
|
Args: |
|
|
model: loaded MiniCPM-o model |
|
|
processor: tokenizer/processor |
|
|
args: argparse namespace with capture_region, audio_device, tts options |
|
|
""" |
|
|
from mlx_vlm.models.minicpmo.audio import StreamingMelProcessor |
|
|
|
|
|
print("Starting live streaming mode...") |
|
|
print("Press Ctrl+C to stop.\n") |
|
|
|
|
|
|
|
|
raw_queue = queue.Queue(maxsize=30) |
|
|
sync_queue = queue.Queue(maxsize=10) |
|
|
tts_queue = queue.Queue(maxsize=10) if args.tts else None |
|
|
|
|
|
|
|
|
mel_processor = StreamingMelProcessor(sample_rate=16000) |
|
|
|
|
|
|
|
|
region = None |
|
|
if hasattr(args, "capture_region") and args.capture_region: |
|
|
parts = args.capture_region.split(",") |
|
|
if len(parts) == 4: |
|
|
region = tuple(int(p) for p in parts) |
|
|
|
|
|
|
|
|
screen = ScreenCapture(raw_queue, region=region, fps=1.0) |
|
|
audio_dev = getattr(args, "audio_device", "BlackHole") |
|
|
audio = AudioCapture(raw_queue, device=audio_dev, sample_rate=16000) |
|
|
sync = ChunkSynchronizer(raw_queue, sync_queue, mel_processor) |
|
|
|
|
|
generator = DuplexGenerator( |
|
|
model, |
|
|
processor, |
|
|
sync_queue, |
|
|
tts_queue=tts_queue, |
|
|
temperature=getattr(args, "temp", 0.0), |
|
|
max_tokens_per_chunk=getattr(args, "max_tokens", 50), |
|
|
enable_tts=getattr(args, "tts", False), |
|
|
) |
|
|
|
|
|
tts_playback = None |
|
|
if tts_queue: |
|
|
tts_playback = TTSPlayback(tts_queue) |
|
|
|
|
|
|
|
|
def on_text(text): |
|
|
print(f"[{generator.chunk_count}] {text}") |
|
|
|
|
|
def on_status(status): |
|
|
print( |
|
|
f" >> chunk={status['chunk']} mode={status['mode']} " |
|
|
f"cache={status['cache_tokens']}tok " |
|
|
f"latency={status['latency_ms']}ms " |
|
|
f"mem={status['mem_gb']:.1f}GB", |
|
|
flush=True, |
|
|
) |
|
|
|
|
|
generator.on_text = on_text |
|
|
generator.on_status = on_status |
|
|
|
|
|
|
|
|
screen.start() |
|
|
audio.start() |
|
|
sync.start() |
|
|
generator.start() |
|
|
if tts_playback: |
|
|
tts_playback.start() |
|
|
|
|
|
print("Live mode active. Capturing screen + audio...\n") |
|
|
|
|
|
try: |
|
|
while True: |
|
|
time.sleep(0.5) |
|
|
except KeyboardInterrupt: |
|
|
print("\nStopping live mode...") |
|
|
finally: |
|
|
screen.stop() |
|
|
audio.stop() |
|
|
sync.stop() |
|
|
generator.stop() |
|
|
if tts_playback: |
|
|
tts_playback.stop() |
|
|
print("Live mode stopped.") |
|
|
|