|
|
"""
|
|
|
Example: Using BitLinear as a drop-in replacement for nn.Linear in a Transformer.
|
|
|
|
|
|
This example demonstrates:
|
|
|
1. Creating a simple Transformer block with standard nn.Linear
|
|
|
2. Converting it to use BitLinear layers
|
|
|
3. Running forward passes to verify compatibility
|
|
|
4. Comparing memory usage and output similarity
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from typing import Optional
|
|
|
|
|
|
from bitlinear import BitLinear, MultiTernaryLinear, convert_linear_to_bitlinear
|
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
"""
|
|
|
Simplified Transformer block for demonstration.
|
|
|
|
|
|
Contains:
|
|
|
- Multi-head self-attention with linear projections
|
|
|
- Feed-forward network with two linear layers
|
|
|
- Layer normalization and residual connections
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
d_model: int = 512,
|
|
|
nhead: int = 8,
|
|
|
dim_feedforward: int = 2048,
|
|
|
dropout: float = 0.1,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.d_model = d_model
|
|
|
self.nhead = nhead
|
|
|
self.d_k = d_model // nhead
|
|
|
|
|
|
|
|
|
self.q_proj = nn.Linear(d_model, d_model)
|
|
|
self.k_proj = nn.Linear(d_model, d_model)
|
|
|
self.v_proj = nn.Linear(d_model, d_model)
|
|
|
self.out_proj = nn.Linear(d_model, d_model)
|
|
|
|
|
|
|
|
|
self.ffn = nn.Sequential(
|
|
|
nn.Linear(d_model, dim_feedforward),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(dropout),
|
|
|
nn.Linear(dim_feedforward, d_model),
|
|
|
)
|
|
|
|
|
|
|
|
|
self.norm1 = nn.LayerNorm(d_model)
|
|
|
self.norm2 = nn.LayerNorm(d_model)
|
|
|
|
|
|
|
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
x: torch.Tensor,
|
|
|
mask: Optional[torch.Tensor] = None,
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
Forward pass through Transformer block.
|
|
|
|
|
|
Args:
|
|
|
x: Input tensor [batch_size, seq_len, d_model]
|
|
|
mask: Optional attention mask
|
|
|
|
|
|
Returns:
|
|
|
Output tensor [batch_size, seq_len, d_model]
|
|
|
"""
|
|
|
|
|
|
residual = x
|
|
|
x = self.norm1(x)
|
|
|
|
|
|
|
|
|
q = self.q_proj(x)
|
|
|
k = self.k_proj(x)
|
|
|
v = self.v_proj(x)
|
|
|
|
|
|
|
|
|
batch_size, seq_len, _ = x.shape
|
|
|
q = q.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)
|
|
|
k = k.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)
|
|
|
v = v.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)
|
|
|
|
|
|
|
|
|
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)
|
|
|
if mask is not None:
|
|
|
scores = scores.masked_fill(mask == 0, -1e9)
|
|
|
attn_weights = F.softmax(scores, dim=-1)
|
|
|
attn_output = torch.matmul(attn_weights, v)
|
|
|
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous().view(
|
|
|
batch_size, seq_len, self.d_model
|
|
|
)
|
|
|
attn_output = self.out_proj(attn_output)
|
|
|
attn_output = self.dropout1(attn_output)
|
|
|
|
|
|
|
|
|
x = residual + attn_output
|
|
|
|
|
|
|
|
|
residual = x
|
|
|
x = self.norm2(x)
|
|
|
x = self.ffn(x)
|
|
|
x = self.dropout2(x)
|
|
|
|
|
|
|
|
|
x = residual + x
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
def count_parameters(model: nn.Module) -> int:
|
|
|
"""Count total trainable parameters in a model."""
|
|
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
|
|
|
|
|
def estimate_memory_mb(model: nn.Module) -> float:
|
|
|
"""Estimate memory usage of model parameters in MB."""
|
|
|
total_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
|
|
|
return total_bytes / (1024 ** 2)
|
|
|
|
|
|
|
|
|
def compare_outputs(
|
|
|
output1: torch.Tensor,
|
|
|
output2: torch.Tensor,
|
|
|
) -> dict:
|
|
|
"""
|
|
|
Compare two output tensors and compute similarity metrics.
|
|
|
|
|
|
Returns:
|
|
|
Dictionary with comparison metrics
|
|
|
"""
|
|
|
mse = F.mse_loss(output1, output2).item()
|
|
|
cosine_sim = F.cosine_similarity(
|
|
|
output1.flatten(), output2.flatten(), dim=0
|
|
|
).item()
|
|
|
relative_error = (
|
|
|
torch.norm(output1 - output2) / torch.norm(output1)
|
|
|
).item()
|
|
|
|
|
|
return {
|
|
|
"mse": mse,
|
|
|
"cosine_similarity": cosine_sim,
|
|
|
"relative_error": relative_error,
|
|
|
}
|
|
|
|
|
|
|
|
|
def main():
|
|
|
"""Main example demonstrating BitLinear usage in Transformer."""
|
|
|
|
|
|
print("=" * 80)
|
|
|
print("BitLinear Transformer Example")
|
|
|
print("=" * 80)
|
|
|
|
|
|
|
|
|
batch_size = 32
|
|
|
seq_len = 128
|
|
|
d_model = 512
|
|
|
nhead = 8
|
|
|
dim_feedforward = 2048
|
|
|
|
|
|
|
|
|
x = torch.randn(batch_size, seq_len, d_model)
|
|
|
print(f"\nInput shape: {x.shape}")
|
|
|
|
|
|
|
|
|
print("\n" + "-" * 80)
|
|
|
print("1. Standard Transformer with nn.Linear")
|
|
|
print("-" * 80)
|
|
|
|
|
|
model_standard = TransformerBlock(
|
|
|
d_model=d_model,
|
|
|
nhead=nhead,
|
|
|
dim_feedforward=dim_feedforward,
|
|
|
)
|
|
|
|
|
|
print(f"Parameters: {count_parameters(model_standard):,}")
|
|
|
print(f"Memory: {estimate_memory_mb(model_standard):.2f} MB")
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
output_standard = model_standard(x)
|
|
|
print(f"Output shape: {output_standard.shape}")
|
|
|
|
|
|
|
|
|
print("\n" + "-" * 80)
|
|
|
print("2. Transformer with BitLinear")
|
|
|
print("-" * 80)
|
|
|
|
|
|
model_bitlinear = convert_linear_to_bitlinear(model_standard, inplace=False)
|
|
|
|
|
|
print(f"Parameters: {count_parameters(model_bitlinear):,}")
|
|
|
print(f"Memory: {estimate_memory_mb(model_bitlinear):.2f} MB")
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
output_bitlinear = model_bitlinear(x)
|
|
|
print(f"Output shape: {output_bitlinear.shape}")
|
|
|
|
|
|
|
|
|
print("\n" + "-" * 80)
|
|
|
print("3. Output Comparison")
|
|
|
print("-" * 80)
|
|
|
|
|
|
metrics = compare_outputs(output_standard, output_bitlinear)
|
|
|
print(f"MSE: {metrics['mse']:.6f}")
|
|
|
print(f"Cosine similarity: {metrics['cosine_similarity']:.6f}")
|
|
|
print(f"Relative error: {metrics['relative_error']:.6f}")
|
|
|
|
|
|
|
|
|
print("\n" + "-" * 80)
|
|
|
print("4. Memory Savings")
|
|
|
print("-" * 80)
|
|
|
|
|
|
mem_standard = estimate_memory_mb(model_standard)
|
|
|
mem_bitlinear = estimate_memory_mb(model_bitlinear)
|
|
|
savings = (mem_standard - mem_bitlinear) / mem_standard * 100
|
|
|
|
|
|
print(f"Standard model: {mem_standard:.2f} MB")
|
|
|
print(f"BitLinear model: {mem_bitlinear:.2f} MB")
|
|
|
print(f"Memory savings: {savings:.1f}%")
|
|
|
print(f"Compression ratio: {mem_standard / mem_bitlinear:.1f}x")
|
|
|
|
|
|
|
|
|
print("\n" + "-" * 80)
|
|
|
print("5. Conversion Details")
|
|
|
print("-" * 80)
|
|
|
|
|
|
def count_linear_layers(model):
|
|
|
count = 0
|
|
|
for module in model.modules():
|
|
|
if isinstance(module, nn.Linear):
|
|
|
count += 1
|
|
|
return count
|
|
|
|
|
|
def count_bitlinear_layers(model):
|
|
|
count = 0
|
|
|
for module in model.modules():
|
|
|
if isinstance(module, BitLinear):
|
|
|
count += 1
|
|
|
return count
|
|
|
|
|
|
print(f"Original Linear layers: {count_linear_layers(model_standard)}")
|
|
|
print(f"Converted BitLinear layers: {count_bitlinear_layers(model_bitlinear)}")
|
|
|
|
|
|
print("\n" + "=" * 80)
|
|
|
print("Example complete!")
|
|
|
print("=" * 80)
|
|
|
print("\nKey Takeaways:")
|
|
|
print("- BitLinear is a drop-in replacement for nn.Linear")
|
|
|
print("- Significant memory savings (~20x for weights)")
|
|
|
print("- Output similarity is high (cosine sim > 0.99 typically)")
|
|
|
print("- Slight accuracy trade-off due to ternary quantization")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|