MiniCPM-o-4_5-MLX-4bit / streaming.py
andrevp's picture
Add full duplex streaming mode (streaming.py)
35ab828 verified
"""Full duplex streaming mode for MiniCPM-o 4.5 MLX.
Captures screen video + system audio, processes through the model in real-time,
and outputs text analysis with optional TTS playback.
Architecture:
[Screen 1fps] + [Audio 16kHz] -> ChunkSynchronizer -> DuplexGenerator -> TTSPlayback
"""
import queue
import threading
import time
from typing import Optional
import mlx.core as mx
import numpy as np
class ScreenCapture:
"""Capture screen region at 1fps using mss.
Produces (H, W, C) float32 frames resized to 448x448.
"""
def __init__(
self,
out_queue: queue.Queue,
region: Optional[tuple] = None,
fps: float = 1.0,
target_size: int = 448,
):
self.out_queue = out_queue
self.region = region # (x, y, w, h) or None for primary monitor
self.fps = fps
self.target_size = target_size
self._stop = threading.Event()
self._thread: Optional[threading.Thread] = None
def start(self):
self._stop.clear()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
def stop(self):
self._stop.set()
if self._thread:
self._thread.join(timeout=2)
def _run(self):
import mss
from PIL import Image
with mss.mss() as sct:
if self.region:
x, y, w, h = self.region
monitor = {"left": x, "top": y, "width": w, "height": h}
else:
monitor = sct.monitors[1] # Primary monitor
while not self._stop.is_set():
t0 = time.time()
screenshot = sct.grab(monitor)
# Convert to PIL Image, resize, convert to float32
img = Image.frombytes("RGB", screenshot.size, screenshot.rgb)
img = img.resize(
(self.target_size, self.target_size), Image.BILINEAR
)
frame = np.array(img, dtype=np.float32) / 255.0 # (H, W, 3)
try:
self.out_queue.put_nowait(
{"type": "video", "frame": frame, "time": time.time()}
)
except queue.Full:
pass # Drop frame if queue full
elapsed = time.time() - t0
sleep_time = max(0, (1.0 / self.fps) - elapsed)
if sleep_time > 0:
self._stop.wait(sleep_time)
class AudioCapture:
"""Capture system audio at 16kHz using sounddevice.
Uses BlackHole virtual audio device for system audio loopback on macOS.
Produces 1-second mono float32 audio chunks.
"""
def __init__(
self,
out_queue: queue.Queue,
device: Optional[str] = None,
sample_rate: int = 16000,
chunk_seconds: float = 1.0,
):
self.out_queue = out_queue
self.device = device # Device name or index
self.sample_rate = sample_rate
self.chunk_seconds = chunk_seconds
self.chunk_samples = int(sample_rate * chunk_seconds)
self._stop = threading.Event()
self._thread: Optional[threading.Thread] = None
def start(self):
self._stop.clear()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
def stop(self):
self._stop.set()
if self._thread:
self._thread.join(timeout=2)
def _find_device(self):
"""Find audio device by name."""
import sounddevice as sd
if self.device is None:
return None # Use default
if isinstance(self.device, int):
return self.device
devices = sd.query_devices()
for i, d in enumerate(devices):
if self.device.lower() in d["name"].lower() and d["max_input_channels"] > 0:
return i
print(f"Warning: Audio device '{self.device}' not found, using default.")
return None
def _run(self):
import sounddevice as sd
device_id = self._find_device()
buffer = np.array([], dtype=np.float32)
def callback(indata, frames, time_info, status):
nonlocal buffer
if status:
pass # Ignore overflow/underflow
mono = indata.mean(axis=1) if indata.ndim > 1 else indata.flatten()
buffer = np.concatenate([buffer, mono])
try:
with sd.InputStream(
device=device_id,
channels=1,
samplerate=self.sample_rate,
blocksize=1024,
callback=callback,
):
while not self._stop.is_set():
if len(buffer) >= self.chunk_samples:
chunk = buffer[: self.chunk_samples].copy()
buffer = buffer[self.chunk_samples :]
try:
self.out_queue.put_nowait(
{
"type": "audio",
"data": chunk,
"time": time.time(),
}
)
except queue.Full:
pass
else:
self._stop.wait(0.05)
except Exception as e:
print(f"Audio capture error: {e}")
class ChunkSynchronizer:
"""Synchronize video frames and audio into 1-second chunks.
Pairs the latest video frame with each 1-second audio chunk.
Runs mel processing on the audio.
"""
def __init__(
self,
raw_queue: queue.Queue,
sync_queue: queue.Queue,
mel_processor,
):
self.raw_queue = raw_queue
self.sync_queue = sync_queue
self.mel_processor = mel_processor
self._stop = threading.Event()
self._thread: Optional[threading.Thread] = None
self._latest_frame: Optional[np.ndarray] = None
def start(self):
self._stop.clear()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
def stop(self):
self._stop.set()
if self._thread:
self._thread.join(timeout=2)
def _run(self):
while not self._stop.is_set():
try:
item = self.raw_queue.get(timeout=0.1)
except queue.Empty:
continue
if item["type"] == "video":
self._latest_frame = item["frame"]
elif item["type"] == "audio":
self.mel_processor.add_audio(item["data"])
mel_chunk = self.mel_processor.get_mel_chunk()
if mel_chunk is not None:
try:
self.sync_queue.put_nowait(
{
"video_frame": self._latest_frame,
"mel_chunk": mel_chunk,
"time": item["time"],
}
)
except queue.Full:
pass # Drop if consumer is slow
class DuplexGenerator:
"""Main processing loop for full duplex streaming.
Dequeues synchronized chunks, runs model inference, generates text responses,
and optionally queues TTS audio for playback.
"""
def __init__(
self,
model,
processor,
sync_queue: queue.Queue,
tts_queue: Optional[queue.Queue] = None,
temperature: float = 0.0,
max_tokens_per_chunk: int = 50,
enable_tts: bool = False,
):
self.model = model
self.processor = processor
self.sync_queue = sync_queue
self.tts_queue = tts_queue
self.temperature = temperature
self.max_tokens = max_tokens_per_chunk
self.enable_tts = enable_tts
self._stop = threading.Event()
self._thread: Optional[threading.Thread] = None
self.ctx = None
self.chunk_count = 0
self.on_text = None # callback(text: str)
self.on_status = None # callback(status: dict)
def start(self):
self._stop.clear()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
def stop(self):
self._stop.set()
if self._thread:
self._thread.join(timeout=5)
def _build_chunk_prompt(self, has_video: bool, has_audio: bool):
"""Build prompt tokens for one streaming chunk.
Returns:
dict with input_ids, image_bound, audio_bound
"""
tokenizer = self.processor.tokenizer
parts = []
parts.append("<|im_start|>user\n")
image_bound = []
audio_bound = []
# Video placeholder
if has_video:
# 64 query tokens for resampled image
n_img_tokens = self.model.config.query_num # 64
img_placeholder = "<image>" + "<unk>" * n_img_tokens + "</image>"
parts.append(img_placeholder)
# Audio placeholder
if has_audio:
# Approximate audio tokens: ~10 after pooling for 1 second
n_audio_tokens = 10
audio_placeholder = (
"<|audio_start|>" + "<unk>" * n_audio_tokens + "<|audio_end|>"
)
parts.append(audio_placeholder)
parts.append("\nDescribe what you see and hear.<|im_end|>\n")
parts.append("<|im_start|>assistant\n")
text = "".join(parts)
tokenized = tokenizer(text, return_tensors="np")
input_ids = mx.array(tokenized["input_ids"])
# Find image_bound and audio_bound positions
ids_list = tokenized["input_ids"][0].tolist()
unk_id = tokenizer.convert_tokens_to_ids("<unk>")
if has_video:
img_start_id = tokenizer.convert_tokens_to_ids("<image>")
img_end_id = tokenizer.convert_tokens_to_ids("</image>")
in_img = False
start_idx = None
for i, tok in enumerate(ids_list):
if tok == img_start_id:
in_img = True
start_idx = i + 1
elif tok == img_end_id and in_img:
image_bound.append((start_idx, i))
in_img = False
if has_audio:
audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_start|>")
audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_end|>")
in_audio = False
start_idx = None
for i, tok in enumerate(ids_list):
if tok == audio_start_id:
in_audio = True
start_idx = i + 1
elif tok == audio_end_id and in_audio:
audio_bound.append((start_idx, i))
in_audio = False
return {
"input_ids": input_ids,
"image_bound": image_bound if image_bound else None,
"audio_bound": audio_bound if audio_bound else None,
}
def _prepare_video_frame(self, frame: np.ndarray):
"""Prepare a video frame for model input.
Args:
frame: (H, W, 3) float32 frame
Returns:
(pixel_values, tgt_sizes, patch_attention_mask)
"""
# Frame is already (448, 448, 3) float32
# Add batch dimension: (1, H, W, 3)
pv = mx.array(frame[np.newaxis, ...])
# Compute patch sizes
h_patches = frame.shape[0] // 14 # 32
w_patches = frame.shape[1] // 14 # 32
tgt_sizes = mx.array([[h_patches, w_patches]], dtype=mx.int32)
total_patches = h_patches * w_patches
patch_attention_mask = mx.ones((1, total_patches), dtype=mx.bool_)
return pv, tgt_sizes, patch_attention_mask
def _run(self):
# Initialize streaming context
self.ctx = self.model.init_streaming()
self.chunk_count = 0
while not self._stop.is_set():
try:
chunk = self.sync_queue.get(timeout=0.5)
except queue.Empty:
continue
t0 = time.time()
self.chunk_count += 1
video_frame = chunk.get("video_frame")
mel_chunk = chunk.get("mel_chunk")
has_video = video_frame is not None
has_audio = mel_chunk is not None
if not has_video and not has_audio:
continue
# Build prompt for this chunk
prompt = self._build_chunk_prompt(has_video, has_audio)
# Prepare video
pixel_values = None
tgt_sizes = None
patch_attention_mask = None
if has_video:
pixel_values, tgt_sizes, patch_attention_mask = (
self._prepare_video_frame(video_frame)
)
# Process chunk through model
logits = self.model.process_streaming_chunk(
ctx=self.ctx,
video_frame=pixel_values,
audio_chunk=mel_chunk,
prompt_tokens=prompt["input_ids"],
image_bound=prompt["image_bound"],
audio_bound=prompt["audio_bound"],
tgt_sizes=tgt_sizes,
patch_attention_mask=patch_attention_mask,
)
# Generate text response
tokens = self.model.streaming_generate(
ctx=self.ctx,
logits=logits,
tokenizer=self.processor.tokenizer,
max_tokens=self.max_tokens,
temperature=self.temperature,
)
elapsed = time.time() - t0
if tokens:
text = self.processor.tokenizer.decode(
tokens, skip_special_tokens=True
)
if self.on_text and text.strip():
self.on_text(text.strip())
# TTS if enabled
if self.enable_tts and self.tts_queue and tokens:
self.tts_queue.put_nowait(
{"tokens": tokens, "text": text}
)
if self.on_status:
self.on_status(
{
"chunk": self.chunk_count,
"mode": self.ctx.mode,
"cache_tokens": self.ctx.total_tokens,
"latency_ms": int(elapsed * 1000),
"mem_gb": mx.get_peak_memory() / 1e9,
}
)
class TTSPlayback:
"""Dequeue TTS tokens, convert to audio, and play back.
Uses Token2wav vocoder for audio synthesis and sounddevice for playback.
"""
def __init__(self, tts_queue: queue.Queue, sample_rate: int = 24000):
self.tts_queue = tts_queue
self.sample_rate = sample_rate
self._stop = threading.Event()
self._thread: Optional[threading.Thread] = None
self._vocoder = None
def start(self):
self._stop.clear()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
def stop(self):
self._stop.set()
if self._thread:
self._thread.join(timeout=2)
def _run(self):
import sounddevice as sd
# Try loading vocoder
try:
from stepaudio2 import Token2wav
self._vocoder = Token2wav()
except ImportError:
print("TTSPlayback: Token2wav not available, TTS disabled.")
return
while not self._stop.is_set():
try:
item = self.tts_queue.get(timeout=0.5)
except queue.Empty:
continue
tokens = item.get("tokens", [])
if not tokens:
continue
try:
import io
import soundfile as sf
wav_bytes = self._vocoder(tokens, None)
waveform, sr = sf.read(io.BytesIO(wav_bytes))
sd.play(waveform, sr, blocking=False)
except Exception as e:
print(f"TTS playback error: {e}")
def run_live_mode(model, processor, args):
"""Run full duplex streaming mode.
Args:
model: loaded MiniCPM-o model
processor: tokenizer/processor
args: argparse namespace with capture_region, audio_device, tts options
"""
from mlx_vlm.models.minicpmo.audio import StreamingMelProcessor
print("Starting live streaming mode...")
print("Press Ctrl+C to stop.\n")
# Create queues
raw_queue = queue.Queue(maxsize=30)
sync_queue = queue.Queue(maxsize=10)
tts_queue = queue.Queue(maxsize=10) if args.tts else None
# Create mel processor
mel_processor = StreamingMelProcessor(sample_rate=16000)
# Parse capture region
region = None
if hasattr(args, "capture_region") and args.capture_region:
parts = args.capture_region.split(",")
if len(parts) == 4:
region = tuple(int(p) for p in parts)
# Create threads
screen = ScreenCapture(raw_queue, region=region, fps=1.0)
audio_dev = getattr(args, "audio_device", "BlackHole")
audio = AudioCapture(raw_queue, device=audio_dev, sample_rate=16000)
sync = ChunkSynchronizer(raw_queue, sync_queue, mel_processor)
generator = DuplexGenerator(
model,
processor,
sync_queue,
tts_queue=tts_queue,
temperature=getattr(args, "temp", 0.0),
max_tokens_per_chunk=getattr(args, "max_tokens", 50),
enable_tts=getattr(args, "tts", False),
)
tts_playback = None
if tts_queue:
tts_playback = TTSPlayback(tts_queue)
# Set up callbacks
def on_text(text):
print(f"[{generator.chunk_count}] {text}")
def on_status(status):
print(
f" >> chunk={status['chunk']} mode={status['mode']} "
f"cache={status['cache_tokens']}tok "
f"latency={status['latency_ms']}ms "
f"mem={status['mem_gb']:.1f}GB",
flush=True,
)
generator.on_text = on_text
generator.on_status = on_status
# Start all threads
screen.start()
audio.start()
sync.start()
generator.start()
if tts_playback:
tts_playback.start()
print("Live mode active. Capturing screen + audio...\n")
try:
while True:
time.sleep(0.5)
except KeyboardInterrupt:
print("\nStopping live mode...")
finally:
screen.stop()
audio.stop()
sync.stop()
generator.stop()
if tts_playback:
tts_playback.stop()
print("Live mode stopped.")