|
|
"""
|
|
|
Simple usage example for BitLinear.
|
|
|
|
|
|
This demonstrates the basic API and shows how to use BitLinear as a drop-in
|
|
|
replacement for nn.Linear with significant memory savings.
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from bitlinear import BitLinear, estimate_memory_savings
|
|
|
|
|
|
|
|
|
def basic_usage():
|
|
|
"""Basic usage example."""
|
|
|
|
|
|
print("BitLinear Basic Usage Example")
|
|
|
print("=" * 80)
|
|
|
|
|
|
|
|
|
print("\n1. Creating BitLinear Layer")
|
|
|
print("-" * 80)
|
|
|
layer = BitLinear(in_features=512, out_features=1024, bias=True)
|
|
|
print(f"Created: {layer}")
|
|
|
print(f"Weight values (ternary): {torch.unique(layer.W_ternary)}")
|
|
|
print(f"Gamma scaling factors shape: {layer.gamma.shape}")
|
|
|
|
|
|
|
|
|
batch_size = 32
|
|
|
seq_len = 128
|
|
|
x = torch.randn(batch_size, seq_len, 512)
|
|
|
|
|
|
|
|
|
print("\n2. Forward Pass")
|
|
|
print("-" * 80)
|
|
|
output = layer(x)
|
|
|
print(f"Input shape: {x.shape}")
|
|
|
print(f"Output shape: {output.shape}")
|
|
|
print(f"Output dtype: {output.dtype}")
|
|
|
|
|
|
|
|
|
print("\n3. Memory Savings")
|
|
|
print("-" * 80)
|
|
|
stats = estimate_memory_savings(512, 1024, num_layers=1)
|
|
|
print(f"Float32 weights: {stats['float32_bytes'] / 1024:.2f} KB")
|
|
|
print(f"Packed weights: {stats['packed_bytes'] / 1024:.2f} KB")
|
|
|
print(f"Memory saved: {stats['savings_bytes'] / 1024:.2f} KB")
|
|
|
print(f"Compression: {stats['compression_ratio']:.1f}x")
|
|
|
|
|
|
|
|
|
def conversion_example():
|
|
|
"""Example of converting existing nn.Linear to BitLinear."""
|
|
|
|
|
|
print("\n\nConversion Example")
|
|
|
print("=" * 80)
|
|
|
|
|
|
|
|
|
print("\n1. Original nn.Linear Layer")
|
|
|
print("-" * 80)
|
|
|
linear = nn.Linear(512, 1024)
|
|
|
print(f"Created: {linear}")
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
linear.weight.normal_(0, 0.02)
|
|
|
|
|
|
|
|
|
print("\n2. Convert to BitLinear")
|
|
|
print("-" * 80)
|
|
|
bitlinear = BitLinear.from_linear(linear)
|
|
|
print(f"Converted: {bitlinear}")
|
|
|
print(f"Weight values: {torch.unique(bitlinear.W_ternary)}")
|
|
|
|
|
|
|
|
|
print("\n3. Forward Pass Comparison")
|
|
|
print("-" * 80)
|
|
|
x = torch.randn(16, 512)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
output_linear = linear(x)
|
|
|
output_bitlinear = bitlinear(x)
|
|
|
|
|
|
|
|
|
mse = torch.mean((output_linear - output_bitlinear) ** 2).item()
|
|
|
cosine_sim = torch.nn.functional.cosine_similarity(
|
|
|
output_linear.flatten(),
|
|
|
output_bitlinear.flatten(),
|
|
|
dim=0
|
|
|
).item()
|
|
|
relative_error = (torch.norm(output_linear - output_bitlinear) / torch.norm(output_linear)).item()
|
|
|
|
|
|
print(f"Original output shape: {output_linear.shape}")
|
|
|
print(f"BitLinear output shape: {output_bitlinear.shape}")
|
|
|
print(f"MSE: {mse:.6f}")
|
|
|
print(f"Cosine similarity: {cosine_sim:.6f}")
|
|
|
print(f"Relative error: {relative_error:.6f}")
|
|
|
|
|
|
|
|
|
def multi_ternary_example():
|
|
|
"""Example using MultiTernaryLinear for better approximation."""
|
|
|
|
|
|
print("\n\nMulti-Ternary Example")
|
|
|
print("=" * 80)
|
|
|
|
|
|
from bitlinear import MultiTernaryLinear
|
|
|
|
|
|
|
|
|
print("\n1. Creating MultiTernaryLinear Layer")
|
|
|
print("-" * 80)
|
|
|
layer = MultiTernaryLinear(in_features=512, out_features=1024, k=3, bias=True)
|
|
|
print(f"Created: {layer}")
|
|
|
print(f"Number of components: {layer.k}")
|
|
|
print(f"W_ternary shape: {layer.W_ternary.shape}")
|
|
|
print(f"Gammas shape: {layer.gammas.shape}")
|
|
|
|
|
|
|
|
|
print("\n2. Forward Pass")
|
|
|
print("-" * 80)
|
|
|
x = torch.randn(16, 512)
|
|
|
output = layer(x)
|
|
|
print(f"Input shape: {x.shape}")
|
|
|
print(f"Output shape: {output.shape}")
|
|
|
|
|
|
|
|
|
print("\n3. Comparison with Standard BitLinear")
|
|
|
print("-" * 80)
|
|
|
linear = nn.Linear(512, 1024)
|
|
|
bitlinear_k1 = BitLinear.from_linear(linear)
|
|
|
bitlinear_k3 = MultiTernaryLinear.from_linear(linear, k=3)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
out_orig = linear(x)
|
|
|
out_k1 = bitlinear_k1(x)
|
|
|
out_k3 = bitlinear_k3(x)
|
|
|
|
|
|
error_k1 = (torch.norm(out_orig - out_k1) / torch.norm(out_orig)).item()
|
|
|
error_k3 = (torch.norm(out_orig - out_k3) / torch.norm(out_orig)).item()
|
|
|
|
|
|
print(f"Relative error (k=1): {error_k1:.6f}")
|
|
|
print(f"Relative error (k=3): {error_k3:.6f}")
|
|
|
print(f"Improvement: {(error_k1 - error_k3) / error_k1 * 100:.1f}%")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
basic_usage()
|
|
|
conversion_example()
|
|
|
multi_ternary_example()
|
|
|
|
|
|
print("\n" + "=" * 80)
|
|
|
print("All examples completed successfully!")
|
|
|
print("=" * 80)
|
|
|
|