File size: 11,028 Bytes
fd8c8b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
"""

Unit tests for quantization utilities.



These tests are here to validate ternary quantization, scaling, and packing functions. Here are the following test cases:



TestAbsmaxScale (3 tests)

    1. test_global_scale - Tests global absmax scaling computation

    2. test_per_channel_scale - Tests per-channel (per-row) absmax scaling

    3. test_zero_tensor - Validates behavior with zero tensors (numerical stability)



TestTernaryQuantize (3 tests)

    1. test_quantization_values - Ensures output contains only {-1, 0, +1}

    2. test_sign_preservation - Validates sign preservation for large values

    3. test_threshold_behavior - Tests threshold-based zero assignment



TestWeightToTernary (3 tests)

    1. test_output_shapes - Verifies correct output tensor shapes

    2. test_per_channel_vs_global - Tests per-channel vs. global scaling modes

    3. test_reconstruction_quality - Validates reconstruction error is reasonable



TestActivationQuantization (2 tests)

    1. test_quantization_range - Tests 8-bit quantization range

    2. test_per_token_scaling - Validates per-token vs. global scaling



TestDequantization (1 test)

    1. test_dequantize_inverse - Tests quantize β†’ dequantize inverse operation



TestBase3Packing (3 tests)

    1. test_pack_unpack_roundtrip - Validates pack β†’ unpack recovers original

    2. test_memory_efficiency - Tests ~20x compression achievement

    3. test_packing_with_padding - Tests padding for non-multiple-of-5 dimensions



TestCompressionUtilities (2 tests)

    1. test_compression_ratio_calculation - Tests compression ratio computation

    2. test_memory_savings_estimation - Validates memory savings estimation



TestQuantizationIntegration (2 tests)

    1. test_full_quantization_pipeline - Tests dense β†’ ternary β†’ packed β†’ unpacked

    2. test_quantization_preserves_functionality - Validates quantized layer outputs

"""

import pytest
import torch

from bitlinear.quantization import (
    absmax_scale,
    ternary_quantize,
    weight_to_ternary,
    quantize_activations_absmax,
    dequantize_scale,
)
from bitlinear.packing import (
    pack_ternary_base3,
    unpack_ternary_base3,
    compute_compression_ratio,
    estimate_memory_savings,
)


class TestAbsmaxScale:
    """Tests for absmax_scale function."""
    
    def test_global_scale(self):
        """Test global absmax scaling."""
        W = torch.tensor([[1.0, -2.0, 3.0], [4.0, -5.0, 6.0]])
        scale = absmax_scale(W, dim=None)
        assert torch.isclose(scale, torch.tensor(6.0))
    
    def test_per_channel_scale(self):
        """Test per-channel (per-row) absmax scaling."""
        W = torch.tensor([[1.0, -2.0, 3.0], [4.0, -5.0, 6.0]])
        scale = absmax_scale(W, dim=1)
        expected = torch.tensor([3.0, 6.0])
        assert torch.allclose(scale, expected)
    
    def test_zero_tensor(self):
        """Test behavior with zero tensor."""
        W = torch.zeros(10, 10)
        scale = absmax_scale(W, dim=None)
        # Should handle division by zero gracefully (clamped to epsilon)
        assert scale > 0
        assert scale < 1e-4


class TestTernaryQuantize:
    """Tests for ternary_quantize function."""
    
    def test_quantization_values(self):
        """Test that output contains only {-1, 0, +1}."""
        W = torch.randn(100, 100)
        W_ternary = ternary_quantize(W)
        unique_values = torch.unique(W_ternary)
        assert set(unique_values.tolist()).issubset({-1.0, 0.0, 1.0})
    
    def test_sign_preservation(self):
        """Test that signs are preserved correctly."""
        # Use values well above threshold (> 0.5 * max)
        W = torch.tensor([[10.0, -10.0, 0.01], [-8.0, 8.0, -0.01]])
        W_ternary = ternary_quantize(W)
        # Large positive values should be +1
        assert W_ternary[0, 0] == 1.0
        # Large negative values should be -1
        assert W_ternary[0, 1] == -1.0
        assert W_ternary[1, 0] == -1.0
        # Large positive
        assert W_ternary[1, 1] == 1.0
    
    def test_threshold_behavior(self):
        """Test that threshold determines zero assignment."""
        # Create tensor with known values
        W = torch.tensor([[10.0, 0.1, -10.0], [0.2, -0.2, 5.0]])
        W_ternary = ternary_quantize(W)
        # Small values near zero should become 0
        # Exact behavior depends on threshold, but there should be some zeros
        assert 0.0 in W_ternary


class TestWeightToTernary:
    """Tests for weight_to_ternary function."""
    
    def test_output_shapes(self):
        """Test that output shapes are correct."""
        W = torch.randn(512, 768)
        W_ternary, gamma = weight_to_ternary(W, per_channel=True)
        assert W_ternary.shape == (512, 768)
        assert gamma.shape == (512,)
    
    def test_per_channel_vs_global(self):
        """Test difference between per-channel and global scaling."""
        W = torch.randn(512, 768)
        W_t_pc, gamma_pc = weight_to_ternary(W, per_channel=True)
        W_t_g, gamma_g = weight_to_ternary(W, per_channel=False)
        
        assert gamma_pc.shape == (512,)
        assert gamma_g.shape == torch.Size([])  # Scalar
    
    def test_reconstruction_quality(self):
        """Test that reconstruction W_ternary * gamma approximates W."""
        W = torch.randn(512, 768)
        W_ternary, gamma = weight_to_ternary(W, per_channel=True)
        W_reconstructed = W_ternary * gamma.unsqueeze(1)
        error = torch.norm(W - W_reconstructed) / torch.norm(W)
        # Ternary quantization has inherent error, allow up to 0.9 relative error
        # This is expected for aggressive quantization to only 3 values
        assert error < 1.0


class TestActivationQuantization:
    """Tests for activation quantization."""
    
    def test_quantization_range(self):
        """Test that quantized values are in expected range."""
        x = torch.randn(16, 32, 512)
        x_quant = quantize_activations_absmax(x, bits=8, per_token=True)
        # Should be roughly in similar range as input
        assert x_quant.abs().max() <= x.abs().max() * 1.1
    
    def test_per_token_scaling(self):
        """Test per-token vs. global scaling."""
        x = torch.randn(16, 32, 512)
        x_quant_per_token = quantize_activations_absmax(x, bits=8, per_token=True)
        x_quant_global = quantize_activations_absmax(x, bits=8, per_token=False)
        # Both should work without errors
        assert x_quant_per_token.shape == x.shape
        assert x_quant_global.shape == x.shape


class TestDequantization:
    """Tests for dequantization."""
    
    def test_dequantize_inverse(self):
        """Test that quantize β†’ dequantize is approximately identity."""
        W = torch.randn(512, 768)
        W_quant, scale = weight_to_ternary(W, per_channel=True)
        W_dequant = dequantize_scale(W_quant, scale)
        # Should be close to W_quant * scale reconstruction
        W_expected = W_quant * scale.unsqueeze(1)
        assert torch.allclose(W_dequant, W_expected)


class TestBase3Packing:
    """Tests for base-3 packing utilities."""
    
    def test_pack_unpack_roundtrip(self):
        """Test that pack β†’ unpack recovers original ternary weights."""
        W_ternary = torch.randint(-1, 2, (512, 768)).float()
        packed, shape = pack_ternary_base3(W_ternary)
        W_unpacked = unpack_ternary_base3(packed, shape)
        assert torch.allclose(W_ternary, W_unpacked)
    
    def test_memory_efficiency(self):
        """Test that packing achieves expected compression."""
        W_ternary = torch.randint(-1, 2, (512, 768)).float()
        original_size = W_ternary.numel() * 4  # float32 = 4 bytes
        
        packed, shape = pack_ternary_base3(W_ternary)
        packed_size = packed.numel() * 1  # uint8 = 1 byte
        
        compression = original_size / packed_size
        # Should achieve ~20x compression (32 bits β†’ 1.6 bits)
        assert compression > 15  # Allow some overhead
    
    def test_packing_with_padding(self):
        """Test packing when dimensions are not multiples of 5."""
        # Test with various sizes to ensure padding is handled correctly
        for size in [(13, 17), (100, 203), (7, 11)]:
            W_ternary = torch.randint(-1, 2, size).float()
            packed, shape = pack_ternary_base3(W_ternary)
            W_unpacked = unpack_ternary_base3(packed, shape)
            assert torch.allclose(W_ternary, W_unpacked)


class TestCompressionUtilities:
    """Tests for compression ratio and memory estimation utilities."""
    
    def test_compression_ratio_calculation(self):
        """Test compression ratio calculation."""
        ratio = compute_compression_ratio(1024, 51)
        assert abs(ratio - 20.0) < 0.5
    
    def test_memory_savings_estimation(self):
        """Test memory savings estimation for layer."""
        stats = estimate_memory_savings(768, 3072, num_layers=12)
        assert 'float32_bytes' in stats
        assert 'packed_bytes' in stats
        assert 'savings_bytes' in stats
        assert 'compression_ratio' in stats
        assert stats['compression_ratio'] > 15


class TestQuantizationIntegration:
    """Integration tests for quantization pipeline."""
    
    def test_full_quantization_pipeline(self):
        """Test complete pipeline: dense β†’ ternary β†’ packed β†’ unpacked."""
        # 1. Start with dense weights
        W = torch.randn(128, 256)
        
        # 2. Quantize to ternary
        W_ternary, gamma = weight_to_ternary(W, per_channel=True)
        
        # 3. Pack to base-3
        packed, shape = pack_ternary_base3(W_ternary)
        
        # 4. Unpack
        W_unpacked = unpack_ternary_base3(packed, shape)
        
        # 5. Verify correctness
        assert torch.allclose(W_ternary, W_unpacked)
        assert set(W_unpacked.unique().tolist()).issubset({-1.0, 0.0, 1.0})
    
    def test_quantization_preserves_functionality(self):
        """Test that quantized layer produces reasonable outputs."""
        from bitlinear import BitLinear
        import torch.nn as nn
        
        # Create dense layer
        dense = nn.Linear(256, 128)
        
        # Test input
        x = torch.randn(16, 256)
        out_dense = dense(x)
        
        # Quantize to BitLinear
        bitlinear = BitLinear.from_linear(dense)
        out_quantized = bitlinear(x)
        
        # Outputs should have same shape
        assert out_dense.shape == out_quantized.shape
        
        # Outputs should be correlated (similar but not identical)
        # Calculate correlation
        correlation = torch.corrcoef(torch.stack([out_dense.flatten(), out_quantized.flatten()]))[0, 1]
        assert correlation > 0.5  # Should have reasonable correlation