|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <immintrin.h> |
|
|
#include <stdint.h> |
|
|
#include <stdlib.h> |
|
|
#include <string.h> |
|
|
#include <math.h> |
|
|
#include <stdio.h> |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void ternary_matvec_avx512( |
|
|
const uint64_t *pos_bits, |
|
|
const uint64_t *neg_bits, |
|
|
const float *scales, |
|
|
const float *x, |
|
|
float *y, |
|
|
int out_dim, |
|
|
int in_dim |
|
|
) { |
|
|
int chunks = (in_dim + 63) / 64; |
|
|
|
|
|
|
|
|
int in_padded = (in_dim + 15) & ~15; |
|
|
float *x_pad = (float *)aligned_alloc(64, in_padded * sizeof(float)); |
|
|
memcpy(x_pad, x, in_dim * sizeof(float)); |
|
|
memset(x_pad + in_dim, 0, (in_padded - in_dim) * sizeof(float)); |
|
|
|
|
|
for (int i = 0; i < out_dim; i++) { |
|
|
__m512 acc = _mm512_setzero_ps(); |
|
|
|
|
|
const uint64_t *row_pos = pos_bits + (size_t)i * chunks; |
|
|
const uint64_t *row_neg = neg_bits + (size_t)i * chunks; |
|
|
|
|
|
|
|
|
for (int c = 0; c < chunks; c++) { |
|
|
uint64_t pb = row_pos[c]; |
|
|
uint64_t nb = row_neg[c]; |
|
|
int base = c * 64; |
|
|
|
|
|
|
|
|
for (int g = 0; g < 4 && (base + g * 16) < in_padded; g++) { |
|
|
int offset = base + g * 16; |
|
|
__m512 xv = _mm512_load_ps(x_pad + offset); |
|
|
|
|
|
|
|
|
__mmask16 pmask = (__mmask16)((pb >> (g * 16)) & 0xFFFF); |
|
|
__mmask16 nmask = (__mmask16)((nb >> (g * 16)) & 0xFFFF); |
|
|
|
|
|
|
|
|
acc = _mm512_mask_add_ps(acc, pmask, acc, xv); |
|
|
acc = _mm512_mask_sub_ps(acc, nmask, acc, xv); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
float sum = _mm512_reduce_add_ps(acc); |
|
|
|
|
|
|
|
|
y[i] = sum * scales[i]; |
|
|
} |
|
|
|
|
|
free(x_pad); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void ternary_matmul_avx512( |
|
|
const uint64_t *pos_bits, |
|
|
const uint64_t *neg_bits, |
|
|
const float *scales, |
|
|
const float *X, |
|
|
float *Y, |
|
|
int batch, |
|
|
int out_dim, |
|
|
int in_dim |
|
|
) { |
|
|
for (int b = 0; b < batch; b++) { |
|
|
ternary_matvec_avx512( |
|
|
pos_bits, neg_bits, scales, |
|
|
X + (size_t)b * in_dim, |
|
|
Y + (size_t)b * out_dim, |
|
|
out_dim, in_dim |
|
|
); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void rmsnorm_avx512( |
|
|
const float *x, |
|
|
const float *weight, |
|
|
float *y, |
|
|
int dim, |
|
|
float eps |
|
|
) { |
|
|
|
|
|
__m512 sum_sq = _mm512_setzero_ps(); |
|
|
int i; |
|
|
for (i = 0; i + 16 <= dim; i += 16) { |
|
|
__m512 xv = _mm512_loadu_ps(x + i); |
|
|
sum_sq = _mm512_fmadd_ps(xv, xv, sum_sq); |
|
|
} |
|
|
float ss = _mm512_reduce_add_ps(sum_sq); |
|
|
|
|
|
for (; i < dim; i++) ss += x[i] * x[i]; |
|
|
|
|
|
float rms = 1.0f / sqrtf(ss / dim + eps); |
|
|
|
|
|
|
|
|
for (i = 0; i + 16 <= dim; i += 16) { |
|
|
__m512 xv = _mm512_loadu_ps(x + i); |
|
|
__m512 wv = _mm512_loadu_ps(weight + i); |
|
|
__m512 rv = _mm512_set1_ps(rms); |
|
|
__m512 out = _mm512_mul_ps(_mm512_mul_ps(xv, rv), wv); |
|
|
_mm512_storeu_ps(y + i, out); |
|
|
} |
|
|
for (; i < dim; i++) y[i] = x[i] * rms * weight[i]; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static inline float silu_scalar(float x) { |
|
|
return x / (1.0f + expf(-x)); |
|
|
} |
|
|
|
|
|
void silu_avx512(float *x, int n) { |
|
|
|
|
|
for (int i = 0; i < n; i++) { |
|
|
x[i] = silu_scalar(x[i]); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void elemwise_mul_avx512(const float *a, const float *b, float *y, int n) { |
|
|
int i; |
|
|
for (i = 0; i + 16 <= n; i += 16) { |
|
|
__m512 av = _mm512_loadu_ps(a + i); |
|
|
__m512 bv = _mm512_loadu_ps(b + i); |
|
|
_mm512_storeu_ps(y + i, _mm512_mul_ps(av, bv)); |
|
|
} |
|
|
for (; i < n; i++) y[i] = a[i] * b[i]; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void softmax(float *x, int n) { |
|
|
float max_val = x[0]; |
|
|
for (int i = 1; i < n; i++) if (x[i] > max_val) max_val = x[i]; |
|
|
float sum = 0; |
|
|
for (int i = 0; i < n; i++) { |
|
|
x[i] = expf(x[i] - max_val); |
|
|
sum += x[i]; |
|
|
} |
|
|
float inv_sum = 1.0f / sum; |
|
|
for (int i = 0; i < n; i++) x[i] *= inv_sum; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void apply_rope( |
|
|
float *q, |
|
|
float *k, |
|
|
int n_heads, |
|
|
int n_kv_heads, |
|
|
int head_dim, |
|
|
int pos, |
|
|
float rope_theta |
|
|
) { |
|
|
for (int h = 0; h < n_heads + n_kv_heads; h++) { |
|
|
float *vec = (h < n_heads) ? q + h * head_dim : k + (h - n_heads) * head_dim; |
|
|
for (int i = 0; i < head_dim; i += 2) { |
|
|
float freq = 1.0f / powf(rope_theta, (float)i / head_dim); |
|
|
float angle = pos * freq; |
|
|
float cos_a = cosf(angle); |
|
|
float sin_a = sinf(angle); |
|
|
float v0 = vec[i]; |
|
|
float v1 = vec[i + 1]; |
|
|
vec[i] = v0 * cos_a - v1 * sin_a; |
|
|
vec[i + 1] = v0 * sin_a + v1 * cos_a; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void quantize_to_ternary( |
|
|
const float *weights, |
|
|
uint64_t *pos_bits, |
|
|
uint64_t *neg_bits, |
|
|
float *scales, |
|
|
int out_dim, |
|
|
int in_dim, |
|
|
float alpha |
|
|
) { |
|
|
int chunks = (in_dim + 63) / 64; |
|
|
|
|
|
for (int i = 0; i < out_dim; i++) { |
|
|
const float *row = weights + (size_t)i * in_dim; |
|
|
|
|
|
|
|
|
float abs_sum = 0; |
|
|
for (int j = 0; j < in_dim; j++) abs_sum += fabsf(row[j]); |
|
|
float mean_abs = abs_sum / in_dim; |
|
|
float threshold = alpha * mean_abs; |
|
|
|
|
|
|
|
|
float nz_sum = 0; |
|
|
int nz_count = 0; |
|
|
for (int j = 0; j < in_dim; j++) { |
|
|
if (fabsf(row[j]) >= threshold) { |
|
|
nz_sum += fabsf(row[j]); |
|
|
nz_count++; |
|
|
} |
|
|
} |
|
|
scales[i] = (nz_count > 0) ? (nz_sum / nz_count) : 1.0f; |
|
|
|
|
|
|
|
|
for (int c = 0; c < chunks; c++) { |
|
|
uint64_t pb = 0, nb = 0; |
|
|
for (int b = 0; b < 64; b++) { |
|
|
int j = c * 64 + b; |
|
|
if (j >= in_dim) break; |
|
|
if (row[j] >= threshold) { |
|
|
pb |= (1ULL << b); |
|
|
} else if (row[j] <= -threshold) { |
|
|
nb |= (1ULL << b); |
|
|
} |
|
|
} |
|
|
pos_bits[(size_t)i * chunks + c] = pb; |
|
|
neg_bits[(size_t)i * chunks + c] = nb; |
|
|
} |
|
|
} |
|
|
} |
|
|
|