BitLinear: Ultra-Low-Precision Linear Layers for PyTorch
A production-ready PyTorch implementation of 1.58-bit ternary linear layers that achieves ~19x memory compression while maintaining high accuracy. Drop-in replacement for nn.Linear with optimized C++/CUDA kernels.
Key Features
- 19.3x Memory Compression - Near-theoretical maximum (20x)
- Drop-in Replacement - Same API as
nn.Linear - Optimized Kernels - C++ CPU and CUDA GPU implementations
- Research-Grade - Based on BitNet and JMLR ternary networks papers
- Production Ready - Fully tested with comprehensive benchmarks
π Performance Highlights
Memory Compression
Achieves 19.23x average compression across various layer sizes:
| Layer Size | nn.Linear | BitLinear (Packed) | Compression |
|---|---|---|---|
| 512Γ512 | 1.00 MB | 0.05 MB | 18.6x |
| 1024Γ1024 | 4.00 MB | 0.21 MB | 19.3x |
| 4096Γ4096 | 64.02 MB | 3.23 MB | 19.8x |
Real-World Example: GPT-2 Small
Converting a GPT-2 Small model (12 layers, d_model=768, d_ff=3072):
- Original: 324 MB
- BitLinear: 16.8 MB
- Saved: 307 MB (19.3x compression)
Accuracy
Maintains high output similarity despite extreme quantization:
- Cosine Similarity: 96.3%
- Relative Error: ~28%
- Multi-Ternary (k=3): 75% error reduction vs k=1
See BENCHMARKS.md for detailed performance analysis.
π Quick Start
Installation
# CPU-only build
pip install -e .
# With CUDA support (requires CUDA toolkit)
CUDA_HOME=/usr/local/cuda pip install -e .
Basic Usage
import torch
from bitlinear import BitLinear
# Create a BitLinear layer (same interface as nn.Linear)
layer = BitLinear(in_features=512, out_features=1024, bias=True)
# Forward pass
x = torch.randn(32, 128, 512)
output = layer(x) # Same as nn.Linear!
print(f"Weight values: {torch.unique(layer.W_ternary)}") # [-1, 0, 1]
Converting Existing Models
import torch.nn as nn
from bitlinear import convert_linear_to_bitlinear
# Convert a pre-trained model
model = nn.TransformerEncoderLayer(d_model=512, nhead=8)
model_compressed = convert_linear_to_bitlinear(model, inplace=False)
# Use as normal - all Linear layers are now BitLinear
x = torch.randn(10, 32, 512)
output = model_compressed(x)
Multi-Ternary for Better Accuracy
from bitlinear import MultiTernaryLinear
# Use k=3 components for 75% error reduction
layer = MultiTernaryLinear(in_features=512, out_features=1024, k=3)
π How It Works
BitLinear uses ternary quantization to represent weights with only three values: {-1, 0, +1}.
Architecture
- Quantization: Weights quantized to {-1, 0, +1} using absmax scaling
- Scaling: Per-output-channel scaling factors (gamma) compensate for quantization
- Packing: Base-3 encoding stores 5 ternary values per byte
- Computation: Optimized kernels exploit ternary structure (no multiplications needed)
Memory Efficiency
- Theoretical: logβ(3) β 1.58 bits per weight
- Actual: 1.6 bits per weight (5 values per byte)
- Efficiency: 98.8% of theoretical maximum
π Project Structure
BitLinear/
βββ bitlinear/ # Main package
β βββ layers.py # BitLinear and MultiTernaryLinear modules
β βββ functional.py # Core functional implementations
β βββ quantization.py # Ternary quantization utilities
β βββ packing.py # Base-3 packing for memory efficiency
β βββ cpp/ # C++/CUDA extensions
β βββ bitlinear.cpp # PyBind11 bindings & CPU kernels
β βββ bitlinear_kernel.cu # CUDA GPU kernels
βββ tests/ # Comprehensive test suite
βββ examples/ # Usage examples
β βββ basic_usage.py # Simple demonstrations
β βββ transformer_example.py # Transformer integration
βββ benchmarks/ # Performance benchmarks
β βββ benchmark_memory.py # Memory analysis
β βββ benchmark_performance.py # Speed comparison
βββ notebooks/ # Interactive tutorials
βββ demo.md # Step-by-step guide
π§ͺ Examples
Example 1: Basic Layer
from bitlinear import BitLinear, estimate_memory_savings
# Create layer
layer = BitLinear(512, 1024)
# Check memory savings
stats = estimate_memory_savings(512, 1024)
print(f"Compression: {stats['compression_ratio']:.1f}x") # ~19x
Example 2: Transformer Conversion
from bitlinear import convert_linear_to_bitlinear
# Original transformer
model = nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=3072)
# Convert to BitLinear
model_bit = convert_linear_to_bitlinear(model)
# Compare memory
mem_original = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
mem_bitlinear = sum(p.numel() * p.element_size() for p in model_bit.parameters()) / 1024**2
print(f"Memory: {mem_original:.2f} MB β {mem_bitlinear:.2f} MB")
Run complete examples:
python examples/basic_usage.py
python examples/transformer_example.py
π Benchmarks
Run benchmarks to see performance on your hardware:
# Memory compression analysis
python benchmarks/benchmark_memory.py
# Forward pass performance
python benchmarks/benchmark_performance.py
π§ͺ Testing
Comprehensive test suite with 60+ tests:
# Run all tests
pytest tests/ -v
# Run specific test modules
pytest tests/test_quantization.py -v
pytest tests/test_layers.py -v
π Research Background
This implementation is based on:
- BitNet: Scaling 1-bit Transformers for Large Language Models
- JMLR: Ternary Representations of Neural Networks
Key Innovations
- Ternary Quantization: Reduces weights to {-1, 0, +1}
- Absmax Scaling: Per-channel scaling for accuracy
- Greedy Decomposition: Multi-ternary for better approximation
- Base-3 Packing: Near-optimal memory compression
π οΈ Implementation Details
Python Baseline
Pure PyTorch implementation for correctness and clarity:
bitlinear_python()- Reference ternary matmulgreedy_ternary_decomposition()- Multi-component quantization- Full gradient support for training
C++ Extensions
Optimized CPU kernels with PyBind11:
- Ternary-specific optimizations (no multiplications)
- Efficient memory access patterns
- Base-3 packing/unpacking
CUDA Kernels
GPU-accelerated implementation:
- Warp-level reductions using shuffle intrinsics
- Shared memory tiling
- Memory coalescing
- Fused multi-ternary kernels
π― Use Cases
Ideal For:
- Edge Deployment: Mobile and embedded devices
- Large Models: Billion-parameter models with memory constraints
- Production Inference: Cost-effective serving at scale
- Research: Exploring ultra-low-precision networks
Considerations:
- Training: Best results with quantization-aware training (QAT)
- Accuracy: 3-5% accuracy drop typical (acceptable for many tasks)
- Speed: Python implementation may be slower; use C++/CUDA for production
π Documentation
- BENCHMARKS.md - Detailed performance analysis
- MODEL_CARD.md - HuggingFace model card
- notebooks/demo.md - Interactive tutorial
- read/IMPLEMENTATION_GUIDE.md - Implementation details (Note can release if needed. Working on extending the pipeline to support future Machine Learning Research)
π€ Contributing
Contributions welcome! Areas for improvement:
- AVX/AVX512 vectorization for CPU
- Tensor Core utilization for CUDA
- Additional quantization schemes
- Training examples and tutorials
π License
MIT License - see LICENSE file for details.
π Citation
If you use BitLinear in your research, please cite:
@article{jmlr_ternary_2024,
title={Ternary Representations of Neural Networks},
journal={Journal of Machine Learning Research},
volume={26},
year={2024},
url={https://jmlr.org/papers/volume26/24-2050/24-2050.pdf}
}
@article{bitnet2023,
title={BitNet: Scaling 1-bit Transformers for Large Language Models},
author={Wang, Hongyu and Ma, Shuming and Dong, Li and Huang, Shaohan and Wang, Huaijie and Ma, Lingxiao and Yang, Fan and Wang, Ruiping and Wu, Yi and Wei, Furu},
journal={arXiv preprint arXiv:2310.11453},
year={2023}
}
π Acknowledgments
This implementation builds upon the groundbreaking work in:
- BitNet by Microsoft Research
- Ternary Neural Networks research (JMLR)
- PyTorch's extensibility framework
π Contact
For questions, issues, or collaboration:
- Open an issue on GitHub
- Check existing documentation
- Review examples and benchmarks
Please tag me if you use this in anything you build. I would love to see what you build with it.
Made with β€οΈ for efficient deep learning