BitLinear: Ultra-Low-Precision Linear Layers for PyTorch

License: MIT Python 3.8+ PyTorch 2.0+

A production-ready PyTorch implementation of 1.58-bit ternary linear layers that achieves ~19x memory compression while maintaining high accuracy. Drop-in replacement for nn.Linear with optimized C++/CUDA kernels.

Key Features

  • 19.3x Memory Compression - Near-theoretical maximum (20x)
  • Drop-in Replacement - Same API as nn.Linear
  • Optimized Kernels - C++ CPU and CUDA GPU implementations
  • Research-Grade - Based on BitNet and JMLR ternary networks papers
  • Production Ready - Fully tested with comprehensive benchmarks

πŸ“Š Performance Highlights

Memory Compression

Achieves 19.23x average compression across various layer sizes:

Layer Size nn.Linear BitLinear (Packed) Compression
512Γ—512 1.00 MB 0.05 MB 18.6x
1024Γ—1024 4.00 MB 0.21 MB 19.3x
4096Γ—4096 64.02 MB 3.23 MB 19.8x

Real-World Example: GPT-2 Small

Converting a GPT-2 Small model (12 layers, d_model=768, d_ff=3072):

  • Original: 324 MB
  • BitLinear: 16.8 MB
  • Saved: 307 MB (19.3x compression)

Accuracy

Maintains high output similarity despite extreme quantization:

  • Cosine Similarity: 96.3%
  • Relative Error: ~28%
  • Multi-Ternary (k=3): 75% error reduction vs k=1

See BENCHMARKS.md for detailed performance analysis.

πŸš€ Quick Start

Installation

# CPU-only build
pip install -e .

# With CUDA support (requires CUDA toolkit)
CUDA_HOME=/usr/local/cuda pip install -e .

Basic Usage

import torch
from bitlinear import BitLinear

# Create a BitLinear layer (same interface as nn.Linear)
layer = BitLinear(in_features=512, out_features=1024, bias=True)

# Forward pass
x = torch.randn(32, 128, 512)
output = layer(x)  # Same as nn.Linear!

print(f"Weight values: {torch.unique(layer.W_ternary)}")  # [-1, 0, 1]

Converting Existing Models

import torch.nn as nn
from bitlinear import convert_linear_to_bitlinear

# Convert a pre-trained model
model = nn.TransformerEncoderLayer(d_model=512, nhead=8)
model_compressed = convert_linear_to_bitlinear(model, inplace=False)

# Use as normal - all Linear layers are now BitLinear
x = torch.randn(10, 32, 512)
output = model_compressed(x)

Multi-Ternary for Better Accuracy

from bitlinear import MultiTernaryLinear

# Use k=3 components for 75% error reduction
layer = MultiTernaryLinear(in_features=512, out_features=1024, k=3)

πŸ“– How It Works

BitLinear uses ternary quantization to represent weights with only three values: {-1, 0, +1}.

Architecture

  1. Quantization: Weights quantized to {-1, 0, +1} using absmax scaling
  2. Scaling: Per-output-channel scaling factors (gamma) compensate for quantization
  3. Packing: Base-3 encoding stores 5 ternary values per byte
  4. Computation: Optimized kernels exploit ternary structure (no multiplications needed)

Memory Efficiency

  • Theoretical: logβ‚‚(3) β‰ˆ 1.58 bits per weight
  • Actual: 1.6 bits per weight (5 values per byte)
  • Efficiency: 98.8% of theoretical maximum

πŸ“ Project Structure

BitLinear/
β”œβ”€β”€ bitlinear/              # Main package
β”‚   β”œβ”€β”€ layers.py           # BitLinear and MultiTernaryLinear modules
β”‚   β”œβ”€β”€ functional.py       # Core functional implementations
β”‚   β”œβ”€β”€ quantization.py     # Ternary quantization utilities
β”‚   β”œβ”€β”€ packing.py          # Base-3 packing for memory efficiency
β”‚   └── cpp/                # C++/CUDA extensions
β”‚       β”œβ”€β”€ bitlinear.cpp   # PyBind11 bindings & CPU kernels
β”‚       └── bitlinear_kernel.cu  # CUDA GPU kernels
β”œβ”€β”€ tests/                  # Comprehensive test suite
β”œβ”€β”€ examples/               # Usage examples
β”‚   β”œβ”€β”€ basic_usage.py      # Simple demonstrations
β”‚   └── transformer_example.py  # Transformer integration
β”œβ”€β”€ benchmarks/             # Performance benchmarks
β”‚   β”œβ”€β”€ benchmark_memory.py     # Memory analysis
β”‚   └── benchmark_performance.py # Speed comparison
└── notebooks/              # Interactive tutorials
    └── demo.md             # Step-by-step guide

πŸ§ͺ Examples

Example 1: Basic Layer

from bitlinear import BitLinear, estimate_memory_savings

# Create layer
layer = BitLinear(512, 1024)

# Check memory savings
stats = estimate_memory_savings(512, 1024)
print(f"Compression: {stats['compression_ratio']:.1f}x")  # ~19x

Example 2: Transformer Conversion

from bitlinear import convert_linear_to_bitlinear

# Original transformer
model = nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=3072)

# Convert to BitLinear
model_bit = convert_linear_to_bitlinear(model)

# Compare memory
mem_original = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
mem_bitlinear = sum(p.numel() * p.element_size() for p in model_bit.parameters()) / 1024**2
print(f"Memory: {mem_original:.2f} MB β†’ {mem_bitlinear:.2f} MB")

Run complete examples:

python examples/basic_usage.py
python examples/transformer_example.py

πŸ“ˆ Benchmarks

Run benchmarks to see performance on your hardware:

# Memory compression analysis
python benchmarks/benchmark_memory.py

# Forward pass performance
python benchmarks/benchmark_performance.py

πŸ§ͺ Testing

Comprehensive test suite with 60+ tests:

# Run all tests
pytest tests/ -v

# Run specific test modules
pytest tests/test_quantization.py -v
pytest tests/test_layers.py -v

πŸŽ“ Research Background

This implementation is based on:

Key Innovations

  1. Ternary Quantization: Reduces weights to {-1, 0, +1}
  2. Absmax Scaling: Per-channel scaling for accuracy
  3. Greedy Decomposition: Multi-ternary for better approximation
  4. Base-3 Packing: Near-optimal memory compression

πŸ› οΈ Implementation Details

Python Baseline

Pure PyTorch implementation for correctness and clarity:

  • bitlinear_python() - Reference ternary matmul
  • greedy_ternary_decomposition() - Multi-component quantization
  • Full gradient support for training

C++ Extensions

Optimized CPU kernels with PyBind11:

  • Ternary-specific optimizations (no multiplications)
  • Efficient memory access patterns
  • Base-3 packing/unpacking

CUDA Kernels

GPU-accelerated implementation:

  • Warp-level reductions using shuffle intrinsics
  • Shared memory tiling
  • Memory coalescing
  • Fused multi-ternary kernels

🎯 Use Cases

Ideal For:

  • Edge Deployment: Mobile and embedded devices
  • Large Models: Billion-parameter models with memory constraints
  • Production Inference: Cost-effective serving at scale
  • Research: Exploring ultra-low-precision networks

Considerations:

  • Training: Best results with quantization-aware training (QAT)
  • Accuracy: 3-5% accuracy drop typical (acceptable for many tasks)
  • Speed: Python implementation may be slower; use C++/CUDA for production

πŸ“š Documentation

🀝 Contributing

Contributions welcome! Areas for improvement:

  • AVX/AVX512 vectorization for CPU
  • Tensor Core utilization for CUDA
  • Additional quantization schemes
  • Training examples and tutorials

πŸ“„ License

MIT License - see LICENSE file for details.

πŸ“– Citation

If you use BitLinear in your research, please cite:

@article{jmlr_ternary_2024,
  title={Ternary Representations of Neural Networks},
  journal={Journal of Machine Learning Research},
  volume={26},
  year={2024},
  url={https://jmlr.org/papers/volume26/24-2050/24-2050.pdf}
}

@article{bitnet2023,
  title={BitNet: Scaling 1-bit Transformers for Large Language Models},
  author={Wang, Hongyu and Ma, Shuming and Dong, Li and Huang, Shaohan and Wang, Huaijie and Ma, Lingxiao and Yang, Fan and Wang, Ruiping and Wu, Yi and Wei, Furu},
  journal={arXiv preprint arXiv:2310.11453},
  year={2023}
}

🌟 Acknowledgments

This implementation builds upon the groundbreaking work in:

  • BitNet by Microsoft Research
  • Ternary Neural Networks research (JMLR)
  • PyTorch's extensibility framework

πŸ“ž Contact

For questions, issues, or collaboration:

  • Open an issue on GitHub
  • Check existing documentation
  • Review examples and benchmarks

Please tag me if you use this in anything you build. I would love to see what you build with it.

Made with ❀️ for efficient deep learning

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Paper for krisaujla/BitLinear