File size: 5,000 Bytes
fd8c8b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""

Simple usage example for BitLinear.



This demonstrates the basic API and shows how to use BitLinear as a drop-in

replacement for nn.Linear with significant memory savings.

"""

import torch
import torch.nn as nn
from bitlinear import BitLinear, estimate_memory_savings


def basic_usage():
    """Basic usage example."""
    
    print("BitLinear Basic Usage Example")
    print("=" * 80)
    
    # Create a BitLinear layer (same interface as nn.Linear)
    print("\n1. Creating BitLinear Layer")
    print("-" * 80)
    layer = BitLinear(in_features=512, out_features=1024, bias=True)
    print(f"Created: {layer}")
    print(f"Weight values (ternary): {torch.unique(layer.W_ternary)}")
    print(f"Gamma scaling factors shape: {layer.gamma.shape}")
    
    # Create input
    batch_size = 32
    seq_len = 128
    x = torch.randn(batch_size, seq_len, 512)
    
    # Forward pass (same as nn.Linear)
    print("\n2. Forward Pass")
    print("-" * 80)
    output = layer(x)
    print(f"Input shape:  {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Output dtype: {output.dtype}")
    
    # Memory savings
    print("\n3. Memory Savings")
    print("-" * 80)
    stats = estimate_memory_savings(512, 1024, num_layers=1)
    print(f"Float32 weights: {stats['float32_bytes'] / 1024:.2f} KB")
    print(f"Packed weights:  {stats['packed_bytes'] / 1024:.2f} KB")
    print(f"Memory saved:    {stats['savings_bytes'] / 1024:.2f} KB")
    print(f"Compression:     {stats['compression_ratio']:.1f}x")


def conversion_example():
    """Example of converting existing nn.Linear to BitLinear."""
    
    print("\n\nConversion Example")
    print("=" * 80)
    
    # Start with a pre-trained Linear layer
    print("\n1. Original nn.Linear Layer")
    print("-" * 80)
    linear = nn.Linear(512, 1024)
    print(f"Created: {linear}")
    
    # Simulate some training by setting random weights
    with torch.no_grad():
        linear.weight.normal_(0, 0.02)
    
    # Convert to BitLinear
    print("\n2. Convert to BitLinear")
    print("-" * 80)
    bitlinear = BitLinear.from_linear(linear)
    print(f"Converted: {bitlinear}")
    print(f"Weight values: {torch.unique(bitlinear.W_ternary)}")
    
    # Use as drop-in replacement
    print("\n3. Forward Pass Comparison")
    print("-" * 80)
    x = torch.randn(16, 512)
    
    with torch.no_grad():
        output_linear = linear(x)
        output_bitlinear = bitlinear(x)
    
    # Compare outputs
    mse = torch.mean((output_linear - output_bitlinear) ** 2).item()
    cosine_sim = torch.nn.functional.cosine_similarity(
        output_linear.flatten(), 
        output_bitlinear.flatten(), 
        dim=0
    ).item()
    relative_error = (torch.norm(output_linear - output_bitlinear) / torch.norm(output_linear)).item()
    
    print(f"Original output shape:   {output_linear.shape}")
    print(f"BitLinear output shape:  {output_bitlinear.shape}")
    print(f"MSE:                     {mse:.6f}")
    print(f"Cosine similarity:       {cosine_sim:.6f}")
    print(f"Relative error:          {relative_error:.6f}")


def multi_ternary_example():
    """Example using MultiTernaryLinear for better approximation."""
    
    print("\n\nMulti-Ternary Example")
    print("=" * 80)
    
    from bitlinear import MultiTernaryLinear
    
    # Create multi-ternary layer with k=3 components
    print("\n1. Creating MultiTernaryLinear Layer")
    print("-" * 80)
    layer = MultiTernaryLinear(in_features=512, out_features=1024, k=3, bias=True)
    print(f"Created: {layer}")
    print(f"Number of components: {layer.k}")
    print(f"W_ternary shape: {layer.W_ternary.shape}")
    print(f"Gammas shape: {layer.gammas.shape}")
    
    # Forward pass
    print("\n2. Forward Pass")
    print("-" * 80)
    x = torch.randn(16, 512)
    output = layer(x)
    print(f"Input shape:  {x.shape}")
    print(f"Output shape: {output.shape}")
    
    # Compare with standard BitLinear
    print("\n3. Comparison with Standard BitLinear")
    print("-" * 80)
    linear = nn.Linear(512, 1024)
    bitlinear_k1 = BitLinear.from_linear(linear)
    bitlinear_k3 = MultiTernaryLinear.from_linear(linear, k=3)
    
    with torch.no_grad():
        out_orig = linear(x)
        out_k1 = bitlinear_k1(x)
        out_k3 = bitlinear_k3(x)
    
    error_k1 = (torch.norm(out_orig - out_k1) / torch.norm(out_orig)).item()
    error_k3 = (torch.norm(out_orig - out_k3) / torch.norm(out_orig)).item()
    
    print(f"Relative error (k=1): {error_k1:.6f}")
    print(f"Relative error (k=3): {error_k3:.6f}")
    print(f"Improvement:          {(error_k1 - error_k3) / error_k1 * 100:.1f}%")


if __name__ == "__main__":
    basic_usage()
    conversion_example()
    multi_ternary_example()
    
    print("\n" + "=" * 80)
    print("All examples completed successfully!")
    print("=" * 80)