|
|
|
|
|
"""MiniCPM-o 4.5 MLX Chat — Image, Audio & TTS inference on Apple Silicon.""" |
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
import math |
|
|
import os |
|
|
import re |
|
|
import sys |
|
|
import time |
|
|
import warnings |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
warnings.filterwarnings("ignore") |
|
|
logging.getLogger("transformers").setLevel(logging.ERROR) |
|
|
|
|
|
import mlx.core as mx |
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL import Image |
|
|
from mlx_vlm import load |
|
|
from mlx_vlm.generate import generate_step |
|
|
|
|
|
THINK_RE = re.compile(r"<think>.*?</think>\s*", re.DOTALL) |
|
|
|
|
|
|
|
|
def process_image_inputs(processor, image, prompt, max_slice_nums=9): |
|
|
"""Process image + text through MiniCPM-o processor.""" |
|
|
|
|
|
text = f"<|im_start|>user\n<image>./</image>\n{prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" |
|
|
|
|
|
|
|
|
inputs = processor( |
|
|
text=text, |
|
|
images=[image], |
|
|
max_slice_nums=max_slice_nums, |
|
|
) |
|
|
|
|
|
|
|
|
input_ids = mx.array(inputs["input_ids"].numpy()) |
|
|
mask = mx.array(inputs["attention_mask"].numpy()) |
|
|
|
|
|
|
|
|
pixel_values_list = inputs["pixel_values"] |
|
|
tgt_sizes_list = inputs["tgt_sizes"] |
|
|
image_bound = inputs["image_bound"] |
|
|
|
|
|
|
|
|
all_pv = [] |
|
|
for batch_pvs in pixel_values_list: |
|
|
for pv in batch_pvs: |
|
|
|
|
|
pv_np = pv.numpy() |
|
|
pv_np = np.transpose(pv_np, (1, 2, 0)) |
|
|
all_pv.append(pv_np) |
|
|
|
|
|
|
|
|
if all_pv: |
|
|
max_h = max(p.shape[0] for p in all_pv) |
|
|
max_w = max(p.shape[1] for p in all_pv) |
|
|
padded = [] |
|
|
for p in all_pv: |
|
|
pad_h = max_h - p.shape[0] |
|
|
pad_w = max_w - p.shape[1] |
|
|
if pad_h > 0 or pad_w > 0: |
|
|
p = np.pad(p, ((0, pad_h), (0, pad_w), (0, 0))) |
|
|
padded.append(p) |
|
|
pixel_values = mx.array(np.stack(padded, axis=0)) |
|
|
else: |
|
|
pixel_values = None |
|
|
|
|
|
|
|
|
patch_attention_mask = None |
|
|
if pixel_values is not None: |
|
|
B = pixel_values.shape[0] |
|
|
total_patches = (pixel_values.shape[1] // 14) * (pixel_values.shape[2] // 14) |
|
|
patch_attention_mask = np.zeros((B, total_patches), dtype=bool) |
|
|
offset = 0 |
|
|
for ts_batch in tgt_sizes_list: |
|
|
if isinstance(ts_batch, torch.Tensor): |
|
|
for j in range(ts_batch.shape[0]): |
|
|
idx = offset + j |
|
|
if idx < B: |
|
|
h, w = int(ts_batch[j][0]), int(ts_batch[j][1]) |
|
|
n_patches = h * w |
|
|
patch_attention_mask[idx, :n_patches] = True |
|
|
offset += ts_batch.shape[0] |
|
|
patch_attention_mask = mx.array(patch_attention_mask) |
|
|
|
|
|
|
|
|
tgt_sizes = [] |
|
|
for ts_batch in tgt_sizes_list: |
|
|
if isinstance(ts_batch, torch.Tensor): |
|
|
tgt_sizes.append(ts_batch.numpy()) |
|
|
if tgt_sizes: |
|
|
tgt_sizes = mx.array(np.concatenate(tgt_sizes, axis=0).astype(np.int32)) |
|
|
else: |
|
|
tgt_sizes = None |
|
|
|
|
|
|
|
|
bounds = [] |
|
|
for batch_bounds in image_bound: |
|
|
if isinstance(batch_bounds, torch.Tensor) and batch_bounds.numel() > 0: |
|
|
bounds.extend(batch_bounds.numpy().tolist()) |
|
|
elif isinstance(batch_bounds, list): |
|
|
bounds.extend(batch_bounds) |
|
|
|
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"pixel_values": pixel_values, |
|
|
"mask": mask, |
|
|
"tgt_sizes": tgt_sizes, |
|
|
"image_bound": bounds, |
|
|
"patch_attention_mask": patch_attention_mask, |
|
|
} |
|
|
|
|
|
|
|
|
def process_audio_inputs(processor, audio_path, prompt, pool_step=5): |
|
|
"""Process audio + text through mel spectrogram extraction. |
|
|
|
|
|
Pipeline: |
|
|
1. Load audio file -> resample to 16kHz if needed |
|
|
2. WhisperFeatureExtractor -> mel spectrogram (80 bins) |
|
|
3. Build prompt with <|audio_start|><unk>...<|audio_end|> placeholders |
|
|
4. Compute audio_bound from placeholder token positions |
|
|
""" |
|
|
import soundfile as sf |
|
|
|
|
|
|
|
|
audio_data, sr = sf.read(audio_path, dtype="float32") |
|
|
if audio_data.ndim > 1: |
|
|
audio_data = audio_data.mean(axis=1) |
|
|
|
|
|
|
|
|
if sr != 16000: |
|
|
try: |
|
|
import librosa |
|
|
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000) |
|
|
except ImportError: |
|
|
print("Warning: librosa not installed, cannot resample. Install with: pip install librosa") |
|
|
print(f"Audio sample rate is {sr}Hz, expected 16000Hz.") |
|
|
|
|
|
|
|
|
from transformers import WhisperFeatureExtractor |
|
|
feature_extractor = WhisperFeatureExtractor( |
|
|
feature_size=80, sampling_rate=16000, n_fft=400, hop_length=160 |
|
|
) |
|
|
audio_input = feature_extractor( |
|
|
audio_data, |
|
|
sampling_rate=16000, |
|
|
return_tensors="pt", |
|
|
padding="max_length", |
|
|
return_attention_mask=True, |
|
|
) |
|
|
|
|
|
audio_feature = audio_input["input_features"] |
|
|
actual_len = audio_input["attention_mask"].sum(dim=1) |
|
|
audio_feature = audio_feature[:, :, :actual_len[0]] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
feature_lens = int(actual_len[0]) |
|
|
cnn_out = (feature_lens - 1) // 2 + 1 |
|
|
num_audio_tokens = (cnn_out - pool_step) // pool_step + 1 |
|
|
|
|
|
|
|
|
audio_placeholder = "<|audio_start|>" + "<unk>" * num_audio_tokens + "<|audio_end|>" |
|
|
|
|
|
|
|
|
text = f"<|im_start|>user\n{audio_placeholder}\n{prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" |
|
|
|
|
|
|
|
|
tokenized = processor.tokenizer(text, return_tensors="np") |
|
|
input_ids = mx.array(tokenized["input_ids"]) |
|
|
mask = mx.array(tokenized["attention_mask"]) |
|
|
|
|
|
|
|
|
input_ids_np = tokenized["input_ids"][0] |
|
|
audio_start_token = processor.tokenizer.convert_tokens_to_ids("<|audio_start|>") |
|
|
audio_end_token = processor.tokenizer.convert_tokens_to_ids("<|audio_end|>") |
|
|
|
|
|
audio_bounds = [] |
|
|
in_audio = False |
|
|
start_idx = None |
|
|
for i, tok in enumerate(input_ids_np): |
|
|
if tok == audio_start_token: |
|
|
in_audio = True |
|
|
start_idx = i + 1 |
|
|
elif tok == audio_end_token and in_audio: |
|
|
audio_bounds.append((start_idx, i)) |
|
|
in_audio = False |
|
|
|
|
|
|
|
|
audio_features = mx.array(audio_feature.numpy()) |
|
|
|
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"pixel_values": None, |
|
|
"mask": mask, |
|
|
"audio_features": audio_features, |
|
|
"audio_bound": audio_bounds, |
|
|
} |
|
|
|
|
|
|
|
|
def generate(model, processor, inputs, max_tokens=512, temp=0.0, no_think=True, stream=True): |
|
|
"""Generate text from model inputs with optional streaming.""" |
|
|
input_ids = inputs["input_ids"] |
|
|
pixel_values = inputs.get("pixel_values") |
|
|
mask = inputs.get("mask") |
|
|
|
|
|
|
|
|
kwargs = {} |
|
|
if inputs.get("tgt_sizes") is not None: |
|
|
kwargs["tgt_sizes"] = inputs["tgt_sizes"] |
|
|
if inputs.get("image_bound"): |
|
|
kwargs["image_bound"] = inputs["image_bound"] |
|
|
if inputs.get("patch_attention_mask") is not None: |
|
|
kwargs["patch_attention_mask"] = inputs["patch_attention_mask"] |
|
|
if inputs.get("audio_features") is not None: |
|
|
kwargs["audio_features"] = inputs["audio_features"] |
|
|
if inputs.get("audio_bound"): |
|
|
kwargs["audio_bound"] = inputs["audio_bound"] |
|
|
|
|
|
tokens = [] |
|
|
start = time.time() |
|
|
printed_len = 0 |
|
|
|
|
|
for n, (token, _logprobs) in enumerate( |
|
|
generate_step(input_ids, model, pixel_values, mask, temp=temp, **kwargs) |
|
|
): |
|
|
tok_val = token.item() if hasattr(token, "item") else int(token) |
|
|
tokens.append(tok_val) |
|
|
|
|
|
if n == 0: |
|
|
prompt_time = time.time() - start |
|
|
prompt_tps = input_ids.size / prompt_time |
|
|
gen_start = time.time() |
|
|
|
|
|
tok_str = processor.tokenizer.decode([tok_val]) |
|
|
|
|
|
if tok_str in ["<|im_end|>", "<|endoftext|>", "<|tts_eos|>"]: |
|
|
break |
|
|
if n >= max_tokens: |
|
|
break |
|
|
|
|
|
if stream and (n + 1) % 4 == 0: |
|
|
|
|
|
full_text = processor.tokenizer.decode(tokens, skip_special_tokens=True) |
|
|
if no_think: |
|
|
full_text = THINK_RE.sub("", full_text) |
|
|
|
|
|
if len(full_text) > printed_len: |
|
|
print(full_text[printed_len:], end="", flush=True) |
|
|
printed_len = len(full_text) |
|
|
|
|
|
gen_time = time.time() - gen_start |
|
|
|
|
|
if stream: |
|
|
|
|
|
full_text = processor.tokenizer.decode(tokens, skip_special_tokens=True) |
|
|
if no_think: |
|
|
full_text = THINK_RE.sub("", full_text) |
|
|
if len(full_text) > printed_len: |
|
|
print(full_text[printed_len:], end="", flush=True) |
|
|
print() |
|
|
else: |
|
|
raw = processor.tokenizer.decode(tokens, skip_special_tokens=True) |
|
|
if no_think: |
|
|
raw = THINK_RE.sub("", raw) |
|
|
print(raw.strip()) |
|
|
|
|
|
print( |
|
|
f"\n--- {input_ids.size} prompt tok @ {prompt_tps:.0f} t/s | " |
|
|
f"{len(tokens)} gen tok @ {len(tokens)/gen_time:.0f} t/s | " |
|
|
f"mem {mx.get_peak_memory()/1e9:.1f} GB ---" |
|
|
) |
|
|
|
|
|
return tokens |
|
|
|
|
|
|
|
|
def run_once(model, processor, args): |
|
|
"""Single-shot inference.""" |
|
|
if args.audio: |
|
|
inputs = process_audio_inputs(processor, args.audio, args.prompt or "What is being said?") |
|
|
elif args.file: |
|
|
image = Image.open(args.file).convert("RGB") |
|
|
inputs = process_image_inputs( |
|
|
processor, image, args.prompt, max_slice_nums=args.max_slices |
|
|
) |
|
|
else: |
|
|
|
|
|
text = f"<|im_start|>user\n{args.prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" |
|
|
input_ids = mx.array( |
|
|
processor.tokenizer(text, return_tensors="np")["input_ids"] |
|
|
) |
|
|
inputs = {"input_ids": input_ids, "pixel_values": None, "mask": None} |
|
|
|
|
|
tokens = generate(model, processor, inputs, max_tokens=args.max_tokens, temp=args.temp) |
|
|
|
|
|
|
|
|
if args.tts and tokens: |
|
|
generate_tts(model, processor, tokens, inputs, args) |
|
|
|
|
|
|
|
|
def generate_tts(model, processor, generated_tokens, inputs, args): |
|
|
"""Generate speech from model output via TTS pipeline.""" |
|
|
if not hasattr(model, "tts"): |
|
|
print("TTS model not loaded. Re-convert model with audio+TTS weights.") |
|
|
return |
|
|
|
|
|
|
|
|
tts_bos_token = processor.tokenizer.convert_tokens_to_ids("<|tts_bos|>") |
|
|
tts_eos_token = processor.tokenizer.convert_tokens_to_ids("<|tts_eos|>") |
|
|
|
|
|
|
|
|
input_ids_np = inputs["input_ids"].tolist() |
|
|
if isinstance(input_ids_np[0], list): |
|
|
input_ids_np = input_ids_np[0] |
|
|
full_sequence = input_ids_np + generated_tokens |
|
|
|
|
|
tts_bos_idx = -1 |
|
|
tts_eos_idx = None |
|
|
for i, tok in enumerate(full_sequence): |
|
|
if tok == tts_bos_token: |
|
|
tts_bos_idx = i + 1 |
|
|
elif tok == tts_eos_token: |
|
|
tts_eos_idx = i |
|
|
|
|
|
if tts_bos_idx == -1: |
|
|
print("No TTS markers found in output. Use a TTS prompt template.") |
|
|
return |
|
|
|
|
|
tts_bound = (tts_bos_idx, tts_eos_idx) |
|
|
print(f"Generating speech for tokens [{tts_bos_idx}:{tts_eos_idx}]...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
full_ids = mx.array([full_sequence]) |
|
|
hidden = model.language_model.model( |
|
|
full_ids, |
|
|
inputs_embeds=model.language_model.model.embed_tokens(full_ids), |
|
|
) |
|
|
|
|
|
|
|
|
token_ids = mx.array(full_sequence) |
|
|
audio_tokens = model.generate_speech( |
|
|
hidden_states=hidden[0], |
|
|
token_ids=token_ids, |
|
|
tts_bound=tts_bound, |
|
|
temperature=0.1, |
|
|
top_p=0.9, |
|
|
) |
|
|
|
|
|
|
|
|
audio_tokens_list = audio_tokens.squeeze(-1).squeeze(0).tolist() |
|
|
print(f"Generated {len(audio_tokens_list)} audio tokens.") |
|
|
|
|
|
output_path = args.tts_output or "output.wav" |
|
|
try: |
|
|
from stepaudio2 import Token2wav |
|
|
vocoder = Token2wav() |
|
|
wav_bytes = vocoder(audio_tokens_list, None) |
|
|
import soundfile as sf |
|
|
import io |
|
|
waveform, sr = sf.read(io.BytesIO(wav_bytes)) |
|
|
sf.write(output_path, waveform, sr) |
|
|
print(f"Speech saved to {output_path}") |
|
|
except ImportError: |
|
|
|
|
|
token_path = output_path.replace(".wav", "_tokens.npy") |
|
|
np.save(token_path, np.array(audio_tokens_list)) |
|
|
print(f"Token2wav not installed. Raw audio tokens saved to {token_path}") |
|
|
print("Install vocoder: pip install minicpmo-utils[all]") |
|
|
|
|
|
|
|
|
def run_interactive(model, processor, args): |
|
|
"""Interactive chat mode.""" |
|
|
current_file = args.file |
|
|
current_audio = args.audio |
|
|
print("MiniCPM-o 4.5 MLX Chat") |
|
|
print("Commands: /image <path> | /audio <path> | /live | /clear | /quit") |
|
|
if current_file: |
|
|
print(f"Loaded image: {current_file}") |
|
|
if current_audio: |
|
|
print(f"Loaded audio: {current_audio}") |
|
|
print() |
|
|
|
|
|
while True: |
|
|
try: |
|
|
prompt = input("You: ").strip() |
|
|
except (EOFError, KeyboardInterrupt): |
|
|
print("\nBye!") |
|
|
break |
|
|
|
|
|
if not prompt: |
|
|
continue |
|
|
if prompt.lower() in ("/quit", "/exit", "/q"): |
|
|
print("Bye!") |
|
|
break |
|
|
if prompt.lower() == "/clear": |
|
|
current_file = None |
|
|
current_audio = None |
|
|
print("Cleared.\n") |
|
|
continue |
|
|
if prompt.lower().startswith("/image "): |
|
|
current_file = prompt[7:].strip() |
|
|
current_audio = None |
|
|
print(f"Image loaded: {current_file}\n") |
|
|
continue |
|
|
if prompt.lower().startswith("/audio "): |
|
|
current_audio = prompt[7:].strip() |
|
|
current_file = None |
|
|
print(f"Audio loaded: {current_audio}\n") |
|
|
continue |
|
|
if prompt.lower() == "/live": |
|
|
from streaming import run_live_mode |
|
|
run_live_mode(model, processor, args) |
|
|
print() |
|
|
continue |
|
|
|
|
|
print() |
|
|
|
|
|
if current_audio: |
|
|
inputs = process_audio_inputs(processor, current_audio, prompt) |
|
|
elif current_file: |
|
|
image = Image.open(current_file).convert("RGB") |
|
|
inputs = process_image_inputs( |
|
|
processor, image, prompt, max_slice_nums=args.max_slices |
|
|
) |
|
|
else: |
|
|
text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" |
|
|
input_ids = mx.array( |
|
|
processor.tokenizer(text, return_tensors="np")["input_ids"] |
|
|
) |
|
|
inputs = {"input_ids": input_ids, "pixel_values": None, "mask": None} |
|
|
|
|
|
tokens = generate( |
|
|
model, processor, inputs, max_tokens=args.max_tokens, temp=args.temp |
|
|
) |
|
|
|
|
|
if args.tts and tokens: |
|
|
generate_tts(model, processor, tokens, inputs, args) |
|
|
|
|
|
print() |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="MiniCPM-o 4.5 MLX Chat", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog="""Examples: |
|
|
python chat_minicpmo.py photo.jpg -p "What's in this image?" |
|
|
python chat_minicpmo.py --audio speech.wav -p "Transcribe this." |
|
|
python chat_minicpmo.py --audio speech.wav # interactive with audio |
|
|
python chat_minicpmo.py --live # full duplex streaming |
|
|
python chat_minicpmo.py --live --capture-region 0,0,1920,1080 |
|
|
python chat_minicpmo.py # interactive mode |
|
|
""", |
|
|
) |
|
|
parser.add_argument("file", nargs="?", help="Image file (optional)") |
|
|
parser.add_argument("-p", "--prompt", default=None, help="Prompt (interactive if omitted)") |
|
|
parser.add_argument( |
|
|
"-m", "--model", default="./minicpmo-mlx", help="MLX model path" |
|
|
) |
|
|
parser.add_argument("--audio", default=None, help="Audio file (.wav) for speech input") |
|
|
parser.add_argument("--tts", action="store_true", help="Enable TTS speech output") |
|
|
parser.add_argument("--tts-output", default="output.wav", help="TTS output .wav path") |
|
|
parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens") |
|
|
parser.add_argument("--temp", type=float, default=0.0, help="Temperature") |
|
|
parser.add_argument("--max-slices", type=int, default=9, help="Max image slices") |
|
|
parser.add_argument("--live", action="store_true", help="Full duplex streaming mode") |
|
|
parser.add_argument("--capture-region", default=None, help="Screen region x,y,w,h (default: primary monitor)") |
|
|
parser.add_argument("--audio-device", default="BlackHole", help="Audio input device (default: BlackHole)") |
|
|
args = parser.parse_args() |
|
|
|
|
|
print("Loading model...", flush=True) |
|
|
model, processor = load(args.model, trust_remote_code=True) |
|
|
print("Model ready.\n") |
|
|
|
|
|
if args.live: |
|
|
from streaming import run_live_mode |
|
|
run_live_mode(model, processor, args) |
|
|
elif args.prompt: |
|
|
run_once(model, processor, args) |
|
|
else: |
|
|
run_interactive(model, processor, args) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|