|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <torch/extension.h>
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifdef WITH_CUDA
|
|
|
torch::Tensor bitlinear_cuda_forward(
|
|
|
torch::Tensor x,
|
|
|
torch::Tensor W_ternary,
|
|
|
torch::Tensor gamma,
|
|
|
torch::optional<torch::Tensor> bias
|
|
|
);
|
|
|
|
|
|
torch::Tensor multi_ternary_cuda_forward(
|
|
|
torch::Tensor x,
|
|
|
torch::Tensor W_ternary,
|
|
|
torch::Tensor gammas,
|
|
|
torch::optional<torch::Tensor> bias
|
|
|
);
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor bitlinear_cpu_forward(
|
|
|
torch::Tensor x,
|
|
|
torch::Tensor W_ternary,
|
|
|
torch::Tensor gamma,
|
|
|
torch::optional<torch::Tensor> bias
|
|
|
) {
|
|
|
|
|
|
auto x_shape = x.sizes().vec();
|
|
|
int64_t batch_size = 1;
|
|
|
for (size_t i = 0; i < x_shape.size() - 1; i++) {
|
|
|
batch_size *= x_shape[i];
|
|
|
}
|
|
|
int64_t in_features = x_shape.back();
|
|
|
int64_t out_features = W_ternary.size(0);
|
|
|
|
|
|
|
|
|
auto x_2d = x.view({batch_size, in_features});
|
|
|
|
|
|
|
|
|
|
|
|
auto output = torch::matmul(x_2d, W_ternary.t());
|
|
|
|
|
|
|
|
|
|
|
|
output = output * gamma.unsqueeze(0);
|
|
|
|
|
|
|
|
|
if (bias.has_value() && bias.value().defined()) {
|
|
|
output = output + bias.value().unsqueeze(0);
|
|
|
}
|
|
|
|
|
|
|
|
|
std::vector<int64_t> out_shape(x_shape.begin(), x_shape.end() - 1);
|
|
|
out_shape.push_back(out_features);
|
|
|
output = output.view(out_shape);
|
|
|
|
|
|
return output;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor multi_ternary_cpu_forward(
|
|
|
torch::Tensor x,
|
|
|
torch::Tensor W_ternary,
|
|
|
torch::Tensor gammas,
|
|
|
torch::optional<torch::Tensor> bias
|
|
|
) {
|
|
|
|
|
|
|
|
|
int64_t k = W_ternary.size(0);
|
|
|
int64_t out_features = W_ternary.size(1);
|
|
|
int64_t in_features = W_ternary.size(2);
|
|
|
|
|
|
|
|
|
auto x_shape = x.sizes().vec();
|
|
|
int64_t batch_size = 1;
|
|
|
for (size_t i = 0; i < x_shape.size() - 1; i++) {
|
|
|
batch_size *= x_shape[i];
|
|
|
}
|
|
|
|
|
|
|
|
|
auto x_2d = x.view({batch_size, in_features});
|
|
|
|
|
|
|
|
|
auto output = torch::zeros({batch_size, out_features}, x.options());
|
|
|
|
|
|
|
|
|
for (int64_t i = 0; i < k; i++) {
|
|
|
|
|
|
auto W_i = W_ternary[i];
|
|
|
auto gamma_i = gammas[i];
|
|
|
|
|
|
|
|
|
auto component = torch::matmul(x_2d, W_i.t());
|
|
|
component = component * gamma_i.unsqueeze(0);
|
|
|
|
|
|
|
|
|
output = output + component;
|
|
|
}
|
|
|
|
|
|
|
|
|
if (bias.has_value() && bias.value().defined()) {
|
|
|
output = output + bias.value().unsqueeze(0);
|
|
|
}
|
|
|
|
|
|
|
|
|
std::vector<int64_t> out_shape(x_shape.begin(), x_shape.end() - 1);
|
|
|
out_shape.push_back(out_features);
|
|
|
output = output.view(out_shape);
|
|
|
|
|
|
return output;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor bitlinear_forward(
|
|
|
torch::Tensor x,
|
|
|
torch::Tensor W_ternary,
|
|
|
torch::Tensor gamma,
|
|
|
torch::optional<torch::Tensor> bias
|
|
|
) {
|
|
|
|
|
|
TORCH_CHECK(x.dim() >= 2, "Input must have at least 2 dimensions");
|
|
|
TORCH_CHECK(W_ternary.dim() == 2, "W_ternary must be 2D");
|
|
|
TORCH_CHECK(gamma.dim() == 1 || gamma.dim() == 2, "gamma must be 1D or 2D");
|
|
|
|
|
|
|
|
|
if (x.is_cuda()) {
|
|
|
#ifdef WITH_CUDA
|
|
|
return bitlinear_cuda_forward(x, W_ternary, gamma, bias);
|
|
|
#else
|
|
|
AT_ERROR("BitLinear CUDA kernels not compiled. Rebuild with CUDA support.");
|
|
|
#endif
|
|
|
} else {
|
|
|
return bitlinear_cpu_forward(x, W_ternary, gamma, bias);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor multi_ternary_forward(
|
|
|
torch::Tensor x,
|
|
|
torch::Tensor W_ternary,
|
|
|
torch::Tensor gammas,
|
|
|
torch::optional<torch::Tensor> bias
|
|
|
) {
|
|
|
|
|
|
TORCH_CHECK(x.dim() >= 2, "Input must have at least 2 dimensions");
|
|
|
TORCH_CHECK(W_ternary.dim() == 3, "W_ternary must be 3D [k, out_features, in_features]");
|
|
|
TORCH_CHECK(gammas.dim() == 2, "gammas must be 2D [k, out_features]");
|
|
|
|
|
|
|
|
|
if (x.is_cuda()) {
|
|
|
#ifdef WITH_CUDA
|
|
|
return multi_ternary_cuda_forward(x, W_ternary, gammas, bias);
|
|
|
#else
|
|
|
AT_ERROR("Multi-ternary CUDA kernels not compiled. Rebuild with CUDA support.");
|
|
|
#endif
|
|
|
} else {
|
|
|
return multi_ternary_cpu_forward(x, W_ternary, gammas, bias);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor pack_ternary_base3_cpp(torch::Tensor W_ternary) {
|
|
|
|
|
|
auto flat = W_ternary.flatten().to(torch::kCPU).to(torch::kInt8);
|
|
|
int64_t numel = flat.numel();
|
|
|
|
|
|
|
|
|
auto mapped = (flat + 1).to(torch::kUInt8);
|
|
|
|
|
|
|
|
|
int64_t packed_size = (numel + 4) / 5;
|
|
|
auto packed = torch::zeros({packed_size}, torch::dtype(torch::kUInt8).device(torch::kCPU));
|
|
|
|
|
|
|
|
|
auto mapped_ptr = mapped.data_ptr<uint8_t>();
|
|
|
auto packed_ptr = packed.data_ptr<uint8_t>();
|
|
|
|
|
|
|
|
|
const uint8_t powers[5] = {1, 3, 9, 27, 81};
|
|
|
|
|
|
|
|
|
for (int64_t i = 0; i < packed_size; i++) {
|
|
|
int64_t base_idx = i * 5;
|
|
|
uint8_t packed_val = 0;
|
|
|
|
|
|
for (int j = 0; j < 5; j++) {
|
|
|
int64_t idx = base_idx + j;
|
|
|
if (idx < numel) {
|
|
|
packed_val += mapped_ptr[idx] * powers[j];
|
|
|
} else {
|
|
|
|
|
|
packed_val += 1 * powers[j];
|
|
|
}
|
|
|
}
|
|
|
packed_ptr[i] = packed_val;
|
|
|
}
|
|
|
|
|
|
return packed;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor unpack_ternary_base3_cpp(
|
|
|
torch::Tensor packed,
|
|
|
std::vector<int64_t> original_shape
|
|
|
) {
|
|
|
|
|
|
int64_t numel = 1;
|
|
|
for (auto dim : original_shape) {
|
|
|
numel *= dim;
|
|
|
}
|
|
|
|
|
|
|
|
|
auto packed_flat = packed.flatten().to(torch::kCPU).to(torch::kUInt8);
|
|
|
int64_t packed_size = packed_flat.numel();
|
|
|
|
|
|
|
|
|
auto unpacked = torch::zeros({numel}, torch::dtype(torch::kInt8).device(torch::kCPU));
|
|
|
|
|
|
|
|
|
auto packed_ptr = packed_flat.data_ptr<uint8_t>();
|
|
|
auto unpacked_ptr = unpacked.data_ptr<int8_t>();
|
|
|
|
|
|
|
|
|
int64_t out_idx = 0;
|
|
|
for (int64_t i = 0; i < packed_size && out_idx < numel; i++) {
|
|
|
uint8_t packed_val = packed_ptr[i];
|
|
|
|
|
|
|
|
|
for (int j = 0; j < 5 && out_idx < numel; j++) {
|
|
|
uint8_t val = packed_val % 3;
|
|
|
packed_val /= 3;
|
|
|
|
|
|
|
|
|
unpacked_ptr[out_idx] = static_cast<int8_t>(val) - 1;
|
|
|
out_idx++;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
return unpacked.view(original_shape).to(torch::kFloat32);
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
|
m.def("forward", &bitlinear_forward, "BitLinear forward (CPU/CUDA)",
|
|
|
py::arg("x"),
|
|
|
py::arg("W_ternary"),
|
|
|
py::arg("gamma"),
|
|
|
py::arg("bias") = py::none());
|
|
|
|
|
|
m.def("multi_ternary_forward", &multi_ternary_forward,
|
|
|
"Multi-ternary linear forward (CPU/CUDA)",
|
|
|
py::arg("x"),
|
|
|
py::arg("W_ternary"),
|
|
|
py::arg("gammas"),
|
|
|
py::arg("bias") = py::none());
|
|
|
|
|
|
m.def("pack_ternary_base3", &pack_ternary_base3_cpp,
|
|
|
"Pack ternary weights to base-3 (CPU)",
|
|
|
py::arg("W_ternary"));
|
|
|
|
|
|
m.def("unpack_ternary_base3", &unpack_ternary_base3_cpp,
|
|
|
"Unpack base-3 ternary weights (CPU)",
|
|
|
py::arg("packed"),
|
|
|
py::arg("original_shape"));
|
|
|
}
|
|
|
|