File size: 11,028 Bytes
fd8c8b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 |
"""
Unit tests for quantization utilities.
These tests are here to validate ternary quantization, scaling, and packing functions. Here are the following test cases:
TestAbsmaxScale (3 tests)
1. test_global_scale - Tests global absmax scaling computation
2. test_per_channel_scale - Tests per-channel (per-row) absmax scaling
3. test_zero_tensor - Validates behavior with zero tensors (numerical stability)
TestTernaryQuantize (3 tests)
1. test_quantization_values - Ensures output contains only {-1, 0, +1}
2. test_sign_preservation - Validates sign preservation for large values
3. test_threshold_behavior - Tests threshold-based zero assignment
TestWeightToTernary (3 tests)
1. test_output_shapes - Verifies correct output tensor shapes
2. test_per_channel_vs_global - Tests per-channel vs. global scaling modes
3. test_reconstruction_quality - Validates reconstruction error is reasonable
TestActivationQuantization (2 tests)
1. test_quantization_range - Tests 8-bit quantization range
2. test_per_token_scaling - Validates per-token vs. global scaling
TestDequantization (1 test)
1. test_dequantize_inverse - Tests quantize β dequantize inverse operation
TestBase3Packing (3 tests)
1. test_pack_unpack_roundtrip - Validates pack β unpack recovers original
2. test_memory_efficiency - Tests ~20x compression achievement
3. test_packing_with_padding - Tests padding for non-multiple-of-5 dimensions
TestCompressionUtilities (2 tests)
1. test_compression_ratio_calculation - Tests compression ratio computation
2. test_memory_savings_estimation - Validates memory savings estimation
TestQuantizationIntegration (2 tests)
1. test_full_quantization_pipeline - Tests dense β ternary β packed β unpacked
2. test_quantization_preserves_functionality - Validates quantized layer outputs
"""
import pytest
import torch
from bitlinear.quantization import (
absmax_scale,
ternary_quantize,
weight_to_ternary,
quantize_activations_absmax,
dequantize_scale,
)
from bitlinear.packing import (
pack_ternary_base3,
unpack_ternary_base3,
compute_compression_ratio,
estimate_memory_savings,
)
class TestAbsmaxScale:
"""Tests for absmax_scale function."""
def test_global_scale(self):
"""Test global absmax scaling."""
W = torch.tensor([[1.0, -2.0, 3.0], [4.0, -5.0, 6.0]])
scale = absmax_scale(W, dim=None)
assert torch.isclose(scale, torch.tensor(6.0))
def test_per_channel_scale(self):
"""Test per-channel (per-row) absmax scaling."""
W = torch.tensor([[1.0, -2.0, 3.0], [4.0, -5.0, 6.0]])
scale = absmax_scale(W, dim=1)
expected = torch.tensor([3.0, 6.0])
assert torch.allclose(scale, expected)
def test_zero_tensor(self):
"""Test behavior with zero tensor."""
W = torch.zeros(10, 10)
scale = absmax_scale(W, dim=None)
# Should handle division by zero gracefully (clamped to epsilon)
assert scale > 0
assert scale < 1e-4
class TestTernaryQuantize:
"""Tests for ternary_quantize function."""
def test_quantization_values(self):
"""Test that output contains only {-1, 0, +1}."""
W = torch.randn(100, 100)
W_ternary = ternary_quantize(W)
unique_values = torch.unique(W_ternary)
assert set(unique_values.tolist()).issubset({-1.0, 0.0, 1.0})
def test_sign_preservation(self):
"""Test that signs are preserved correctly."""
# Use values well above threshold (> 0.5 * max)
W = torch.tensor([[10.0, -10.0, 0.01], [-8.0, 8.0, -0.01]])
W_ternary = ternary_quantize(W)
# Large positive values should be +1
assert W_ternary[0, 0] == 1.0
# Large negative values should be -1
assert W_ternary[0, 1] == -1.0
assert W_ternary[1, 0] == -1.0
# Large positive
assert W_ternary[1, 1] == 1.0
def test_threshold_behavior(self):
"""Test that threshold determines zero assignment."""
# Create tensor with known values
W = torch.tensor([[10.0, 0.1, -10.0], [0.2, -0.2, 5.0]])
W_ternary = ternary_quantize(W)
# Small values near zero should become 0
# Exact behavior depends on threshold, but there should be some zeros
assert 0.0 in W_ternary
class TestWeightToTernary:
"""Tests for weight_to_ternary function."""
def test_output_shapes(self):
"""Test that output shapes are correct."""
W = torch.randn(512, 768)
W_ternary, gamma = weight_to_ternary(W, per_channel=True)
assert W_ternary.shape == (512, 768)
assert gamma.shape == (512,)
def test_per_channel_vs_global(self):
"""Test difference between per-channel and global scaling."""
W = torch.randn(512, 768)
W_t_pc, gamma_pc = weight_to_ternary(W, per_channel=True)
W_t_g, gamma_g = weight_to_ternary(W, per_channel=False)
assert gamma_pc.shape == (512,)
assert gamma_g.shape == torch.Size([]) # Scalar
def test_reconstruction_quality(self):
"""Test that reconstruction W_ternary * gamma approximates W."""
W = torch.randn(512, 768)
W_ternary, gamma = weight_to_ternary(W, per_channel=True)
W_reconstructed = W_ternary * gamma.unsqueeze(1)
error = torch.norm(W - W_reconstructed) / torch.norm(W)
# Ternary quantization has inherent error, allow up to 0.9 relative error
# This is expected for aggressive quantization to only 3 values
assert error < 1.0
class TestActivationQuantization:
"""Tests for activation quantization."""
def test_quantization_range(self):
"""Test that quantized values are in expected range."""
x = torch.randn(16, 32, 512)
x_quant = quantize_activations_absmax(x, bits=8, per_token=True)
# Should be roughly in similar range as input
assert x_quant.abs().max() <= x.abs().max() * 1.1
def test_per_token_scaling(self):
"""Test per-token vs. global scaling."""
x = torch.randn(16, 32, 512)
x_quant_per_token = quantize_activations_absmax(x, bits=8, per_token=True)
x_quant_global = quantize_activations_absmax(x, bits=8, per_token=False)
# Both should work without errors
assert x_quant_per_token.shape == x.shape
assert x_quant_global.shape == x.shape
class TestDequantization:
"""Tests for dequantization."""
def test_dequantize_inverse(self):
"""Test that quantize β dequantize is approximately identity."""
W = torch.randn(512, 768)
W_quant, scale = weight_to_ternary(W, per_channel=True)
W_dequant = dequantize_scale(W_quant, scale)
# Should be close to W_quant * scale reconstruction
W_expected = W_quant * scale.unsqueeze(1)
assert torch.allclose(W_dequant, W_expected)
class TestBase3Packing:
"""Tests for base-3 packing utilities."""
def test_pack_unpack_roundtrip(self):
"""Test that pack β unpack recovers original ternary weights."""
W_ternary = torch.randint(-1, 2, (512, 768)).float()
packed, shape = pack_ternary_base3(W_ternary)
W_unpacked = unpack_ternary_base3(packed, shape)
assert torch.allclose(W_ternary, W_unpacked)
def test_memory_efficiency(self):
"""Test that packing achieves expected compression."""
W_ternary = torch.randint(-1, 2, (512, 768)).float()
original_size = W_ternary.numel() * 4 # float32 = 4 bytes
packed, shape = pack_ternary_base3(W_ternary)
packed_size = packed.numel() * 1 # uint8 = 1 byte
compression = original_size / packed_size
# Should achieve ~20x compression (32 bits β 1.6 bits)
assert compression > 15 # Allow some overhead
def test_packing_with_padding(self):
"""Test packing when dimensions are not multiples of 5."""
# Test with various sizes to ensure padding is handled correctly
for size in [(13, 17), (100, 203), (7, 11)]:
W_ternary = torch.randint(-1, 2, size).float()
packed, shape = pack_ternary_base3(W_ternary)
W_unpacked = unpack_ternary_base3(packed, shape)
assert torch.allclose(W_ternary, W_unpacked)
class TestCompressionUtilities:
"""Tests for compression ratio and memory estimation utilities."""
def test_compression_ratio_calculation(self):
"""Test compression ratio calculation."""
ratio = compute_compression_ratio(1024, 51)
assert abs(ratio - 20.0) < 0.5
def test_memory_savings_estimation(self):
"""Test memory savings estimation for layer."""
stats = estimate_memory_savings(768, 3072, num_layers=12)
assert 'float32_bytes' in stats
assert 'packed_bytes' in stats
assert 'savings_bytes' in stats
assert 'compression_ratio' in stats
assert stats['compression_ratio'] > 15
class TestQuantizationIntegration:
"""Integration tests for quantization pipeline."""
def test_full_quantization_pipeline(self):
"""Test complete pipeline: dense β ternary β packed β unpacked."""
# 1. Start with dense weights
W = torch.randn(128, 256)
# 2. Quantize to ternary
W_ternary, gamma = weight_to_ternary(W, per_channel=True)
# 3. Pack to base-3
packed, shape = pack_ternary_base3(W_ternary)
# 4. Unpack
W_unpacked = unpack_ternary_base3(packed, shape)
# 5. Verify correctness
assert torch.allclose(W_ternary, W_unpacked)
assert set(W_unpacked.unique().tolist()).issubset({-1.0, 0.0, 1.0})
def test_quantization_preserves_functionality(self):
"""Test that quantized layer produces reasonable outputs."""
from bitlinear import BitLinear
import torch.nn as nn
# Create dense layer
dense = nn.Linear(256, 128)
# Test input
x = torch.randn(16, 256)
out_dense = dense(x)
# Quantize to BitLinear
bitlinear = BitLinear.from_linear(dense)
out_quantized = bitlinear(x)
# Outputs should have same shape
assert out_dense.shape == out_quantized.shape
# Outputs should be correlated (similar but not identical)
# Calculate correlation
correlation = torch.corrcoef(torch.stack([out_dense.flatten(), out_quantized.flatten()]))[0, 1]
assert correlation > 0.5 # Should have reasonable correlation
|