""" Unit tests for layers.py and packing.py implementations. These tests are here to validate the complete functionality of BitLinear layers and packing utilities. Here are the following test cases: test_bitlinear (1 test) - Tests BitLinear layer initialization, forward pass, and ternary weight constraints test_multi_ternary_linear (1 test) - Tests MultiTernaryLinear layer with k-component decomposition test_from_linear (1 test) - Tests conversion from nn.Linear to BitLinear using from_linear() method test_convert_module (1 test) - Tests recursive model conversion using convert_linear_to_bitlinear() test_packing (1 test) - Tests base-3 packing/unpacking round-trip correctness test_memory_estimation (1 test) - Tests memory savings estimation for various layer configurations """ import torch from bitlinear.layers import BitLinear, MultiTernaryLinear, convert_linear_to_bitlinear from bitlinear.packing import pack_ternary_base3, unpack_ternary_base3, estimate_memory_savings def test_bitlinear(): """Test BitLinear layer.""" print("Testing BitLinear layer...") # Create layer layer = BitLinear(128, 64, bias=True) # Test forward pass x = torch.randn(32, 128) y = layer(x) print(f" Input shape: {x.shape}") print(f" Output shape: {y.shape}") print(f" W_ternary unique values: {layer.W_ternary.unique().tolist()}") print(f" Gamma shape: {layer.gamma.shape}") print(" ✓ BitLinear works!\n") def test_multi_ternary_linear(): """Test MultiTernaryLinear layer.""" print("Testing MultiTernaryLinear layer...") # Create layer with k=3 components layer = MultiTernaryLinear(128, 64, k=3, bias=True) # Test forward pass x = torch.randn(32, 128) y = layer(x) print(f" Input shape: {x.shape}") print(f" Output shape: {y.shape}") print(f" W_ternary shape: {layer.W_ternary.shape}") print(f" Gammas shape: {layer.gammas.shape}") print(f" Number of components: {layer.k}") print(" ✓ MultiTernaryLinear works!\n") def test_from_linear(): """Test conversion from nn.Linear.""" print("Testing from_linear conversion...") # Create standard linear layer linear = torch.nn.Linear(128, 64) # Convert to BitLinear bitlinear = BitLinear.from_linear(linear) # Test that it works x = torch.randn(16, 128) y = bitlinear(x) print(f" Original Linear: {linear.in_features} -> {linear.out_features}") print(f" Converted BitLinear: {bitlinear.in_features} -> {bitlinear.out_features}") print(f" Output shape: {y.shape}") print(" ✓ from_linear conversion works!\n") def test_convert_module(): """Test convert_linear_to_bitlinear utility.""" print("Testing convert_linear_to_bitlinear...") # Create a simple model with Linear layers class SimpleModel(torch.nn.Module): def __init__(self): super().__init__() self.fc1 = torch.nn.Linear(64, 128) self.fc2 = torch.nn.Linear(128, 64) self.fc3 = torch.nn.Linear(64, 10) def forward(self, x): x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return x model = SimpleModel() # Count Linear layers before linear_count = sum(1 for m in model.modules() if isinstance(m, torch.nn.Linear)) print(f" Linear layers before: {linear_count}") # Convert model = convert_linear_to_bitlinear(model) # Count BitLinear layers after bitlinear_count = sum(1 for m in model.modules() if isinstance(m, BitLinear)) print(f" BitLinear layers after: {bitlinear_count}") # Test forward pass x = torch.randn(8, 64) y = model(x) print(f" Output shape: {y.shape}") print(" ✓ convert_linear_to_bitlinear works!\n") def test_packing(): """Test base-3 packing.""" print("Testing base-3 packing...") # Create ternary weights W_ternary = torch.tensor([ [-1, 0, 1, -1, 0], [1, 1, -1, 0, 1], ], dtype=torch.float32) print(f" Original shape: {W_ternary.shape}") print(f" Original values: {W_ternary.flatten().tolist()}") # Pack packed, original_shape = pack_ternary_base3(W_ternary) print(f" Packed shape: {packed.shape}") print(f" Packed dtype: {packed.dtype}") print(f" Compression: {W_ternary.numel() * 4} bytes -> {packed.numel()} bytes") # Unpack W_unpacked = unpack_ternary_base3(packed, original_shape) print(f" Unpacked shape: {W_unpacked.shape}") print(f" Unpacked values: {W_unpacked.flatten().tolist()}") # Verify correctness assert torch.allclose(W_ternary, W_unpacked), "Packing/unpacking mismatch!" print(" ✓ Base-3 packing works!\n") def test_memory_estimation(): """Test memory estimation.""" print("Testing memory estimation...") # Estimate for a typical transformer layer stats = estimate_memory_savings(768, 3072, num_layers=12) print(f" Configuration: 768 -> 3072, 12 layers") print(f" Float32 memory: {stats['float32_bytes'] / 1e6:.2f} MB") print(f" Packed memory: {stats['packed_bytes'] / 1e6:.2f} MB") print(f" Savings: {stats['savings_bytes'] / 1e6:.2f} MB") print(f" Compression ratio: {stats['compression_ratio']:.2f}x") print(" ✓ Memory estimation works!\n") if __name__ == "__main__": print("=" * 60) print("Testing layers.py and packing.py implementations") print("=" * 60 + "\n") test_bitlinear() test_multi_ternary_linear() test_from_linear() test_convert_module() test_packing() test_memory_estimation() print("=" * 60) print("All tests passed! ✓") print("=" * 60)