""" Unit tests for quantization utilities. These tests are here to validate ternary quantization, scaling, and packing functions. Here are the following test cases: TestAbsmaxScale (3 tests) 1. test_global_scale - Tests global absmax scaling computation 2. test_per_channel_scale - Tests per-channel (per-row) absmax scaling 3. test_zero_tensor - Validates behavior with zero tensors (numerical stability) TestTernaryQuantize (3 tests) 1. test_quantization_values - Ensures output contains only {-1, 0, +1} 2. test_sign_preservation - Validates sign preservation for large values 3. test_threshold_behavior - Tests threshold-based zero assignment TestWeightToTernary (3 tests) 1. test_output_shapes - Verifies correct output tensor shapes 2. test_per_channel_vs_global - Tests per-channel vs. global scaling modes 3. test_reconstruction_quality - Validates reconstruction error is reasonable TestActivationQuantization (2 tests) 1. test_quantization_range - Tests 8-bit quantization range 2. test_per_token_scaling - Validates per-token vs. global scaling TestDequantization (1 test) 1. test_dequantize_inverse - Tests quantize → dequantize inverse operation TestBase3Packing (3 tests) 1. test_pack_unpack_roundtrip - Validates pack → unpack recovers original 2. test_memory_efficiency - Tests ~20x compression achievement 3. test_packing_with_padding - Tests padding for non-multiple-of-5 dimensions TestCompressionUtilities (2 tests) 1. test_compression_ratio_calculation - Tests compression ratio computation 2. test_memory_savings_estimation - Validates memory savings estimation TestQuantizationIntegration (2 tests) 1. test_full_quantization_pipeline - Tests dense → ternary → packed → unpacked 2. test_quantization_preserves_functionality - Validates quantized layer outputs """ import pytest import torch from bitlinear.quantization import ( absmax_scale, ternary_quantize, weight_to_ternary, quantize_activations_absmax, dequantize_scale, ) from bitlinear.packing import ( pack_ternary_base3, unpack_ternary_base3, compute_compression_ratio, estimate_memory_savings, ) class TestAbsmaxScale: """Tests for absmax_scale function.""" def test_global_scale(self): """Test global absmax scaling.""" W = torch.tensor([[1.0, -2.0, 3.0], [4.0, -5.0, 6.0]]) scale = absmax_scale(W, dim=None) assert torch.isclose(scale, torch.tensor(6.0)) def test_per_channel_scale(self): """Test per-channel (per-row) absmax scaling.""" W = torch.tensor([[1.0, -2.0, 3.0], [4.0, -5.0, 6.0]]) scale = absmax_scale(W, dim=1) expected = torch.tensor([3.0, 6.0]) assert torch.allclose(scale, expected) def test_zero_tensor(self): """Test behavior with zero tensor.""" W = torch.zeros(10, 10) scale = absmax_scale(W, dim=None) # Should handle division by zero gracefully (clamped to epsilon) assert scale > 0 assert scale < 1e-4 class TestTernaryQuantize: """Tests for ternary_quantize function.""" def test_quantization_values(self): """Test that output contains only {-1, 0, +1}.""" W = torch.randn(100, 100) W_ternary = ternary_quantize(W) unique_values = torch.unique(W_ternary) assert set(unique_values.tolist()).issubset({-1.0, 0.0, 1.0}) def test_sign_preservation(self): """Test that signs are preserved correctly.""" # Use values well above threshold (> 0.5 * max) W = torch.tensor([[10.0, -10.0, 0.01], [-8.0, 8.0, -0.01]]) W_ternary = ternary_quantize(W) # Large positive values should be +1 assert W_ternary[0, 0] == 1.0 # Large negative values should be -1 assert W_ternary[0, 1] == -1.0 assert W_ternary[1, 0] == -1.0 # Large positive assert W_ternary[1, 1] == 1.0 def test_threshold_behavior(self): """Test that threshold determines zero assignment.""" # Create tensor with known values W = torch.tensor([[10.0, 0.1, -10.0], [0.2, -0.2, 5.0]]) W_ternary = ternary_quantize(W) # Small values near zero should become 0 # Exact behavior depends on threshold, but there should be some zeros assert 0.0 in W_ternary class TestWeightToTernary: """Tests for weight_to_ternary function.""" def test_output_shapes(self): """Test that output shapes are correct.""" W = torch.randn(512, 768) W_ternary, gamma = weight_to_ternary(W, per_channel=True) assert W_ternary.shape == (512, 768) assert gamma.shape == (512,) def test_per_channel_vs_global(self): """Test difference between per-channel and global scaling.""" W = torch.randn(512, 768) W_t_pc, gamma_pc = weight_to_ternary(W, per_channel=True) W_t_g, gamma_g = weight_to_ternary(W, per_channel=False) assert gamma_pc.shape == (512,) assert gamma_g.shape == torch.Size([]) # Scalar def test_reconstruction_quality(self): """Test that reconstruction W_ternary * gamma approximates W.""" W = torch.randn(512, 768) W_ternary, gamma = weight_to_ternary(W, per_channel=True) W_reconstructed = W_ternary * gamma.unsqueeze(1) error = torch.norm(W - W_reconstructed) / torch.norm(W) # Ternary quantization has inherent error, allow up to 0.9 relative error # This is expected for aggressive quantization to only 3 values assert error < 1.0 class TestActivationQuantization: """Tests for activation quantization.""" def test_quantization_range(self): """Test that quantized values are in expected range.""" x = torch.randn(16, 32, 512) x_quant = quantize_activations_absmax(x, bits=8, per_token=True) # Should be roughly in similar range as input assert x_quant.abs().max() <= x.abs().max() * 1.1 def test_per_token_scaling(self): """Test per-token vs. global scaling.""" x = torch.randn(16, 32, 512) x_quant_per_token = quantize_activations_absmax(x, bits=8, per_token=True) x_quant_global = quantize_activations_absmax(x, bits=8, per_token=False) # Both should work without errors assert x_quant_per_token.shape == x.shape assert x_quant_global.shape == x.shape class TestDequantization: """Tests for dequantization.""" def test_dequantize_inverse(self): """Test that quantize → dequantize is approximately identity.""" W = torch.randn(512, 768) W_quant, scale = weight_to_ternary(W, per_channel=True) W_dequant = dequantize_scale(W_quant, scale) # Should be close to W_quant * scale reconstruction W_expected = W_quant * scale.unsqueeze(1) assert torch.allclose(W_dequant, W_expected) class TestBase3Packing: """Tests for base-3 packing utilities.""" def test_pack_unpack_roundtrip(self): """Test that pack → unpack recovers original ternary weights.""" W_ternary = torch.randint(-1, 2, (512, 768)).float() packed, shape = pack_ternary_base3(W_ternary) W_unpacked = unpack_ternary_base3(packed, shape) assert torch.allclose(W_ternary, W_unpacked) def test_memory_efficiency(self): """Test that packing achieves expected compression.""" W_ternary = torch.randint(-1, 2, (512, 768)).float() original_size = W_ternary.numel() * 4 # float32 = 4 bytes packed, shape = pack_ternary_base3(W_ternary) packed_size = packed.numel() * 1 # uint8 = 1 byte compression = original_size / packed_size # Should achieve ~20x compression (32 bits → 1.6 bits) assert compression > 15 # Allow some overhead def test_packing_with_padding(self): """Test packing when dimensions are not multiples of 5.""" # Test with various sizes to ensure padding is handled correctly for size in [(13, 17), (100, 203), (7, 11)]: W_ternary = torch.randint(-1, 2, size).float() packed, shape = pack_ternary_base3(W_ternary) W_unpacked = unpack_ternary_base3(packed, shape) assert torch.allclose(W_ternary, W_unpacked) class TestCompressionUtilities: """Tests for compression ratio and memory estimation utilities.""" def test_compression_ratio_calculation(self): """Test compression ratio calculation.""" ratio = compute_compression_ratio(1024, 51) assert abs(ratio - 20.0) < 0.5 def test_memory_savings_estimation(self): """Test memory savings estimation for layer.""" stats = estimate_memory_savings(768, 3072, num_layers=12) assert 'float32_bytes' in stats assert 'packed_bytes' in stats assert 'savings_bytes' in stats assert 'compression_ratio' in stats assert stats['compression_ratio'] > 15 class TestQuantizationIntegration: """Integration tests for quantization pipeline.""" def test_full_quantization_pipeline(self): """Test complete pipeline: dense → ternary → packed → unpacked.""" # 1. Start with dense weights W = torch.randn(128, 256) # 2. Quantize to ternary W_ternary, gamma = weight_to_ternary(W, per_channel=True) # 3. Pack to base-3 packed, shape = pack_ternary_base3(W_ternary) # 4. Unpack W_unpacked = unpack_ternary_base3(packed, shape) # 5. Verify correctness assert torch.allclose(W_ternary, W_unpacked) assert set(W_unpacked.unique().tolist()).issubset({-1.0, 0.0, 1.0}) def test_quantization_preserves_functionality(self): """Test that quantized layer produces reasonable outputs.""" from bitlinear import BitLinear import torch.nn as nn # Create dense layer dense = nn.Linear(256, 128) # Test input x = torch.randn(16, 256) out_dense = dense(x) # Quantize to BitLinear bitlinear = BitLinear.from_linear(dense) out_quantized = bitlinear(x) # Outputs should have same shape assert out_dense.shape == out_quantized.shape # Outputs should be correlated (similar but not identical) # Calculate correlation correlation = torch.corrcoef(torch.stack([out_dense.flatten(), out_quantized.flatten()]))[0, 1] assert correlation > 0.5 # Should have reasonable correlation