|
|
"""
|
|
|
Unit tests for layers.py and packing.py implementations.
|
|
|
|
|
|
These tests are here to validate the complete functionality of BitLinear layers and packing utilities. Here are the following test cases:
|
|
|
|
|
|
test_bitlinear (1 test)
|
|
|
- Tests BitLinear layer initialization, forward pass, and ternary weight constraints
|
|
|
|
|
|
test_multi_ternary_linear (1 test)
|
|
|
- Tests MultiTernaryLinear layer with k-component decomposition
|
|
|
|
|
|
test_from_linear (1 test)
|
|
|
- Tests conversion from nn.Linear to BitLinear using from_linear() method
|
|
|
|
|
|
test_convert_module (1 test)
|
|
|
- Tests recursive model conversion using convert_linear_to_bitlinear()
|
|
|
|
|
|
test_packing (1 test)
|
|
|
- Tests base-3 packing/unpacking round-trip correctness
|
|
|
|
|
|
test_memory_estimation (1 test)
|
|
|
- Tests memory savings estimation for various layer configurations
|
|
|
"""
|
|
|
import torch
|
|
|
from bitlinear.layers import BitLinear, MultiTernaryLinear, convert_linear_to_bitlinear
|
|
|
from bitlinear.packing import pack_ternary_base3, unpack_ternary_base3, estimate_memory_savings
|
|
|
|
|
|
def test_bitlinear():
|
|
|
"""Test BitLinear layer."""
|
|
|
print("Testing BitLinear layer...")
|
|
|
|
|
|
|
|
|
layer = BitLinear(128, 64, bias=True)
|
|
|
|
|
|
|
|
|
x = torch.randn(32, 128)
|
|
|
y = layer(x)
|
|
|
|
|
|
print(f" Input shape: {x.shape}")
|
|
|
print(f" Output shape: {y.shape}")
|
|
|
print(f" W_ternary unique values: {layer.W_ternary.unique().tolist()}")
|
|
|
print(f" Gamma shape: {layer.gamma.shape}")
|
|
|
print(" ✓ BitLinear works!\n")
|
|
|
|
|
|
def test_multi_ternary_linear():
|
|
|
"""Test MultiTernaryLinear layer."""
|
|
|
print("Testing MultiTernaryLinear layer...")
|
|
|
|
|
|
|
|
|
layer = MultiTernaryLinear(128, 64, k=3, bias=True)
|
|
|
|
|
|
|
|
|
x = torch.randn(32, 128)
|
|
|
y = layer(x)
|
|
|
|
|
|
print(f" Input shape: {x.shape}")
|
|
|
print(f" Output shape: {y.shape}")
|
|
|
print(f" W_ternary shape: {layer.W_ternary.shape}")
|
|
|
print(f" Gammas shape: {layer.gammas.shape}")
|
|
|
print(f" Number of components: {layer.k}")
|
|
|
print(" ✓ MultiTernaryLinear works!\n")
|
|
|
|
|
|
def test_from_linear():
|
|
|
"""Test conversion from nn.Linear."""
|
|
|
print("Testing from_linear conversion...")
|
|
|
|
|
|
|
|
|
linear = torch.nn.Linear(128, 64)
|
|
|
|
|
|
|
|
|
bitlinear = BitLinear.from_linear(linear)
|
|
|
|
|
|
|
|
|
x = torch.randn(16, 128)
|
|
|
y = bitlinear(x)
|
|
|
|
|
|
print(f" Original Linear: {linear.in_features} -> {linear.out_features}")
|
|
|
print(f" Converted BitLinear: {bitlinear.in_features} -> {bitlinear.out_features}")
|
|
|
print(f" Output shape: {y.shape}")
|
|
|
print(" ✓ from_linear conversion works!\n")
|
|
|
|
|
|
def test_convert_module():
|
|
|
"""Test convert_linear_to_bitlinear utility."""
|
|
|
print("Testing convert_linear_to_bitlinear...")
|
|
|
|
|
|
|
|
|
class SimpleModel(torch.nn.Module):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.fc1 = torch.nn.Linear(64, 128)
|
|
|
self.fc2 = torch.nn.Linear(128, 64)
|
|
|
self.fc3 = torch.nn.Linear(64, 10)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = torch.relu(self.fc1(x))
|
|
|
x = torch.relu(self.fc2(x))
|
|
|
x = self.fc3(x)
|
|
|
return x
|
|
|
|
|
|
model = SimpleModel()
|
|
|
|
|
|
|
|
|
linear_count = sum(1 for m in model.modules() if isinstance(m, torch.nn.Linear))
|
|
|
print(f" Linear layers before: {linear_count}")
|
|
|
|
|
|
|
|
|
model = convert_linear_to_bitlinear(model)
|
|
|
|
|
|
|
|
|
bitlinear_count = sum(1 for m in model.modules() if isinstance(m, BitLinear))
|
|
|
print(f" BitLinear layers after: {bitlinear_count}")
|
|
|
|
|
|
|
|
|
x = torch.randn(8, 64)
|
|
|
y = model(x)
|
|
|
print(f" Output shape: {y.shape}")
|
|
|
print(" ✓ convert_linear_to_bitlinear works!\n")
|
|
|
|
|
|
def test_packing():
|
|
|
"""Test base-3 packing."""
|
|
|
print("Testing base-3 packing...")
|
|
|
|
|
|
|
|
|
W_ternary = torch.tensor([
|
|
|
[-1, 0, 1, -1, 0],
|
|
|
[1, 1, -1, 0, 1],
|
|
|
], dtype=torch.float32)
|
|
|
|
|
|
print(f" Original shape: {W_ternary.shape}")
|
|
|
print(f" Original values: {W_ternary.flatten().tolist()}")
|
|
|
|
|
|
|
|
|
packed, original_shape = pack_ternary_base3(W_ternary)
|
|
|
print(f" Packed shape: {packed.shape}")
|
|
|
print(f" Packed dtype: {packed.dtype}")
|
|
|
print(f" Compression: {W_ternary.numel() * 4} bytes -> {packed.numel()} bytes")
|
|
|
|
|
|
|
|
|
W_unpacked = unpack_ternary_base3(packed, original_shape)
|
|
|
print(f" Unpacked shape: {W_unpacked.shape}")
|
|
|
print(f" Unpacked values: {W_unpacked.flatten().tolist()}")
|
|
|
|
|
|
|
|
|
assert torch.allclose(W_ternary, W_unpacked), "Packing/unpacking mismatch!"
|
|
|
print(" ✓ Base-3 packing works!\n")
|
|
|
|
|
|
def test_memory_estimation():
|
|
|
"""Test memory estimation."""
|
|
|
print("Testing memory estimation...")
|
|
|
|
|
|
|
|
|
stats = estimate_memory_savings(768, 3072, num_layers=12)
|
|
|
|
|
|
print(f" Configuration: 768 -> 3072, 12 layers")
|
|
|
print(f" Float32 memory: {stats['float32_bytes'] / 1e6:.2f} MB")
|
|
|
print(f" Packed memory: {stats['packed_bytes'] / 1e6:.2f} MB")
|
|
|
print(f" Savings: {stats['savings_bytes'] / 1e6:.2f} MB")
|
|
|
print(f" Compression ratio: {stats['compression_ratio']:.2f}x")
|
|
|
print(" ✓ Memory estimation works!\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
print("=" * 60)
|
|
|
print("Testing layers.py and packing.py implementations")
|
|
|
print("=" * 60 + "\n")
|
|
|
|
|
|
test_bitlinear()
|
|
|
test_multi_ternary_linear()
|
|
|
test_from_linear()
|
|
|
test_convert_module()
|
|
|
test_packing()
|
|
|
test_memory_estimation()
|
|
|
|
|
|
print("=" * 60)
|
|
|
print("All tests passed! ✓")
|
|
|
print("=" * 60)
|
|
|
|