BitLinear / examples /basic_usage.py
krisaujla's picture
Upload folder using huggingface_hub
fd8c8b9 verified
"""
Simple usage example for BitLinear.
This demonstrates the basic API and shows how to use BitLinear as a drop-in
replacement for nn.Linear with significant memory savings.
"""
import torch
import torch.nn as nn
from bitlinear import BitLinear, estimate_memory_savings
def basic_usage():
"""Basic usage example."""
print("BitLinear Basic Usage Example")
print("=" * 80)
# Create a BitLinear layer (same interface as nn.Linear)
print("\n1. Creating BitLinear Layer")
print("-" * 80)
layer = BitLinear(in_features=512, out_features=1024, bias=True)
print(f"Created: {layer}")
print(f"Weight values (ternary): {torch.unique(layer.W_ternary)}")
print(f"Gamma scaling factors shape: {layer.gamma.shape}")
# Create input
batch_size = 32
seq_len = 128
x = torch.randn(batch_size, seq_len, 512)
# Forward pass (same as nn.Linear)
print("\n2. Forward Pass")
print("-" * 80)
output = layer(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Output dtype: {output.dtype}")
# Memory savings
print("\n3. Memory Savings")
print("-" * 80)
stats = estimate_memory_savings(512, 1024, num_layers=1)
print(f"Float32 weights: {stats['float32_bytes'] / 1024:.2f} KB")
print(f"Packed weights: {stats['packed_bytes'] / 1024:.2f} KB")
print(f"Memory saved: {stats['savings_bytes'] / 1024:.2f} KB")
print(f"Compression: {stats['compression_ratio']:.1f}x")
def conversion_example():
"""Example of converting existing nn.Linear to BitLinear."""
print("\n\nConversion Example")
print("=" * 80)
# Start with a pre-trained Linear layer
print("\n1. Original nn.Linear Layer")
print("-" * 80)
linear = nn.Linear(512, 1024)
print(f"Created: {linear}")
# Simulate some training by setting random weights
with torch.no_grad():
linear.weight.normal_(0, 0.02)
# Convert to BitLinear
print("\n2. Convert to BitLinear")
print("-" * 80)
bitlinear = BitLinear.from_linear(linear)
print(f"Converted: {bitlinear}")
print(f"Weight values: {torch.unique(bitlinear.W_ternary)}")
# Use as drop-in replacement
print("\n3. Forward Pass Comparison")
print("-" * 80)
x = torch.randn(16, 512)
with torch.no_grad():
output_linear = linear(x)
output_bitlinear = bitlinear(x)
# Compare outputs
mse = torch.mean((output_linear - output_bitlinear) ** 2).item()
cosine_sim = torch.nn.functional.cosine_similarity(
output_linear.flatten(),
output_bitlinear.flatten(),
dim=0
).item()
relative_error = (torch.norm(output_linear - output_bitlinear) / torch.norm(output_linear)).item()
print(f"Original output shape: {output_linear.shape}")
print(f"BitLinear output shape: {output_bitlinear.shape}")
print(f"MSE: {mse:.6f}")
print(f"Cosine similarity: {cosine_sim:.6f}")
print(f"Relative error: {relative_error:.6f}")
def multi_ternary_example():
"""Example using MultiTernaryLinear for better approximation."""
print("\n\nMulti-Ternary Example")
print("=" * 80)
from bitlinear import MultiTernaryLinear
# Create multi-ternary layer with k=3 components
print("\n1. Creating MultiTernaryLinear Layer")
print("-" * 80)
layer = MultiTernaryLinear(in_features=512, out_features=1024, k=3, bias=True)
print(f"Created: {layer}")
print(f"Number of components: {layer.k}")
print(f"W_ternary shape: {layer.W_ternary.shape}")
print(f"Gammas shape: {layer.gammas.shape}")
# Forward pass
print("\n2. Forward Pass")
print("-" * 80)
x = torch.randn(16, 512)
output = layer(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
# Compare with standard BitLinear
print("\n3. Comparison with Standard BitLinear")
print("-" * 80)
linear = nn.Linear(512, 1024)
bitlinear_k1 = BitLinear.from_linear(linear)
bitlinear_k3 = MultiTernaryLinear.from_linear(linear, k=3)
with torch.no_grad():
out_orig = linear(x)
out_k1 = bitlinear_k1(x)
out_k3 = bitlinear_k3(x)
error_k1 = (torch.norm(out_orig - out_k1) / torch.norm(out_orig)).item()
error_k3 = (torch.norm(out_orig - out_k3) / torch.norm(out_orig)).item()
print(f"Relative error (k=1): {error_k1:.6f}")
print(f"Relative error (k=3): {error_k3:.6f}")
print(f"Improvement: {(error_k1 - error_k3) / error_k1 * 100:.1f}%")
if __name__ == "__main__":
basic_usage()
conversion_example()
multi_ternary_example()
print("\n" + "=" * 80)
print("All examples completed successfully!")
print("=" * 80)