#!/usr/bin/env python3 """ Convert DeepSeek-R1-Distill-Qwen-1.5B to ternary format. Stores linear weights as bitplanes (pos_mask, neg_mask) + per-row scale. Embeddings and layernorms stay FP16. LM head stays FP16. (c) 2026 OpenTransformers Ltd / Scott Bisset """ import os import json import struct import numpy as np from pathlib import Path import time def load_safetensors(model_dir): """Load all tensors from safetensors files.""" import torch; from safetensors.torch import load_file tensors = {} for f in sorted(Path(model_dir).glob("*.safetensors")): print(f"Loading {f.name}...") state = load_file(str(f)) for key, val in state.items(): tensors[key] = val.float().numpy() return tensors def quantize_row_ternary(row, alpha=0.7): """Quantize a single row to ternary {-1, 0, +1}. Vectorized bitpacking.""" row = row.astype(np.float32) mean_abs = np.mean(np.abs(row)) threshold = alpha * mean_abs pos = row >= threshold neg = row <= -threshold nz_mask = pos | neg scale = np.mean(np.abs(row[nz_mask])) if nz_mask.any() else np.float32(1.0) # Pad to multiple of 64 in_dim = len(row) pad = (64 - in_dim % 64) % 64 if pad: pos = np.concatenate([pos, np.zeros(pad, dtype=bool)]) neg = np.concatenate([neg, np.zeros(pad, dtype=bool)]) # Vectorized bitpack: reshape to [chunks, 64], multiply by bit positions, sum pos_r = pos.reshape(-1, 64).astype(np.uint64) neg_r = neg.reshape(-1, 64).astype(np.uint64) bit_positions = (np.uint64(1) << np.arange(64, dtype=np.uint64)) pos_bits = np.bitwise_or.reduce(pos_r * bit_positions, axis=1) neg_bits = np.bitwise_or.reduce(neg_r * bit_positions, axis=1) return pos_bits, neg_bits, np.float32(scale) return pos_bits, neg_bits, np.float32(scale) def quantize_weight_matrix(weight, alpha=0.7): """Quantize entire weight matrix [out_dim, in_dim] to ternary. Fully vectorized.""" w = weight.astype(np.float32) out_dim, in_dim = w.shape # Per-row thresholds row_means = np.mean(np.abs(w), axis=1, keepdims=True) thresholds = alpha * row_means pos = w >= thresholds # [out_dim, in_dim] neg = w <= -thresholds # Per-row scales nz = pos | neg # Use row means of absolute values where non-zero scales = np.zeros(out_dim, dtype=np.float32) for i in range(out_dim): if nz[i].any(): scales[i] = np.mean(np.abs(w[i, nz[i]])) else: scales[i] = 1.0 # Sparsity total = out_dim * in_dim sparsity = 1.0 - np.sum(nz) / total # Pad to multiple of 64 pad = (64 - in_dim % 64) % 64 if pad: pos = np.concatenate([pos, np.zeros((out_dim, pad), dtype=bool)], axis=1) neg = np.concatenate([neg, np.zeros((out_dim, pad), dtype=bool)], axis=1) padded_dim = pos.shape[1] chunks = padded_dim // 64 # Vectorized bitpacking for entire matrix at once bit_positions = (np.uint64(1) << np.arange(64, dtype=np.uint64)) # [64] pos_r = pos.reshape(out_dim, chunks, 64).astype(np.uint64) # [out, chunks, 64] neg_r = neg.reshape(out_dim, chunks, 64).astype(np.uint64) all_pos = np.bitwise_or.reduce(pos_r * bit_positions, axis=2) # [out, chunks] all_neg = np.bitwise_or.reduce(neg_r * bit_positions, axis=2) return all_pos, all_neg, scales, sparsity def save_ternary_model(tensors, output_dir, alpha=0.7): """Convert and save full model to ternary format.""" os.makedirs(output_dir, exist_ok=True) config = { "hidden_size": 1536, "intermediate_size": 8960, "num_attention_heads": 12, "num_key_value_heads": 2, "num_hidden_layers": 28, "vocab_size": 151936, "head_dim": 128, "rope_theta": 1000000.0, "rms_norm_eps": 1e-6, "alpha": alpha, } # Identify which tensors to ternarize vs keep as-is ternary_keys = [] # Linear weights to ternarize keep_keys = [] # Embeddings, norms, biases to keep as FP16 for key in tensors: if any(p in key for p in ['q_proj.weight', 'k_proj.weight', 'v_proj.weight', 'o_proj.weight', 'gate_proj.weight', 'up_proj.weight', 'down_proj.weight']): ternary_keys.append(key) else: keep_keys.append(key) print(f"\nTernary layers: {len(ternary_keys)}") print(f"FP16 layers: {len(keep_keys)}") # Save config with open(os.path.join(output_dir, "config.json"), "w") as f: json.dump(config, f, indent=2) # Save ternary weights total_ternary_bytes = 0 total_original_bytes = 0 for key in ternary_keys: w = tensors[key].astype(np.float32) out_dim, in_dim = w.shape total_original_bytes += w.nbytes t0 = time.time() pos, neg, scales, sparsity = quantize_weight_matrix(w, alpha) dt = time.time() - t0 # Save as binary prefix = os.path.join(output_dir, key.replace(".", "_")) pos.tofile(prefix + ".pos") neg.tofile(prefix + ".neg") scales.tofile(prefix + ".scales") ternary_bytes = pos.nbytes + neg.nbytes + scales.nbytes total_ternary_bytes += ternary_bytes ratio = w.nbytes / ternary_bytes print(f" {key}: {w.shape} -> ternary ({ternary_bytes/1024:.0f}KB, " f"{ratio:.1f}x compression, {sparsity:.1%} sparse, {dt:.1f}s)") # Save FP16 weights total_fp16_bytes = 0 for key in keep_keys: w = tensors[key].astype(np.float16) prefix = os.path.join(output_dir, key.replace(".", "_")) w.tofile(prefix + ".fp16") total_fp16_bytes += w.nbytes print(f" {key}: {w.shape} -> fp16 ({w.nbytes/1024:.0f}KB)") # Save tensor manifest manifest = { "ternary": {k: list(tensors[k].shape) for k in ternary_keys}, "fp16": {k: list(tensors[k].shape) for k in keep_keys}, } with open(os.path.join(output_dir, "manifest.json"), "w") as f: json.dump(manifest, f, indent=2) total_bytes = total_ternary_bytes + total_fp16_bytes orig_bytes = total_original_bytes + total_fp16_bytes print(f"\n=== Summary ===") print(f"Original FP32 linear weights: {total_original_bytes/1024/1024:.1f} MB") print(f"Ternary linear weights: {total_ternary_bytes/1024/1024:.1f} MB") print(f"FP16 other weights: {total_fp16_bytes/1024/1024:.1f} MB") print(f"Total model size: {total_bytes/1024/1024:.1f} MB") print(f"Compression vs FP32: {orig_bytes/total_bytes:.1f}x") if __name__ == "__main__": import sys model_dir = sys.argv[1] if len(sys.argv) > 1 else "deepseek-r1-1.5b-hf" output_dir = sys.argv[2] if len(sys.argv) > 2 else "deepseek-r1-1.5b-ternary" alpha = float(sys.argv[3]) if len(sys.argv) > 3 else 0.7 print(f"Loading model from {model_dir}...") tensors = load_safetensors(model_dir) print(f"Converting to ternary (alpha={alpha})...") save_ternary_model(tensors, output_dir, alpha) print("Done!")