/* BitLinear CUDA Kernels This file contains CUDA kernel implementations for BitLinear operations. The kernels optimize ternary matrix multiplication for GPU execution. Key optimizations implemented: 1. Ternary weight specialization (only -1, 0, +1) 2. Shared memory tiling for reduced global memory access 3. Warp-level reduction using shuffle intrinsics 4. Memory coalescing for efficient global reads 5. Thread coarsening for better instruction-level parallelism */ #include #include #include #include // Tile size for shared memory - tuned for occupancy and cache utilization constexpr int TILE_SIZE = 256; constexpr int WARP_SIZE = 32; /* * Warp-level reduction using shuffle intrinsics * Reduces a value across all threads in a warp */ template __device__ __forceinline__ scalar_t warp_reduce_sum(scalar_t val) { #pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { val += __shfl_down_sync(0xffffffff, val, offset); } return val; } /* * Block-level reduction using shared memory * Reduces partial sums from each warp to a single value */ template __device__ scalar_t block_reduce_sum(scalar_t val, scalar_t* shared_mem) { int lane = threadIdx.x % WARP_SIZE; int warp_id = threadIdx.x / WARP_SIZE; // First reduce within warp val = warp_reduce_sum(val); // Write reduced warp value to shared memory if (lane == 0) { shared_mem[warp_id] = val; } __syncthreads(); // Read from shared memory only if this thread is in the first warp int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; val = (threadIdx.x < num_warps) ? shared_mem[lane] : scalar_t(0); // Final reduce within first warp if (warp_id == 0) { val = warp_reduce_sum(val); } return val; } /* * CUDA kernel for BitLinear forward pass * * Computes: output[batch, out] = sum_in (x[batch, in] * W[out, in]) * gamma[out] * * This is a specialized matrix multiplication kernel that exploits: * - Ternary weights: only need additions/subtractions (no multiplications) * - Shared memory tiling for reduced memory bandwidth * - Warp shuffle for efficient reductions * * Grid/Block configuration: * - Grid: (batch_size, out_features) * - Block: TILE_SIZE threads * - Each block computes one output element */ template __global__ void bitlinear_forward_kernel( const scalar_t* __restrict__ x, // [batch_size, in_features] const scalar_t* __restrict__ W_ternary, // [out_features, in_features] const scalar_t* __restrict__ gamma, // [out_features] const scalar_t* __restrict__ bias, // [out_features] or nullptr scalar_t* __restrict__ output, // [batch_size, out_features] int batch_size, int in_features, int out_features ) { int batch_idx = blockIdx.x; int out_idx = blockIdx.y; int tid = threadIdx.x; // Shared memory for partial sums reduction extern __shared__ char shared_mem_raw[]; scalar_t* shared_mem = reinterpret_cast(shared_mem_raw); // Each thread computes partial dot product scalar_t partial_sum = scalar_t(0); // Coalesced access: each thread handles multiple elements strided by TILE_SIZE for (int i = tid; i < in_features; i += TILE_SIZE) { scalar_t x_val = x[batch_idx * in_features + i]; scalar_t w_val = W_ternary[out_idx * in_features + i]; // Exploit ternary structure: conditional accumulation (no multiply) // This is faster than general multiply when weights are truly ternary if (w_val > scalar_t(0)) { partial_sum += x_val; } else if (w_val < scalar_t(0)) { partial_sum -= x_val; } // w_val == 0: skip (implicit in else) } // Reduce partial sums across block partial_sum = block_reduce_sum(partial_sum, shared_mem); // Thread 0 writes the final result if (tid == 0) { // Apply gamma scaling scalar_t result = partial_sum * gamma[out_idx]; // Add bias if present if (bias != nullptr) { result += bias[out_idx]; } output[batch_idx * out_features + out_idx] = result; } } /* * CUDA kernel launcher for BitLinear forward * * This function: * 1. Handles multi-dimensional input by flattening * 2. Sets up grid and block dimensions * 3. Launches the CUDA kernel with dynamic shared memory * 4. Reshapes output to match input batch dimensions */ torch::Tensor bitlinear_cuda_forward( torch::Tensor x, torch::Tensor W_ternary, torch::Tensor gamma, torch::optional bias ) { // Handle multi-dimensional input auto x_shape = x.sizes().vec(); int64_t batch_size = 1; for (size_t i = 0; i < x_shape.size() - 1; i++) { batch_size *= x_shape[i]; } const int in_features = x.size(-1); const int out_features = W_ternary.size(0); // Flatten input to 2D for kernel auto x_2d = x.view({batch_size, in_features}).contiguous(); // Ensure all tensors are contiguous for efficient memory access auto W_cont = W_ternary.contiguous(); auto gamma_cont = gamma.contiguous(); // Allocate output auto output = torch::zeros({batch_size, out_features}, x.options()); // Calculate shared memory size for reduction int num_warps = (TILE_SIZE + WARP_SIZE - 1) / WARP_SIZE; // Grid: one block per (batch, output feature) pair dim3 grid(batch_size, out_features); dim3 block(TILE_SIZE); // Get current CUDA stream auto stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "bitlinear_forward_cuda", ([&] { size_t shared_mem_size = num_warps * sizeof(scalar_t); bitlinear_forward_kernel<<>>( x_2d.data_ptr(), W_cont.data_ptr(), gamma_cont.data_ptr(), bias.has_value() && bias.value().defined() ? bias.value().contiguous().data_ptr() : nullptr, output.data_ptr(), batch_size, in_features, out_features ); })); // Check for CUDA errors cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { AT_ERROR("BitLinear CUDA kernel failed: ", cudaGetErrorString(err)); } // Reshape output to match input batch dimensions std::vector out_shape(x_shape.begin(), x_shape.end() - 1); out_shape.push_back(out_features); return output.view(out_shape); } /* * CUDA kernel for multi-ternary forward pass * * Computes: output = sum_{i=1}^k [(x @ W_i^T) * gamma_i] + bias * * This kernel fuses k ternary matrix multiplications into a single kernel * to reduce memory bandwidth requirements. Each block handles one * (batch, output) pair and accumulates contributions from all k components. * * Grid/Block configuration: * - Grid: (batch_size, out_features) * - Block: TILE_SIZE threads */ template __global__ void multi_ternary_forward_kernel( const scalar_t* __restrict__ x, // [batch_size, in_features] const scalar_t* __restrict__ W_ternary, // [k, out_features, in_features] const scalar_t* __restrict__ gammas, // [k, out_features] const scalar_t* __restrict__ bias, // [out_features] or nullptr scalar_t* __restrict__ output, // [batch_size, out_features] int batch_size, int in_features, int out_features, int k ) { int batch_idx = blockIdx.x; int out_idx = blockIdx.y; int tid = threadIdx.x; // Shared memory for reduction extern __shared__ char shared_mem_raw[]; scalar_t* shared_mem = reinterpret_cast(shared_mem_raw); // Accumulate total result across all k components scalar_t total_result = scalar_t(0); // Stride for indexing into W_ternary: [k, out_features, in_features] int W_out_stride = in_features; int W_k_stride = out_features * in_features; // Process each of the k components for (int comp = 0; comp < k; comp++) { scalar_t partial_sum = scalar_t(0); // Compute dot product for this component for (int i = tid; i < in_features; i += TILE_SIZE) { scalar_t x_val = x[batch_idx * in_features + i]; scalar_t w_val = W_ternary[comp * W_k_stride + out_idx * W_out_stride + i]; // Ternary conditional accumulation if (w_val > scalar_t(0)) { partial_sum += x_val; } else if (w_val < scalar_t(0)) { partial_sum -= x_val; } } // Reduce partial sums across block partial_sum = block_reduce_sum(partial_sum, shared_mem); __syncthreads(); // Thread 0 accumulates with gamma scaling if (tid == 0) { scalar_t gamma_val = gammas[comp * out_features + out_idx]; total_result += partial_sum * gamma_val; } __syncthreads(); } // Thread 0 writes the final result if (tid == 0) { // Add bias if present if (bias != nullptr) { total_result += bias[out_idx]; } output[batch_idx * out_features + out_idx] = total_result; } } /* * Launcher for multi-ternary CUDA kernel * * This function: * 1. Handles multi-dimensional input by flattening * 2. Sets up grid and block dimensions * 3. Launches the fused multi-ternary kernel * 4. Reshapes output to match input batch dimensions */ torch::Tensor multi_ternary_cuda_forward( torch::Tensor x, torch::Tensor W_ternary, torch::Tensor gammas, torch::optional bias ) { // Handle multi-dimensional input auto x_shape = x.sizes().vec(); int64_t batch_size = 1; for (size_t i = 0; i < x_shape.size() - 1; i++) { batch_size *= x_shape[i]; } const int in_features = x.size(-1); const int k = W_ternary.size(0); const int out_features = W_ternary.size(1); // Flatten input to 2D auto x_2d = x.view({batch_size, in_features}).contiguous(); // Ensure tensors are contiguous auto W_cont = W_ternary.contiguous(); auto gammas_cont = gammas.contiguous(); // Allocate output auto output = torch::zeros({batch_size, out_features}, x.options()); // Calculate shared memory size int num_warps = (TILE_SIZE + WARP_SIZE - 1) / WARP_SIZE; // Grid configuration dim3 grid(batch_size, out_features); dim3 block(TILE_SIZE); // Get current CUDA stream auto stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "multi_ternary_forward_cuda", ([&] { size_t shared_mem_size = num_warps * sizeof(scalar_t); multi_ternary_forward_kernel<<>>( x_2d.data_ptr(), W_cont.data_ptr(), gammas_cont.data_ptr(), bias.has_value() && bias.value().defined() ? bias.value().contiguous().data_ptr() : nullptr, output.data_ptr(), batch_size, in_features, out_features, k ); })); // Check for CUDA errors cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { AT_ERROR("Multi-ternary CUDA kernel failed: ", cudaGetErrorString(err)); } // Reshape output std::vector out_shape(x_shape.begin(), x_shape.end() - 1); out_shape.push_back(out_features); return output.view(out_shape); } /* * Advanced optimization: Ternary matrix multiplication using Tensor Cores * * Modern GPUs (Volta+) have Tensor Cores that accelerate matrix operations. * While designed for FP16/INT8, we can potentially leverage them for ternary * operations by packing ternary values into INT4/INT8 formats. * * This is a future optimization once basic kernels are working. * * Potential approaches: * 1. Pack ternary values into INT8 and use INT8 Tensor Cores * 2. Use FP16 with ternary values for FP16 Tensor Cores * 3. Custom WMMA (Warp Matrix Multiply Accumulate) implementation */ /* * CUDA kernel for packing ternary weights to base-3 representation * * Maps {-1, 0, +1} to {0, 1, 2} and packs 5 values per byte. * Each thread handles multiple output bytes for efficiency. */ template __global__ void pack_ternary_kernel( const scalar_t* __restrict__ input, // Flat ternary weights uint8_t* __restrict__ output, // Packed output int64_t numel, // Number of input elements int64_t packed_size // Number of output bytes ) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < packed_size) { int64_t base_idx = idx * 5; uint8_t packed_val = 0; uint8_t powers[5] = {1, 3, 9, 27, 81}; #pragma unroll for (int j = 0; j < 5; j++) { int64_t in_idx = base_idx + j; if (in_idx < numel) { // Map {-1, 0, +1} to {0, 1, 2} int8_t val = static_cast(input[in_idx]) + 1; packed_val += static_cast(val) * powers[j]; } else { // Pad with 1 (representing 0) packed_val += 1 * powers[j]; } } output[idx] = packed_val; } } /* * CUDA kernel for unpacking base-3 ternary weights * * Extracts 5 values per byte and maps {0, 1, 2} back to {-1, 0, +1}. */ template __global__ void unpack_ternary_kernel( const uint8_t* __restrict__ input, // Packed input scalar_t* __restrict__ output, // Unpacked output int64_t numel, // Number of output elements int64_t packed_size // Number of input bytes ) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < packed_size) { int64_t base_idx = idx * 5; uint8_t packed_val = input[idx]; #pragma unroll for (int j = 0; j < 5 && base_idx + j < numel; j++) { uint8_t val = packed_val % 3; packed_val /= 3; // Map {0, 1, 2} to {-1, 0, +1} output[base_idx + j] = static_cast(val) - scalar_t(1); } } } /* * GPU-accelerated packing launcher */ torch::Tensor pack_ternary_cuda(torch::Tensor W_ternary) { auto flat = W_ternary.flatten().contiguous(); int64_t numel = flat.numel(); int64_t packed_size = (numel + 4) / 5; auto packed = torch::zeros({packed_size}, torch::dtype(torch::kUInt8).device(W_ternary.device())); const int threads = 256; const int blocks = (packed_size + threads - 1) / threads; auto stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES(W_ternary.scalar_type(), "pack_ternary_cuda", ([&] { pack_ternary_kernel<<>>( flat.data_ptr(), packed.data_ptr(), numel, packed_size ); })); return packed; } /* * GPU-accelerated unpacking launcher */ torch::Tensor unpack_ternary_cuda( torch::Tensor packed, std::vector original_shape, torch::ScalarType dtype ) { int64_t numel = 1; for (auto dim : original_shape) { numel *= dim; } auto packed_flat = packed.flatten().contiguous(); int64_t packed_size = packed_flat.numel(); auto unpacked = torch::zeros({numel}, torch::dtype(dtype).device(packed.device())); const int threads = 256; const int blocks = (packed_size + threads - 1) / threads; auto stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES(dtype, "unpack_ternary_cuda", ([&] { unpack_ternary_kernel<<>>( packed_flat.data_ptr(), unpacked.data_ptr(), numel, packed_size ); })); return unpacked.view(original_shape); } // End of CUDA kernels