BitLinear / tests /test_implementations.py
krisaujla's picture
Upload folder using huggingface_hub
fd8c8b9 verified
"""
Unit tests for layers.py and packing.py implementations.
These tests are here to validate the complete functionality of BitLinear layers and packing utilities. Here are the following test cases:
test_bitlinear (1 test)
- Tests BitLinear layer initialization, forward pass, and ternary weight constraints
test_multi_ternary_linear (1 test)
- Tests MultiTernaryLinear layer with k-component decomposition
test_from_linear (1 test)
- Tests conversion from nn.Linear to BitLinear using from_linear() method
test_convert_module (1 test)
- Tests recursive model conversion using convert_linear_to_bitlinear()
test_packing (1 test)
- Tests base-3 packing/unpacking round-trip correctness
test_memory_estimation (1 test)
- Tests memory savings estimation for various layer configurations
"""
import torch
from bitlinear.layers import BitLinear, MultiTernaryLinear, convert_linear_to_bitlinear
from bitlinear.packing import pack_ternary_base3, unpack_ternary_base3, estimate_memory_savings
def test_bitlinear():
"""Test BitLinear layer."""
print("Testing BitLinear layer...")
# Create layer
layer = BitLinear(128, 64, bias=True)
# Test forward pass
x = torch.randn(32, 128)
y = layer(x)
print(f" Input shape: {x.shape}")
print(f" Output shape: {y.shape}")
print(f" W_ternary unique values: {layer.W_ternary.unique().tolist()}")
print(f" Gamma shape: {layer.gamma.shape}")
print(" ✓ BitLinear works!\n")
def test_multi_ternary_linear():
"""Test MultiTernaryLinear layer."""
print("Testing MultiTernaryLinear layer...")
# Create layer with k=3 components
layer = MultiTernaryLinear(128, 64, k=3, bias=True)
# Test forward pass
x = torch.randn(32, 128)
y = layer(x)
print(f" Input shape: {x.shape}")
print(f" Output shape: {y.shape}")
print(f" W_ternary shape: {layer.W_ternary.shape}")
print(f" Gammas shape: {layer.gammas.shape}")
print(f" Number of components: {layer.k}")
print(" ✓ MultiTernaryLinear works!\n")
def test_from_linear():
"""Test conversion from nn.Linear."""
print("Testing from_linear conversion...")
# Create standard linear layer
linear = torch.nn.Linear(128, 64)
# Convert to BitLinear
bitlinear = BitLinear.from_linear(linear)
# Test that it works
x = torch.randn(16, 128)
y = bitlinear(x)
print(f" Original Linear: {linear.in_features} -> {linear.out_features}")
print(f" Converted BitLinear: {bitlinear.in_features} -> {bitlinear.out_features}")
print(f" Output shape: {y.shape}")
print(" ✓ from_linear conversion works!\n")
def test_convert_module():
"""Test convert_linear_to_bitlinear utility."""
print("Testing convert_linear_to_bitlinear...")
# Create a simple model with Linear layers
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(64, 128)
self.fc2 = torch.nn.Linear(128, 64)
self.fc3 = torch.nn.Linear(64, 10)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
model = SimpleModel()
# Count Linear layers before
linear_count = sum(1 for m in model.modules() if isinstance(m, torch.nn.Linear))
print(f" Linear layers before: {linear_count}")
# Convert
model = convert_linear_to_bitlinear(model)
# Count BitLinear layers after
bitlinear_count = sum(1 for m in model.modules() if isinstance(m, BitLinear))
print(f" BitLinear layers after: {bitlinear_count}")
# Test forward pass
x = torch.randn(8, 64)
y = model(x)
print(f" Output shape: {y.shape}")
print(" ✓ convert_linear_to_bitlinear works!\n")
def test_packing():
"""Test base-3 packing."""
print("Testing base-3 packing...")
# Create ternary weights
W_ternary = torch.tensor([
[-1, 0, 1, -1, 0],
[1, 1, -1, 0, 1],
], dtype=torch.float32)
print(f" Original shape: {W_ternary.shape}")
print(f" Original values: {W_ternary.flatten().tolist()}")
# Pack
packed, original_shape = pack_ternary_base3(W_ternary)
print(f" Packed shape: {packed.shape}")
print(f" Packed dtype: {packed.dtype}")
print(f" Compression: {W_ternary.numel() * 4} bytes -> {packed.numel()} bytes")
# Unpack
W_unpacked = unpack_ternary_base3(packed, original_shape)
print(f" Unpacked shape: {W_unpacked.shape}")
print(f" Unpacked values: {W_unpacked.flatten().tolist()}")
# Verify correctness
assert torch.allclose(W_ternary, W_unpacked), "Packing/unpacking mismatch!"
print(" ✓ Base-3 packing works!\n")
def test_memory_estimation():
"""Test memory estimation."""
print("Testing memory estimation...")
# Estimate for a typical transformer layer
stats = estimate_memory_savings(768, 3072, num_layers=12)
print(f" Configuration: 768 -> 3072, 12 layers")
print(f" Float32 memory: {stats['float32_bytes'] / 1e6:.2f} MB")
print(f" Packed memory: {stats['packed_bytes'] / 1e6:.2f} MB")
print(f" Savings: {stats['savings_bytes'] / 1e6:.2f} MB")
print(f" Compression ratio: {stats['compression_ratio']:.2f}x")
print(" ✓ Memory estimation works!\n")
if __name__ == "__main__":
print("=" * 60)
print("Testing layers.py and packing.py implementations")
print("=" * 60 + "\n")
test_bitlinear()
test_multi_ternary_linear()
test_from_linear()
test_convert_module()
test_packing()
test_memory_estimation()
print("=" * 60)
print("All tests passed! ✓")
print("=" * 60)