""" Example: Using BitLinear as a drop-in replacement for nn.Linear in a Transformer. This example demonstrates: 1. Creating a simple Transformer block with standard nn.Linear 2. Converting it to use BitLinear layers 3. Running forward passes to verify compatibility 4. Comparing memory usage and output similarity """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional from bitlinear import BitLinear, MultiTernaryLinear, convert_linear_to_bitlinear class TransformerBlock(nn.Module): """ Simplified Transformer block for demonstration. Contains: - Multi-head self-attention with linear projections - Feed-forward network with two linear layers - Layer normalization and residual connections """ def __init__( self, d_model: int = 512, nhead: int = 8, dim_feedforward: int = 2048, dropout: float = 0.1, ): super().__init__() # Multi-head attention components self.d_model = d_model self.nhead = nhead self.d_k = d_model // nhead # Linear projections for Q, K, V self.q_proj = nn.Linear(d_model, d_model) self.k_proj = nn.Linear(d_model, d_model) self.v_proj = nn.Linear(d_model, d_model) self.out_proj = nn.Linear(d_model, d_model) # Feed-forward network self.ffn = nn.Sequential( nn.Linear(d_model, dim_feedforward), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) # Layer normalization self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) # Dropout self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass through Transformer block. Args: x: Input tensor [batch_size, seq_len, d_model] mask: Optional attention mask Returns: Output tensor [batch_size, seq_len, d_model] """ # Multi-head self-attention residual = x x = self.norm1(x) # Compute Q, K, V q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x) # Reshape for multi-head attention batch_size, seq_len, _ = x.shape q = q.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2) k = k.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2) v = v.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2) # Scaled dot-product attention scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attn_weights = F.softmax(scores, dim=-1) attn_output = torch.matmul(attn_weights, v) # Reshape and project back attn_output = attn_output.transpose(1, 2).contiguous().view( batch_size, seq_len, self.d_model ) attn_output = self.out_proj(attn_output) attn_output = self.dropout1(attn_output) # First residual connection x = residual + attn_output # Feed-forward network residual = x x = self.norm2(x) x = self.ffn(x) x = self.dropout2(x) # Second residual connection x = residual + x return x def count_parameters(model: nn.Module) -> int: """Count total trainable parameters in a model.""" return sum(p.numel() for p in model.parameters() if p.requires_grad) def estimate_memory_mb(model: nn.Module) -> float: """Estimate memory usage of model parameters in MB.""" total_bytes = sum(p.numel() * p.element_size() for p in model.parameters()) return total_bytes / (1024 ** 2) def compare_outputs( output1: torch.Tensor, output2: torch.Tensor, ) -> dict: """ Compare two output tensors and compute similarity metrics. Returns: Dictionary with comparison metrics """ mse = F.mse_loss(output1, output2).item() cosine_sim = F.cosine_similarity( output1.flatten(), output2.flatten(), dim=0 ).item() relative_error = ( torch.norm(output1 - output2) / torch.norm(output1) ).item() return { "mse": mse, "cosine_similarity": cosine_sim, "relative_error": relative_error, } def main(): """Main example demonstrating BitLinear usage in Transformer.""" print("=" * 80) print("BitLinear Transformer Example") print("=" * 80) # Configuration batch_size = 32 seq_len = 128 d_model = 512 nhead = 8 dim_feedforward = 2048 # Create input x = torch.randn(batch_size, seq_len, d_model) print(f"\nInput shape: {x.shape}") # 1. Create standard Transformer block print("\n" + "-" * 80) print("1. Standard Transformer with nn.Linear") print("-" * 80) model_standard = TransformerBlock( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, ) print(f"Parameters: {count_parameters(model_standard):,}") print(f"Memory: {estimate_memory_mb(model_standard):.2f} MB") # Forward pass with torch.no_grad(): output_standard = model_standard(x) print(f"Output shape: {output_standard.shape}") # 2. Convert to BitLinear print("\n" + "-" * 80) print("2. Transformer with BitLinear") print("-" * 80) model_bitlinear = convert_linear_to_bitlinear(model_standard, inplace=False) print(f"Parameters: {count_parameters(model_bitlinear):,}") print(f"Memory: {estimate_memory_mb(model_bitlinear):.2f} MB") # Forward pass with torch.no_grad(): output_bitlinear = model_bitlinear(x) print(f"Output shape: {output_bitlinear.shape}") # 3. Compare outputs print("\n" + "-" * 80) print("3. Output Comparison") print("-" * 80) metrics = compare_outputs(output_standard, output_bitlinear) print(f"MSE: {metrics['mse']:.6f}") print(f"Cosine similarity: {metrics['cosine_similarity']:.6f}") print(f"Relative error: {metrics['relative_error']:.6f}") # 4. Memory savings print("\n" + "-" * 80) print("4. Memory Savings") print("-" * 80) mem_standard = estimate_memory_mb(model_standard) mem_bitlinear = estimate_memory_mb(model_bitlinear) savings = (mem_standard - mem_bitlinear) / mem_standard * 100 print(f"Standard model: {mem_standard:.2f} MB") print(f"BitLinear model: {mem_bitlinear:.2f} MB") print(f"Memory savings: {savings:.1f}%") print(f"Compression ratio: {mem_standard / mem_bitlinear:.1f}x") # 5. Count Linear layers converted print("\n" + "-" * 80) print("5. Conversion Details") print("-" * 80) def count_linear_layers(model): count = 0 for module in model.modules(): if isinstance(module, nn.Linear): count += 1 return count def count_bitlinear_layers(model): count = 0 for module in model.modules(): if isinstance(module, BitLinear): count += 1 return count print(f"Original Linear layers: {count_linear_layers(model_standard)}") print(f"Converted BitLinear layers: {count_bitlinear_layers(model_bitlinear)}") print("\n" + "=" * 80) print("Example complete!") print("=" * 80) print("\nKey Takeaways:") print("- BitLinear is a drop-in replacement for nn.Linear") print("- Significant memory savings (~20x for weights)") print("- Output similarity is high (cosine sim > 0.99 typically)") print("- Slight accuracy trade-off due to ternary quantization") if __name__ == "__main__": main()