|
|
"""
|
|
|
Unit tests for quantization utilities.
|
|
|
|
|
|
These tests are here to validate ternary quantization, scaling, and packing functions. Here are the following test cases:
|
|
|
|
|
|
TestAbsmaxScale (3 tests)
|
|
|
1. test_global_scale - Tests global absmax scaling computation
|
|
|
2. test_per_channel_scale - Tests per-channel (per-row) absmax scaling
|
|
|
3. test_zero_tensor - Validates behavior with zero tensors (numerical stability)
|
|
|
|
|
|
TestTernaryQuantize (3 tests)
|
|
|
1. test_quantization_values - Ensures output contains only {-1, 0, +1}
|
|
|
2. test_sign_preservation - Validates sign preservation for large values
|
|
|
3. test_threshold_behavior - Tests threshold-based zero assignment
|
|
|
|
|
|
TestWeightToTernary (3 tests)
|
|
|
1. test_output_shapes - Verifies correct output tensor shapes
|
|
|
2. test_per_channel_vs_global - Tests per-channel vs. global scaling modes
|
|
|
3. test_reconstruction_quality - Validates reconstruction error is reasonable
|
|
|
|
|
|
TestActivationQuantization (2 tests)
|
|
|
1. test_quantization_range - Tests 8-bit quantization range
|
|
|
2. test_per_token_scaling - Validates per-token vs. global scaling
|
|
|
|
|
|
TestDequantization (1 test)
|
|
|
1. test_dequantize_inverse - Tests quantize β dequantize inverse operation
|
|
|
|
|
|
TestBase3Packing (3 tests)
|
|
|
1. test_pack_unpack_roundtrip - Validates pack β unpack recovers original
|
|
|
2. test_memory_efficiency - Tests ~20x compression achievement
|
|
|
3. test_packing_with_padding - Tests padding for non-multiple-of-5 dimensions
|
|
|
|
|
|
TestCompressionUtilities (2 tests)
|
|
|
1. test_compression_ratio_calculation - Tests compression ratio computation
|
|
|
2. test_memory_savings_estimation - Validates memory savings estimation
|
|
|
|
|
|
TestQuantizationIntegration (2 tests)
|
|
|
1. test_full_quantization_pipeline - Tests dense β ternary β packed β unpacked
|
|
|
2. test_quantization_preserves_functionality - Validates quantized layer outputs
|
|
|
"""
|
|
|
|
|
|
import pytest
|
|
|
import torch
|
|
|
|
|
|
from bitlinear.quantization import (
|
|
|
absmax_scale,
|
|
|
ternary_quantize,
|
|
|
weight_to_ternary,
|
|
|
quantize_activations_absmax,
|
|
|
dequantize_scale,
|
|
|
)
|
|
|
from bitlinear.packing import (
|
|
|
pack_ternary_base3,
|
|
|
unpack_ternary_base3,
|
|
|
compute_compression_ratio,
|
|
|
estimate_memory_savings,
|
|
|
)
|
|
|
|
|
|
|
|
|
class TestAbsmaxScale:
|
|
|
"""Tests for absmax_scale function."""
|
|
|
|
|
|
def test_global_scale(self):
|
|
|
"""Test global absmax scaling."""
|
|
|
W = torch.tensor([[1.0, -2.0, 3.0], [4.0, -5.0, 6.0]])
|
|
|
scale = absmax_scale(W, dim=None)
|
|
|
assert torch.isclose(scale, torch.tensor(6.0))
|
|
|
|
|
|
def test_per_channel_scale(self):
|
|
|
"""Test per-channel (per-row) absmax scaling."""
|
|
|
W = torch.tensor([[1.0, -2.0, 3.0], [4.0, -5.0, 6.0]])
|
|
|
scale = absmax_scale(W, dim=1)
|
|
|
expected = torch.tensor([3.0, 6.0])
|
|
|
assert torch.allclose(scale, expected)
|
|
|
|
|
|
def test_zero_tensor(self):
|
|
|
"""Test behavior with zero tensor."""
|
|
|
W = torch.zeros(10, 10)
|
|
|
scale = absmax_scale(W, dim=None)
|
|
|
|
|
|
assert scale > 0
|
|
|
assert scale < 1e-4
|
|
|
|
|
|
|
|
|
class TestTernaryQuantize:
|
|
|
"""Tests for ternary_quantize function."""
|
|
|
|
|
|
def test_quantization_values(self):
|
|
|
"""Test that output contains only {-1, 0, +1}."""
|
|
|
W = torch.randn(100, 100)
|
|
|
W_ternary = ternary_quantize(W)
|
|
|
unique_values = torch.unique(W_ternary)
|
|
|
assert set(unique_values.tolist()).issubset({-1.0, 0.0, 1.0})
|
|
|
|
|
|
def test_sign_preservation(self):
|
|
|
"""Test that signs are preserved correctly."""
|
|
|
|
|
|
W = torch.tensor([[10.0, -10.0, 0.01], [-8.0, 8.0, -0.01]])
|
|
|
W_ternary = ternary_quantize(W)
|
|
|
|
|
|
assert W_ternary[0, 0] == 1.0
|
|
|
|
|
|
assert W_ternary[0, 1] == -1.0
|
|
|
assert W_ternary[1, 0] == -1.0
|
|
|
|
|
|
assert W_ternary[1, 1] == 1.0
|
|
|
|
|
|
def test_threshold_behavior(self):
|
|
|
"""Test that threshold determines zero assignment."""
|
|
|
|
|
|
W = torch.tensor([[10.0, 0.1, -10.0], [0.2, -0.2, 5.0]])
|
|
|
W_ternary = ternary_quantize(W)
|
|
|
|
|
|
|
|
|
assert 0.0 in W_ternary
|
|
|
|
|
|
|
|
|
class TestWeightToTernary:
|
|
|
"""Tests for weight_to_ternary function."""
|
|
|
|
|
|
def test_output_shapes(self):
|
|
|
"""Test that output shapes are correct."""
|
|
|
W = torch.randn(512, 768)
|
|
|
W_ternary, gamma = weight_to_ternary(W, per_channel=True)
|
|
|
assert W_ternary.shape == (512, 768)
|
|
|
assert gamma.shape == (512,)
|
|
|
|
|
|
def test_per_channel_vs_global(self):
|
|
|
"""Test difference between per-channel and global scaling."""
|
|
|
W = torch.randn(512, 768)
|
|
|
W_t_pc, gamma_pc = weight_to_ternary(W, per_channel=True)
|
|
|
W_t_g, gamma_g = weight_to_ternary(W, per_channel=False)
|
|
|
|
|
|
assert gamma_pc.shape == (512,)
|
|
|
assert gamma_g.shape == torch.Size([])
|
|
|
|
|
|
def test_reconstruction_quality(self):
|
|
|
"""Test that reconstruction W_ternary * gamma approximates W."""
|
|
|
W = torch.randn(512, 768)
|
|
|
W_ternary, gamma = weight_to_ternary(W, per_channel=True)
|
|
|
W_reconstructed = W_ternary * gamma.unsqueeze(1)
|
|
|
error = torch.norm(W - W_reconstructed) / torch.norm(W)
|
|
|
|
|
|
|
|
|
assert error < 1.0
|
|
|
|
|
|
|
|
|
class TestActivationQuantization:
|
|
|
"""Tests for activation quantization."""
|
|
|
|
|
|
def test_quantization_range(self):
|
|
|
"""Test that quantized values are in expected range."""
|
|
|
x = torch.randn(16, 32, 512)
|
|
|
x_quant = quantize_activations_absmax(x, bits=8, per_token=True)
|
|
|
|
|
|
assert x_quant.abs().max() <= x.abs().max() * 1.1
|
|
|
|
|
|
def test_per_token_scaling(self):
|
|
|
"""Test per-token vs. global scaling."""
|
|
|
x = torch.randn(16, 32, 512)
|
|
|
x_quant_per_token = quantize_activations_absmax(x, bits=8, per_token=True)
|
|
|
x_quant_global = quantize_activations_absmax(x, bits=8, per_token=False)
|
|
|
|
|
|
assert x_quant_per_token.shape == x.shape
|
|
|
assert x_quant_global.shape == x.shape
|
|
|
|
|
|
|
|
|
class TestDequantization:
|
|
|
"""Tests for dequantization."""
|
|
|
|
|
|
def test_dequantize_inverse(self):
|
|
|
"""Test that quantize β dequantize is approximately identity."""
|
|
|
W = torch.randn(512, 768)
|
|
|
W_quant, scale = weight_to_ternary(W, per_channel=True)
|
|
|
W_dequant = dequantize_scale(W_quant, scale)
|
|
|
|
|
|
W_expected = W_quant * scale.unsqueeze(1)
|
|
|
assert torch.allclose(W_dequant, W_expected)
|
|
|
|
|
|
|
|
|
class TestBase3Packing:
|
|
|
"""Tests for base-3 packing utilities."""
|
|
|
|
|
|
def test_pack_unpack_roundtrip(self):
|
|
|
"""Test that pack β unpack recovers original ternary weights."""
|
|
|
W_ternary = torch.randint(-1, 2, (512, 768)).float()
|
|
|
packed, shape = pack_ternary_base3(W_ternary)
|
|
|
W_unpacked = unpack_ternary_base3(packed, shape)
|
|
|
assert torch.allclose(W_ternary, W_unpacked)
|
|
|
|
|
|
def test_memory_efficiency(self):
|
|
|
"""Test that packing achieves expected compression."""
|
|
|
W_ternary = torch.randint(-1, 2, (512, 768)).float()
|
|
|
original_size = W_ternary.numel() * 4
|
|
|
|
|
|
packed, shape = pack_ternary_base3(W_ternary)
|
|
|
packed_size = packed.numel() * 1
|
|
|
|
|
|
compression = original_size / packed_size
|
|
|
|
|
|
assert compression > 15
|
|
|
|
|
|
def test_packing_with_padding(self):
|
|
|
"""Test packing when dimensions are not multiples of 5."""
|
|
|
|
|
|
for size in [(13, 17), (100, 203), (7, 11)]:
|
|
|
W_ternary = torch.randint(-1, 2, size).float()
|
|
|
packed, shape = pack_ternary_base3(W_ternary)
|
|
|
W_unpacked = unpack_ternary_base3(packed, shape)
|
|
|
assert torch.allclose(W_ternary, W_unpacked)
|
|
|
|
|
|
|
|
|
class TestCompressionUtilities:
|
|
|
"""Tests for compression ratio and memory estimation utilities."""
|
|
|
|
|
|
def test_compression_ratio_calculation(self):
|
|
|
"""Test compression ratio calculation."""
|
|
|
ratio = compute_compression_ratio(1024, 51)
|
|
|
assert abs(ratio - 20.0) < 0.5
|
|
|
|
|
|
def test_memory_savings_estimation(self):
|
|
|
"""Test memory savings estimation for layer."""
|
|
|
stats = estimate_memory_savings(768, 3072, num_layers=12)
|
|
|
assert 'float32_bytes' in stats
|
|
|
assert 'packed_bytes' in stats
|
|
|
assert 'savings_bytes' in stats
|
|
|
assert 'compression_ratio' in stats
|
|
|
assert stats['compression_ratio'] > 15
|
|
|
|
|
|
|
|
|
class TestQuantizationIntegration:
|
|
|
"""Integration tests for quantization pipeline."""
|
|
|
|
|
|
def test_full_quantization_pipeline(self):
|
|
|
"""Test complete pipeline: dense β ternary β packed β unpacked."""
|
|
|
|
|
|
W = torch.randn(128, 256)
|
|
|
|
|
|
|
|
|
W_ternary, gamma = weight_to_ternary(W, per_channel=True)
|
|
|
|
|
|
|
|
|
packed, shape = pack_ternary_base3(W_ternary)
|
|
|
|
|
|
|
|
|
W_unpacked = unpack_ternary_base3(packed, shape)
|
|
|
|
|
|
|
|
|
assert torch.allclose(W_ternary, W_unpacked)
|
|
|
assert set(W_unpacked.unique().tolist()).issubset({-1.0, 0.0, 1.0})
|
|
|
|
|
|
def test_quantization_preserves_functionality(self):
|
|
|
"""Test that quantized layer produces reasonable outputs."""
|
|
|
from bitlinear import BitLinear
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
dense = nn.Linear(256, 128)
|
|
|
|
|
|
|
|
|
x = torch.randn(16, 256)
|
|
|
out_dense = dense(x)
|
|
|
|
|
|
|
|
|
bitlinear = BitLinear.from_linear(dense)
|
|
|
out_quantized = bitlinear(x)
|
|
|
|
|
|
|
|
|
assert out_dense.shape == out_quantized.shape
|
|
|
|
|
|
|
|
|
|
|
|
correlation = torch.corrcoef(torch.stack([out_dense.flatten(), out_quantized.flatten()]))[0, 1]
|
|
|
assert correlation > 0.5
|
|
|
|