BitLinear / bitlinear /cpp /bitlinear_kernel.cu
krisaujla's picture
Upload folder using huggingface_hub
fd8c8b9 verified
/*
BitLinear CUDA Kernels
This file contains CUDA kernel implementations for BitLinear operations.
The kernels optimize ternary matrix multiplication for GPU execution.
Key optimizations implemented:
1. Ternary weight specialization (only -1, 0, +1)
2. Shared memory tiling for reduced global memory access
3. Warp-level reduction using shuffle intrinsics
4. Memory coalescing for efficient global reads
5. Thread coarsening for better instruction-level parallelism
*/
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDAContext.h>
#include <vector>
// Tile size for shared memory - tuned for occupancy and cache utilization
constexpr int TILE_SIZE = 256;
constexpr int WARP_SIZE = 32;
/*
* Warp-level reduction using shuffle intrinsics
* Reduces a value across all threads in a warp
*/
template <typename scalar_t>
__device__ __forceinline__ scalar_t warp_reduce_sum(scalar_t val) {
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return val;
}
/*
* Block-level reduction using shared memory
* Reduces partial sums from each warp to a single value
*/
template <typename scalar_t>
__device__ scalar_t block_reduce_sum(scalar_t val, scalar_t* shared_mem) {
int lane = threadIdx.x % WARP_SIZE;
int warp_id = threadIdx.x / WARP_SIZE;
// First reduce within warp
val = warp_reduce_sum(val);
// Write reduced warp value to shared memory
if (lane == 0) {
shared_mem[warp_id] = val;
}
__syncthreads();
// Read from shared memory only if this thread is in the first warp
int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE;
val = (threadIdx.x < num_warps) ? shared_mem[lane] : scalar_t(0);
// Final reduce within first warp
if (warp_id == 0) {
val = warp_reduce_sum(val);
}
return val;
}
/*
* CUDA kernel for BitLinear forward pass
*
* Computes: output[batch, out] = sum_in (x[batch, in] * W[out, in]) * gamma[out]
*
* This is a specialized matrix multiplication kernel that exploits:
* - Ternary weights: only need additions/subtractions (no multiplications)
* - Shared memory tiling for reduced memory bandwidth
* - Warp shuffle for efficient reductions
*
* Grid/Block configuration:
* - Grid: (batch_size, out_features)
* - Block: TILE_SIZE threads
* - Each block computes one output element
*/
template <typename scalar_t>
__global__ void bitlinear_forward_kernel(
const scalar_t* __restrict__ x, // [batch_size, in_features]
const scalar_t* __restrict__ W_ternary, // [out_features, in_features]
const scalar_t* __restrict__ gamma, // [out_features]
const scalar_t* __restrict__ bias, // [out_features] or nullptr
scalar_t* __restrict__ output, // [batch_size, out_features]
int batch_size,
int in_features,
int out_features
) {
int batch_idx = blockIdx.x;
int out_idx = blockIdx.y;
int tid = threadIdx.x;
// Shared memory for partial sums reduction
extern __shared__ char shared_mem_raw[];
scalar_t* shared_mem = reinterpret_cast<scalar_t*>(shared_mem_raw);
// Each thread computes partial dot product
scalar_t partial_sum = scalar_t(0);
// Coalesced access: each thread handles multiple elements strided by TILE_SIZE
for (int i = tid; i < in_features; i += TILE_SIZE) {
scalar_t x_val = x[batch_idx * in_features + i];
scalar_t w_val = W_ternary[out_idx * in_features + i];
// Exploit ternary structure: conditional accumulation (no multiply)
// This is faster than general multiply when weights are truly ternary
if (w_val > scalar_t(0)) {
partial_sum += x_val;
} else if (w_val < scalar_t(0)) {
partial_sum -= x_val;
}
// w_val == 0: skip (implicit in else)
}
// Reduce partial sums across block
partial_sum = block_reduce_sum(partial_sum, shared_mem);
// Thread 0 writes the final result
if (tid == 0) {
// Apply gamma scaling
scalar_t result = partial_sum * gamma[out_idx];
// Add bias if present
if (bias != nullptr) {
result += bias[out_idx];
}
output[batch_idx * out_features + out_idx] = result;
}
}
/*
* CUDA kernel launcher for BitLinear forward
*
* This function:
* 1. Handles multi-dimensional input by flattening
* 2. Sets up grid and block dimensions
* 3. Launches the CUDA kernel with dynamic shared memory
* 4. Reshapes output to match input batch dimensions
*/
torch::Tensor bitlinear_cuda_forward(
torch::Tensor x,
torch::Tensor W_ternary,
torch::Tensor gamma,
torch::optional<torch::Tensor> bias
) {
// Handle multi-dimensional input
auto x_shape = x.sizes().vec();
int64_t batch_size = 1;
for (size_t i = 0; i < x_shape.size() - 1; i++) {
batch_size *= x_shape[i];
}
const int in_features = x.size(-1);
const int out_features = W_ternary.size(0);
// Flatten input to 2D for kernel
auto x_2d = x.view({batch_size, in_features}).contiguous();
// Ensure all tensors are contiguous for efficient memory access
auto W_cont = W_ternary.contiguous();
auto gamma_cont = gamma.contiguous();
// Allocate output
auto output = torch::zeros({batch_size, out_features}, x.options());
// Calculate shared memory size for reduction
int num_warps = (TILE_SIZE + WARP_SIZE - 1) / WARP_SIZE;
// Grid: one block per (batch, output feature) pair
dim3 grid(batch_size, out_features);
dim3 block(TILE_SIZE);
// Get current CUDA stream
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "bitlinear_forward_cuda", ([&] {
size_t shared_mem_size = num_warps * sizeof(scalar_t);
bitlinear_forward_kernel<scalar_t><<<grid, block, shared_mem_size, stream>>>(
x_2d.data_ptr<scalar_t>(),
W_cont.data_ptr<scalar_t>(),
gamma_cont.data_ptr<scalar_t>(),
bias.has_value() && bias.value().defined()
? bias.value().contiguous().data_ptr<scalar_t>()
: nullptr,
output.data_ptr<scalar_t>(),
batch_size,
in_features,
out_features
);
}));
// Check for CUDA errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
AT_ERROR("BitLinear CUDA kernel failed: ", cudaGetErrorString(err));
}
// Reshape output to match input batch dimensions
std::vector<int64_t> out_shape(x_shape.begin(), x_shape.end() - 1);
out_shape.push_back(out_features);
return output.view(out_shape);
}
/*
* CUDA kernel for multi-ternary forward pass
*
* Computes: output = sum_{i=1}^k [(x @ W_i^T) * gamma_i] + bias
*
* This kernel fuses k ternary matrix multiplications into a single kernel
* to reduce memory bandwidth requirements. Each block handles one
* (batch, output) pair and accumulates contributions from all k components.
*
* Grid/Block configuration:
* - Grid: (batch_size, out_features)
* - Block: TILE_SIZE threads
*/
template <typename scalar_t>
__global__ void multi_ternary_forward_kernel(
const scalar_t* __restrict__ x, // [batch_size, in_features]
const scalar_t* __restrict__ W_ternary, // [k, out_features, in_features]
const scalar_t* __restrict__ gammas, // [k, out_features]
const scalar_t* __restrict__ bias, // [out_features] or nullptr
scalar_t* __restrict__ output, // [batch_size, out_features]
int batch_size,
int in_features,
int out_features,
int k
) {
int batch_idx = blockIdx.x;
int out_idx = blockIdx.y;
int tid = threadIdx.x;
// Shared memory for reduction
extern __shared__ char shared_mem_raw[];
scalar_t* shared_mem = reinterpret_cast<scalar_t*>(shared_mem_raw);
// Accumulate total result across all k components
scalar_t total_result = scalar_t(0);
// Stride for indexing into W_ternary: [k, out_features, in_features]
int W_out_stride = in_features;
int W_k_stride = out_features * in_features;
// Process each of the k components
for (int comp = 0; comp < k; comp++) {
scalar_t partial_sum = scalar_t(0);
// Compute dot product for this component
for (int i = tid; i < in_features; i += TILE_SIZE) {
scalar_t x_val = x[batch_idx * in_features + i];
scalar_t w_val = W_ternary[comp * W_k_stride + out_idx * W_out_stride + i];
// Ternary conditional accumulation
if (w_val > scalar_t(0)) {
partial_sum += x_val;
} else if (w_val < scalar_t(0)) {
partial_sum -= x_val;
}
}
// Reduce partial sums across block
partial_sum = block_reduce_sum(partial_sum, shared_mem);
__syncthreads();
// Thread 0 accumulates with gamma scaling
if (tid == 0) {
scalar_t gamma_val = gammas[comp * out_features + out_idx];
total_result += partial_sum * gamma_val;
}
__syncthreads();
}
// Thread 0 writes the final result
if (tid == 0) {
// Add bias if present
if (bias != nullptr) {
total_result += bias[out_idx];
}
output[batch_idx * out_features + out_idx] = total_result;
}
}
/*
* Launcher for multi-ternary CUDA kernel
*
* This function:
* 1. Handles multi-dimensional input by flattening
* 2. Sets up grid and block dimensions
* 3. Launches the fused multi-ternary kernel
* 4. Reshapes output to match input batch dimensions
*/
torch::Tensor multi_ternary_cuda_forward(
torch::Tensor x,
torch::Tensor W_ternary,
torch::Tensor gammas,
torch::optional<torch::Tensor> bias
) {
// Handle multi-dimensional input
auto x_shape = x.sizes().vec();
int64_t batch_size = 1;
for (size_t i = 0; i < x_shape.size() - 1; i++) {
batch_size *= x_shape[i];
}
const int in_features = x.size(-1);
const int k = W_ternary.size(0);
const int out_features = W_ternary.size(1);
// Flatten input to 2D
auto x_2d = x.view({batch_size, in_features}).contiguous();
// Ensure tensors are contiguous
auto W_cont = W_ternary.contiguous();
auto gammas_cont = gammas.contiguous();
// Allocate output
auto output = torch::zeros({batch_size, out_features}, x.options());
// Calculate shared memory size
int num_warps = (TILE_SIZE + WARP_SIZE - 1) / WARP_SIZE;
// Grid configuration
dim3 grid(batch_size, out_features);
dim3 block(TILE_SIZE);
// Get current CUDA stream
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "multi_ternary_forward_cuda", ([&] {
size_t shared_mem_size = num_warps * sizeof(scalar_t);
multi_ternary_forward_kernel<scalar_t><<<grid, block, shared_mem_size, stream>>>(
x_2d.data_ptr<scalar_t>(),
W_cont.data_ptr<scalar_t>(),
gammas_cont.data_ptr<scalar_t>(),
bias.has_value() && bias.value().defined()
? bias.value().contiguous().data_ptr<scalar_t>()
: nullptr,
output.data_ptr<scalar_t>(),
batch_size,
in_features,
out_features,
k
);
}));
// Check for CUDA errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
AT_ERROR("Multi-ternary CUDA kernel failed: ", cudaGetErrorString(err));
}
// Reshape output
std::vector<int64_t> out_shape(x_shape.begin(), x_shape.end() - 1);
out_shape.push_back(out_features);
return output.view(out_shape);
}
/*
* Advanced optimization: Ternary matrix multiplication using Tensor Cores
*
* Modern GPUs (Volta+) have Tensor Cores that accelerate matrix operations.
* While designed for FP16/INT8, we can potentially leverage them for ternary
* operations by packing ternary values into INT4/INT8 formats.
*
* This is a future optimization once basic kernels are working.
*
* Potential approaches:
* 1. Pack ternary values into INT8 and use INT8 Tensor Cores
* 2. Use FP16 with ternary values for FP16 Tensor Cores
* 3. Custom WMMA (Warp Matrix Multiply Accumulate) implementation
*/
/*
* CUDA kernel for packing ternary weights to base-3 representation
*
* Maps {-1, 0, +1} to {0, 1, 2} and packs 5 values per byte.
* Each thread handles multiple output bytes for efficiency.
*/
template <typename scalar_t>
__global__ void pack_ternary_kernel(
const scalar_t* __restrict__ input, // Flat ternary weights
uint8_t* __restrict__ output, // Packed output
int64_t numel, // Number of input elements
int64_t packed_size // Number of output bytes
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < packed_size) {
int64_t base_idx = idx * 5;
uint8_t packed_val = 0;
uint8_t powers[5] = {1, 3, 9, 27, 81};
#pragma unroll
for (int j = 0; j < 5; j++) {
int64_t in_idx = base_idx + j;
if (in_idx < numel) {
// Map {-1, 0, +1} to {0, 1, 2}
int8_t val = static_cast<int8_t>(input[in_idx]) + 1;
packed_val += static_cast<uint8_t>(val) * powers[j];
} else {
// Pad with 1 (representing 0)
packed_val += 1 * powers[j];
}
}
output[idx] = packed_val;
}
}
/*
* CUDA kernel for unpacking base-3 ternary weights
*
* Extracts 5 values per byte and maps {0, 1, 2} back to {-1, 0, +1}.
*/
template <typename scalar_t>
__global__ void unpack_ternary_kernel(
const uint8_t* __restrict__ input, // Packed input
scalar_t* __restrict__ output, // Unpacked output
int64_t numel, // Number of output elements
int64_t packed_size // Number of input bytes
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < packed_size) {
int64_t base_idx = idx * 5;
uint8_t packed_val = input[idx];
#pragma unroll
for (int j = 0; j < 5 && base_idx + j < numel; j++) {
uint8_t val = packed_val % 3;
packed_val /= 3;
// Map {0, 1, 2} to {-1, 0, +1}
output[base_idx + j] = static_cast<scalar_t>(val) - scalar_t(1);
}
}
}
/*
* GPU-accelerated packing launcher
*/
torch::Tensor pack_ternary_cuda(torch::Tensor W_ternary) {
auto flat = W_ternary.flatten().contiguous();
int64_t numel = flat.numel();
int64_t packed_size = (numel + 4) / 5;
auto packed = torch::zeros({packed_size},
torch::dtype(torch::kUInt8).device(W_ternary.device()));
const int threads = 256;
const int blocks = (packed_size + threads - 1) / threads;
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(W_ternary.scalar_type(), "pack_ternary_cuda", ([&] {
pack_ternary_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
flat.data_ptr<scalar_t>(),
packed.data_ptr<uint8_t>(),
numel,
packed_size
);
}));
return packed;
}
/*
* GPU-accelerated unpacking launcher
*/
torch::Tensor unpack_ternary_cuda(
torch::Tensor packed,
std::vector<int64_t> original_shape,
torch::ScalarType dtype
) {
int64_t numel = 1;
for (auto dim : original_shape) {
numel *= dim;
}
auto packed_flat = packed.flatten().contiguous();
int64_t packed_size = packed_flat.numel();
auto unpacked = torch::zeros({numel},
torch::dtype(dtype).device(packed.device()));
const int threads = 256;
const int blocks = (packed_size + threads - 1) / threads;
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(dtype, "unpack_ternary_cuda", ([&] {
unpack_ternary_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
packed_flat.data_ptr<uint8_t>(),
unpacked.data_ptr<scalar_t>(),
numel,
packed_size
);
}));
return unpacked.view(original_shape);
}
// End of CUDA kernels