""" Simple usage example for BitLinear. This demonstrates the basic API and shows how to use BitLinear as a drop-in replacement for nn.Linear with significant memory savings. """ import torch import torch.nn as nn from bitlinear import BitLinear, estimate_memory_savings def basic_usage(): """Basic usage example.""" print("BitLinear Basic Usage Example") print("=" * 80) # Create a BitLinear layer (same interface as nn.Linear) print("\n1. Creating BitLinear Layer") print("-" * 80) layer = BitLinear(in_features=512, out_features=1024, bias=True) print(f"Created: {layer}") print(f"Weight values (ternary): {torch.unique(layer.W_ternary)}") print(f"Gamma scaling factors shape: {layer.gamma.shape}") # Create input batch_size = 32 seq_len = 128 x = torch.randn(batch_size, seq_len, 512) # Forward pass (same as nn.Linear) print("\n2. Forward Pass") print("-" * 80) output = layer(x) print(f"Input shape: {x.shape}") print(f"Output shape: {output.shape}") print(f"Output dtype: {output.dtype}") # Memory savings print("\n3. Memory Savings") print("-" * 80) stats = estimate_memory_savings(512, 1024, num_layers=1) print(f"Float32 weights: {stats['float32_bytes'] / 1024:.2f} KB") print(f"Packed weights: {stats['packed_bytes'] / 1024:.2f} KB") print(f"Memory saved: {stats['savings_bytes'] / 1024:.2f} KB") print(f"Compression: {stats['compression_ratio']:.1f}x") def conversion_example(): """Example of converting existing nn.Linear to BitLinear.""" print("\n\nConversion Example") print("=" * 80) # Start with a pre-trained Linear layer print("\n1. Original nn.Linear Layer") print("-" * 80) linear = nn.Linear(512, 1024) print(f"Created: {linear}") # Simulate some training by setting random weights with torch.no_grad(): linear.weight.normal_(0, 0.02) # Convert to BitLinear print("\n2. Convert to BitLinear") print("-" * 80) bitlinear = BitLinear.from_linear(linear) print(f"Converted: {bitlinear}") print(f"Weight values: {torch.unique(bitlinear.W_ternary)}") # Use as drop-in replacement print("\n3. Forward Pass Comparison") print("-" * 80) x = torch.randn(16, 512) with torch.no_grad(): output_linear = linear(x) output_bitlinear = bitlinear(x) # Compare outputs mse = torch.mean((output_linear - output_bitlinear) ** 2).item() cosine_sim = torch.nn.functional.cosine_similarity( output_linear.flatten(), output_bitlinear.flatten(), dim=0 ).item() relative_error = (torch.norm(output_linear - output_bitlinear) / torch.norm(output_linear)).item() print(f"Original output shape: {output_linear.shape}") print(f"BitLinear output shape: {output_bitlinear.shape}") print(f"MSE: {mse:.6f}") print(f"Cosine similarity: {cosine_sim:.6f}") print(f"Relative error: {relative_error:.6f}") def multi_ternary_example(): """Example using MultiTernaryLinear for better approximation.""" print("\n\nMulti-Ternary Example") print("=" * 80) from bitlinear import MultiTernaryLinear # Create multi-ternary layer with k=3 components print("\n1. Creating MultiTernaryLinear Layer") print("-" * 80) layer = MultiTernaryLinear(in_features=512, out_features=1024, k=3, bias=True) print(f"Created: {layer}") print(f"Number of components: {layer.k}") print(f"W_ternary shape: {layer.W_ternary.shape}") print(f"Gammas shape: {layer.gammas.shape}") # Forward pass print("\n2. Forward Pass") print("-" * 80) x = torch.randn(16, 512) output = layer(x) print(f"Input shape: {x.shape}") print(f"Output shape: {output.shape}") # Compare with standard BitLinear print("\n3. Comparison with Standard BitLinear") print("-" * 80) linear = nn.Linear(512, 1024) bitlinear_k1 = BitLinear.from_linear(linear) bitlinear_k3 = MultiTernaryLinear.from_linear(linear, k=3) with torch.no_grad(): out_orig = linear(x) out_k1 = bitlinear_k1(x) out_k3 = bitlinear_k3(x) error_k1 = (torch.norm(out_orig - out_k1) / torch.norm(out_orig)).item() error_k3 = (torch.norm(out_orig - out_k3) / torch.norm(out_orig)).item() print(f"Relative error (k=1): {error_k1:.6f}") print(f"Relative error (k=3): {error_k3:.6f}") print(f"Improvement: {(error_k1 - error_k3) / error_k1 * 100:.1f}%") if __name__ == "__main__": basic_usage() conversion_example() multi_ternary_example() print("\n" + "=" * 80) print("All examples completed successfully!") print("=" * 80)