|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <immintrin.h> |
|
|
#include <stdint.h> |
|
|
#include <stdlib.h> |
|
|
#include <string.h> |
|
|
#include <math.h> |
|
|
#include <stdio.h> |
|
|
#include <time.h> |
|
|
#include <omp.h> |
|
|
|
|
|
#define HIDDEN 1536 |
|
|
#define INTER 8960 |
|
|
#define N_HEADS 12 |
|
|
#define N_KV_HEADS 2 |
|
|
#define HEAD_DIM 128 |
|
|
#define N_LAYERS 28 |
|
|
#define VOCAB 151936 |
|
|
#define RMS_EPS 1e-6f |
|
|
#define ROPE_THETA 1000000.0f |
|
|
#define MAX_SEQ 4096 |
|
|
#define GQA_RATIO (N_HEADS / N_KV_HEADS) |
|
|
|
|
|
typedef struct { |
|
|
uint8_t *mags; |
|
|
uint64_t *sign_bits; |
|
|
float *scales; |
|
|
float *bias; |
|
|
int out_dim, in_dim; |
|
|
uint8_t *row_maxmag; |
|
|
} PL; |
|
|
|
|
|
typedef struct { uint16_t *w; int od, id; } FL; |
|
|
|
|
|
typedef struct { |
|
|
PL qp, kp, vp, op, gp, up, dp; |
|
|
float *in_norm, *pn_norm; |
|
|
float *qb, *kb, *vb; |
|
|
} Lay; |
|
|
|
|
|
typedef struct { |
|
|
uint16_t *emb; |
|
|
Lay lay[N_LAYERS]; |
|
|
float *fnorm; |
|
|
FL lmh; |
|
|
float *kc, *vc; |
|
|
float *h, *h2; |
|
|
float *sq, *sk, *sv, *ao; |
|
|
float *sg, *su, *sd; |
|
|
float *lg, *as; |
|
|
} M; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static void pmv(const PL *L, const float *x, float *y) { |
|
|
const int od = L->out_dim, id = L->in_dim; |
|
|
const int chunks = (id + 63) / 64; |
|
|
const int id16 = (id + 15) & ~15; |
|
|
|
|
|
float *xp = (float*)aligned_alloc(64, id16 * sizeof(float)); |
|
|
memcpy(xp, x, id * sizeof(float)); |
|
|
if (id16 > id) memset(xp + id, 0, (id16 - id) * sizeof(float)); |
|
|
|
|
|
#pragma omp parallel for schedule(dynamic, 64) |
|
|
for (int i = 0; i < od; i++) { |
|
|
const uint8_t *row_mag = L->mags + (size_t)i * id; |
|
|
const uint64_t *row_sign = L->sign_bits + (size_t)i * chunks; |
|
|
const int rmax = L->row_maxmag[i]; |
|
|
|
|
|
__m512 acc = _mm512_setzero_ps(); |
|
|
|
|
|
for (int j = 0; j < id; j += 16) { |
|
|
if (j >= id16) break; |
|
|
|
|
|
|
|
|
__m128i mv = _mm_loadu_si128((__m128i*)(row_mag + j)); |
|
|
|
|
|
|
|
|
if (_mm_testz_si128(mv, mv)) continue; |
|
|
|
|
|
__m512 xv = _mm512_load_ps(xp + j); |
|
|
|
|
|
|
|
|
int chunk_idx = j / 64; |
|
|
int bit_off = j % 64; |
|
|
uint64_t sbits = row_sign[chunk_idx]; |
|
|
uint16_t signs = (uint16_t)((sbits >> bit_off) & 0xFFFF); |
|
|
|
|
|
|
|
|
|
|
|
__m128i mx = mv; |
|
|
mx = _mm_max_epu8(mx, _mm_srli_si128(mx, 8)); |
|
|
mx = _mm_max_epu8(mx, _mm_srli_si128(mx, 4)); |
|
|
mx = _mm_max_epu8(mx, _mm_srli_si128(mx, 2)); |
|
|
mx = _mm_max_epu8(mx, _mm_srli_si128(mx, 1)); |
|
|
int local_max = _mm_extract_epi8(mx, 0); |
|
|
|
|
|
|
|
|
for (int m = 1; m <= local_max; m++) { |
|
|
|
|
|
__m128i thresh = _mm_set1_epi8((char)m); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__m128i cmp = _mm_cmpeq_epi8(_mm_max_epu8(mv, thresh), mv); |
|
|
uint16_t active = (uint16_t)_mm_movemask_epi8(cmp); |
|
|
|
|
|
__mmask16 pos = (__mmask16)(active & ~signs); |
|
|
__mmask16 neg = (__mmask16)(active & signs); |
|
|
|
|
|
acc = _mm512_mask_add_ps(acc, pos, acc, xv); |
|
|
acc = _mm512_mask_sub_ps(acc, neg, acc, xv); |
|
|
} |
|
|
} |
|
|
|
|
|
y[i] = _mm512_reduce_add_ps(acc) * L->scales[i]; |
|
|
if (L->bias) y[i] += L->bias[i]; |
|
|
} |
|
|
free(xp); |
|
|
} |
|
|
|
|
|
|
|
|
static void fmv(const FL *L, const float *x, float *y) { |
|
|
#pragma omp parallel for schedule(dynamic, 256) |
|
|
for (int i = 0; i < L->od; i++) { |
|
|
__m512 acc = _mm512_setzero_ps(); |
|
|
const uint16_t *row = L->w + (size_t)i * L->id; |
|
|
int j; |
|
|
for (j = 0; j + 16 <= L->id; j += 16) { |
|
|
__m256i h = _mm256_loadu_si256((__m256i*)(row + j)); |
|
|
acc = _mm512_fmadd_ps(_mm512_cvtph_ps(h), _mm512_loadu_ps(x + j), acc); |
|
|
} |
|
|
float s = _mm512_reduce_add_ps(acc); |
|
|
for (; j < L->id; j++) { |
|
|
float wf; _mm_store_ss(&wf, _mm_cvtph_ps(_mm_set1_epi16(row[j]))); |
|
|
s += wf * x[j]; |
|
|
} |
|
|
y[i] = s; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
static void rn(const float *x, const float *w, float *y, int d) { |
|
|
__m512 sq = _mm512_setzero_ps(); |
|
|
int i; |
|
|
for (i = 0; i+16 <= d; i += 16) { |
|
|
__m512 v = _mm512_loadu_ps(x+i); |
|
|
sq = _mm512_fmadd_ps(v, v, sq); |
|
|
} |
|
|
float ss = _mm512_reduce_add_ps(sq); |
|
|
for (; i < d; i++) ss += x[i]*x[i]; |
|
|
float r = 1.0f / sqrtf(ss/d + RMS_EPS); |
|
|
__m512 rv = _mm512_set1_ps(r); |
|
|
for (i = 0; i+16 <= d; i += 16) |
|
|
_mm512_storeu_ps(y+i, _mm512_mul_ps(_mm512_mul_ps( |
|
|
_mm512_loadu_ps(x+i), rv), _mm512_loadu_ps(w+i))); |
|
|
for (; i < d; i++) y[i] = x[i]*r*w[i]; |
|
|
} |
|
|
|
|
|
static void silu(float *x, int n) { |
|
|
for (int i = 0; i < n; i++) x[i] /= (1.0f + expf(-x[i])); |
|
|
} |
|
|
static void emul(const float *a, const float *b, float *c, int n) { |
|
|
int i; |
|
|
for (i = 0; i+16 <= n; i += 16) |
|
|
_mm512_storeu_ps(c+i, _mm512_mul_ps(_mm512_loadu_ps(a+i), _mm512_loadu_ps(b+i))); |
|
|
for (; i < n; i++) c[i] = a[i]*b[i]; |
|
|
} |
|
|
static void va(float *y, const float *x, int n) { |
|
|
int i; |
|
|
for (i = 0; i+16 <= n; i += 16) |
|
|
_mm512_storeu_ps(y+i, _mm512_add_ps(_mm512_loadu_ps(y+i), _mm512_loadu_ps(x+i))); |
|
|
for (; i < n; i++) y[i] += x[i]; |
|
|
} |
|
|
static void rope(float *v, int pos, int d) { |
|
|
for (int i = 0; i < d; i += 2) { |
|
|
float f = 1.0f / powf(ROPE_THETA, (float)i/d); |
|
|
float a = pos*f, co = cosf(a), si = sinf(a); |
|
|
float v0 = v[i], v1 = v[i+1]; |
|
|
v[i] = v0*co - v1*si; v[i+1] = v0*si + v1*co; |
|
|
} |
|
|
} |
|
|
static void sm(float *x, int n) { |
|
|
float mx = x[0]; |
|
|
for (int i = 1; i < n; i++) if (x[i] > mx) mx = x[i]; |
|
|
float s = 0; |
|
|
for (int i = 0; i < n; i++) { x[i] = expf(x[i]-mx); s += x[i]; } |
|
|
float iv = 1.0f/s; |
|
|
for (int i = 0; i < n; i++) x[i] *= iv; |
|
|
} |
|
|
static void etok(const M *m, int t, float *o) { |
|
|
const uint16_t *r = m->emb + (size_t)t * HIDDEN; |
|
|
int i; |
|
|
for (i = 0; i+16 <= HIDDEN; i += 16) |
|
|
_mm512_storeu_ps(o+i, _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(r+i)))); |
|
|
for (; i < HIDDEN; i++) _mm_store_ss(o+i, _mm_cvtph_ps(_mm_set1_epi16(r[i]))); |
|
|
} |
|
|
static float* kvp(float *c, int l, int p, int h) { |
|
|
return c + ((size_t)l*MAX_SEQ*N_KV_HEADS + (size_t)p*N_KV_HEADS + h)*HEAD_DIM; |
|
|
} |
|
|
|
|
|
static void do_attn(M *m, int l, int pos) { |
|
|
Lay *ly = &m->lay[l]; |
|
|
pmv(&ly->qp, m->h2, m->sq); |
|
|
pmv(&ly->kp, m->h2, m->sk); |
|
|
pmv(&ly->vp, m->h2, m->sv); |
|
|
if (ly->qb) va(m->sq, ly->qb, N_HEADS*HEAD_DIM); |
|
|
if (ly->kb) va(m->sk, ly->kb, N_KV_HEADS*HEAD_DIM); |
|
|
if (ly->vb) va(m->sv, ly->vb, N_KV_HEADS*HEAD_DIM); |
|
|
for (int h = 0; h < N_HEADS; h++) rope(m->sq + h*HEAD_DIM, pos, HEAD_DIM); |
|
|
for (int h = 0; h < N_KV_HEADS; h++) rope(m->sk + h*HEAD_DIM, pos, HEAD_DIM); |
|
|
for (int h = 0; h < N_KV_HEADS; h++) { |
|
|
memcpy(kvp(m->kc,l,pos,h), m->sk+h*HEAD_DIM, HEAD_DIM*4); |
|
|
memcpy(kvp(m->vc,l,pos,h), m->sv+h*HEAD_DIM, HEAD_DIM*4); |
|
|
} |
|
|
float sc = 1.0f/sqrtf((float)HEAD_DIM); |
|
|
memset(m->ao, 0, N_HEADS*HEAD_DIM*4); |
|
|
for (int h = 0; h < N_HEADS; h++) { |
|
|
int kvh = h / GQA_RATIO; |
|
|
float *qh = m->sq + h*HEAD_DIM, *oh = m->ao + h*HEAD_DIM; |
|
|
for (int t = 0; t <= pos; t++) { |
|
|
float *kk = kvp(m->kc,l,t,kvh); |
|
|
__m512 a = _mm512_setzero_ps(); |
|
|
int d; |
|
|
for (d = 0; d+16 <= HEAD_DIM; d += 16) |
|
|
a = _mm512_fmadd_ps(_mm512_loadu_ps(qh+d), _mm512_loadu_ps(kk+d), a); |
|
|
float dot = _mm512_reduce_add_ps(a); |
|
|
for (; d < HEAD_DIM; d++) dot += qh[d]*kk[d]; |
|
|
m->as[t] = dot * sc; |
|
|
} |
|
|
sm(m->as, pos+1); |
|
|
for (int t = 0; t <= pos; t++) { |
|
|
float w = m->as[t]; |
|
|
if (w < 1e-8f) continue; |
|
|
float *vv = kvp(m->vc,l,t,kvh); |
|
|
__m512 wv = _mm512_set1_ps(w); |
|
|
int d; |
|
|
for (d = 0; d+16 <= HEAD_DIM; d += 16) |
|
|
_mm512_storeu_ps(oh+d, _mm512_fmadd_ps(wv, _mm512_loadu_ps(vv+d), _mm512_loadu_ps(oh+d))); |
|
|
for (; d < HEAD_DIM; d++) oh[d] += w*vv[d]; |
|
|
} |
|
|
} |
|
|
pmv(&ly->op, m->ao, m->h2); |
|
|
} |
|
|
|
|
|
static void do_mlp(M *m, int l) { |
|
|
Lay *ly = &m->lay[l]; |
|
|
pmv(&ly->gp, m->h2, m->sg); |
|
|
pmv(&ly->up, m->h2, m->su); |
|
|
silu(m->sg, INTER); |
|
|
emul(m->sg, m->su, m->sd, INTER); |
|
|
pmv(&ly->dp, m->sd, m->h2); |
|
|
} |
|
|
|
|
|
float* forward_token(M *m, int tid, int pos) { |
|
|
etok(m, tid, m->h); |
|
|
for (int l = 0; l < N_LAYERS; l++) { |
|
|
rn(m->h, m->lay[l].in_norm, m->h2, HIDDEN); |
|
|
do_attn(m, l, pos); |
|
|
va(m->h, m->h2, HIDDEN); |
|
|
rn(m->h, m->lay[l].pn_norm, m->h2, HIDDEN); |
|
|
do_mlp(m, l); |
|
|
va(m->h, m->h2, HIDDEN); |
|
|
} |
|
|
rn(m->h, m->fnorm, m->h2, HIDDEN); |
|
|
fmv(&m->lmh, m->h2, m->lg); |
|
|
return m->lg; |
|
|
} |
|
|
|
|
|
static int samp(float *lg, int V, float T, float tp) { |
|
|
if (T > 0) { float it = 1.0f/T; for (int i = 0; i < V; i++) lg[i] *= it; } |
|
|
sm(lg, V); |
|
|
float *pr = (float*)malloc(V*4); int *ix = (int*)malloc(V*4); |
|
|
memcpy(pr, lg, V*4); |
|
|
for (int i = 0; i < V; i++) ix[i] = i; |
|
|
float cum = 0; int nk = 0; |
|
|
while (cum < tp && nk < V && nk < 50) { |
|
|
int b = nk; |
|
|
for (int i = nk+1; i < V; i++) if (pr[i] > pr[b]) b = i; |
|
|
float t = pr[nk]; pr[nk] = pr[b]; pr[b] = t; |
|
|
int ti = ix[nk]; ix[nk] = ix[b]; ix[b] = ti; |
|
|
cum += pr[nk]; nk++; |
|
|
} |
|
|
float s = 0; for (int i = 0; i < nk; i++) s += pr[i]; |
|
|
float r = (float)rand()/RAND_MAX * s, ac = 0; |
|
|
int ch = ix[0]; |
|
|
for (int i = 0; i < nk; i++) { ac += pr[i]; if (ac >= r) { ch = ix[i]; break; } } |
|
|
free(pr); free(ix); |
|
|
return ch; |
|
|
} |
|
|
|
|
|
int generate(M *m, const int *pr, int pl, int *out, int mx, |
|
|
float T, float tp, int eos) { |
|
|
srand(time(NULL)); |
|
|
for (int i = 0; i < pl; i++) forward_token(m, pr[i], i); |
|
|
int pos = pl, gen = 0; |
|
|
for (int t = 0; t < mx; t++) { |
|
|
int nx; |
|
|
if (T <= 0) { |
|
|
nx = 0; |
|
|
for (int i = 1; i < VOCAB; i++) if (m->lg[i] > m->lg[nx]) nx = i; |
|
|
} else { |
|
|
nx = samp(m->lg, VOCAB, T, tp); |
|
|
} |
|
|
out[t] = nx; gen++; |
|
|
if (nx == eos) break; |
|
|
forward_token(m, nx, pos); pos++; |
|
|
} |
|
|
return gen; |
|
|
} |
|
|
|
|
|
M* model_alloc(void) { |
|
|
M *m = (M*)calloc(1, sizeof(M)); |
|
|
size_t kv = (size_t)N_LAYERS*MAX_SEQ*N_KV_HEADS*HEAD_DIM; |
|
|
m->kc = (float*)calloc(kv,4); m->vc = (float*)calloc(kv,4); |
|
|
m->h = (float*)aligned_alloc(64,HIDDEN*4); |
|
|
m->h2 = (float*)aligned_alloc(64,HIDDEN*4); |
|
|
m->sq = (float*)aligned_alloc(64,N_HEADS*HEAD_DIM*4); |
|
|
m->sk = (float*)aligned_alloc(64,N_KV_HEADS*HEAD_DIM*4); |
|
|
m->sv = (float*)aligned_alloc(64,N_KV_HEADS*HEAD_DIM*4); |
|
|
m->ao = (float*)aligned_alloc(64,N_HEADS*HEAD_DIM*4); |
|
|
m->sg = (float*)aligned_alloc(64,INTER*4); |
|
|
m->su = (float*)aligned_alloc(64,INTER*4); |
|
|
m->sd = (float*)aligned_alloc(64,INTER*4); |
|
|
m->lg = (float*)aligned_alloc(64,VOCAB*4); |
|
|
m->as = (float*)aligned_alloc(64,MAX_SEQ*4); |
|
|
m->fnorm = (float*)aligned_alloc(64,HIDDEN*4); |
|
|
printf("Alloc: KV=%zuMB\n", kv*2*4/1024/1024); |
|
|
return m; |
|
|
} |
|
|
|
|
|
void model_set_embed(M *m, uint16_t *d) { m->emb = d; } |
|
|
void model_set_final_norm(M *m, float *d) { memcpy(m->fnorm, d, HIDDEN*4); } |
|
|
void model_set_lm_head(M *m, uint16_t *d, int o, int i) { |
|
|
m->lmh.w = d; m->lmh.od = o; m->lmh.id = i; |
|
|
} |
|
|
void layer_set_norms(M *m, int l, float *i, float *p) { |
|
|
m->lay[l].in_norm = i; m->lay[l].pn_norm = p; |
|
|
} |
|
|
void layer_set_bias(M *m, int l, float *q, float *k, float *v) { |
|
|
m->lay[l].qb = q; m->lay[l].kb = k; m->lay[l].vb = v; |
|
|
} |
|
|
void set_pl(PL *p, uint8_t *mags, uint64_t *signs, float *scales, |
|
|
uint8_t *rmm, int od, int id) { |
|
|
p->mags = mags; p->sign_bits = signs; p->scales = scales; |
|
|
p->row_maxmag = rmm; p->out_dim = od; p->in_dim = id; p->bias = NULL; |
|
|
} |
|
|
void layer_set_linears(M *m, int l, |
|
|
uint8_t*qm,uint64_t*qs,float*qc,uint8_t*qx,int qo,int qi, |
|
|
uint8_t*km,uint64_t*ks,float*kc,uint8_t*kx,int ko,int ki, |
|
|
uint8_t*vm,uint64_t*vs,float*vc,uint8_t*vx,int vo,int vi, |
|
|
uint8_t*om,uint64_t*os_,float*oc,uint8_t*ox,int oo,int oi, |
|
|
uint8_t*gm,uint64_t*gs,float*gc,uint8_t*gx,int go,int gi, |
|
|
uint8_t*um,uint64_t*us,float*uc,uint8_t*ux,int uo,int ui, |
|
|
uint8_t*dm,uint64_t*ds,float*dc,uint8_t*dx,int doo,int di) { |
|
|
set_pl(&m->lay[l].qp,qm,qs,qc,qx,qo,qi); |
|
|
set_pl(&m->lay[l].kp,km,ks,kc,kx,ko,ki); |
|
|
set_pl(&m->lay[l].vp,vm,vs,vc,vx,vo,vi); |
|
|
set_pl(&m->lay[l].op,om,os_,oc,ox,oo,oi); |
|
|
set_pl(&m->lay[l].gp,gm,gs,gc,gx,go,gi); |
|
|
set_pl(&m->lay[l].up,um,us,uc,ux,uo,ui); |
|
|
set_pl(&m->lay[l].dp,dm,ds,dc,dx,doo,di); |
|
|
} |
|
|
void model_reset_cache(M *m) { |
|
|
size_t kv=(size_t)N_LAYERS*MAX_SEQ*N_KV_HEADS*HEAD_DIM; |
|
|
memset(m->kc,0,kv*4); memset(m->vc,0,kv*4); |
|
|
} |
|
|
void model_free(M *m) { |
|
|
free(m->kc);free(m->vc);free(m->h);free(m->h2); |
|
|
free(m->sq);free(m->sk);free(m->sv);free(m->ao); |
|
|
free(m->sg);free(m->su);free(m->sd); |
|
|
free(m->lg);free(m->as);free(m->fnorm);free(m); |
|
|
} |
|
|
|