File size: 7,079 Bytes
fd8c8b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
"""
Quantization utilities for ternary weight representation.
This module implements the core quantization functions for converting
dense weights to ternary ({-1, 0, +1}) representation with appropriate
scaling factors.
"""
import torch
from typing import Tuple, Optional
def absmax_scale(tensor: torch.Tensor, dim: Optional[int] = None) -> torch.Tensor:
"""
Compute absmax scaling factor for quantization.
The absmax scale is:
scale = max(abs(tensor)) / Q_max
where Q_max is the maximum quantized value (e.g., 1 for ternary).
Args:
tensor: Input tensor to compute scale for
dim: Dimension to compute scale along (None = global, int = per-dim)
Returns:
Scaling factor(s)
Examples:
>>> W = torch.randn(512, 512)
>>> scale = absmax_scale(W, dim=0) # Per output channel
>>> scale.shape
torch.Size([512])
"""
if dim is None:
# Global absmax
scale = torch.max(torch.abs(tensor))
else:
# Per-dimension absmax
scale = torch.max(torch.abs(tensor), dim=dim, keepdim=True)[0]
# Remove keepdim for cleaner output
scale = scale.squeeze(dim)
# Add small epsilon to avoid division by zero
scale = torch.clamp(scale, min=1e-5)
return scale
def ternary_quantize(
tensor: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Quantize tensor to ternary values {-1, 0, +1}.
Uses a threshold-based approach:
- Values > threshold → +1
- Values < -threshold → -1
- Values in [-threshold, threshold] → 0
The threshold is typically computed as a fraction of the scale.
Args:
tensor: Input tensor to quantize
scale: Optional pre-computed scale (if None, compute from tensor)
Returns:
Ternary tensor with values in {-1, 0, +1}
Notes:
- The threshold determines sparsity (more zeros)
- Common thresholds: 0.33 * scale or 0.5 * scale
- Inspired by BitNet's weight quantization scheme
"""
# Compute scale if not provided
if scale is None:
scale = absmax_scale(tensor, dim=None)
# Compute threshold (using 0.5 as a reasonable default)
# This can be tuned: smaller threshold = more zeros (more sparse)
threshold = 0.5 * scale
# Ensure scale and threshold have proper shape for broadcasting
if scale.dim() > 0:
# Add dimensions to match tensor shape for broadcasting
while threshold.dim() < tensor.dim():
threshold = threshold.unsqueeze(-1)
# Initialize ternary tensor with zeros
ternary = torch.zeros_like(tensor)
# Assign +1 and -1 based on threshold
ternary[tensor > threshold] = 1
ternary[tensor < -threshold] = -1
return ternary
def weight_to_ternary(
W: torch.Tensor,
per_channel: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convert dense weights to ternary representation with scaling.
This is the main quantization function that combines:
1. Scale computation (absmax per channel or global)
2. Ternary quantization
3. Return both quantized weights and scales
Args:
W: Dense weight matrix of shape [out_features, in_features]
per_channel: If True, use per-output-channel scaling (recommended)
Returns:
W_ternary: Ternary weight matrix (values in {-1, 0, +1})
gamma: Scaling factors (shape [out_features] or scalar)
Examples:
>>> W = torch.randn(512, 768)
>>> W_t, gamma = weight_to_ternary(W, per_channel=True)
>>> W_reconstructed = W_t * gamma.unsqueeze(1)
>>> error = torch.norm(W - W_reconstructed)
Notes:
- Per-channel scaling preserves output scale better
- The scaling factor gamma compensates for quantization
- This function is used during layer initialization/conversion
"""
if per_channel:
# Compute scale per output channel (along dimension 1)
# W is [out_features, in_features], so dim=1 gives scale per output
gamma = absmax_scale(W, dim=1)
else:
# Global scale for entire weight matrix
gamma = absmax_scale(W, dim=None)
# Quantize to ternary using the computed scale
W_ternary = ternary_quantize(W, gamma)
return W_ternary, gamma
def quantize_activations_absmax(
x: torch.Tensor,
bits: int = 8,
per_token: bool = True,
) -> torch.Tensor:
"""
Quantize activations using absmax scaling.
BitNet quantizes both weights (ternary) and activations (8-bit).
This function implements activation quantization with per-token scaling.
Args:
x: Input activations of shape [batch, seq_len, features]
bits: Number of bits for quantization (default: 8)
per_token: If True, scale per token; if False, global scaling
Returns:
Quantized activations (as float, simulating INT8)
Notes:
- Per-token scaling is important for handling outliers
- Returns float for autograd compatibility
- Simulates quantization without actual int8 storage
"""
# Calculate quantization range based on bits
Q_max = 2 ** (bits - 1) - 1 # For 8-bit: 127
Q_min = -Q_max # -127
if per_token:
# Compute scale per token (across feature dimension)
# x shape: [batch, seq_len, features]
# Scale along last dimension, keeping dims for broadcasting
scale = torch.max(torch.abs(x), dim=-1, keepdim=True)[0]
scale = torch.clamp(scale, min=1e-5) # Avoid division by zero
else:
# Global scale for entire tensor
scale = torch.max(torch.abs(x))
scale = torch.clamp(scale, min=1e-5)
# Quantize: scale to [-Q_max, Q_max], round, and scale back
x_scaled = x / scale * Q_max
x_quant = torch.clamp(x_scaled, Q_min, Q_max)
x_quant = torch.round(x_quant)
# Dequantize back to float (simulating int8 → float32 for autograd)
x_dequant = x_quant * scale / Q_max
return x_dequant
def dequantize_scale(
x_quant: torch.Tensor,
scale: torch.Tensor,
) -> torch.Tensor:
"""
Dequantize tensor back to float using scale.
Simple helper for:
x_float = x_quant * scale
Args:
x_quant: Quantized tensor (ternary or int8)
scale: Scaling factors
Returns:
Dequantized float tensor
"""
# Ensure scale has proper shape for broadcasting
if scale.dim() > 0 and scale.dim() < x_quant.dim():
# Add dimensions to the right to match x_quant shape
while scale.dim() < x_quant.dim():
scale = scale.unsqueeze(-1)
return x_quant * scale
|