File size: 6,034 Bytes
fd8c8b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""

Unit tests for layers.py and packing.py implementations.



These tests are here to validate the complete functionality of BitLinear layers and packing utilities. Here are the following test cases:



test_bitlinear (1 test)

    - Tests BitLinear layer initialization, forward pass, and ternary weight constraints

    

test_multi_ternary_linear (1 test)

    - Tests MultiTernaryLinear layer with k-component decomposition

    

test_from_linear (1 test)

    - Tests conversion from nn.Linear to BitLinear using from_linear() method

    

test_convert_module (1 test)

    - Tests recursive model conversion using convert_linear_to_bitlinear()

    

test_packing (1 test)

    - Tests base-3 packing/unpacking round-trip correctness

    

test_memory_estimation (1 test)

    - Tests memory savings estimation for various layer configurations

"""
import torch
from bitlinear.layers import BitLinear, MultiTernaryLinear, convert_linear_to_bitlinear
from bitlinear.packing import pack_ternary_base3, unpack_ternary_base3, estimate_memory_savings

def test_bitlinear():
    """Test BitLinear layer."""
    print("Testing BitLinear layer...")
    
    # Create layer
    layer = BitLinear(128, 64, bias=True)
    
    # Test forward pass
    x = torch.randn(32, 128)
    y = layer(x)
    
    print(f"  Input shape: {x.shape}")
    print(f"  Output shape: {y.shape}")
    print(f"  W_ternary unique values: {layer.W_ternary.unique().tolist()}")
    print(f"  Gamma shape: {layer.gamma.shape}")
    print("  ✓ BitLinear works!\n")

def test_multi_ternary_linear():
    """Test MultiTernaryLinear layer."""
    print("Testing MultiTernaryLinear layer...")
    
    # Create layer with k=3 components
    layer = MultiTernaryLinear(128, 64, k=3, bias=True)
    
    # Test forward pass
    x = torch.randn(32, 128)
    y = layer(x)
    
    print(f"  Input shape: {x.shape}")
    print(f"  Output shape: {y.shape}")
    print(f"  W_ternary shape: {layer.W_ternary.shape}")
    print(f"  Gammas shape: {layer.gammas.shape}")
    print(f"  Number of components: {layer.k}")
    print("  ✓ MultiTernaryLinear works!\n")

def test_from_linear():
    """Test conversion from nn.Linear."""
    print("Testing from_linear conversion...")
    
    # Create standard linear layer
    linear = torch.nn.Linear(128, 64)
    
    # Convert to BitLinear
    bitlinear = BitLinear.from_linear(linear)
    
    # Test that it works
    x = torch.randn(16, 128)
    y = bitlinear(x)
    
    print(f"  Original Linear: {linear.in_features} -> {linear.out_features}")
    print(f"  Converted BitLinear: {bitlinear.in_features} -> {bitlinear.out_features}")
    print(f"  Output shape: {y.shape}")
    print("  ✓ from_linear conversion works!\n")

def test_convert_module():
    """Test convert_linear_to_bitlinear utility."""
    print("Testing convert_linear_to_bitlinear...")
    
    # Create a simple model with Linear layers
    class SimpleModel(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = torch.nn.Linear(64, 128)
            self.fc2 = torch.nn.Linear(128, 64)
            self.fc3 = torch.nn.Linear(64, 10)
        
        def forward(self, x):
            x = torch.relu(self.fc1(x))
            x = torch.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
    model = SimpleModel()
    
    # Count Linear layers before
    linear_count = sum(1 for m in model.modules() if isinstance(m, torch.nn.Linear))
    print(f"  Linear layers before: {linear_count}")
    
    # Convert
    model = convert_linear_to_bitlinear(model)
    
    # Count BitLinear layers after
    bitlinear_count = sum(1 for m in model.modules() if isinstance(m, BitLinear))
    print(f"  BitLinear layers after: {bitlinear_count}")
    
    # Test forward pass
    x = torch.randn(8, 64)
    y = model(x)
    print(f"  Output shape: {y.shape}")
    print("  ✓ convert_linear_to_bitlinear works!\n")

def test_packing():
    """Test base-3 packing."""
    print("Testing base-3 packing...")
    
    # Create ternary weights
    W_ternary = torch.tensor([
        [-1, 0, 1, -1, 0],
        [1, 1, -1, 0, 1],
    ], dtype=torch.float32)
    
    print(f"  Original shape: {W_ternary.shape}")
    print(f"  Original values: {W_ternary.flatten().tolist()}")
    
    # Pack
    packed, original_shape = pack_ternary_base3(W_ternary)
    print(f"  Packed shape: {packed.shape}")
    print(f"  Packed dtype: {packed.dtype}")
    print(f"  Compression: {W_ternary.numel() * 4} bytes -> {packed.numel()} bytes")
    
    # Unpack
    W_unpacked = unpack_ternary_base3(packed, original_shape)
    print(f"  Unpacked shape: {W_unpacked.shape}")
    print(f"  Unpacked values: {W_unpacked.flatten().tolist()}")
    
    # Verify correctness
    assert torch.allclose(W_ternary, W_unpacked), "Packing/unpacking mismatch!"
    print("  ✓ Base-3 packing works!\n")

def test_memory_estimation():
    """Test memory estimation."""
    print("Testing memory estimation...")
    
    # Estimate for a typical transformer layer
    stats = estimate_memory_savings(768, 3072, num_layers=12)
    
    print(f"  Configuration: 768 -> 3072, 12 layers")
    print(f"  Float32 memory: {stats['float32_bytes'] / 1e6:.2f} MB")
    print(f"  Packed memory: {stats['packed_bytes'] / 1e6:.2f} MB")
    print(f"  Savings: {stats['savings_bytes'] / 1e6:.2f} MB")
    print(f"  Compression ratio: {stats['compression_ratio']:.2f}x")
    print("  ✓ Memory estimation works!\n")

if __name__ == "__main__":
    print("=" * 60)
    print("Testing layers.py and packing.py implementations")
    print("=" * 60 + "\n")
    
    test_bitlinear()
    test_multi_ternary_linear()
    test_from_linear()
    test_convert_module()
    test_packing()
    test_memory_estimation()
    
    print("=" * 60)
    print("All tests passed! ✓")
    print("=" * 60)