|
|
"""
|
|
|
Setup script for BitLinear PyTorch extension.
|
|
|
|
|
|
This script builds the C++/CUDA extension using PyTorch's built-in
|
|
|
cpp_extension utilities. It handles:
|
|
|
- CPU-only builds (development)
|
|
|
- CUDA builds (production)
|
|
|
- Conditional compilation based on CUDA availability
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import torch
|
|
|
from setuptools import setup, find_packages
|
|
|
from torch.utils.cpp_extension import (
|
|
|
BuildExtension,
|
|
|
CppExtension,
|
|
|
CUDAExtension,
|
|
|
CUDA_HOME,
|
|
|
)
|
|
|
|
|
|
|
|
|
VERSION = "0.1.0"
|
|
|
DESCRIPTION = "BitLinear: Ultra-Low-Precision Linear Layers for PyTorch"
|
|
|
LONG_DESCRIPTION = """
|
|
|
A research-grade PyTorch extension for ultra-low-precision (1.58-bit) ternary
|
|
|
linear layers inspired by BitNet and recent JMLR work on ternary representations
|
|
|
of neural networks.
|
|
|
|
|
|
Features:
|
|
|
- Drop-in replacement for nn.Linear with ternary weights
|
|
|
- 20x memory compression
|
|
|
- Optimized CUDA kernels for GPU acceleration
|
|
|
- Greedy ternary decomposition for improved expressiveness
|
|
|
"""
|
|
|
|
|
|
|
|
|
def cuda_is_available():
|
|
|
"""Check if CUDA is available for compilation."""
|
|
|
return torch.cuda.is_available() and CUDA_HOME is not None
|
|
|
|
|
|
|
|
|
def get_extensions():
|
|
|
"""
|
|
|
Build extension modules based on CUDA availability.
|
|
|
|
|
|
Returns:
|
|
|
List of extension modules to compile
|
|
|
"""
|
|
|
|
|
|
source_dir = os.path.join("bitlinear", "cpp")
|
|
|
sources = [os.path.join(source_dir, "bitlinear.cpp")]
|
|
|
|
|
|
|
|
|
extra_compile_args = {
|
|
|
"cxx": ["-O3", "-std=c++17"],
|
|
|
}
|
|
|
|
|
|
|
|
|
define_macros = []
|
|
|
|
|
|
if cuda_is_available():
|
|
|
print("CUDA detected, building with GPU support")
|
|
|
|
|
|
|
|
|
sources.append(os.path.join(source_dir, "bitlinear_kernel.cu"))
|
|
|
|
|
|
|
|
|
extra_compile_args["nvcc"] = [
|
|
|
"-O3",
|
|
|
"-std=c++17",
|
|
|
"--use_fast_math",
|
|
|
"-gencode=arch=compute_70,code=sm_70",
|
|
|
"-gencode=arch=compute_75,code=sm_75",
|
|
|
"-gencode=arch=compute_80,code=sm_80",
|
|
|
"-gencode=arch=compute_86,code=sm_86",
|
|
|
"-gencode=arch=compute_89,code=sm_89",
|
|
|
"-gencode=arch=compute_90,code=sm_90",
|
|
|
]
|
|
|
|
|
|
|
|
|
define_macros.append(("WITH_CUDA", None))
|
|
|
|
|
|
|
|
|
extension = CUDAExtension(
|
|
|
name="bitlinear_cpp",
|
|
|
sources=sources,
|
|
|
extra_compile_args=extra_compile_args,
|
|
|
define_macros=define_macros,
|
|
|
)
|
|
|
else:
|
|
|
print("CUDA not detected, building CPU-only version")
|
|
|
|
|
|
|
|
|
extension = CppExtension(
|
|
|
name="bitlinear_cpp",
|
|
|
sources=sources,
|
|
|
extra_compile_args=extra_compile_args["cxx"],
|
|
|
define_macros=define_macros,
|
|
|
)
|
|
|
|
|
|
return [extension]
|
|
|
|
|
|
|
|
|
|
|
|
def read_requirements():
|
|
|
"""Read requirements from requirements.txt if it exists."""
|
|
|
req_file = "requirements.txt"
|
|
|
if os.path.exists(req_file):
|
|
|
with open(req_file, "r") as f:
|
|
|
return [line.strip() for line in f if line.strip() and not line.startswith("#")]
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
setup(
|
|
|
name="bitlinear",
|
|
|
version=VERSION,
|
|
|
author="BitLinear Contributors",
|
|
|
description=DESCRIPTION,
|
|
|
long_description=LONG_DESCRIPTION,
|
|
|
long_description_content_type="text/markdown",
|
|
|
url="https://github.com/yourusername/bitlinear",
|
|
|
packages=find_packages(),
|
|
|
ext_modules=get_extensions(),
|
|
|
cmdclass={
|
|
|
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)
|
|
|
},
|
|
|
install_requires=[
|
|
|
"torch>=2.0.0",
|
|
|
"numpy>=1.20.0",
|
|
|
],
|
|
|
extras_require={
|
|
|
"dev": [
|
|
|
"pytest>=7.0.0",
|
|
|
"pytest-cov>=4.0.0",
|
|
|
"black>=22.0.0",
|
|
|
"flake8>=5.0.0",
|
|
|
"mypy>=0.990",
|
|
|
],
|
|
|
"test": [
|
|
|
"pytest>=7.0.0",
|
|
|
"pytest-cov>=4.0.0",
|
|
|
],
|
|
|
},
|
|
|
python_requires=">=3.8",
|
|
|
classifiers=[
|
|
|
"Development Status :: 3 - Alpha",
|
|
|
"Intended Audience :: Science/Research",
|
|
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
|
"License :: OSI Approved :: MIT License",
|
|
|
"Programming Language :: Python :: 3",
|
|
|
"Programming Language :: Python :: 3.8",
|
|
|
"Programming Language :: Python :: 3.9",
|
|
|
"Programming Language :: Python :: 3.10",
|
|
|
"Programming Language :: Python :: 3.11",
|
|
|
"Programming Language :: C++",
|
|
|
"Programming Language :: Python :: Implementation :: CPython",
|
|
|
],
|
|
|
keywords="pytorch deep-learning quantization ternary bitnet transformer",
|
|
|
project_urls={
|
|
|
"Bug Reports": "https://github.com/yourusername/bitlinear/issues",
|
|
|
"Source": "https://github.com/yourusername/bitlinear",
|
|
|
"Documentation": "https://github.com/yourusername/bitlinear/blob/main/README.md",
|
|
|
},
|
|
|
)
|
|
|
|