|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <torch/extension.h>
|
|
|
#include <c10/cuda/CUDAStream.h>
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
constexpr int TILE_SIZE = 256;
|
|
|
constexpr int WARP_SIZE = 32;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t>
|
|
|
__device__ __forceinline__ scalar_t warp_reduce_sum(scalar_t val) {
|
|
|
#pragma unroll
|
|
|
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
|
|
|
val += __shfl_down_sync(0xffffffff, val, offset);
|
|
|
}
|
|
|
return val;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t>
|
|
|
__device__ scalar_t block_reduce_sum(scalar_t val, scalar_t* shared_mem) {
|
|
|
int lane = threadIdx.x % WARP_SIZE;
|
|
|
int warp_id = threadIdx.x / WARP_SIZE;
|
|
|
|
|
|
|
|
|
val = warp_reduce_sum(val);
|
|
|
|
|
|
|
|
|
if (lane == 0) {
|
|
|
shared_mem[warp_id] = val;
|
|
|
}
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE;
|
|
|
val = (threadIdx.x < num_warps) ? shared_mem[lane] : scalar_t(0);
|
|
|
|
|
|
|
|
|
if (warp_id == 0) {
|
|
|
val = warp_reduce_sum(val);
|
|
|
}
|
|
|
|
|
|
return val;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t>
|
|
|
__global__ void bitlinear_forward_kernel(
|
|
|
const scalar_t* __restrict__ x,
|
|
|
const scalar_t* __restrict__ W_ternary,
|
|
|
const scalar_t* __restrict__ gamma,
|
|
|
const scalar_t* __restrict__ bias,
|
|
|
scalar_t* __restrict__ output,
|
|
|
int batch_size,
|
|
|
int in_features,
|
|
|
int out_features
|
|
|
) {
|
|
|
int batch_idx = blockIdx.x;
|
|
|
int out_idx = blockIdx.y;
|
|
|
int tid = threadIdx.x;
|
|
|
|
|
|
|
|
|
extern __shared__ char shared_mem_raw[];
|
|
|
scalar_t* shared_mem = reinterpret_cast<scalar_t*>(shared_mem_raw);
|
|
|
|
|
|
|
|
|
scalar_t partial_sum = scalar_t(0);
|
|
|
|
|
|
|
|
|
for (int i = tid; i < in_features; i += TILE_SIZE) {
|
|
|
scalar_t x_val = x[batch_idx * in_features + i];
|
|
|
scalar_t w_val = W_ternary[out_idx * in_features + i];
|
|
|
|
|
|
|
|
|
|
|
|
if (w_val > scalar_t(0)) {
|
|
|
partial_sum += x_val;
|
|
|
} else if (w_val < scalar_t(0)) {
|
|
|
partial_sum -= x_val;
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
partial_sum = block_reduce_sum(partial_sum, shared_mem);
|
|
|
|
|
|
|
|
|
if (tid == 0) {
|
|
|
|
|
|
scalar_t result = partial_sum * gamma[out_idx];
|
|
|
|
|
|
|
|
|
if (bias != nullptr) {
|
|
|
result += bias[out_idx];
|
|
|
}
|
|
|
|
|
|
output[batch_idx * out_features + out_idx] = result;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor bitlinear_cuda_forward(
|
|
|
torch::Tensor x,
|
|
|
torch::Tensor W_ternary,
|
|
|
torch::Tensor gamma,
|
|
|
torch::optional<torch::Tensor> bias
|
|
|
) {
|
|
|
|
|
|
auto x_shape = x.sizes().vec();
|
|
|
int64_t batch_size = 1;
|
|
|
for (size_t i = 0; i < x_shape.size() - 1; i++) {
|
|
|
batch_size *= x_shape[i];
|
|
|
}
|
|
|
const int in_features = x.size(-1);
|
|
|
const int out_features = W_ternary.size(0);
|
|
|
|
|
|
|
|
|
auto x_2d = x.view({batch_size, in_features}).contiguous();
|
|
|
|
|
|
|
|
|
auto W_cont = W_ternary.contiguous();
|
|
|
auto gamma_cont = gamma.contiguous();
|
|
|
|
|
|
|
|
|
auto output = torch::zeros({batch_size, out_features}, x.options());
|
|
|
|
|
|
|
|
|
int num_warps = (TILE_SIZE + WARP_SIZE - 1) / WARP_SIZE;
|
|
|
|
|
|
|
|
|
dim3 grid(batch_size, out_features);
|
|
|
dim3 block(TILE_SIZE);
|
|
|
|
|
|
|
|
|
auto stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "bitlinear_forward_cuda", ([&] {
|
|
|
size_t shared_mem_size = num_warps * sizeof(scalar_t);
|
|
|
|
|
|
bitlinear_forward_kernel<scalar_t><<<grid, block, shared_mem_size, stream>>>(
|
|
|
x_2d.data_ptr<scalar_t>(),
|
|
|
W_cont.data_ptr<scalar_t>(),
|
|
|
gamma_cont.data_ptr<scalar_t>(),
|
|
|
bias.has_value() && bias.value().defined()
|
|
|
? bias.value().contiguous().data_ptr<scalar_t>()
|
|
|
: nullptr,
|
|
|
output.data_ptr<scalar_t>(),
|
|
|
batch_size,
|
|
|
in_features,
|
|
|
out_features
|
|
|
);
|
|
|
}));
|
|
|
|
|
|
|
|
|
cudaError_t err = cudaGetLastError();
|
|
|
if (err != cudaSuccess) {
|
|
|
AT_ERROR("BitLinear CUDA kernel failed: ", cudaGetErrorString(err));
|
|
|
}
|
|
|
|
|
|
|
|
|
std::vector<int64_t> out_shape(x_shape.begin(), x_shape.end() - 1);
|
|
|
out_shape.push_back(out_features);
|
|
|
|
|
|
return output.view(out_shape);
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t>
|
|
|
__global__ void multi_ternary_forward_kernel(
|
|
|
const scalar_t* __restrict__ x,
|
|
|
const scalar_t* __restrict__ W_ternary,
|
|
|
const scalar_t* __restrict__ gammas,
|
|
|
const scalar_t* __restrict__ bias,
|
|
|
scalar_t* __restrict__ output,
|
|
|
int batch_size,
|
|
|
int in_features,
|
|
|
int out_features,
|
|
|
int k
|
|
|
) {
|
|
|
int batch_idx = blockIdx.x;
|
|
|
int out_idx = blockIdx.y;
|
|
|
int tid = threadIdx.x;
|
|
|
|
|
|
|
|
|
extern __shared__ char shared_mem_raw[];
|
|
|
scalar_t* shared_mem = reinterpret_cast<scalar_t*>(shared_mem_raw);
|
|
|
|
|
|
|
|
|
scalar_t total_result = scalar_t(0);
|
|
|
|
|
|
|
|
|
int W_out_stride = in_features;
|
|
|
int W_k_stride = out_features * in_features;
|
|
|
|
|
|
|
|
|
for (int comp = 0; comp < k; comp++) {
|
|
|
scalar_t partial_sum = scalar_t(0);
|
|
|
|
|
|
|
|
|
for (int i = tid; i < in_features; i += TILE_SIZE) {
|
|
|
scalar_t x_val = x[batch_idx * in_features + i];
|
|
|
scalar_t w_val = W_ternary[comp * W_k_stride + out_idx * W_out_stride + i];
|
|
|
|
|
|
|
|
|
if (w_val > scalar_t(0)) {
|
|
|
partial_sum += x_val;
|
|
|
} else if (w_val < scalar_t(0)) {
|
|
|
partial_sum -= x_val;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
partial_sum = block_reduce_sum(partial_sum, shared_mem);
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
if (tid == 0) {
|
|
|
scalar_t gamma_val = gammas[comp * out_features + out_idx];
|
|
|
total_result += partial_sum * gamma_val;
|
|
|
}
|
|
|
__syncthreads();
|
|
|
}
|
|
|
|
|
|
|
|
|
if (tid == 0) {
|
|
|
|
|
|
if (bias != nullptr) {
|
|
|
total_result += bias[out_idx];
|
|
|
}
|
|
|
|
|
|
output[batch_idx * out_features + out_idx] = total_result;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor multi_ternary_cuda_forward(
|
|
|
torch::Tensor x,
|
|
|
torch::Tensor W_ternary,
|
|
|
torch::Tensor gammas,
|
|
|
torch::optional<torch::Tensor> bias
|
|
|
) {
|
|
|
|
|
|
auto x_shape = x.sizes().vec();
|
|
|
int64_t batch_size = 1;
|
|
|
for (size_t i = 0; i < x_shape.size() - 1; i++) {
|
|
|
batch_size *= x_shape[i];
|
|
|
}
|
|
|
const int in_features = x.size(-1);
|
|
|
const int k = W_ternary.size(0);
|
|
|
const int out_features = W_ternary.size(1);
|
|
|
|
|
|
|
|
|
auto x_2d = x.view({batch_size, in_features}).contiguous();
|
|
|
|
|
|
|
|
|
auto W_cont = W_ternary.contiguous();
|
|
|
auto gammas_cont = gammas.contiguous();
|
|
|
|
|
|
|
|
|
auto output = torch::zeros({batch_size, out_features}, x.options());
|
|
|
|
|
|
|
|
|
int num_warps = (TILE_SIZE + WARP_SIZE - 1) / WARP_SIZE;
|
|
|
|
|
|
|
|
|
dim3 grid(batch_size, out_features);
|
|
|
dim3 block(TILE_SIZE);
|
|
|
|
|
|
|
|
|
auto stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "multi_ternary_forward_cuda", ([&] {
|
|
|
size_t shared_mem_size = num_warps * sizeof(scalar_t);
|
|
|
|
|
|
multi_ternary_forward_kernel<scalar_t><<<grid, block, shared_mem_size, stream>>>(
|
|
|
x_2d.data_ptr<scalar_t>(),
|
|
|
W_cont.data_ptr<scalar_t>(),
|
|
|
gammas_cont.data_ptr<scalar_t>(),
|
|
|
bias.has_value() && bias.value().defined()
|
|
|
? bias.value().contiguous().data_ptr<scalar_t>()
|
|
|
: nullptr,
|
|
|
output.data_ptr<scalar_t>(),
|
|
|
batch_size,
|
|
|
in_features,
|
|
|
out_features,
|
|
|
k
|
|
|
);
|
|
|
}));
|
|
|
|
|
|
|
|
|
cudaError_t err = cudaGetLastError();
|
|
|
if (err != cudaSuccess) {
|
|
|
AT_ERROR("Multi-ternary CUDA kernel failed: ", cudaGetErrorString(err));
|
|
|
}
|
|
|
|
|
|
|
|
|
std::vector<int64_t> out_shape(x_shape.begin(), x_shape.end() - 1);
|
|
|
out_shape.push_back(out_features);
|
|
|
|
|
|
return output.view(out_shape);
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t>
|
|
|
__global__ void pack_ternary_kernel(
|
|
|
const scalar_t* __restrict__ input,
|
|
|
uint8_t* __restrict__ output,
|
|
|
int64_t numel,
|
|
|
int64_t packed_size
|
|
|
) {
|
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
|
if (idx < packed_size) {
|
|
|
int64_t base_idx = idx * 5;
|
|
|
uint8_t packed_val = 0;
|
|
|
uint8_t powers[5] = {1, 3, 9, 27, 81};
|
|
|
|
|
|
#pragma unroll
|
|
|
for (int j = 0; j < 5; j++) {
|
|
|
int64_t in_idx = base_idx + j;
|
|
|
if (in_idx < numel) {
|
|
|
|
|
|
int8_t val = static_cast<int8_t>(input[in_idx]) + 1;
|
|
|
packed_val += static_cast<uint8_t>(val) * powers[j];
|
|
|
} else {
|
|
|
|
|
|
packed_val += 1 * powers[j];
|
|
|
}
|
|
|
}
|
|
|
output[idx] = packed_val;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t>
|
|
|
__global__ void unpack_ternary_kernel(
|
|
|
const uint8_t* __restrict__ input,
|
|
|
scalar_t* __restrict__ output,
|
|
|
int64_t numel,
|
|
|
int64_t packed_size
|
|
|
) {
|
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
|
if (idx < packed_size) {
|
|
|
int64_t base_idx = idx * 5;
|
|
|
uint8_t packed_val = input[idx];
|
|
|
|
|
|
#pragma unroll
|
|
|
for (int j = 0; j < 5 && base_idx + j < numel; j++) {
|
|
|
uint8_t val = packed_val % 3;
|
|
|
packed_val /= 3;
|
|
|
|
|
|
|
|
|
output[base_idx + j] = static_cast<scalar_t>(val) - scalar_t(1);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor pack_ternary_cuda(torch::Tensor W_ternary) {
|
|
|
auto flat = W_ternary.flatten().contiguous();
|
|
|
int64_t numel = flat.numel();
|
|
|
int64_t packed_size = (numel + 4) / 5;
|
|
|
|
|
|
auto packed = torch::zeros({packed_size},
|
|
|
torch::dtype(torch::kUInt8).device(W_ternary.device()));
|
|
|
|
|
|
const int threads = 256;
|
|
|
const int blocks = (packed_size + threads - 1) / threads;
|
|
|
|
|
|
auto stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
|
|
AT_DISPATCH_FLOATING_TYPES(W_ternary.scalar_type(), "pack_ternary_cuda", ([&] {
|
|
|
pack_ternary_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
|
|
flat.data_ptr<scalar_t>(),
|
|
|
packed.data_ptr<uint8_t>(),
|
|
|
numel,
|
|
|
packed_size
|
|
|
);
|
|
|
}));
|
|
|
|
|
|
return packed;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor unpack_ternary_cuda(
|
|
|
torch::Tensor packed,
|
|
|
std::vector<int64_t> original_shape,
|
|
|
torch::ScalarType dtype
|
|
|
) {
|
|
|
int64_t numel = 1;
|
|
|
for (auto dim : original_shape) {
|
|
|
numel *= dim;
|
|
|
}
|
|
|
|
|
|
auto packed_flat = packed.flatten().contiguous();
|
|
|
int64_t packed_size = packed_flat.numel();
|
|
|
|
|
|
auto unpacked = torch::zeros({numel},
|
|
|
torch::dtype(dtype).device(packed.device()));
|
|
|
|
|
|
const int threads = 256;
|
|
|
const int blocks = (packed_size + threads - 1) / threads;
|
|
|
|
|
|
auto stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
|
|
AT_DISPATCH_FLOATING_TYPES(dtype, "unpack_ternary_cuda", ([&] {
|
|
|
unpack_ternary_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
|
|
packed_flat.data_ptr<uint8_t>(),
|
|
|
unpacked.data_ptr<scalar_t>(),
|
|
|
numel,
|
|
|
packed_size
|
|
|
);
|
|
|
}));
|
|
|
|
|
|
return unpacked.view(original_shape);
|
|
|
}
|
|
|
|
|
|
|
|
|
|