""" Setup script for BitLinear PyTorch extension. This script builds the C++/CUDA extension using PyTorch's built-in cpp_extension utilities. It handles: - CPU-only builds (development) - CUDA builds (production) - Conditional compilation based on CUDA availability """ import os import torch from setuptools import setup, find_packages from torch.utils.cpp_extension import ( BuildExtension, CppExtension, CUDAExtension, CUDA_HOME, ) # Package metadata VERSION = "0.1.0" DESCRIPTION = "BitLinear: Ultra-Low-Precision Linear Layers for PyTorch" LONG_DESCRIPTION = """ A research-grade PyTorch extension for ultra-low-precision (1.58-bit) ternary linear layers inspired by BitNet and recent JMLR work on ternary representations of neural networks. Features: - Drop-in replacement for nn.Linear with ternary weights - 20x memory compression - Optimized CUDA kernels for GPU acceleration - Greedy ternary decomposition for improved expressiveness """ # Determine if CUDA is available def cuda_is_available(): """Check if CUDA is available for compilation.""" return torch.cuda.is_available() and CUDA_HOME is not None def get_extensions(): """ Build extension modules based on CUDA availability. Returns: List of extension modules to compile """ # Source files source_dir = os.path.join("bitlinear", "cpp") sources = [os.path.join(source_dir, "bitlinear.cpp")] # Compiler flags extra_compile_args = { "cxx": ["-O3", "-std=c++17"], } # Define macros define_macros = [] if cuda_is_available(): print("CUDA detected, building with GPU support") # Add CUDA source sources.append(os.path.join(source_dir, "bitlinear_kernel.cu")) # CUDA compiler flags extra_compile_args["nvcc"] = [ "-O3", "-std=c++17", "--use_fast_math", "-gencode=arch=compute_70,code=sm_70", # V100 "-gencode=arch=compute_75,code=sm_75", # T4, RTX 20xx "-gencode=arch=compute_80,code=sm_80", # A100 "-gencode=arch=compute_86,code=sm_86", # RTX 30xx "-gencode=arch=compute_89,code=sm_89", # RTX 40xx "-gencode=arch=compute_90,code=sm_90", # H100 ] # Define CUDA macro define_macros.append(("WITH_CUDA", None)) # Create CUDA extension extension = CUDAExtension( name="bitlinear_cpp", sources=sources, extra_compile_args=extra_compile_args, define_macros=define_macros, ) else: print("CUDA not detected, building CPU-only version") # Create CPU-only extension extension = CppExtension( name="bitlinear_cpp", sources=sources, extra_compile_args=extra_compile_args["cxx"], define_macros=define_macros, ) return [extension] # Read requirements def read_requirements(): """Read requirements from requirements.txt if it exists.""" req_file = "requirements.txt" if os.path.exists(req_file): with open(req_file, "r") as f: return [line.strip() for line in f if line.strip() and not line.startswith("#")] return [] # Main setup setup( name="bitlinear", version=VERSION, author="BitLinear Contributors", description=DESCRIPTION, long_description=LONG_DESCRIPTION, long_description_content_type="text/markdown", url="https://github.com/yourusername/bitlinear", # TODO: Update with actual repo packages=find_packages(), ext_modules=get_extensions(), cmdclass={ "build_ext": BuildExtension.with_options(no_python_abi_suffix=True) }, install_requires=[ "torch>=2.0.0", "numpy>=1.20.0", ], extras_require={ "dev": [ "pytest>=7.0.0", "pytest-cov>=4.0.0", "black>=22.0.0", "flake8>=5.0.0", "mypy>=0.990", ], "test": [ "pytest>=7.0.0", "pytest-cov>=4.0.0", ], }, python_requires=">=3.8", classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering :: Artificial Intelligence", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: C++", "Programming Language :: Python :: Implementation :: CPython", ], keywords="pytorch deep-learning quantization ternary bitnet transformer", project_urls={ "Bug Reports": "https://github.com/yourusername/bitlinear/issues", "Source": "https://github.com/yourusername/bitlinear", "Documentation": "https://github.com/yourusername/bitlinear/blob/main/README.md", }, )