BitLinear / bitlinear /quantization.py
krisaujla's picture
Upload folder using huggingface_hub
fd8c8b9 verified
"""
Quantization utilities for ternary weight representation.
This module implements the core quantization functions for converting
dense weights to ternary ({-1, 0, +1}) representation with appropriate
scaling factors.
"""
import torch
from typing import Tuple, Optional
def absmax_scale(tensor: torch.Tensor, dim: Optional[int] = None) -> torch.Tensor:
"""
Compute absmax scaling factor for quantization.
The absmax scale is:
scale = max(abs(tensor)) / Q_max
where Q_max is the maximum quantized value (e.g., 1 for ternary).
Args:
tensor: Input tensor to compute scale for
dim: Dimension to compute scale along (None = global, int = per-dim)
Returns:
Scaling factor(s)
Examples:
>>> W = torch.randn(512, 512)
>>> scale = absmax_scale(W, dim=0) # Per output channel
>>> scale.shape
torch.Size([512])
"""
if dim is None:
# Global absmax
scale = torch.max(torch.abs(tensor))
else:
# Per-dimension absmax
scale = torch.max(torch.abs(tensor), dim=dim, keepdim=True)[0]
# Remove keepdim for cleaner output
scale = scale.squeeze(dim)
# Add small epsilon to avoid division by zero
scale = torch.clamp(scale, min=1e-5)
return scale
def ternary_quantize(
tensor: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Quantize tensor to ternary values {-1, 0, +1}.
Uses a threshold-based approach:
- Values > threshold → +1
- Values < -threshold → -1
- Values in [-threshold, threshold] → 0
The threshold is typically computed as a fraction of the scale.
Args:
tensor: Input tensor to quantize
scale: Optional pre-computed scale (if None, compute from tensor)
Returns:
Ternary tensor with values in {-1, 0, +1}
Notes:
- The threshold determines sparsity (more zeros)
- Common thresholds: 0.33 * scale or 0.5 * scale
- Inspired by BitNet's weight quantization scheme
"""
# Compute scale if not provided
if scale is None:
scale = absmax_scale(tensor, dim=None)
# Compute threshold (using 0.5 as a reasonable default)
# This can be tuned: smaller threshold = more zeros (more sparse)
threshold = 0.5 * scale
# Ensure scale and threshold have proper shape for broadcasting
if scale.dim() > 0:
# Add dimensions to match tensor shape for broadcasting
while threshold.dim() < tensor.dim():
threshold = threshold.unsqueeze(-1)
# Initialize ternary tensor with zeros
ternary = torch.zeros_like(tensor)
# Assign +1 and -1 based on threshold
ternary[tensor > threshold] = 1
ternary[tensor < -threshold] = -1
return ternary
def weight_to_ternary(
W: torch.Tensor,
per_channel: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convert dense weights to ternary representation with scaling.
This is the main quantization function that combines:
1. Scale computation (absmax per channel or global)
2. Ternary quantization
3. Return both quantized weights and scales
Args:
W: Dense weight matrix of shape [out_features, in_features]
per_channel: If True, use per-output-channel scaling (recommended)
Returns:
W_ternary: Ternary weight matrix (values in {-1, 0, +1})
gamma: Scaling factors (shape [out_features] or scalar)
Examples:
>>> W = torch.randn(512, 768)
>>> W_t, gamma = weight_to_ternary(W, per_channel=True)
>>> W_reconstructed = W_t * gamma.unsqueeze(1)
>>> error = torch.norm(W - W_reconstructed)
Notes:
- Per-channel scaling preserves output scale better
- The scaling factor gamma compensates for quantization
- This function is used during layer initialization/conversion
"""
if per_channel:
# Compute scale per output channel (along dimension 1)
# W is [out_features, in_features], so dim=1 gives scale per output
gamma = absmax_scale(W, dim=1)
else:
# Global scale for entire weight matrix
gamma = absmax_scale(W, dim=None)
# Quantize to ternary using the computed scale
W_ternary = ternary_quantize(W, gamma)
return W_ternary, gamma
def quantize_activations_absmax(
x: torch.Tensor,
bits: int = 8,
per_token: bool = True,
) -> torch.Tensor:
"""
Quantize activations using absmax scaling.
BitNet quantizes both weights (ternary) and activations (8-bit).
This function implements activation quantization with per-token scaling.
Args:
x: Input activations of shape [batch, seq_len, features]
bits: Number of bits for quantization (default: 8)
per_token: If True, scale per token; if False, global scaling
Returns:
Quantized activations (as float, simulating INT8)
Notes:
- Per-token scaling is important for handling outliers
- Returns float for autograd compatibility
- Simulates quantization without actual int8 storage
"""
# Calculate quantization range based on bits
Q_max = 2 ** (bits - 1) - 1 # For 8-bit: 127
Q_min = -Q_max # -127
if per_token:
# Compute scale per token (across feature dimension)
# x shape: [batch, seq_len, features]
# Scale along last dimension, keeping dims for broadcasting
scale = torch.max(torch.abs(x), dim=-1, keepdim=True)[0]
scale = torch.clamp(scale, min=1e-5) # Avoid division by zero
else:
# Global scale for entire tensor
scale = torch.max(torch.abs(x))
scale = torch.clamp(scale, min=1e-5)
# Quantize: scale to [-Q_max, Q_max], round, and scale back
x_scaled = x / scale * Q_max
x_quant = torch.clamp(x_scaled, Q_min, Q_max)
x_quant = torch.round(x_quant)
# Dequantize back to float (simulating int8 → float32 for autograd)
x_dequant = x_quant * scale / Q_max
return x_dequant
def dequantize_scale(
x_quant: torch.Tensor,
scale: torch.Tensor,
) -> torch.Tensor:
"""
Dequantize tensor back to float using scale.
Simple helper for:
x_float = x_quant * scale
Args:
x_quant: Quantized tensor (ternary or int8)
scale: Scaling factors
Returns:
Dequantized float tensor
"""
# Ensure scale has proper shape for broadcasting
if scale.dim() > 0 and scale.dim() < x_quant.dim():
# Add dimensions to the right to match x_quant shape
while scale.dim() < x_quant.dim():
scale = scale.unsqueeze(-1)
return x_quant * scale