File size: 5,000 Bytes
fd8c8b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
"""
Simple usage example for BitLinear.
This demonstrates the basic API and shows how to use BitLinear as a drop-in
replacement for nn.Linear with significant memory savings.
"""
import torch
import torch.nn as nn
from bitlinear import BitLinear, estimate_memory_savings
def basic_usage():
"""Basic usage example."""
print("BitLinear Basic Usage Example")
print("=" * 80)
# Create a BitLinear layer (same interface as nn.Linear)
print("\n1. Creating BitLinear Layer")
print("-" * 80)
layer = BitLinear(in_features=512, out_features=1024, bias=True)
print(f"Created: {layer}")
print(f"Weight values (ternary): {torch.unique(layer.W_ternary)}")
print(f"Gamma scaling factors shape: {layer.gamma.shape}")
# Create input
batch_size = 32
seq_len = 128
x = torch.randn(batch_size, seq_len, 512)
# Forward pass (same as nn.Linear)
print("\n2. Forward Pass")
print("-" * 80)
output = layer(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Output dtype: {output.dtype}")
# Memory savings
print("\n3. Memory Savings")
print("-" * 80)
stats = estimate_memory_savings(512, 1024, num_layers=1)
print(f"Float32 weights: {stats['float32_bytes'] / 1024:.2f} KB")
print(f"Packed weights: {stats['packed_bytes'] / 1024:.2f} KB")
print(f"Memory saved: {stats['savings_bytes'] / 1024:.2f} KB")
print(f"Compression: {stats['compression_ratio']:.1f}x")
def conversion_example():
"""Example of converting existing nn.Linear to BitLinear."""
print("\n\nConversion Example")
print("=" * 80)
# Start with a pre-trained Linear layer
print("\n1. Original nn.Linear Layer")
print("-" * 80)
linear = nn.Linear(512, 1024)
print(f"Created: {linear}")
# Simulate some training by setting random weights
with torch.no_grad():
linear.weight.normal_(0, 0.02)
# Convert to BitLinear
print("\n2. Convert to BitLinear")
print("-" * 80)
bitlinear = BitLinear.from_linear(linear)
print(f"Converted: {bitlinear}")
print(f"Weight values: {torch.unique(bitlinear.W_ternary)}")
# Use as drop-in replacement
print("\n3. Forward Pass Comparison")
print("-" * 80)
x = torch.randn(16, 512)
with torch.no_grad():
output_linear = linear(x)
output_bitlinear = bitlinear(x)
# Compare outputs
mse = torch.mean((output_linear - output_bitlinear) ** 2).item()
cosine_sim = torch.nn.functional.cosine_similarity(
output_linear.flatten(),
output_bitlinear.flatten(),
dim=0
).item()
relative_error = (torch.norm(output_linear - output_bitlinear) / torch.norm(output_linear)).item()
print(f"Original output shape: {output_linear.shape}")
print(f"BitLinear output shape: {output_bitlinear.shape}")
print(f"MSE: {mse:.6f}")
print(f"Cosine similarity: {cosine_sim:.6f}")
print(f"Relative error: {relative_error:.6f}")
def multi_ternary_example():
"""Example using MultiTernaryLinear for better approximation."""
print("\n\nMulti-Ternary Example")
print("=" * 80)
from bitlinear import MultiTernaryLinear
# Create multi-ternary layer with k=3 components
print("\n1. Creating MultiTernaryLinear Layer")
print("-" * 80)
layer = MultiTernaryLinear(in_features=512, out_features=1024, k=3, bias=True)
print(f"Created: {layer}")
print(f"Number of components: {layer.k}")
print(f"W_ternary shape: {layer.W_ternary.shape}")
print(f"Gammas shape: {layer.gammas.shape}")
# Forward pass
print("\n2. Forward Pass")
print("-" * 80)
x = torch.randn(16, 512)
output = layer(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
# Compare with standard BitLinear
print("\n3. Comparison with Standard BitLinear")
print("-" * 80)
linear = nn.Linear(512, 1024)
bitlinear_k1 = BitLinear.from_linear(linear)
bitlinear_k3 = MultiTernaryLinear.from_linear(linear, k=3)
with torch.no_grad():
out_orig = linear(x)
out_k1 = bitlinear_k1(x)
out_k3 = bitlinear_k3(x)
error_k1 = (torch.norm(out_orig - out_k1) / torch.norm(out_orig)).item()
error_k3 = (torch.norm(out_orig - out_k3) / torch.norm(out_orig)).item()
print(f"Relative error (k=1): {error_k1:.6f}")
print(f"Relative error (k=3): {error_k3:.6f}")
print(f"Improvement: {(error_k1 - error_k3) / error_k1 * 100:.1f}%")
if __name__ == "__main__":
basic_usage()
conversion_example()
multi_ternary_example()
print("\n" + "=" * 80)
print("All examples completed successfully!")
print("=" * 80)
|