unary-quantization-research / ternary_kernel.c
OpenTransformer's picture
Add files using upload-large-folder tool
19ed98b verified
/*
* Ternary Neural Network Kernel - AVX-512 optimized
*
* Weights are stored as two bitplanes per row:
* pos_mask: bit=1 where weight = +1
* neg_mask: bit=1 where weight = -1
* (both 0 = weight is 0)
*
* Matmul becomes: y[i] = sum(x[j] where pos) - sum(x[j] where neg)
* No multiplication at all — just masked add/subtract.
*
* (c) 2026 OpenTransformers Ltd / Scott Bisset
*/
#include <immintrin.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <stdio.h>
/* ============================================================
* Core ternary matmul: y = W_ternary @ x
*
* W stored as bitplanes: pos_bits[out_dim][ceil(in_dim/64)] uint64
* neg_bits[out_dim][ceil(in_dim/64)] uint64
* x: float32[in_dim]
* y: float32[out_dim]
* bias: float32[out_dim] or NULL
* scale: float32 per-row scale factor (to recover magnitude)
* ============================================================ */
void ternary_matvec_avx512(
const uint64_t *pos_bits, /* [out_dim * chunks] */
const uint64_t *neg_bits, /* [out_dim * chunks] */
const float *scales, /* [out_dim] per-row scale */
const float *x, /* [in_dim] input activations */
float *y, /* [out_dim] output */
int out_dim,
int in_dim
) {
int chunks = (in_dim + 63) / 64; /* 64 weights per uint64 */
/* Pad input to multiple of 16 floats for AVX-512 */
int in_padded = (in_dim + 15) & ~15;
float *x_pad = (float *)aligned_alloc(64, in_padded * sizeof(float));
memcpy(x_pad, x, in_dim * sizeof(float));
memset(x_pad + in_dim, 0, (in_padded - in_dim) * sizeof(float));
for (int i = 0; i < out_dim; i++) {
__m512 acc = _mm512_setzero_ps();
const uint64_t *row_pos = pos_bits + (size_t)i * chunks;
const uint64_t *row_neg = neg_bits + (size_t)i * chunks;
/* Process 64 weights at a time (4 AVX-512 ops of 16 floats each) */
for (int c = 0; c < chunks; c++) {
uint64_t pb = row_pos[c];
uint64_t nb = row_neg[c];
int base = c * 64;
/* Process in groups of 16 floats */
for (int g = 0; g < 4 && (base + g * 16) < in_padded; g++) {
int offset = base + g * 16;
__m512 xv = _mm512_load_ps(x_pad + offset);
/* Extract 16 bits for this group */
__mmask16 pmask = (__mmask16)((pb >> (g * 16)) & 0xFFFF);
__mmask16 nmask = (__mmask16)((nb >> (g * 16)) & 0xFFFF);
/* Masked add where weight = +1, masked subtract where weight = -1 */
acc = _mm512_mask_add_ps(acc, pmask, acc, xv);
acc = _mm512_mask_sub_ps(acc, nmask, acc, xv);
}
}
/* Horizontal sum */
float sum = _mm512_reduce_add_ps(acc);
/* Apply per-row scale to recover magnitude */
y[i] = sum * scales[i];
}
free(x_pad);
}
/* ============================================================
* Batched version: Y = W_ternary @ X (multiple input vectors)
* X: [batch, in_dim], Y: [batch, out_dim]
* ============================================================ */
void ternary_matmul_avx512(
const uint64_t *pos_bits,
const uint64_t *neg_bits,
const float *scales,
const float *X,
float *Y,
int batch,
int out_dim,
int in_dim
) {
for (int b = 0; b < batch; b++) {
ternary_matvec_avx512(
pos_bits, neg_bits, scales,
X + (size_t)b * in_dim,
Y + (size_t)b * out_dim,
out_dim, in_dim
);
}
}
/* ============================================================
* RMSNorm: y = x * (1/rms(x)) * weight
* ============================================================ */
void rmsnorm_avx512(
const float *x,
const float *weight,
float *y,
int dim,
float eps
) {
/* Compute sum of squares */
__m512 sum_sq = _mm512_setzero_ps();
int i;
for (i = 0; i + 16 <= dim; i += 16) {
__m512 xv = _mm512_loadu_ps(x + i);
sum_sq = _mm512_fmadd_ps(xv, xv, sum_sq);
}
float ss = _mm512_reduce_add_ps(sum_sq);
/* Handle remainder */
for (; i < dim; i++) ss += x[i] * x[i];
float rms = 1.0f / sqrtf(ss / dim + eps);
/* Apply norm and weight */
for (i = 0; i + 16 <= dim; i += 16) {
__m512 xv = _mm512_loadu_ps(x + i);
__m512 wv = _mm512_loadu_ps(weight + i);
__m512 rv = _mm512_set1_ps(rms);
__m512 out = _mm512_mul_ps(_mm512_mul_ps(xv, rv), wv);
_mm512_storeu_ps(y + i, out);
}
for (; i < dim; i++) y[i] = x[i] * rms * weight[i];
}
/* ============================================================
* SiLU activation: x * sigmoid(x)
* ============================================================ */
static inline float silu_scalar(float x) {
return x / (1.0f + expf(-x));
}
void silu_avx512(float *x, int n) {
/* Scalar fallback — vectorized exp is complex */
for (int i = 0; i < n; i++) {
x[i] = silu_scalar(x[i]);
}
}
/* ============================================================
* Element-wise multiply: y = a * b
* ============================================================ */
void elemwise_mul_avx512(const float *a, const float *b, float *y, int n) {
int i;
for (i = 0; i + 16 <= n; i += 16) {
__m512 av = _mm512_loadu_ps(a + i);
__m512 bv = _mm512_loadu_ps(b + i);
_mm512_storeu_ps(y + i, _mm512_mul_ps(av, bv));
}
for (; i < n; i++) y[i] = a[i] * b[i];
}
/* ============================================================
* Softmax
* ============================================================ */
void softmax(float *x, int n) {
float max_val = x[0];
for (int i = 1; i < n; i++) if (x[i] > max_val) max_val = x[i];
float sum = 0;
for (int i = 0; i < n; i++) {
x[i] = expf(x[i] - max_val);
sum += x[i];
}
float inv_sum = 1.0f / sum;
for (int i = 0; i < n; i++) x[i] *= inv_sum;
}
/* ============================================================
* RoPE (Rotary Position Embedding)
* ============================================================ */
void apply_rope(
float *q, /* [n_heads, head_dim] */
float *k, /* [n_kv_heads, head_dim] */
int n_heads,
int n_kv_heads,
int head_dim,
int pos,
float rope_theta
) {
for (int h = 0; h < n_heads + n_kv_heads; h++) {
float *vec = (h < n_heads) ? q + h * head_dim : k + (h - n_heads) * head_dim;
for (int i = 0; i < head_dim; i += 2) {
float freq = 1.0f / powf(rope_theta, (float)i / head_dim);
float angle = pos * freq;
float cos_a = cosf(angle);
float sin_a = sinf(angle);
float v0 = vec[i];
float v1 = vec[i + 1];
vec[i] = v0 * cos_a - v1 * sin_a;
vec[i + 1] = v0 * sin_a + v1 * cos_a;
}
}
}
/* ============================================================
* Quantization: convert float weights to ternary
* Uses per-row threshold: threshold = alpha * mean(|w|)
* Returns: pos_bits, neg_bits, scales
* ============================================================ */
void quantize_to_ternary(
const float *weights, /* [out_dim, in_dim] */
uint64_t *pos_bits, /* [out_dim * chunks] output */
uint64_t *neg_bits, /* [out_dim * chunks] output */
float *scales, /* [out_dim] output */
int out_dim,
int in_dim,
float alpha /* threshold multiplier, typically 0.7-1.0 */
) {
int chunks = (in_dim + 63) / 64;
for (int i = 0; i < out_dim; i++) {
const float *row = weights + (size_t)i * in_dim;
/* Compute mean absolute value for threshold */
float abs_sum = 0;
for (int j = 0; j < in_dim; j++) abs_sum += fabsf(row[j]);
float mean_abs = abs_sum / in_dim;
float threshold = alpha * mean_abs;
/* Compute scale: mean of absolute values of non-zero quantized weights */
float nz_sum = 0;
int nz_count = 0;
for (int j = 0; j < in_dim; j++) {
if (fabsf(row[j]) >= threshold) {
nz_sum += fabsf(row[j]);
nz_count++;
}
}
scales[i] = (nz_count > 0) ? (nz_sum / nz_count) : 1.0f;
/* Quantize to ternary bits */
for (int c = 0; c < chunks; c++) {
uint64_t pb = 0, nb = 0;
for (int b = 0; b < 64; b++) {
int j = c * 64 + b;
if (j >= in_dim) break;
if (row[j] >= threshold) {
pb |= (1ULL << b);
} else if (row[j] <= -threshold) {
nb |= (1ULL << b);
}
}
pos_bits[(size_t)i * chunks + c] = pb;
neg_bits[(size_t)i * chunks + c] = nb;
}
}
}