/* BitLinear C++ Extension This file provides the C++/PyBind11 interface for BitLinear operations. It dispatches to CPU or CUDA implementations based on tensor device. Architecture: - Python (torch) → PyBind11 → C++ dispatcher → CPU/CUDA kernels - This file handles: binding, type checking, device dispatch - Actual computation in: CPU (this file) and CUDA (bitlinear_kernel.cu) */ #include #include /* * Forward declarations for CUDA kernels (implemented in bitlinear_kernel.cu) * These will be linked at compile time if CUDA is available. */ #ifdef WITH_CUDA torch::Tensor bitlinear_cuda_forward( torch::Tensor x, torch::Tensor W_ternary, torch::Tensor gamma, torch::optional bias ); torch::Tensor multi_ternary_cuda_forward( torch::Tensor x, torch::Tensor W_ternary, torch::Tensor gammas, torch::optional bias ); #endif /* * CPU implementation of BitLinear forward pass * * Computes: output = (x @ W_ternary^T) * gamma + bias * * This is a reference implementation optimized for clarity. * Further optimizations can be added: * - Vectorization (AVX/AVX512) * - OpenMP parallelization * - Cache-efficient tiling * * Args: * x: Input tensor [..., in_features] * W_ternary: Ternary weights [out_features, in_features], values in {-1, 0, 1} * gamma: Scaling factors [out_features] * bias: Optional bias [out_features] * * Returns: * Output tensor [..., out_features] */ torch::Tensor bitlinear_cpu_forward( torch::Tensor x, torch::Tensor W_ternary, torch::Tensor gamma, torch::optional bias ) { // Handle multi-dimensional input by flattening to 2D auto x_shape = x.sizes().vec(); int64_t batch_size = 1; for (size_t i = 0; i < x_shape.size() - 1; i++) { batch_size *= x_shape[i]; } int64_t in_features = x_shape.back(); int64_t out_features = W_ternary.size(0); // Reshape x to [batch_size, in_features] auto x_2d = x.view({batch_size, in_features}); // Compute matmul: [batch, in] @ [in, out] = [batch, out] // W_ternary is [out, in], so transpose it auto output = torch::matmul(x_2d, W_ternary.t()); // Apply gamma scaling: element-wise multiply by gamma[out_features] // gamma shape is [out_features], output is [batch, out_features] output = output * gamma.unsqueeze(0); // Add bias if present if (bias.has_value() && bias.value().defined()) { output = output + bias.value().unsqueeze(0); } // Reshape output back to original batch dimensions std::vector out_shape(x_shape.begin(), x_shape.end() - 1); out_shape.push_back(out_features); output = output.view(out_shape); return output; } /* * CPU implementation of multi-ternary forward pass * * Computes: output = sum_{i=1}^k [(x @ W_i^T) * gamma_i] + bias * * Iterates over k ternary components and accumulates their contributions. * * Args: * x: Input tensor [..., in_features] * W_ternary: Stacked ternary weights [k, out_features, in_features] * gammas: Stacked scaling factors [k, out_features] * bias: Optional bias [out_features] * * Returns: * Output tensor [..., out_features] */ torch::Tensor multi_ternary_cpu_forward( torch::Tensor x, torch::Tensor W_ternary, torch::Tensor gammas, torch::optional bias ) { // W_ternary: [k, out_features, in_features] // gammas: [k, out_features] int64_t k = W_ternary.size(0); int64_t out_features = W_ternary.size(1); int64_t in_features = W_ternary.size(2); // Handle multi-dimensional input by flattening to 2D auto x_shape = x.sizes().vec(); int64_t batch_size = 1; for (size_t i = 0; i < x_shape.size() - 1; i++) { batch_size *= x_shape[i]; } // Reshape x to [batch_size, in_features] auto x_2d = x.view({batch_size, in_features}); // Initialize output auto output = torch::zeros({batch_size, out_features}, x.options()); // Accumulate k ternary linear operations for (int64_t i = 0; i < k; i++) { // Get i-th component: W_i [out_features, in_features], gamma_i [out_features] auto W_i = W_ternary[i]; auto gamma_i = gammas[i]; // Compute: (x @ W_i^T) * gamma_i auto component = torch::matmul(x_2d, W_i.t()); component = component * gamma_i.unsqueeze(0); // Accumulate output = output + component; } // Add bias if present if (bias.has_value() && bias.value().defined()) { output = output + bias.value().unsqueeze(0); } // Reshape output back to original batch dimensions std::vector out_shape(x_shape.begin(), x_shape.end() - 1); out_shape.push_back(out_features); output = output.view(out_shape); return output; } /* * Dispatcher: routes to CPU or CUDA implementation based on tensor device * * This is the main entry point called from Python. * Checks tensor device and dispatches accordingly. */ torch::Tensor bitlinear_forward( torch::Tensor x, torch::Tensor W_ternary, torch::Tensor gamma, torch::optional bias ) { // Type and shape checks TORCH_CHECK(x.dim() >= 2, "Input must have at least 2 dimensions"); TORCH_CHECK(W_ternary.dim() == 2, "W_ternary must be 2D"); TORCH_CHECK(gamma.dim() == 1 || gamma.dim() == 2, "gamma must be 1D or 2D"); // Device dispatch if (x.is_cuda()) { #ifdef WITH_CUDA return bitlinear_cuda_forward(x, W_ternary, gamma, bias); #else AT_ERROR("BitLinear CUDA kernels not compiled. Rebuild with CUDA support."); #endif } else { return bitlinear_cpu_forward(x, W_ternary, gamma, bias); } } /* * Multi-ternary dispatcher */ torch::Tensor multi_ternary_forward( torch::Tensor x, torch::Tensor W_ternary, torch::Tensor gammas, torch::optional bias ) { // Type and shape checks TORCH_CHECK(x.dim() >= 2, "Input must have at least 2 dimensions"); TORCH_CHECK(W_ternary.dim() == 3, "W_ternary must be 3D [k, out_features, in_features]"); TORCH_CHECK(gammas.dim() == 2, "gammas must be 2D [k, out_features]"); // Device dispatch if (x.is_cuda()) { #ifdef WITH_CUDA return multi_ternary_cuda_forward(x, W_ternary, gammas, bias); #else AT_ERROR("Multi-ternary CUDA kernels not compiled. Rebuild with CUDA support."); #endif } else { return multi_ternary_cpu_forward(x, W_ternary, gammas, bias); } } /* * Utility: pack ternary weights to base-3 representation * * Packs ternary weights {-1, 0, +1} into bytes using base-3 encoding. * Each byte stores 5 ternary values: val0 + 3*val1 + 9*val2 + 27*val3 + 81*val4 * Values are mapped: -1 -> 0, 0 -> 1, +1 -> 2 * Max value: 2+6+18+54+162 = 242 (fits in uint8) * * Achieves ~20x memory compression vs float32 */ torch::Tensor pack_ternary_base3_cpp(torch::Tensor W_ternary) { // Flatten input auto flat = W_ternary.flatten().to(torch::kCPU).to(torch::kInt8); int64_t numel = flat.numel(); // Map {-1, 0, +1} to {0, 1, 2} auto mapped = (flat + 1).to(torch::kUInt8); // Calculate output size: ceil(numel / 5) int64_t packed_size = (numel + 4) / 5; auto packed = torch::zeros({packed_size}, torch::dtype(torch::kUInt8).device(torch::kCPU)); // Get data pointers auto mapped_ptr = mapped.data_ptr(); auto packed_ptr = packed.data_ptr(); // Powers of 3 for base-3 encoding const uint8_t powers[5] = {1, 3, 9, 27, 81}; // Pack 5 values into each byte for (int64_t i = 0; i < packed_size; i++) { int64_t base_idx = i * 5; uint8_t packed_val = 0; for (int j = 0; j < 5; j++) { int64_t idx = base_idx + j; if (idx < numel) { packed_val += mapped_ptr[idx] * powers[j]; } else { // Pad with 1 (representing 0) for consistent unpacking packed_val += 1 * powers[j]; } } packed_ptr[i] = packed_val; } return packed; } /* * Utility: unpack base-3 ternary weights * * Unpacks bytes back to ternary weights {-1, 0, +1}. * Reverses the base-3 encoding: extracts 5 values per byte. * Maps {0, 1, 2} back to {-1, 0, +1} */ torch::Tensor unpack_ternary_base3_cpp( torch::Tensor packed, std::vector original_shape ) { // Calculate expected number of elements int64_t numel = 1; for (auto dim : original_shape) { numel *= dim; } // Flatten packed input auto packed_flat = packed.flatten().to(torch::kCPU).to(torch::kUInt8); int64_t packed_size = packed_flat.numel(); // Create output tensor auto unpacked = torch::zeros({numel}, torch::dtype(torch::kInt8).device(torch::kCPU)); // Get data pointers auto packed_ptr = packed_flat.data_ptr(); auto unpacked_ptr = unpacked.data_ptr(); // Unpack 5 values from each byte int64_t out_idx = 0; for (int64_t i = 0; i < packed_size && out_idx < numel; i++) { uint8_t packed_val = packed_ptr[i]; // Extract 5 ternary values using base-3 decoding for (int j = 0; j < 5 && out_idx < numel; j++) { uint8_t val = packed_val % 3; // Get current base-3 digit packed_val /= 3; // Shift to next digit // Map {0, 1, 2} back to {-1, 0, +1} unpacked_ptr[out_idx] = static_cast(val) - 1; out_idx++; } } // Reshape to original shape return unpacked.view(original_shape).to(torch::kFloat32); } /* * PyBind11 module definition * * This exposes C++ functions to Python as: * import bitlinear_cpp * output = bitlinear_cpp.forward(x, W, gamma, bias) */ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &bitlinear_forward, "BitLinear forward (CPU/CUDA)", py::arg("x"), py::arg("W_ternary"), py::arg("gamma"), py::arg("bias") = py::none()); m.def("multi_ternary_forward", &multi_ternary_forward, "Multi-ternary linear forward (CPU/CUDA)", py::arg("x"), py::arg("W_ternary"), py::arg("gammas"), py::arg("bias") = py::none()); m.def("pack_ternary_base3", &pack_ternary_base3_cpp, "Pack ternary weights to base-3 (CPU)", py::arg("W_ternary")); m.def("unpack_ternary_base3", &unpack_ternary_base3_cpp, "Unpack base-3 ternary weights (CPU)", py::arg("packed"), py::arg("original_shape")); }