""" Quantization utilities for ternary weight representation. This module implements the core quantization functions for converting dense weights to ternary ({-1, 0, +1}) representation with appropriate scaling factors. """ import torch from typing import Tuple, Optional def absmax_scale(tensor: torch.Tensor, dim: Optional[int] = None) -> torch.Tensor: """ Compute absmax scaling factor for quantization. The absmax scale is: scale = max(abs(tensor)) / Q_max where Q_max is the maximum quantized value (e.g., 1 for ternary). Args: tensor: Input tensor to compute scale for dim: Dimension to compute scale along (None = global, int = per-dim) Returns: Scaling factor(s) Examples: >>> W = torch.randn(512, 512) >>> scale = absmax_scale(W, dim=0) # Per output channel >>> scale.shape torch.Size([512]) """ if dim is None: # Global absmax scale = torch.max(torch.abs(tensor)) else: # Per-dimension absmax scale = torch.max(torch.abs(tensor), dim=dim, keepdim=True)[0] # Remove keepdim for cleaner output scale = scale.squeeze(dim) # Add small epsilon to avoid division by zero scale = torch.clamp(scale, min=1e-5) return scale def ternary_quantize( tensor: torch.Tensor, scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Quantize tensor to ternary values {-1, 0, +1}. Uses a threshold-based approach: - Values > threshold → +1 - Values < -threshold → -1 - Values in [-threshold, threshold] → 0 The threshold is typically computed as a fraction of the scale. Args: tensor: Input tensor to quantize scale: Optional pre-computed scale (if None, compute from tensor) Returns: Ternary tensor with values in {-1, 0, +1} Notes: - The threshold determines sparsity (more zeros) - Common thresholds: 0.33 * scale or 0.5 * scale - Inspired by BitNet's weight quantization scheme """ # Compute scale if not provided if scale is None: scale = absmax_scale(tensor, dim=None) # Compute threshold (using 0.5 as a reasonable default) # This can be tuned: smaller threshold = more zeros (more sparse) threshold = 0.5 * scale # Ensure scale and threshold have proper shape for broadcasting if scale.dim() > 0: # Add dimensions to match tensor shape for broadcasting while threshold.dim() < tensor.dim(): threshold = threshold.unsqueeze(-1) # Initialize ternary tensor with zeros ternary = torch.zeros_like(tensor) # Assign +1 and -1 based on threshold ternary[tensor > threshold] = 1 ternary[tensor < -threshold] = -1 return ternary def weight_to_ternary( W: torch.Tensor, per_channel: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Convert dense weights to ternary representation with scaling. This is the main quantization function that combines: 1. Scale computation (absmax per channel or global) 2. Ternary quantization 3. Return both quantized weights and scales Args: W: Dense weight matrix of shape [out_features, in_features] per_channel: If True, use per-output-channel scaling (recommended) Returns: W_ternary: Ternary weight matrix (values in {-1, 0, +1}) gamma: Scaling factors (shape [out_features] or scalar) Examples: >>> W = torch.randn(512, 768) >>> W_t, gamma = weight_to_ternary(W, per_channel=True) >>> W_reconstructed = W_t * gamma.unsqueeze(1) >>> error = torch.norm(W - W_reconstructed) Notes: - Per-channel scaling preserves output scale better - The scaling factor gamma compensates for quantization - This function is used during layer initialization/conversion """ if per_channel: # Compute scale per output channel (along dimension 1) # W is [out_features, in_features], so dim=1 gives scale per output gamma = absmax_scale(W, dim=1) else: # Global scale for entire weight matrix gamma = absmax_scale(W, dim=None) # Quantize to ternary using the computed scale W_ternary = ternary_quantize(W, gamma) return W_ternary, gamma def quantize_activations_absmax( x: torch.Tensor, bits: int = 8, per_token: bool = True, ) -> torch.Tensor: """ Quantize activations using absmax scaling. BitNet quantizes both weights (ternary) and activations (8-bit). This function implements activation quantization with per-token scaling. Args: x: Input activations of shape [batch, seq_len, features] bits: Number of bits for quantization (default: 8) per_token: If True, scale per token; if False, global scaling Returns: Quantized activations (as float, simulating INT8) Notes: - Per-token scaling is important for handling outliers - Returns float for autograd compatibility - Simulates quantization without actual int8 storage """ # Calculate quantization range based on bits Q_max = 2 ** (bits - 1) - 1 # For 8-bit: 127 Q_min = -Q_max # -127 if per_token: # Compute scale per token (across feature dimension) # x shape: [batch, seq_len, features] # Scale along last dimension, keeping dims for broadcasting scale = torch.max(torch.abs(x), dim=-1, keepdim=True)[0] scale = torch.clamp(scale, min=1e-5) # Avoid division by zero else: # Global scale for entire tensor scale = torch.max(torch.abs(x)) scale = torch.clamp(scale, min=1e-5) # Quantize: scale to [-Q_max, Q_max], round, and scale back x_scaled = x / scale * Q_max x_quant = torch.clamp(x_scaled, Q_min, Q_max) x_quant = torch.round(x_quant) # Dequantize back to float (simulating int8 → float32 for autograd) x_dequant = x_quant * scale / Q_max return x_dequant def dequantize_scale( x_quant: torch.Tensor, scale: torch.Tensor, ) -> torch.Tensor: """ Dequantize tensor back to float using scale. Simple helper for: x_float = x_quant * scale Args: x_quant: Quantized tensor (ternary or int8) scale: Scaling factors Returns: Dequantized float tensor """ # Ensure scale has proper shape for broadcasting if scale.dim() > 0 and scale.dim() < x_quant.dim(): # Add dimensions to the right to match x_quant shape while scale.dim() < x_quant.dim(): scale = scale.unsqueeze(-1) return x_quant * scale