BitLinear / setup.py
krisaujla's picture
Upload folder using huggingface_hub
fd8c8b9 verified
"""
Setup script for BitLinear PyTorch extension.
This script builds the C++/CUDA extension using PyTorch's built-in
cpp_extension utilities. It handles:
- CPU-only builds (development)
- CUDA builds (production)
- Conditional compilation based on CUDA availability
"""
import os
import torch
from setuptools import setup, find_packages
from torch.utils.cpp_extension import (
BuildExtension,
CppExtension,
CUDAExtension,
CUDA_HOME,
)
# Package metadata
VERSION = "0.1.0"
DESCRIPTION = "BitLinear: Ultra-Low-Precision Linear Layers for PyTorch"
LONG_DESCRIPTION = """
A research-grade PyTorch extension for ultra-low-precision (1.58-bit) ternary
linear layers inspired by BitNet and recent JMLR work on ternary representations
of neural networks.
Features:
- Drop-in replacement for nn.Linear with ternary weights
- 20x memory compression
- Optimized CUDA kernels for GPU acceleration
- Greedy ternary decomposition for improved expressiveness
"""
# Determine if CUDA is available
def cuda_is_available():
"""Check if CUDA is available for compilation."""
return torch.cuda.is_available() and CUDA_HOME is not None
def get_extensions():
"""
Build extension modules based on CUDA availability.
Returns:
List of extension modules to compile
"""
# Source files
source_dir = os.path.join("bitlinear", "cpp")
sources = [os.path.join(source_dir, "bitlinear.cpp")]
# Compiler flags
extra_compile_args = {
"cxx": ["-O3", "-std=c++17"],
}
# Define macros
define_macros = []
if cuda_is_available():
print("CUDA detected, building with GPU support")
# Add CUDA source
sources.append(os.path.join(source_dir, "bitlinear_kernel.cu"))
# CUDA compiler flags
extra_compile_args["nvcc"] = [
"-O3",
"-std=c++17",
"--use_fast_math",
"-gencode=arch=compute_70,code=sm_70", # V100
"-gencode=arch=compute_75,code=sm_75", # T4, RTX 20xx
"-gencode=arch=compute_80,code=sm_80", # A100
"-gencode=arch=compute_86,code=sm_86", # RTX 30xx
"-gencode=arch=compute_89,code=sm_89", # RTX 40xx
"-gencode=arch=compute_90,code=sm_90", # H100
]
# Define CUDA macro
define_macros.append(("WITH_CUDA", None))
# Create CUDA extension
extension = CUDAExtension(
name="bitlinear_cpp",
sources=sources,
extra_compile_args=extra_compile_args,
define_macros=define_macros,
)
else:
print("CUDA not detected, building CPU-only version")
# Create CPU-only extension
extension = CppExtension(
name="bitlinear_cpp",
sources=sources,
extra_compile_args=extra_compile_args["cxx"],
define_macros=define_macros,
)
return [extension]
# Read requirements
def read_requirements():
"""Read requirements from requirements.txt if it exists."""
req_file = "requirements.txt"
if os.path.exists(req_file):
with open(req_file, "r") as f:
return [line.strip() for line in f if line.strip() and not line.startswith("#")]
return []
# Main setup
setup(
name="bitlinear",
version=VERSION,
author="BitLinear Contributors",
description=DESCRIPTION,
long_description=LONG_DESCRIPTION,
long_description_content_type="text/markdown",
url="https://github.com/yourusername/bitlinear", # TODO: Update with actual repo
packages=find_packages(),
ext_modules=get_extensions(),
cmdclass={
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)
},
install_requires=[
"torch>=2.0.0",
"numpy>=1.20.0",
],
extras_require={
"dev": [
"pytest>=7.0.0",
"pytest-cov>=4.0.0",
"black>=22.0.0",
"flake8>=5.0.0",
"mypy>=0.990",
],
"test": [
"pytest>=7.0.0",
"pytest-cov>=4.0.0",
],
},
python_requires=">=3.8",
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: C++",
"Programming Language :: Python :: Implementation :: CPython",
],
keywords="pytorch deep-learning quantization ternary bitnet transformer",
project_urls={
"Bug Reports": "https://github.com/yourusername/bitlinear/issues",
"Source": "https://github.com/yourusername/bitlinear",
"Documentation": "https://github.com/yourusername/bitlinear/blob/main/README.md",
},
)