File size: 9,214 Bytes
19ed98b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
/*
 * Ternary Neural Network Kernel - AVX-512 optimized
 * 
 * Weights are stored as two bitplanes per row:
 *   pos_mask: bit=1 where weight = +1
 *   neg_mask: bit=1 where weight = -1
 *   (both 0 = weight is 0)
 *
 * Matmul becomes: y[i] = sum(x[j] where pos) - sum(x[j] where neg)
 * No multiplication at all — just masked add/subtract.
 *
 * (c) 2026 OpenTransformers Ltd / Scott Bisset
 */

#include <immintrin.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <stdio.h>

/* ============================================================
 * Core ternary matmul: y = W_ternary @ x
 * 
 * W stored as bitplanes: pos_bits[out_dim][ceil(in_dim/64)] uint64
 *                        neg_bits[out_dim][ceil(in_dim/64)] uint64
 * x: float32[in_dim]
 * y: float32[out_dim]
 * bias: float32[out_dim] or NULL
 * scale: float32 per-row scale factor (to recover magnitude)
 * ============================================================ */
void ternary_matvec_avx512(
    const uint64_t *pos_bits,   /* [out_dim * chunks] */
    const uint64_t *neg_bits,   /* [out_dim * chunks] */
    const float    *scales,     /* [out_dim] per-row scale */
    const float    *x,          /* [in_dim] input activations */
    float          *y,          /* [out_dim] output */
    int             out_dim,
    int             in_dim
) {
    int chunks = (in_dim + 63) / 64;  /* 64 weights per uint64 */
    
    /* Pad input to multiple of 16 floats for AVX-512 */
    int in_padded = (in_dim + 15) & ~15;
    float *x_pad = (float *)aligned_alloc(64, in_padded * sizeof(float));
    memcpy(x_pad, x, in_dim * sizeof(float));
    memset(x_pad + in_dim, 0, (in_padded - in_dim) * sizeof(float));
    
    for (int i = 0; i < out_dim; i++) {
        __m512 acc = _mm512_setzero_ps();
        
        const uint64_t *row_pos = pos_bits + (size_t)i * chunks;
        const uint64_t *row_neg = neg_bits + (size_t)i * chunks;
        
        /* Process 64 weights at a time (4 AVX-512 ops of 16 floats each) */
        for (int c = 0; c < chunks; c++) {
            uint64_t pb = row_pos[c];
            uint64_t nb = row_neg[c];
            int base = c * 64;
            
            /* Process in groups of 16 floats */
            for (int g = 0; g < 4 && (base + g * 16) < in_padded; g++) {
                int offset = base + g * 16;
                __m512 xv = _mm512_load_ps(x_pad + offset);
                
                /* Extract 16 bits for this group */
                __mmask16 pmask = (__mmask16)((pb >> (g * 16)) & 0xFFFF);
                __mmask16 nmask = (__mmask16)((nb >> (g * 16)) & 0xFFFF);
                
                /* Masked add where weight = +1, masked subtract where weight = -1 */
                acc = _mm512_mask_add_ps(acc, pmask, acc, xv);
                acc = _mm512_mask_sub_ps(acc, nmask, acc, xv);
            }
        }
        
        /* Horizontal sum */
        float sum = _mm512_reduce_add_ps(acc);
        
        /* Apply per-row scale to recover magnitude */
        y[i] = sum * scales[i];
    }
    
    free(x_pad);
}

/* ============================================================
 * Batched version: Y = W_ternary @ X  (multiple input vectors)
 * X: [batch, in_dim], Y: [batch, out_dim]
 * ============================================================ */
void ternary_matmul_avx512(
    const uint64_t *pos_bits,
    const uint64_t *neg_bits,
    const float    *scales,
    const float    *X,
    float          *Y,
    int             batch,
    int             out_dim,
    int             in_dim
) {
    for (int b = 0; b < batch; b++) {
        ternary_matvec_avx512(
            pos_bits, neg_bits, scales,
            X + (size_t)b * in_dim,
            Y + (size_t)b * out_dim,
            out_dim, in_dim
        );
    }
}

/* ============================================================
 * RMSNorm: y = x * (1/rms(x)) * weight
 * ============================================================ */
void rmsnorm_avx512(
    const float *x,
    const float *weight,
    float       *y,
    int          dim,
    float        eps
) {
    /* Compute sum of squares */
    __m512 sum_sq = _mm512_setzero_ps();
    int i;
    for (i = 0; i + 16 <= dim; i += 16) {
        __m512 xv = _mm512_loadu_ps(x + i);
        sum_sq = _mm512_fmadd_ps(xv, xv, sum_sq);
    }
    float ss = _mm512_reduce_add_ps(sum_sq);
    /* Handle remainder */
    for (; i < dim; i++) ss += x[i] * x[i];
    
    float rms = 1.0f / sqrtf(ss / dim + eps);
    
    /* Apply norm and weight */
    for (i = 0; i + 16 <= dim; i += 16) {
        __m512 xv = _mm512_loadu_ps(x + i);
        __m512 wv = _mm512_loadu_ps(weight + i);
        __m512 rv = _mm512_set1_ps(rms);
        __m512 out = _mm512_mul_ps(_mm512_mul_ps(xv, rv), wv);
        _mm512_storeu_ps(y + i, out);
    }
    for (; i < dim; i++) y[i] = x[i] * rms * weight[i];
}

/* ============================================================
 * SiLU activation: x * sigmoid(x)
 * ============================================================ */
static inline float silu_scalar(float x) {
    return x / (1.0f + expf(-x));
}

void silu_avx512(float *x, int n) {
    /* Scalar fallback — vectorized exp is complex */
    for (int i = 0; i < n; i++) {
        x[i] = silu_scalar(x[i]);
    }
}

/* ============================================================
 * Element-wise multiply: y = a * b
 * ============================================================ */
void elemwise_mul_avx512(const float *a, const float *b, float *y, int n) {
    int i;
    for (i = 0; i + 16 <= n; i += 16) {
        __m512 av = _mm512_loadu_ps(a + i);
        __m512 bv = _mm512_loadu_ps(b + i);
        _mm512_storeu_ps(y + i, _mm512_mul_ps(av, bv));
    }
    for (; i < n; i++) y[i] = a[i] * b[i];
}

/* ============================================================
 * Softmax
 * ============================================================ */
void softmax(float *x, int n) {
    float max_val = x[0];
    for (int i = 1; i < n; i++) if (x[i] > max_val) max_val = x[i];
    float sum = 0;
    for (int i = 0; i < n; i++) {
        x[i] = expf(x[i] - max_val);
        sum += x[i];
    }
    float inv_sum = 1.0f / sum;
    for (int i = 0; i < n; i++) x[i] *= inv_sum;
}

/* ============================================================
 * RoPE (Rotary Position Embedding)
 * ============================================================ */
void apply_rope(
    float *q,       /* [n_heads, head_dim] */
    float *k,       /* [n_kv_heads, head_dim] */
    int    n_heads,
    int    n_kv_heads,
    int    head_dim,
    int    pos,
    float  rope_theta
) {
    for (int h = 0; h < n_heads + n_kv_heads; h++) {
        float *vec = (h < n_heads) ? q + h * head_dim : k + (h - n_heads) * head_dim;
        for (int i = 0; i < head_dim; i += 2) {
            float freq = 1.0f / powf(rope_theta, (float)i / head_dim);
            float angle = pos * freq;
            float cos_a = cosf(angle);
            float sin_a = sinf(angle);
            float v0 = vec[i];
            float v1 = vec[i + 1];
            vec[i]     = v0 * cos_a - v1 * sin_a;
            vec[i + 1] = v0 * sin_a + v1 * cos_a;
        }
    }
}

/* ============================================================
 * Quantization: convert float weights to ternary
 * Uses per-row threshold: threshold = alpha * mean(|w|)
 * Returns: pos_bits, neg_bits, scales
 * ============================================================ */
void quantize_to_ternary(
    const float *weights,   /* [out_dim, in_dim] */
    uint64_t    *pos_bits,  /* [out_dim * chunks] output */
    uint64_t    *neg_bits,  /* [out_dim * chunks] output */
    float       *scales,    /* [out_dim] output */
    int          out_dim,
    int          in_dim,
    float        alpha      /* threshold multiplier, typically 0.7-1.0 */
) {
    int chunks = (in_dim + 63) / 64;
    
    for (int i = 0; i < out_dim; i++) {
        const float *row = weights + (size_t)i * in_dim;
        
        /* Compute mean absolute value for threshold */
        float abs_sum = 0;
        for (int j = 0; j < in_dim; j++) abs_sum += fabsf(row[j]);
        float mean_abs = abs_sum / in_dim;
        float threshold = alpha * mean_abs;
        
        /* Compute scale: mean of absolute values of non-zero quantized weights */
        float nz_sum = 0;
        int nz_count = 0;
        for (int j = 0; j < in_dim; j++) {
            if (fabsf(row[j]) >= threshold) {
                nz_sum += fabsf(row[j]);
                nz_count++;
            }
        }
        scales[i] = (nz_count > 0) ? (nz_sum / nz_count) : 1.0f;
        
        /* Quantize to ternary bits */
        for (int c = 0; c < chunks; c++) {
            uint64_t pb = 0, nb = 0;
            for (int b = 0; b < 64; b++) {
                int j = c * 64 + b;
                if (j >= in_dim) break;
                if (row[j] >= threshold) {
                    pb |= (1ULL << b);
                } else if (row[j] <= -threshold) {
                    nb |= (1ULL << b);
                }
            }
            pos_bits[(size_t)i * chunks + c] = pb;
            neg_bits[(size_t)i * chunks + c] = nb;
        }
    }
}