MiniCPM-o-4_5-MLX-4bit / chat_minicpmo.py
andrevp's picture
Add --live flag and /live command for streaming mode
2fe6f6e verified
#!/usr/bin/env python3
"""MiniCPM-o 4.5 MLX Chat — Image, Audio & TTS inference on Apple Silicon."""
import argparse
import logging
import math
import os
import re
import sys
import time
import warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")
logging.getLogger("transformers").setLevel(logging.ERROR)
import mlx.core as mx
import numpy as np
import torch
from PIL import Image
from mlx_vlm import load
from mlx_vlm.generate import generate_step
THINK_RE = re.compile(r"<think>.*?</think>\s*", re.DOTALL)
def process_image_inputs(processor, image, prompt, max_slice_nums=9):
"""Process image + text through MiniCPM-o processor."""
# Build text with image placeholder
text = f"<|im_start|>user\n<image>./</image>\n{prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
# Run through processor
inputs = processor(
text=text,
images=[image],
max_slice_nums=max_slice_nums,
)
# Convert to MLX arrays
input_ids = mx.array(inputs["input_ids"].numpy())
mask = mx.array(inputs["attention_mask"].numpy())
# Process pixel values: list of lists of tensors -> single batch tensor
pixel_values_list = inputs["pixel_values"]
tgt_sizes_list = inputs["tgt_sizes"]
image_bound = inputs["image_bound"]
# Flatten pixel values from nested lists to batch
all_pv = []
for batch_pvs in pixel_values_list:
for pv in batch_pvs:
# pv: (C, H, W) PyTorch tensor -> (1, H, W, C) MLX (NHWC)
pv_np = pv.numpy()
pv_np = np.transpose(pv_np, (1, 2, 0)) # CHW -> HWC
all_pv.append(pv_np)
# Pad to same size and stack
if all_pv:
max_h = max(p.shape[0] for p in all_pv)
max_w = max(p.shape[1] for p in all_pv)
padded = []
for p in all_pv:
pad_h = max_h - p.shape[0]
pad_w = max_w - p.shape[1]
if pad_h > 0 or pad_w > 0:
p = np.pad(p, ((0, pad_h), (0, pad_w), (0, 0)))
padded.append(p)
pixel_values = mx.array(np.stack(padded, axis=0))
else:
pixel_values = None
# Build patch attention mask
patch_attention_mask = None
if pixel_values is not None:
B = pixel_values.shape[0]
total_patches = (pixel_values.shape[1] // 14) * (pixel_values.shape[2] // 14)
patch_attention_mask = np.zeros((B, total_patches), dtype=bool)
offset = 0
for ts_batch in tgt_sizes_list:
if isinstance(ts_batch, torch.Tensor):
for j in range(ts_batch.shape[0]):
idx = offset + j
if idx < B:
h, w = int(ts_batch[j][0]), int(ts_batch[j][1])
n_patches = h * w
patch_attention_mask[idx, :n_patches] = True
offset += ts_batch.shape[0]
patch_attention_mask = mx.array(patch_attention_mask)
# Build tgt_sizes array
tgt_sizes = []
for ts_batch in tgt_sizes_list:
if isinstance(ts_batch, torch.Tensor):
tgt_sizes.append(ts_batch.numpy())
if tgt_sizes:
tgt_sizes = mx.array(np.concatenate(tgt_sizes, axis=0).astype(np.int32))
else:
tgt_sizes = None
# Process image_bound
bounds = []
for batch_bounds in image_bound:
if isinstance(batch_bounds, torch.Tensor) and batch_bounds.numel() > 0:
bounds.extend(batch_bounds.numpy().tolist())
elif isinstance(batch_bounds, list):
bounds.extend(batch_bounds)
return {
"input_ids": input_ids,
"pixel_values": pixel_values,
"mask": mask,
"tgt_sizes": tgt_sizes,
"image_bound": bounds,
"patch_attention_mask": patch_attention_mask,
}
def process_audio_inputs(processor, audio_path, prompt, pool_step=5):
"""Process audio + text through mel spectrogram extraction.
Pipeline:
1. Load audio file -> resample to 16kHz if needed
2. WhisperFeatureExtractor -> mel spectrogram (80 bins)
3. Build prompt with <|audio_start|><unk>...<|audio_end|> placeholders
4. Compute audio_bound from placeholder token positions
"""
import soundfile as sf
# Load audio
audio_data, sr = sf.read(audio_path, dtype="float32")
if audio_data.ndim > 1:
audio_data = audio_data.mean(axis=1) # stereo -> mono
# Resample to 16kHz if needed
if sr != 16000:
try:
import librosa
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000)
except ImportError:
print("Warning: librosa not installed, cannot resample. Install with: pip install librosa")
print(f"Audio sample rate is {sr}Hz, expected 16000Hz.")
# Extract mel features using WhisperFeatureExtractor
from transformers import WhisperFeatureExtractor
feature_extractor = WhisperFeatureExtractor(
feature_size=80, sampling_rate=16000, n_fft=400, hop_length=160
)
audio_input = feature_extractor(
audio_data,
sampling_rate=16000,
return_tensors="pt",
padding="max_length",
return_attention_mask=True,
)
audio_feature = audio_input["input_features"] # (1, 80, frames)
actual_len = audio_input["attention_mask"].sum(dim=1) # actual frames
audio_feature = audio_feature[:, :, :actual_len[0]] # trim padding
# Compute number of audio placeholder tokens
# After Conv1d stride=2: (frames-1)//2 + 1
# After avg pool stride=pool_step: (cnn_out - pool_step)//pool_step + 1
feature_lens = int(actual_len[0])
cnn_out = (feature_lens - 1) // 2 + 1
num_audio_tokens = (cnn_out - pool_step) // pool_step + 1
# Build audio placeholder: <|audio_start|> + <unk>*N + <|audio_end|>
audio_placeholder = "<|audio_start|>" + "<unk>" * num_audio_tokens + "<|audio_end|>"
# Build prompt text
text = f"<|im_start|>user\n{audio_placeholder}\n{prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
# Tokenize
tokenized = processor.tokenizer(text, return_tensors="np")
input_ids = mx.array(tokenized["input_ids"])
mask = mx.array(tokenized["attention_mask"])
# Find audio_bound: positions of <unk> tokens between audio_start and audio_end
input_ids_np = tokenized["input_ids"][0]
audio_start_token = processor.tokenizer.convert_tokens_to_ids("<|audio_start|>")
audio_end_token = processor.tokenizer.convert_tokens_to_ids("<|audio_end|>")
audio_bounds = []
in_audio = False
start_idx = None
for i, tok in enumerate(input_ids_np):
if tok == audio_start_token:
in_audio = True
start_idx = i + 1 # first <unk> is after <|audio_start|>
elif tok == audio_end_token and in_audio:
audio_bounds.append((start_idx, i))
in_audio = False
# Convert mel features to MLX: (1, 80, frames)
audio_features = mx.array(audio_feature.numpy())
return {
"input_ids": input_ids,
"pixel_values": None,
"mask": mask,
"audio_features": audio_features,
"audio_bound": audio_bounds,
}
def generate(model, processor, inputs, max_tokens=512, temp=0.0, no_think=True, stream=True):
"""Generate text from model inputs with optional streaming."""
input_ids = inputs["input_ids"]
pixel_values = inputs.get("pixel_values")
mask = inputs.get("mask")
# Build kwargs for model
kwargs = {}
if inputs.get("tgt_sizes") is not None:
kwargs["tgt_sizes"] = inputs["tgt_sizes"]
if inputs.get("image_bound"):
kwargs["image_bound"] = inputs["image_bound"]
if inputs.get("patch_attention_mask") is not None:
kwargs["patch_attention_mask"] = inputs["patch_attention_mask"]
if inputs.get("audio_features") is not None:
kwargs["audio_features"] = inputs["audio_features"]
if inputs.get("audio_bound"):
kwargs["audio_bound"] = inputs["audio_bound"]
tokens = []
start = time.time()
printed_len = 0 # how many chars we've already printed
for n, (token, _logprobs) in enumerate(
generate_step(input_ids, model, pixel_values, mask, temp=temp, **kwargs)
):
tok_val = token.item() if hasattr(token, "item") else int(token)
tokens.append(tok_val)
if n == 0:
prompt_time = time.time() - start
prompt_tps = input_ids.size / prompt_time
gen_start = time.time()
tok_str = processor.tokenizer.decode([tok_val])
if tok_str in ["<|im_end|>", "<|endoftext|>", "<|tts_eos|>"]:
break
if n >= max_tokens:
break
if stream and (n + 1) % 4 == 0:
# Decode all tokens so far for correct subword handling
full_text = processor.tokenizer.decode(tokens, skip_special_tokens=True)
if no_think:
full_text = THINK_RE.sub("", full_text)
# Print only the new characters
if len(full_text) > printed_len:
print(full_text[printed_len:], end="", flush=True)
printed_len = len(full_text)
gen_time = time.time() - gen_start
if stream:
# Flush remaining tokens
full_text = processor.tokenizer.decode(tokens, skip_special_tokens=True)
if no_think:
full_text = THINK_RE.sub("", full_text)
if len(full_text) > printed_len:
print(full_text[printed_len:], end="", flush=True)
print() # newline after streamed output
else:
raw = processor.tokenizer.decode(tokens, skip_special_tokens=True)
if no_think:
raw = THINK_RE.sub("", raw)
print(raw.strip())
print(
f"\n--- {input_ids.size} prompt tok @ {prompt_tps:.0f} t/s | "
f"{len(tokens)} gen tok @ {len(tokens)/gen_time:.0f} t/s | "
f"mem {mx.get_peak_memory()/1e9:.1f} GB ---"
)
return tokens
def run_once(model, processor, args):
"""Single-shot inference."""
if args.audio:
inputs = process_audio_inputs(processor, args.audio, args.prompt or "What is being said?")
elif args.file:
image = Image.open(args.file).convert("RGB")
inputs = process_image_inputs(
processor, image, args.prompt, max_slice_nums=args.max_slices
)
else:
# Text-only
text = f"<|im_start|>user\n{args.prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
input_ids = mx.array(
processor.tokenizer(text, return_tensors="np")["input_ids"]
)
inputs = {"input_ids": input_ids, "pixel_values": None, "mask": None}
tokens = generate(model, processor, inputs, max_tokens=args.max_tokens, temp=args.temp)
# TTS: generate speech if requested
if args.tts and tokens:
generate_tts(model, processor, tokens, inputs, args)
def generate_tts(model, processor, generated_tokens, inputs, args):
"""Generate speech from model output via TTS pipeline."""
if not hasattr(model, "tts"):
print("TTS model not loaded. Re-convert model with audio+TTS weights.")
return
# Find tts_bound in the generated sequence
tts_bos_token = processor.tokenizer.convert_tokens_to_ids("<|tts_bos|>")
tts_eos_token = processor.tokenizer.convert_tokens_to_ids("<|tts_eos|>")
# Build full sequence: input_ids + generated tokens
input_ids_np = inputs["input_ids"].tolist()
if isinstance(input_ids_np[0], list):
input_ids_np = input_ids_np[0]
full_sequence = input_ids_np + generated_tokens
tts_bos_idx = -1
tts_eos_idx = None
for i, tok in enumerate(full_sequence):
if tok == tts_bos_token:
tts_bos_idx = i + 1
elif tok == tts_eos_token:
tts_eos_idx = i
if tts_bos_idx == -1:
print("No TTS markers found in output. Use a TTS prompt template.")
return
tts_bound = (tts_bos_idx, tts_eos_idx)
print(f"Generating speech for tokens [{tts_bos_idx}:{tts_eos_idx}]...")
# We need the LLM hidden states for TTS - this requires a separate forward pass
# since generate_step doesn't return hidden states.
# For now, re-run the full sequence through the LLM to get hidden states.
full_ids = mx.array([full_sequence])
hidden = model.language_model.model(
full_ids,
inputs_embeds=model.language_model.model.embed_tokens(full_ids),
)
# hidden: (1, seq_len, llm_dim)
token_ids = mx.array(full_sequence)
audio_tokens = model.generate_speech(
hidden_states=hidden[0],
token_ids=token_ids,
tts_bound=tts_bound,
temperature=0.1,
top_p=0.9,
)
# Convert audio tokens to waveform
audio_tokens_list = audio_tokens.squeeze(-1).squeeze(0).tolist()
print(f"Generated {len(audio_tokens_list)} audio tokens.")
output_path = args.tts_output or "output.wav"
try:
from stepaudio2 import Token2wav
vocoder = Token2wav()
wav_bytes = vocoder(audio_tokens_list, None)
import soundfile as sf
import io
waveform, sr = sf.read(io.BytesIO(wav_bytes))
sf.write(output_path, waveform, sr)
print(f"Speech saved to {output_path}")
except ImportError:
# Save raw tokens for later vocoder processing
token_path = output_path.replace(".wav", "_tokens.npy")
np.save(token_path, np.array(audio_tokens_list))
print(f"Token2wav not installed. Raw audio tokens saved to {token_path}")
print("Install vocoder: pip install minicpmo-utils[all]")
def run_interactive(model, processor, args):
"""Interactive chat mode."""
current_file = args.file
current_audio = args.audio
print("MiniCPM-o 4.5 MLX Chat")
print("Commands: /image <path> | /audio <path> | /live | /clear | /quit")
if current_file:
print(f"Loaded image: {current_file}")
if current_audio:
print(f"Loaded audio: {current_audio}")
print()
while True:
try:
prompt = input("You: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nBye!")
break
if not prompt:
continue
if prompt.lower() in ("/quit", "/exit", "/q"):
print("Bye!")
break
if prompt.lower() == "/clear":
current_file = None
current_audio = None
print("Cleared.\n")
continue
if prompt.lower().startswith("/image "):
current_file = prompt[7:].strip()
current_audio = None
print(f"Image loaded: {current_file}\n")
continue
if prompt.lower().startswith("/audio "):
current_audio = prompt[7:].strip()
current_file = None
print(f"Audio loaded: {current_audio}\n")
continue
if prompt.lower() == "/live":
from streaming import run_live_mode
run_live_mode(model, processor, args)
print()
continue
print()
if current_audio:
inputs = process_audio_inputs(processor, current_audio, prompt)
elif current_file:
image = Image.open(current_file).convert("RGB")
inputs = process_image_inputs(
processor, image, prompt, max_slice_nums=args.max_slices
)
else:
text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
input_ids = mx.array(
processor.tokenizer(text, return_tensors="np")["input_ids"]
)
inputs = {"input_ids": input_ids, "pixel_values": None, "mask": None}
tokens = generate(
model, processor, inputs, max_tokens=args.max_tokens, temp=args.temp
)
if args.tts and tokens:
generate_tts(model, processor, tokens, inputs, args)
print()
def main():
parser = argparse.ArgumentParser(
description="MiniCPM-o 4.5 MLX Chat",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""Examples:
python chat_minicpmo.py photo.jpg -p "What's in this image?"
python chat_minicpmo.py --audio speech.wav -p "Transcribe this."
python chat_minicpmo.py --audio speech.wav # interactive with audio
python chat_minicpmo.py --live # full duplex streaming
python chat_minicpmo.py --live --capture-region 0,0,1920,1080
python chat_minicpmo.py # interactive mode
""",
)
parser.add_argument("file", nargs="?", help="Image file (optional)")
parser.add_argument("-p", "--prompt", default=None, help="Prompt (interactive if omitted)")
parser.add_argument(
"-m", "--model", default="./minicpmo-mlx", help="MLX model path"
)
parser.add_argument("--audio", default=None, help="Audio file (.wav) for speech input")
parser.add_argument("--tts", action="store_true", help="Enable TTS speech output")
parser.add_argument("--tts-output", default="output.wav", help="TTS output .wav path")
parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens")
parser.add_argument("--temp", type=float, default=0.0, help="Temperature")
parser.add_argument("--max-slices", type=int, default=9, help="Max image slices")
parser.add_argument("--live", action="store_true", help="Full duplex streaming mode")
parser.add_argument("--capture-region", default=None, help="Screen region x,y,w,h (default: primary monitor)")
parser.add_argument("--audio-device", default="BlackHole", help="Audio input device (default: BlackHole)")
args = parser.parse_args()
print("Loading model...", flush=True)
model, processor = load(args.model, trust_remote_code=True)
print("Model ready.\n")
if args.live:
from streaming import run_live_mode
run_live_mode(model, processor, args)
elif args.prompt:
run_once(model, processor, args)
else:
run_interactive(model, processor, args)
if __name__ == "__main__":
main()