BitLinear / bitlinear /cpp /bitlinear.cpp
krisaujla's picture
Upload folder using huggingface_hub
fd8c8b9 verified
/*
BitLinear C++ Extension
This file provides the C++/PyBind11 interface for BitLinear operations.
It dispatches to CPU or CUDA implementations based on tensor device.
Architecture:
- Python (torch) → PyBind11 → C++ dispatcher → CPU/CUDA kernels
- This file handles: binding, type checking, device dispatch
- Actual computation in: CPU (this file) and CUDA (bitlinear_kernel.cu)
*/
#include <torch/extension.h>
#include <vector>
/*
* Forward declarations for CUDA kernels (implemented in bitlinear_kernel.cu)
* These will be linked at compile time if CUDA is available.
*/
#ifdef WITH_CUDA
torch::Tensor bitlinear_cuda_forward(
torch::Tensor x,
torch::Tensor W_ternary,
torch::Tensor gamma,
torch::optional<torch::Tensor> bias
);
torch::Tensor multi_ternary_cuda_forward(
torch::Tensor x,
torch::Tensor W_ternary,
torch::Tensor gammas,
torch::optional<torch::Tensor> bias
);
#endif
/*
* CPU implementation of BitLinear forward pass
*
* Computes: output = (x @ W_ternary^T) * gamma + bias
*
* This is a reference implementation optimized for clarity.
* Further optimizations can be added:
* - Vectorization (AVX/AVX512)
* - OpenMP parallelization
* - Cache-efficient tiling
*
* Args:
* x: Input tensor [..., in_features]
* W_ternary: Ternary weights [out_features, in_features], values in {-1, 0, 1}
* gamma: Scaling factors [out_features]
* bias: Optional bias [out_features]
*
* Returns:
* Output tensor [..., out_features]
*/
torch::Tensor bitlinear_cpu_forward(
torch::Tensor x,
torch::Tensor W_ternary,
torch::Tensor gamma,
torch::optional<torch::Tensor> bias
) {
// Handle multi-dimensional input by flattening to 2D
auto x_shape = x.sizes().vec();
int64_t batch_size = 1;
for (size_t i = 0; i < x_shape.size() - 1; i++) {
batch_size *= x_shape[i];
}
int64_t in_features = x_shape.back();
int64_t out_features = W_ternary.size(0);
// Reshape x to [batch_size, in_features]
auto x_2d = x.view({batch_size, in_features});
// Compute matmul: [batch, in] @ [in, out] = [batch, out]
// W_ternary is [out, in], so transpose it
auto output = torch::matmul(x_2d, W_ternary.t());
// Apply gamma scaling: element-wise multiply by gamma[out_features]
// gamma shape is [out_features], output is [batch, out_features]
output = output * gamma.unsqueeze(0);
// Add bias if present
if (bias.has_value() && bias.value().defined()) {
output = output + bias.value().unsqueeze(0);
}
// Reshape output back to original batch dimensions
std::vector<int64_t> out_shape(x_shape.begin(), x_shape.end() - 1);
out_shape.push_back(out_features);
output = output.view(out_shape);
return output;
}
/*
* CPU implementation of multi-ternary forward pass
*
* Computes: output = sum_{i=1}^k [(x @ W_i^T) * gamma_i] + bias
*
* Iterates over k ternary components and accumulates their contributions.
*
* Args:
* x: Input tensor [..., in_features]
* W_ternary: Stacked ternary weights [k, out_features, in_features]
* gammas: Stacked scaling factors [k, out_features]
* bias: Optional bias [out_features]
*
* Returns:
* Output tensor [..., out_features]
*/
torch::Tensor multi_ternary_cpu_forward(
torch::Tensor x,
torch::Tensor W_ternary,
torch::Tensor gammas,
torch::optional<torch::Tensor> bias
) {
// W_ternary: [k, out_features, in_features]
// gammas: [k, out_features]
int64_t k = W_ternary.size(0);
int64_t out_features = W_ternary.size(1);
int64_t in_features = W_ternary.size(2);
// Handle multi-dimensional input by flattening to 2D
auto x_shape = x.sizes().vec();
int64_t batch_size = 1;
for (size_t i = 0; i < x_shape.size() - 1; i++) {
batch_size *= x_shape[i];
}
// Reshape x to [batch_size, in_features]
auto x_2d = x.view({batch_size, in_features});
// Initialize output
auto output = torch::zeros({batch_size, out_features}, x.options());
// Accumulate k ternary linear operations
for (int64_t i = 0; i < k; i++) {
// Get i-th component: W_i [out_features, in_features], gamma_i [out_features]
auto W_i = W_ternary[i];
auto gamma_i = gammas[i];
// Compute: (x @ W_i^T) * gamma_i
auto component = torch::matmul(x_2d, W_i.t());
component = component * gamma_i.unsqueeze(0);
// Accumulate
output = output + component;
}
// Add bias if present
if (bias.has_value() && bias.value().defined()) {
output = output + bias.value().unsqueeze(0);
}
// Reshape output back to original batch dimensions
std::vector<int64_t> out_shape(x_shape.begin(), x_shape.end() - 1);
out_shape.push_back(out_features);
output = output.view(out_shape);
return output;
}
/*
* Dispatcher: routes to CPU or CUDA implementation based on tensor device
*
* This is the main entry point called from Python.
* Checks tensor device and dispatches accordingly.
*/
torch::Tensor bitlinear_forward(
torch::Tensor x,
torch::Tensor W_ternary,
torch::Tensor gamma,
torch::optional<torch::Tensor> bias
) {
// Type and shape checks
TORCH_CHECK(x.dim() >= 2, "Input must have at least 2 dimensions");
TORCH_CHECK(W_ternary.dim() == 2, "W_ternary must be 2D");
TORCH_CHECK(gamma.dim() == 1 || gamma.dim() == 2, "gamma must be 1D or 2D");
// Device dispatch
if (x.is_cuda()) {
#ifdef WITH_CUDA
return bitlinear_cuda_forward(x, W_ternary, gamma, bias);
#else
AT_ERROR("BitLinear CUDA kernels not compiled. Rebuild with CUDA support.");
#endif
} else {
return bitlinear_cpu_forward(x, W_ternary, gamma, bias);
}
}
/*
* Multi-ternary dispatcher
*/
torch::Tensor multi_ternary_forward(
torch::Tensor x,
torch::Tensor W_ternary,
torch::Tensor gammas,
torch::optional<torch::Tensor> bias
) {
// Type and shape checks
TORCH_CHECK(x.dim() >= 2, "Input must have at least 2 dimensions");
TORCH_CHECK(W_ternary.dim() == 3, "W_ternary must be 3D [k, out_features, in_features]");
TORCH_CHECK(gammas.dim() == 2, "gammas must be 2D [k, out_features]");
// Device dispatch
if (x.is_cuda()) {
#ifdef WITH_CUDA
return multi_ternary_cuda_forward(x, W_ternary, gammas, bias);
#else
AT_ERROR("Multi-ternary CUDA kernels not compiled. Rebuild with CUDA support.");
#endif
} else {
return multi_ternary_cpu_forward(x, W_ternary, gammas, bias);
}
}
/*
* Utility: pack ternary weights to base-3 representation
*
* Packs ternary weights {-1, 0, +1} into bytes using base-3 encoding.
* Each byte stores 5 ternary values: val0 + 3*val1 + 9*val2 + 27*val3 + 81*val4
* Values are mapped: -1 -> 0, 0 -> 1, +1 -> 2
* Max value: 2+6+18+54+162 = 242 (fits in uint8)
*
* Achieves ~20x memory compression vs float32
*/
torch::Tensor pack_ternary_base3_cpp(torch::Tensor W_ternary) {
// Flatten input
auto flat = W_ternary.flatten().to(torch::kCPU).to(torch::kInt8);
int64_t numel = flat.numel();
// Map {-1, 0, +1} to {0, 1, 2}
auto mapped = (flat + 1).to(torch::kUInt8);
// Calculate output size: ceil(numel / 5)
int64_t packed_size = (numel + 4) / 5;
auto packed = torch::zeros({packed_size}, torch::dtype(torch::kUInt8).device(torch::kCPU));
// Get data pointers
auto mapped_ptr = mapped.data_ptr<uint8_t>();
auto packed_ptr = packed.data_ptr<uint8_t>();
// Powers of 3 for base-3 encoding
const uint8_t powers[5] = {1, 3, 9, 27, 81};
// Pack 5 values into each byte
for (int64_t i = 0; i < packed_size; i++) {
int64_t base_idx = i * 5;
uint8_t packed_val = 0;
for (int j = 0; j < 5; j++) {
int64_t idx = base_idx + j;
if (idx < numel) {
packed_val += mapped_ptr[idx] * powers[j];
} else {
// Pad with 1 (representing 0) for consistent unpacking
packed_val += 1 * powers[j];
}
}
packed_ptr[i] = packed_val;
}
return packed;
}
/*
* Utility: unpack base-3 ternary weights
*
* Unpacks bytes back to ternary weights {-1, 0, +1}.
* Reverses the base-3 encoding: extracts 5 values per byte.
* Maps {0, 1, 2} back to {-1, 0, +1}
*/
torch::Tensor unpack_ternary_base3_cpp(
torch::Tensor packed,
std::vector<int64_t> original_shape
) {
// Calculate expected number of elements
int64_t numel = 1;
for (auto dim : original_shape) {
numel *= dim;
}
// Flatten packed input
auto packed_flat = packed.flatten().to(torch::kCPU).to(torch::kUInt8);
int64_t packed_size = packed_flat.numel();
// Create output tensor
auto unpacked = torch::zeros({numel}, torch::dtype(torch::kInt8).device(torch::kCPU));
// Get data pointers
auto packed_ptr = packed_flat.data_ptr<uint8_t>();
auto unpacked_ptr = unpacked.data_ptr<int8_t>();
// Unpack 5 values from each byte
int64_t out_idx = 0;
for (int64_t i = 0; i < packed_size && out_idx < numel; i++) {
uint8_t packed_val = packed_ptr[i];
// Extract 5 ternary values using base-3 decoding
for (int j = 0; j < 5 && out_idx < numel; j++) {
uint8_t val = packed_val % 3; // Get current base-3 digit
packed_val /= 3; // Shift to next digit
// Map {0, 1, 2} back to {-1, 0, +1}
unpacked_ptr[out_idx] = static_cast<int8_t>(val) - 1;
out_idx++;
}
}
// Reshape to original shape
return unpacked.view(original_shape).to(torch::kFloat32);
}
/*
* PyBind11 module definition
*
* This exposes C++ functions to Python as:
* import bitlinear_cpp
* output = bitlinear_cpp.forward(x, W, gamma, bias)
*/
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &bitlinear_forward, "BitLinear forward (CPU/CUDA)",
py::arg("x"),
py::arg("W_ternary"),
py::arg("gamma"),
py::arg("bias") = py::none());
m.def("multi_ternary_forward", &multi_ternary_forward,
"Multi-ternary linear forward (CPU/CUDA)",
py::arg("x"),
py::arg("W_ternary"),
py::arg("gammas"),
py::arg("bias") = py::none());
m.def("pack_ternary_base3", &pack_ternary_base3_cpp,
"Pack ternary weights to base-3 (CPU)",
py::arg("W_ternary"));
m.def("unpack_ternary_base3", &unpack_ternary_base3_cpp,
"Unpack base-3 ternary weights (CPU)",
py::arg("packed"),
py::arg("original_shape"));
}