|
|
"""
|
|
|
Quantization utilities for ternary weight representation.
|
|
|
|
|
|
This module implements the core quantization functions for converting
|
|
|
dense weights to ternary ({-1, 0, +1}) representation with appropriate
|
|
|
scaling factors.
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
from typing import Tuple, Optional
|
|
|
|
|
|
|
|
|
def absmax_scale(tensor: torch.Tensor, dim: Optional[int] = None) -> torch.Tensor:
|
|
|
"""
|
|
|
Compute absmax scaling factor for quantization.
|
|
|
|
|
|
The absmax scale is:
|
|
|
scale = max(abs(tensor)) / Q_max
|
|
|
|
|
|
where Q_max is the maximum quantized value (e.g., 1 for ternary).
|
|
|
|
|
|
Args:
|
|
|
tensor: Input tensor to compute scale for
|
|
|
dim: Dimension to compute scale along (None = global, int = per-dim)
|
|
|
|
|
|
Returns:
|
|
|
Scaling factor(s)
|
|
|
|
|
|
Examples:
|
|
|
>>> W = torch.randn(512, 512)
|
|
|
>>> scale = absmax_scale(W, dim=0) # Per output channel
|
|
|
>>> scale.shape
|
|
|
torch.Size([512])
|
|
|
"""
|
|
|
if dim is None:
|
|
|
|
|
|
scale = torch.max(torch.abs(tensor))
|
|
|
else:
|
|
|
|
|
|
scale = torch.max(torch.abs(tensor), dim=dim, keepdim=True)[0]
|
|
|
|
|
|
scale = scale.squeeze(dim)
|
|
|
|
|
|
|
|
|
scale = torch.clamp(scale, min=1e-5)
|
|
|
|
|
|
return scale
|
|
|
|
|
|
|
|
|
def ternary_quantize(
|
|
|
tensor: torch.Tensor,
|
|
|
scale: Optional[torch.Tensor] = None,
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
Quantize tensor to ternary values {-1, 0, +1}.
|
|
|
|
|
|
Uses a threshold-based approach:
|
|
|
- Values > threshold → +1
|
|
|
- Values < -threshold → -1
|
|
|
- Values in [-threshold, threshold] → 0
|
|
|
|
|
|
The threshold is typically computed as a fraction of the scale.
|
|
|
|
|
|
Args:
|
|
|
tensor: Input tensor to quantize
|
|
|
scale: Optional pre-computed scale (if None, compute from tensor)
|
|
|
|
|
|
Returns:
|
|
|
Ternary tensor with values in {-1, 0, +1}
|
|
|
|
|
|
Notes:
|
|
|
- The threshold determines sparsity (more zeros)
|
|
|
- Common thresholds: 0.33 * scale or 0.5 * scale
|
|
|
- Inspired by BitNet's weight quantization scheme
|
|
|
"""
|
|
|
|
|
|
if scale is None:
|
|
|
scale = absmax_scale(tensor, dim=None)
|
|
|
|
|
|
|
|
|
|
|
|
threshold = 0.5 * scale
|
|
|
|
|
|
|
|
|
if scale.dim() > 0:
|
|
|
|
|
|
while threshold.dim() < tensor.dim():
|
|
|
threshold = threshold.unsqueeze(-1)
|
|
|
|
|
|
|
|
|
ternary = torch.zeros_like(tensor)
|
|
|
|
|
|
|
|
|
ternary[tensor > threshold] = 1
|
|
|
ternary[tensor < -threshold] = -1
|
|
|
|
|
|
return ternary
|
|
|
|
|
|
|
|
|
def weight_to_ternary(
|
|
|
W: torch.Tensor,
|
|
|
per_channel: bool = True,
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
"""
|
|
|
Convert dense weights to ternary representation with scaling.
|
|
|
|
|
|
This is the main quantization function that combines:
|
|
|
1. Scale computation (absmax per channel or global)
|
|
|
2. Ternary quantization
|
|
|
3. Return both quantized weights and scales
|
|
|
|
|
|
Args:
|
|
|
W: Dense weight matrix of shape [out_features, in_features]
|
|
|
per_channel: If True, use per-output-channel scaling (recommended)
|
|
|
|
|
|
Returns:
|
|
|
W_ternary: Ternary weight matrix (values in {-1, 0, +1})
|
|
|
gamma: Scaling factors (shape [out_features] or scalar)
|
|
|
|
|
|
Examples:
|
|
|
>>> W = torch.randn(512, 768)
|
|
|
>>> W_t, gamma = weight_to_ternary(W, per_channel=True)
|
|
|
>>> W_reconstructed = W_t * gamma.unsqueeze(1)
|
|
|
>>> error = torch.norm(W - W_reconstructed)
|
|
|
|
|
|
Notes:
|
|
|
- Per-channel scaling preserves output scale better
|
|
|
- The scaling factor gamma compensates for quantization
|
|
|
- This function is used during layer initialization/conversion
|
|
|
"""
|
|
|
if per_channel:
|
|
|
|
|
|
|
|
|
gamma = absmax_scale(W, dim=1)
|
|
|
else:
|
|
|
|
|
|
gamma = absmax_scale(W, dim=None)
|
|
|
|
|
|
|
|
|
W_ternary = ternary_quantize(W, gamma)
|
|
|
|
|
|
return W_ternary, gamma
|
|
|
|
|
|
|
|
|
def quantize_activations_absmax(
|
|
|
x: torch.Tensor,
|
|
|
bits: int = 8,
|
|
|
per_token: bool = True,
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
Quantize activations using absmax scaling.
|
|
|
|
|
|
BitNet quantizes both weights (ternary) and activations (8-bit).
|
|
|
This function implements activation quantization with per-token scaling.
|
|
|
|
|
|
Args:
|
|
|
x: Input activations of shape [batch, seq_len, features]
|
|
|
bits: Number of bits for quantization (default: 8)
|
|
|
per_token: If True, scale per token; if False, global scaling
|
|
|
|
|
|
Returns:
|
|
|
Quantized activations (as float, simulating INT8)
|
|
|
|
|
|
Notes:
|
|
|
- Per-token scaling is important for handling outliers
|
|
|
- Returns float for autograd compatibility
|
|
|
- Simulates quantization without actual int8 storage
|
|
|
"""
|
|
|
|
|
|
Q_max = 2 ** (bits - 1) - 1
|
|
|
Q_min = -Q_max
|
|
|
|
|
|
if per_token:
|
|
|
|
|
|
|
|
|
|
|
|
scale = torch.max(torch.abs(x), dim=-1, keepdim=True)[0]
|
|
|
scale = torch.clamp(scale, min=1e-5)
|
|
|
else:
|
|
|
|
|
|
scale = torch.max(torch.abs(x))
|
|
|
scale = torch.clamp(scale, min=1e-5)
|
|
|
|
|
|
|
|
|
x_scaled = x / scale * Q_max
|
|
|
x_quant = torch.clamp(x_scaled, Q_min, Q_max)
|
|
|
x_quant = torch.round(x_quant)
|
|
|
|
|
|
|
|
|
x_dequant = x_quant * scale / Q_max
|
|
|
|
|
|
return x_dequant
|
|
|
|
|
|
|
|
|
def dequantize_scale(
|
|
|
x_quant: torch.Tensor,
|
|
|
scale: torch.Tensor,
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
Dequantize tensor back to float using scale.
|
|
|
|
|
|
Simple helper for:
|
|
|
x_float = x_quant * scale
|
|
|
|
|
|
Args:
|
|
|
x_quant: Quantized tensor (ternary or int8)
|
|
|
scale: Scaling factors
|
|
|
|
|
|
Returns:
|
|
|
Dequantized float tensor
|
|
|
"""
|
|
|
|
|
|
if scale.dim() > 0 and scale.dim() < x_quant.dim():
|
|
|
|
|
|
while scale.dim() < x_quant.dim():
|
|
|
scale = scale.unsqueeze(-1)
|
|
|
|
|
|
return x_quant * scale
|
|
|
|