/* * Ternary Neural Network Kernel - AVX-512 optimized * * Weights are stored as two bitplanes per row: * pos_mask: bit=1 where weight = +1 * neg_mask: bit=1 where weight = -1 * (both 0 = weight is 0) * * Matmul becomes: y[i] = sum(x[j] where pos) - sum(x[j] where neg) * No multiplication at all — just masked add/subtract. * * (c) 2026 OpenTransformers Ltd / Scott Bisset */ #include #include #include #include #include #include /* ============================================================ * Core ternary matmul: y = W_ternary @ x * * W stored as bitplanes: pos_bits[out_dim][ceil(in_dim/64)] uint64 * neg_bits[out_dim][ceil(in_dim/64)] uint64 * x: float32[in_dim] * y: float32[out_dim] * bias: float32[out_dim] or NULL * scale: float32 per-row scale factor (to recover magnitude) * ============================================================ */ void ternary_matvec_avx512( const uint64_t *pos_bits, /* [out_dim * chunks] */ const uint64_t *neg_bits, /* [out_dim * chunks] */ const float *scales, /* [out_dim] per-row scale */ const float *x, /* [in_dim] input activations */ float *y, /* [out_dim] output */ int out_dim, int in_dim ) { int chunks = (in_dim + 63) / 64; /* 64 weights per uint64 */ /* Pad input to multiple of 16 floats for AVX-512 */ int in_padded = (in_dim + 15) & ~15; float *x_pad = (float *)aligned_alloc(64, in_padded * sizeof(float)); memcpy(x_pad, x, in_dim * sizeof(float)); memset(x_pad + in_dim, 0, (in_padded - in_dim) * sizeof(float)); for (int i = 0; i < out_dim; i++) { __m512 acc = _mm512_setzero_ps(); const uint64_t *row_pos = pos_bits + (size_t)i * chunks; const uint64_t *row_neg = neg_bits + (size_t)i * chunks; /* Process 64 weights at a time (4 AVX-512 ops of 16 floats each) */ for (int c = 0; c < chunks; c++) { uint64_t pb = row_pos[c]; uint64_t nb = row_neg[c]; int base = c * 64; /* Process in groups of 16 floats */ for (int g = 0; g < 4 && (base + g * 16) < in_padded; g++) { int offset = base + g * 16; __m512 xv = _mm512_load_ps(x_pad + offset); /* Extract 16 bits for this group */ __mmask16 pmask = (__mmask16)((pb >> (g * 16)) & 0xFFFF); __mmask16 nmask = (__mmask16)((nb >> (g * 16)) & 0xFFFF); /* Masked add where weight = +1, masked subtract where weight = -1 */ acc = _mm512_mask_add_ps(acc, pmask, acc, xv); acc = _mm512_mask_sub_ps(acc, nmask, acc, xv); } } /* Horizontal sum */ float sum = _mm512_reduce_add_ps(acc); /* Apply per-row scale to recover magnitude */ y[i] = sum * scales[i]; } free(x_pad); } /* ============================================================ * Batched version: Y = W_ternary @ X (multiple input vectors) * X: [batch, in_dim], Y: [batch, out_dim] * ============================================================ */ void ternary_matmul_avx512( const uint64_t *pos_bits, const uint64_t *neg_bits, const float *scales, const float *X, float *Y, int batch, int out_dim, int in_dim ) { for (int b = 0; b < batch; b++) { ternary_matvec_avx512( pos_bits, neg_bits, scales, X + (size_t)b * in_dim, Y + (size_t)b * out_dim, out_dim, in_dim ); } } /* ============================================================ * RMSNorm: y = x * (1/rms(x)) * weight * ============================================================ */ void rmsnorm_avx512( const float *x, const float *weight, float *y, int dim, float eps ) { /* Compute sum of squares */ __m512 sum_sq = _mm512_setzero_ps(); int i; for (i = 0; i + 16 <= dim; i += 16) { __m512 xv = _mm512_loadu_ps(x + i); sum_sq = _mm512_fmadd_ps(xv, xv, sum_sq); } float ss = _mm512_reduce_add_ps(sum_sq); /* Handle remainder */ for (; i < dim; i++) ss += x[i] * x[i]; float rms = 1.0f / sqrtf(ss / dim + eps); /* Apply norm and weight */ for (i = 0; i + 16 <= dim; i += 16) { __m512 xv = _mm512_loadu_ps(x + i); __m512 wv = _mm512_loadu_ps(weight + i); __m512 rv = _mm512_set1_ps(rms); __m512 out = _mm512_mul_ps(_mm512_mul_ps(xv, rv), wv); _mm512_storeu_ps(y + i, out); } for (; i < dim; i++) y[i] = x[i] * rms * weight[i]; } /* ============================================================ * SiLU activation: x * sigmoid(x) * ============================================================ */ static inline float silu_scalar(float x) { return x / (1.0f + expf(-x)); } void silu_avx512(float *x, int n) { /* Scalar fallback — vectorized exp is complex */ for (int i = 0; i < n; i++) { x[i] = silu_scalar(x[i]); } } /* ============================================================ * Element-wise multiply: y = a * b * ============================================================ */ void elemwise_mul_avx512(const float *a, const float *b, float *y, int n) { int i; for (i = 0; i + 16 <= n; i += 16) { __m512 av = _mm512_loadu_ps(a + i); __m512 bv = _mm512_loadu_ps(b + i); _mm512_storeu_ps(y + i, _mm512_mul_ps(av, bv)); } for (; i < n; i++) y[i] = a[i] * b[i]; } /* ============================================================ * Softmax * ============================================================ */ void softmax(float *x, int n) { float max_val = x[0]; for (int i = 1; i < n; i++) if (x[i] > max_val) max_val = x[i]; float sum = 0; for (int i = 0; i < n; i++) { x[i] = expf(x[i] - max_val); sum += x[i]; } float inv_sum = 1.0f / sum; for (int i = 0; i < n; i++) x[i] *= inv_sum; } /* ============================================================ * RoPE (Rotary Position Embedding) * ============================================================ */ void apply_rope( float *q, /* [n_heads, head_dim] */ float *k, /* [n_kv_heads, head_dim] */ int n_heads, int n_kv_heads, int head_dim, int pos, float rope_theta ) { for (int h = 0; h < n_heads + n_kv_heads; h++) { float *vec = (h < n_heads) ? q + h * head_dim : k + (h - n_heads) * head_dim; for (int i = 0; i < head_dim; i += 2) { float freq = 1.0f / powf(rope_theta, (float)i / head_dim); float angle = pos * freq; float cos_a = cosf(angle); float sin_a = sinf(angle); float v0 = vec[i]; float v1 = vec[i + 1]; vec[i] = v0 * cos_a - v1 * sin_a; vec[i + 1] = v0 * sin_a + v1 * cos_a; } } } /* ============================================================ * Quantization: convert float weights to ternary * Uses per-row threshold: threshold = alpha * mean(|w|) * Returns: pos_bits, neg_bits, scales * ============================================================ */ void quantize_to_ternary( const float *weights, /* [out_dim, in_dim] */ uint64_t *pos_bits, /* [out_dim * chunks] output */ uint64_t *neg_bits, /* [out_dim * chunks] output */ float *scales, /* [out_dim] output */ int out_dim, int in_dim, float alpha /* threshold multiplier, typically 0.7-1.0 */ ) { int chunks = (in_dim + 63) / 64; for (int i = 0; i < out_dim; i++) { const float *row = weights + (size_t)i * in_dim; /* Compute mean absolute value for threshold */ float abs_sum = 0; for (int j = 0; j < in_dim; j++) abs_sum += fabsf(row[j]); float mean_abs = abs_sum / in_dim; float threshold = alpha * mean_abs; /* Compute scale: mean of absolute values of non-zero quantized weights */ float nz_sum = 0; int nz_count = 0; for (int j = 0; j < in_dim; j++) { if (fabsf(row[j]) >= threshold) { nz_sum += fabsf(row[j]); nz_count++; } } scales[i] = (nz_count > 0) ? (nz_sum / nz_count) : 1.0f; /* Quantize to ternary bits */ for (int c = 0; c < chunks; c++) { uint64_t pb = 0, nb = 0; for (int b = 0; b < 64; b++) { int j = c * 64 + b; if (j >= in_dim) break; if (row[j] >= threshold) { pb |= (1ULL << b); } else if (row[j] <= -threshold) { nb |= (1ULL << b); } } pos_bits[(size_t)i * chunks + c] = pb; neg_bits[(size_t)i * chunks + c] = nb; } } }