flopml commited on
Commit
1485644
·
1 Parent(s): 0ebab57

added c inferer (mamba.h)

Browse files
c_inferer/Makefile ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CC := gcc
2
+ CFLAGS := -std=c99 -O3 -Ofast -mtune=native -march=native -fopenmp -static
3
+ CLIBS := -lm -lc
4
+
5
+
6
+ C_TOKENIZER := tokenizer.bin
7
+ C_MODEL := model.bin
8
+ C_CONFIG := config.h
9
+
10
+ SRC := $(wildcard *.c)
11
+ TARGET := a.out
12
+
13
+
14
+
15
+
16
+
17
+ all: $(TARGET)
18
+
19
+ clean:
20
+ $(RM) $(TARGET) *.o
21
+
22
+ wipe:
23
+ make clean
24
+ $(RM) $(C_TOKENIZER) $(C_MODEL) $(C_CONFIG)
25
+
26
+ run: $(TARGET)
27
+ OMP_NUM_THREADS=4 ./$< -t 0 -n 512 -p "One day, " -v
28
+
29
+
30
+
31
+
32
+
33
+ $(C_TOKENIZER):
34
+ awk 'BEGIN {for (i = 0; i <= 255; i++) printf("%c%c%c", i, 0, 0)}' > $@
35
+
36
+
37
+
38
+
39
+
40
+ model.o: $(C_MODEL)
41
+ objcopy --input-target binary \
42
+ --output-target elf64-x86-64 \
43
+ --redefine-sym _binary_model_bin_start=_embedded_binary_model \
44
+ $< $@
45
+
46
+
47
+
48
+
49
+ tokenizer.o: $(C_TOKENIZER)
50
+ objcopy --input-target binary \
51
+ --output-target elf64-x86-64 \
52
+ --redefine-sym _binary_tokenizer_bin_start=_embedded_binary_tokenizer \
53
+ $< $@
54
+
55
+
56
+
57
+
58
+
59
+ $(TARGET): $(SRC) model.o tokenizer.o
60
+ $(CC) $(CFLAGS) -o $@ $^ $(CLIBS)
61
+
62
+
63
+
64
+ .PHONY: all wipe clean run
c_inferer/main.c ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include <stdlib.h>
3
+
4
+ #include <time.h>
5
+ #include <sys/time.h>
6
+ #include <stdlib.h>
7
+
8
+
9
+
10
+ #include "qmamba.h"
11
+ #include "tokenizer.h"
12
+ #include "sampler.h"
13
+
14
+
15
+
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+
24
+ static void help(char *name, void *defaults[]) {
25
+ LOG("Usage: %s [-pntsvh]\n\n", name);
26
+ LOG("Infers a Mamba language model.\n\n");
27
+ LOG("Options:\n");
28
+ LOG("\t-p <seed_text> The seed_text to start the generation with. (default NONE)\n");
29
+ LOG("\t-n <n_predict> The number of tokens to predict. (default %lu)\n", *(uint64_t *)defaults[0]);
30
+ LOG("\t-t <temperature> The temperature of the softmax. (default %.1f)\n", *(fp32_t *)defaults[1]);
31
+ LOG("\t-s <seed> The seed for the random number generator. (default %lu)\n", *(uint64_t *)defaults[2]);
32
+ LOG("\t-v Enables verbose mode. (default %s)\n", *(bool *)defaults[3] ? "true" : "false");
33
+ LOG("\t-h Prints this help message.\n\n");
34
+
35
+ exit(EXIT_FAILURE);
36
+ }
37
+
38
+
39
+
40
+ int main(int argc, char *argv[]) {
41
+
42
+ char *seed_text = NULL;
43
+ uint64_t n_predict = 256;
44
+
45
+
46
+ for (int i = 1; i < argc; i++) {
47
+ if (argv[i][0] != '-') {
48
+ LOG("Invalid argument: %s\n", argv[i]);
49
+ return 1;
50
+ }
51
+
52
+ switch (argv[i][1]) {
53
+ case 'p':
54
+ seed_text = argv[++i];
55
+ break;
56
+ case 'n':
57
+ n_predict = strtoull(argv[++i], NULL, 10);
58
+ break;
59
+ case 't':
60
+ sampler.temperature = strtod(argv[++i], NULL);
61
+ break;
62
+ case 's':
63
+ sampler.rng_seed = strtoull(argv[++i], NULL, 10);
64
+ break;
65
+ case 'v':
66
+ sampler.verbose = true;
67
+ break;
68
+ case 'h':
69
+ goto help_;
70
+ break;
71
+ default:
72
+ LOG("Invalid argument: %s\n", argv[i]);
73
+ return 1;
74
+ }
75
+ }
76
+
77
+ if (sampler.verbose)
78
+ mamba.log(&mamba);
79
+
80
+ if (sampler.generate(&sampler, seed_text, n_predict) == EXIT_FAILURE)
81
+ goto help_;
82
+
83
+ return 0;
84
+
85
+ help_:
86
+ help(argv[0], (void *[]) {&n_predict, &sampler.temperature, &sampler.rng_seed, &sampler.verbose});
87
+ }
c_inferer/mamba.h ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+
4
+ #include <stdio.h>
5
+ #include <string.h>
6
+ #include <math.h>
7
+ #include <stdint.h>
8
+
9
+ #include "config.h"
10
+
11
+
12
+
13
+
14
+
15
+
16
+ extern char _embedded_binary_model[];
17
+
18
+ typedef float fp32_t;
19
+
20
+ typedef struct MambaWeights MambaWeights;
21
+ typedef struct MambaConfig MambaConfig;
22
+ typedef struct MambaState MambaState;
23
+ typedef struct Mamba Mamba;
24
+
25
+ static void MambaLog(Mamba *);
26
+ static fp32_t *MambaForwardLayer(Mamba *, uint64_t);
27
+ static fp32_t *MambaForward(Mamba *, uint64_t);
28
+
29
+ #define Tensor(NAME, X, Y, Z) fp32_t NAME[(X) * (Y) * (Z)]
30
+
31
+ struct MambaWeights {
32
+ Tensor(embed, ROUNDED_VOCAB_SIZE, D_MODEL, 1);
33
+ Tensor(in_proj, N_LAYER, 2 * D_INNER, D_MODEL);
34
+ Tensor(conv1d_weight, N_LAYER, D_INNER, D_CONV);
35
+ Tensor(conv1d_bias, N_LAYER, D_INNER, 1);
36
+ Tensor(x_proj, N_LAYER, DT_RANK + 2 * D_STATE, D_INNER);
37
+ Tensor(dt_proj_weight, N_LAYER, D_INNER, DT_RANK);
38
+ Tensor(dt_proj_bias, N_LAYER, D_INNER, 1);
39
+ Tensor(A, N_LAYER, D_INNER, D_STATE);
40
+ Tensor(D, N_LAYER, D_INNER, 1);
41
+ Tensor(out_proj, N_LAYER, D_MODEL, D_INNER);
42
+ Tensor(norm, N_LAYER, D_MODEL, 1);
43
+ Tensor(norm_f, D_MODEL, 1, 1);
44
+ Tensor(lm_head, ROUNDED_VOCAB_SIZE, D_MODEL, 1);
45
+ };
46
+
47
+ struct MambaConfig {
48
+ uint64_t vocab_size; // vocabulary size, rounded to nearest multiple of 8
49
+ uint64_t n_layer; // number of layers
50
+ uint64_t d_model; // embedding dim
51
+ uint64_t d_inner;
52
+ uint64_t dt_rank;
53
+ uint64_t d_state;
54
+ uint64_t d_conv;
55
+ };
56
+
57
+ struct MambaState {
58
+ Tensor(hidden_state, D_MODEL, 1, 1);
59
+ Tensor(conv_state, N_LAYER, D_INNER, D_CONV);
60
+ Tensor(ssm_state, N_LAYER, D_INNER, D_STATE);
61
+ };
62
+
63
+ struct Mamba {
64
+ MambaConfig config;
65
+ MambaState state;
66
+
67
+ MambaWeights *weights;
68
+
69
+ void (*log) (Mamba *);
70
+ fp32_t *(*forward_layer) (Mamba *, uint64_t);
71
+ fp32_t *(*forward) (Mamba *, uint64_t);
72
+ };
73
+
74
+
75
+ Mamba mamba = {
76
+ .config = {
77
+ .vocab_size = ROUNDED_VOCAB_SIZE,
78
+ .n_layer = N_LAYER,
79
+ .d_model = D_MODEL,
80
+ .d_inner = D_INNER,
81
+ .dt_rank = DT_RANK,
82
+ .d_state = D_STATE,
83
+ .d_conv = D_CONV,
84
+ },
85
+
86
+ .state = {},
87
+
88
+ .weights = (MambaWeights *) _embedded_binary_model,
89
+
90
+ .log = MambaLog,
91
+ .forward_layer = MambaForwardLayer,
92
+ .forward = MambaForward
93
+ };
94
+
95
+
96
+
97
+
98
+ static void rmsnorm(fp32_t* y, fp32_t* x, fp32_t* weight, uint64_t size) {
99
+ fp32_t ss = 0.0f;
100
+ for (uint64_t j = 0; j < size; ++j)
101
+ ss += x[j] * x[j];
102
+
103
+ ss /= size;
104
+ ss += 1e-5f;
105
+ ss = 1.0f / sqrtf(ss);
106
+
107
+ for (uint64_t j = 0; j < size; ++j)
108
+ y[j] = x[j] * weight[j] * ss;
109
+ }
110
+
111
+ static fp32_t softplus(fp32_t x) { return logf(1.0f + expf(x)); }
112
+ static fp32_t sigmoid(fp32_t x) { return 1.0f / (1.0f + expf(-x)); }
113
+ static fp32_t silu(fp32_t x) { return x * sigmoid(x); }
114
+
115
+ static void shift_and_update_last_column(fp32_t* matrix, fp32_t* x, uint64_t rows, uint64_t cols) {
116
+ #pragma omp parallel for
117
+ for (uint64_t i = 0; i < rows; ++i) {
118
+ fp32_t* row = matrix + i * cols;
119
+
120
+ for (uint64_t j = 0; j < cols - 1; ++j)
121
+ row[j] = row[j + 1];
122
+
123
+ row[cols - 1] = x[i];
124
+ }
125
+ }
126
+
127
+ static void conv1d_silu(fp32_t* x, fp32_t* conv_state, fp32_t* conv1d_weight, fp32_t* conv1d_bias, uint64_t d_inner, uint64_t d_conv) {
128
+ #pragma omp parallel for
129
+ for (uint64_t i = 0; i < d_inner; ++i) {
130
+ fp32_t val = 0.0f;
131
+
132
+ for (uintconv1d_silu64_t j = 0; j < d_conv; ++j) {
133
+ uint64_t index = i * d_conv + j;
134
+ val += conv_state[index] * conv1d_weight[index];
135
+ }
136
+
137
+ x[i] = silu(val + conv1d_bias[i]);
138
+ }
139
+ }
140
+
141
+ static void dense_softplus(fp32_t* y, fp32_t* x, fp32_t* w, fp32_t* b, uint64_t d, uint64_t n) {
142
+
143
+ #pragma omp parallel for
144
+ for (uint64_t i = 0; i < d; ++i) {
145
+ fp32_t val = 0.0f;
146
+
147
+ for (uint64_t j = 0; j < n; ++j)
148
+ val += w[i * n + j] * x[j];
149
+
150
+ y[i] = softplus(val + b[i]);
151
+ }
152
+ }
153
+
154
+ static void selective_scan(fp32_t* y, fp32_t* ssm_state, fp32_t* dt, fp32_t* A, fp32_t* B, fp32_t* C, fp32_t* D, fp32_t* x, fp32_t* z, uint64_t d_inner, uint64_t d_state) {
155
+ #pragma omp parallel for
156
+ for (uint64_t i = 0; i < d_inner; ++i) {
157
+ fp32_t val = 0.0f;
158
+
159
+ for (uint64_t j = 0; j < d_state; ++j) {
160
+ uint64_t index = i * d_state + j;
161
+
162
+ fp32_t dA = expf(dt[i] * A[index]);
163
+ fp32_t dB = dt[i] * B[j];
164
+
165
+ ssm_state[index] = ssm_state[index] * dA + x[i] * dB;
166
+
167
+ val += ssm_state[index] * C[j];
168
+ }
169
+
170
+ val += D[i] * x[i];
171
+ y[i] = val * silu(z[i]);
172
+ }
173
+ }
174
+
175
+
176
+ static void matmul(fp32_t* y, fp32_t* x, fp32_t* w, uint64_t d, uint64_t n) {
177
+ #pragma omp parallel for
178
+ for (uint64_t i = 0; i < d; ++i) {
179
+ fp32_t val = 0.0f;
180
+
181
+ for (uint64_t j = 0; j < n; ++j)
182
+ val += w[i * n + j] * x[j];
183
+
184
+ y[i] = val;
185
+ }
186
+ }
187
+
188
+
189
+
190
+ #define LOG(...) fprintf(stderr, __VA_ARGS__)
191
+ #define CLOG(X, ...) if(X) LOG(__VA_ARGS__)
192
+
193
+ static char *shortScale(char *result, int64_t number) {
194
+ char *suffixes[] = {"", "k", "m", "b"};
195
+ uint64_t magnitude = 0;
196
+ fp32_t num = (fp32_t)number;
197
+
198
+ if (number < 1000) {
199
+ sprintf(result, "%lu", number);
200
+ return result;
201
+ }
202
+
203
+ while (number >= 1000 || number <= -1000) {
204
+ magnitude++;
205
+ number /= 1000;
206
+ num /= 1000.0;
207
+ }
208
+
209
+ sprintf(result, "%.0f%s", num, suffixes[magnitude]);
210
+ return result;
211
+ }
212
+
213
+
214
+
215
+ static inline void nparams(uint64_t dim) {
216
+ char buff[16];
217
+ shortScale(buff, dim);
218
+
219
+ LOG("%12lu (%s)", dim, buff);
220
+ }
221
+ #define NPARAMS(X) nparams(sizeof(X) / sizeof(fp32_t))
222
+
223
+ static void MambaLog(Mamba *mamba) {
224
+ MambaConfig *p = &mamba->config;
225
+ MambaWeights *w = mamba->weights;
226
+ MambaState *s = &mamba->state;
227
+
228
+ LOG("Mamba Config:");
229
+ LOG("\n\tvocab_size: %12lu", p->vocab_size);
230
+ LOG("\n\tn_layer: %12lu", p->n_layer);
231
+ LOG("\n\td_model: %12lu", p->d_model);
232
+ LOG("\n\td_inner: %12lu", p->d_inner);
233
+ LOG("\n\tdt_rank: %12lu", p->dt_rank);
234
+ LOG("\n\td_state: %12lu", p->d_state);
235
+ LOG("\n\td_conv: %12lu", p->d_conv);
236
+ printf("\n\n\n");
237
+
238
+ LOG("Parameters Count:");
239
+ LOG("\n\tembed: "); NPARAMS(w->embed);
240
+ LOG("\n\tin_proj: "); NPARAMS(w->in_proj);
241
+ LOG("\n\tconv1d_weight: "); NPARAMS(w->conv1d_weight);
242
+ LOG("\n\tconv1d_bias: "); NPARAMS(w->conv1d_bias);
243
+ LOG("\n\tx_proj: "); NPARAMS(w->x_proj);
244
+ LOG("\n\tdt_proj_weight: "); NPARAMS(w->dt_proj_weight);
245
+ LOG("\n\tdt_proj_bias: "); NPARAMS(w->dt_proj_bias);
246
+ LOG("\n\tA: "); NPARAMS(w->A);
247
+ LOG("\n\tD: "); NPARAMS(w->D);
248
+ LOG("\n\tout_proj: "); NPARAMS(w->out_proj);
249
+ LOG("\n\tnorm: "); NPARAMS(w->norm);
250
+ LOG("\n\tnorm_f: "); NPARAMS(w->norm_f);
251
+ LOG("\n\tlm_head: "); NPARAMS(w->lm_head);
252
+ LOG("\n\n\tTotal: "); NPARAMS(MambaWeights); //NTotal();
253
+ printf("\n\n\n");
254
+
255
+ LOG("Recurrent State:");
256
+ LOG("\n\thidden_state: "); NPARAMS(s->hidden_state);
257
+ LOG("\n\tconv_state: "); NPARAMS(s->conv_state);
258
+ LOG("\n\tssm_state: "); NPARAMS(s->ssm_state);
259
+ LOG("\n\n\tTotal: "); NPARAMS(MambaState);
260
+ printf("\n\n\n");
261
+ }
262
+
263
+ static fp32_t *MambaForwardLayer(Mamba *mamba, uint64_t layer) {
264
+ MambaConfig *p = &mamba->config;
265
+ MambaWeights *w = mamba->weights;
266
+ MambaState *s = &mamba->state;
267
+
268
+ uint64_t d_model = p->d_model,
269
+ d_inner = p->d_inner,
270
+ d_conv = p->d_conv,
271
+ d_state = p->d_state,
272
+ dt_rank = p->dt_rank;
273
+
274
+ fp32_t *hidden_state = s->hidden_state,
275
+ *conv_state = s->conv_state + layer * d_inner * d_conv,
276
+ *ssm_state = s->ssm_state + layer * d_inner * d_state;
277
+
278
+ Tensor( xz, 2 * D_INNER, 1, 1);
279
+ Tensor( x_db, DT_RANK + 2 * D_STATE, 1, 1);
280
+ Tensor( dt, D_INNER, 1, 1);
281
+ Tensor( y, D_INNER, 1, 1);
282
+
283
+ fp32_t *x = xz,
284
+ *z = xz + d_inner,
285
+ *B = x_db + dt_rank,
286
+ *C = x_db + dt_rank + d_state;
287
+
288
+
289
+ // Proj input
290
+ matmul(xz, hidden_state, w->in_proj + layer * 2 * d_inner * d_model, 2 * d_inner, d_model);
291
+
292
+
293
+ // Conv
294
+ shift_and_update_last_column(conv_state, x, d_inner, d_conv);
295
+ conv1d_silu(x, conv_state, w->conv1d_weight + layer * d_inner * d_conv, w->conv1d_bias + layer * d_inner, d_inner, d_conv);
296
+
297
+
298
+ // SSM
299
+ matmul(x_db, x, w->x_proj + layer * (dt_rank + 2 * d_state) * d_inner, dt_rank + 2 * d_state, d_inner);
300
+ dense_softplus(dt, x_db, w->dt_proj_weight + layer * d_inner * dt_rank, w->dt_proj_bias + layer * d_inner, d_inner, dt_rank);
301
+ selective_scan(y, ssm_state, dt, w->A + layer * d_inner * d_state, B, C, w->D + layer * d_inner, x, z, d_inner, d_state);
302
+
303
+
304
+ // Proj output
305
+ matmul(hidden_state, y, w->out_proj + layer * d_model * d_inner, d_model, d_inner);
306
+
307
+ return hidden_state;
308
+ }
309
+
310
+ static fp32_t *MambaForward(Mamba *mamba, uint64_t token) {
311
+ MambaConfig *p = &mamba->config;
312
+ MambaWeights *w = mamba->weights;
313
+ MambaState *s = &mamba->state;
314
+
315
+ uint64_t d_model = p->d_model,
316
+ n_layer = p->n_layer,
317
+ vocab_size = p->vocab_size;
318
+
319
+ Tensor( input, D_MODEL, 1, 1);
320
+ static Tensor( logits, ROUNDED_VOCAB_SIZE, 1, 1);
321
+
322
+ fp32_t *hidden_state = s->hidden_state,
323
+ *content_row = w->embed + token * d_model;
324
+
325
+
326
+ memcpy(input, content_row, d_model * sizeof(fp32_t));
327
+
328
+ for (uint64_t layer = 0; layer < n_layer; ++layer) {
329
+ rmsnorm(hidden_state, input, w->norm + layer * d_model, d_model);
330
+ mamba->forward_layer(mamba, layer);
331
+
332
+ for (uint64_t i = 0; i < d_model; ++i) {
333
+ hidden_state[i] += input[i];
334
+ input[i] = hidden_state[i];
335
+ }
336
+ }
337
+
338
+ rmsnorm(hidden_state, hidden_state, w->norm_f, d_model);
339
+ matmul(logits, hidden_state, w->lm_head, vocab_size, d_model);
340
+
341
+ return logits;
342
+ }
343
+
c_inferer/qmamba.h ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+
4
+ #include <stdio.h>
5
+ #include <string.h>
6
+ #include <math.h>
7
+ #include <stdint.h>
8
+
9
+ #include "config.h"
10
+
11
+
12
+
13
+
14
+
15
+
16
+ extern char _embedded_binary_model[];
17
+
18
+ typedef float fp32_t;
19
+
20
+ typedef struct MambaWeights MambaWeights;
21
+ typedef struct MambaConfig MambaConfig;
22
+ typedef struct MambaState MambaState;
23
+ typedef struct Mamba Mamba;
24
+
25
+ static void MambaLog(Mamba *);
26
+ static fp32_t *MambaForwardLayer(Mamba *, uint64_t);
27
+ static fp32_t *MambaForward(Mamba *, uint64_t);
28
+
29
+ #define GS 64
30
+
31
+ #define Tensor(NAME, X, Y, Z) fp32_t NAME[(X) * (Y) * (Z)]
32
+ #define QTensor(NAME, X, Y, Z) struct { int8_t q[(X) * (Y) * (Z)]; fp32_t s[((X) * (Y) * (Z)) / GS]; } NAME
33
+ #define QPointer(NAME, QT, P) struct { int8_t *q; fp32_t *s; } NAME = { .q = QT.q + (P), .s = QT.s + ((P) / GS) }
34
+
35
+ struct MambaWeights {
36
+ QTensor(embed, ROUNDED_VOCAB_SIZE, D_MODEL, 1);
37
+ QTensor(in_proj, N_LAYER, 2 * D_INNER, D_MODEL);
38
+
39
+ Tensor(conv1d_weight, N_LAYER, D_INNER, D_CONV);
40
+ Tensor(conv1d_bias, N_LAYER, D_INNER, 1);
41
+
42
+ QTensor(x_proj, N_LAYER, DT_RANK + 2 * D_STATE, D_INNER);
43
+
44
+ Tensor(dt_proj_weight, N_LAYER, D_INNER, DT_RANK);
45
+ Tensor(dt_proj_bias, N_LAYER, D_INNER, 1);
46
+ Tensor(A, N_LAYER, D_INNER, D_STATE);
47
+ Tensor(D, N_LAYER, D_INNER, 1);
48
+
49
+ QTensor(out_proj, N_LAYER, D_MODEL, D_INNER);
50
+
51
+ Tensor(norm, N_LAYER, D_MODEL, 1);
52
+ Tensor(norm_f, D_MODEL, 1, 1);
53
+
54
+ QTensor(lm_head, ROUNDED_VOCAB_SIZE, D_MODEL, 1);
55
+ };
56
+
57
+ struct MambaConfig {
58
+ uint64_t vocab_size; // vocabulary size, rounded to nearest multiple of 8
59
+ uint64_t n_layer; // number of layers
60
+ uint64_t d_model; // embedding dim
61
+ uint64_t d_inner;
62
+ uint64_t dt_rank;
63
+ uint64_t d_state;
64
+ uint64_t d_conv;
65
+ };
66
+
67
+ struct MambaState {
68
+ Tensor(hidden_state, D_MODEL, 1, 1);
69
+ Tensor(conv_state, N_LAYER, D_INNER, D_CONV);
70
+ Tensor(ssm_state, N_LAYER, D_INNER, D_STATE);
71
+ };
72
+
73
+ struct Mamba {
74
+ MambaConfig config;
75
+ MambaState state;
76
+
77
+ MambaWeights *weights;
78
+
79
+ void (*log) (Mamba *);
80
+ fp32_t *(*forward_layer) (Mamba *, uint64_t);
81
+ fp32_t *(*forward) (Mamba *, uint64_t);
82
+ };
83
+
84
+
85
+ Mamba mamba = {
86
+ .config = {
87
+ .vocab_size = ROUNDED_VOCAB_SIZE,
88
+ .n_layer = N_LAYER,
89
+ .d_model = D_MODEL,
90
+ .d_inner = D_INNER,
91
+ .dt_rank = DT_RANK,
92
+ .d_state = D_STATE,
93
+ .d_conv = D_CONV,
94
+ },
95
+
96
+ .state = {},
97
+
98
+ .weights = (MambaWeights *) _embedded_binary_model,
99
+
100
+ .log = MambaLog,
101
+ .forward_layer = MambaForwardLayer,
102
+ .forward = MambaForward
103
+ };
104
+
105
+
106
+
107
+
108
+ static void rmsnorm(fp32_t* y, fp32_t* x, fp32_t* weight, uint64_t size) {
109
+ fp32_t ss = 0.0f;
110
+ for (uint64_t j = 0; j < size; ++j)
111
+ ss += x[j] * x[j];
112
+
113
+ ss /= size;
114
+ ss += 1e-5f;
115
+ ss = 1.0f / sqrtf(ss);
116
+
117
+ for (uint64_t j = 0; j < size; ++j)
118
+ y[j] = x[j] * weight[j] * ss;
119
+ }
120
+
121
+ static fp32_t softplus(fp32_t x) { return logf(1.0f + expf(x)); }
122
+ static fp32_t sigmoid(fp32_t x) { return 1.0f / (1.0f + expf(-x)); }
123
+ static fp32_t silu(fp32_t x) { return x * sigmoid(x); }
124
+
125
+ static void shift_and_update_last_column(fp32_t* matrix, fp32_t* x, uint64_t rows, uint64_t cols) {
126
+ #pragma omp parallel for
127
+ for (uint64_t i = 0; i < rows; ++i) {
128
+ fp32_t* row = matrix + i * cols;
129
+
130
+ for (uint64_t j = 0; j < cols - 1; ++j)
131
+ row[j] = row[j + 1];
132
+
133
+ row[cols - 1] = x[i];
134
+ }
135
+ }
136
+
137
+ static void conv1d_silu(fp32_t* x, fp32_t* conv_state, fp32_t* conv1d_weight, fp32_t* conv1d_bias, uint64_t d_inner, uint64_t d_conv) {
138
+ #pragma omp parallel for
139
+ for (uint64_t i = 0; i < d_inner; ++i) {
140
+ fp32_t val = 0.0f;
141
+
142
+ for (uint64_t j = 0; j < d_conv; ++j) {
143
+ uint64_t index = i * d_conv + j;
144
+ val += conv_state[index] * conv1d_weight[index];
145
+ }
146
+
147
+ x[i] = silu(val + conv1d_bias[i]);
148
+ }
149
+ }
150
+
151
+ static void dense_softplus(fp32_t* y, fp32_t* x, fp32_t* w, fp32_t* b, uint64_t d, uint64_t n) {
152
+
153
+ #pragma omp parallel for
154
+ for (uint64_t i = 0; i < d; ++i) {
155
+ fp32_t val = 0.0f;
156
+
157
+ for (uint64_t j = 0; j < n; ++j)
158
+ val += w[i * n + j] * x[j];
159
+
160
+ y[i] = softplus(val + b[i]);
161
+ }
162
+ }
163
+
164
+ static void selective_scan(fp32_t* y, fp32_t* ssm_state, fp32_t* dt, fp32_t* A, fp32_t* B, fp32_t* C, fp32_t* D, fp32_t* x, fp32_t* z, uint64_t d_inner, uint64_t d_state) {
165
+ #pragma omp parallel for
166
+ for (uint64_t i = 0; i < d_inner; ++i) {
167
+ fp32_t val = 0.0f;
168
+
169
+ for (uint64_t j = 0; j < d_state; ++j) {
170
+ uint64_t index = i * d_state + j;
171
+
172
+ fp32_t dA = expf(dt[i] * A[index]);
173
+ fp32_t dB = dt[i] * B[j];
174
+
175
+ ssm_state[index] = ssm_state[index] * dA + x[i] * dB;
176
+
177
+ val += ssm_state[index] * C[j];
178
+ }
179
+
180
+ val += D[i] * x[i];
181
+ y[i] = val * silu(z[i]);
182
+ }
183
+ }
184
+
185
+
186
+
187
+ static void qmatmul(fp32_t* y, int8_t *xq, fp32_t *xs, int8_t *wq, fp32_t *ws, uint64_t d, uint64_t n) {
188
+ #pragma omp parallel for
189
+ for (uint64_t i = 0; i < d; i++) {
190
+
191
+ float val = 0.0f;
192
+ int64_t in = i * n;
193
+
194
+ // do the matmul in groups of GS
195
+ for (int j = 0; j <= n - GS; j += GS) {
196
+ int32_t ival = 0;
197
+
198
+ for (uint64_t k = 0; k < GS; k++)
199
+ ival += ((uint64_t) xq[j + k]) * ((uint64_t) wq[in + j + k]);
200
+
201
+ val += ((float) ival) * ws[(in + j) / GS] * xs[j / GS];
202
+ }
203
+
204
+ y[i] = val;
205
+ }
206
+ }
207
+
208
+ void dequantize(fp32_t* x, int8_t *xq, fp32_t *xs, uint64_t n) {
209
+ for (uint64_t i = 0; i < n; i++)
210
+ x[i] = xq[i] * xs[i / GS];
211
+ }
212
+
213
+
214
+ void quantize(int8_t *xq, fp32_t *xs, fp32_t* x, uint64_t n) {
215
+ uint64_t num_groups = n / GS;
216
+ fp32_t Q_MAX = 127.0f;
217
+
218
+ for (uint64_t group = 0; group < num_groups; group++) {
219
+
220
+ // find the max absolute value in the current group
221
+ fp32_t wmax = 0.0;
222
+ for (uint64_t i = 0; i < GS; i++) {
223
+ fp32_t val = fabs(x[group * GS + i]);
224
+ if (val > wmax) {
225
+ wmax = val;
226
+ }
227
+ }
228
+
229
+ // calculate and write the scaling factor
230
+ fp32_t scale = wmax / Q_MAX;
231
+ xs[group] = scale;
232
+
233
+ // calculate and write the quantized values
234
+ for (uint64_t i = 0; i < GS; i++) {
235
+ fp32_t quant_value = x[group * GS + i] / scale; // scale
236
+ int8_t quantized = (int8_t) round(quant_value); // round and clamp
237
+ xq[group * GS + i] = quantized;
238
+ }
239
+ }
240
+ }
241
+
242
+
243
+
244
+ #define LOG(...) fprintf(stderr, __VA_ARGS__)
245
+ #define CLOG(X, ...) if(X) LOG(__VA_ARGS__)
246
+
247
+ static char *shortScale(char *result, int64_t number) {
248
+ char *suffixes[] = {"", "k", "m", "b"};
249
+ uint64_t magnitude = 0;
250
+ fp32_t num = (fp32_t)number;
251
+
252
+ if (number < 1000) {
253
+ sprintf(result, "%lu", number);
254
+ return result;
255
+ }
256
+
257
+ while (number >= 1000 || number <= -1000) {
258
+ magnitude++;
259
+ number /= 1000;
260
+ num /= 1000.0;
261
+ }
262
+
263
+ sprintf(result, "%.0f%s", num, suffixes[magnitude]);
264
+ return result;
265
+ }
266
+
267
+
268
+
269
+ uint64_t global_scale = 0;
270
+ static inline void nparams(uint64_t dim) {
271
+ char buff[16];
272
+ shortScale(buff, dim);
273
+
274
+ global_scale += dim;
275
+
276
+ LOG("%12lu (%s)", dim, buff);
277
+ }
278
+ #define NPARAMS(X) nparams(sizeof(X) / sizeof(fp32_t))
279
+
280
+ static inline void nqparams(uint64_t qdim) {
281
+ uint64_t dim = qdim * 4;
282
+
283
+ char buff[16];
284
+ shortScale(buff, qdim);
285
+
286
+ nparams(dim);
287
+ LOG("\t%12lu (%s)", qdim, buff);
288
+ }
289
+ #define NQPARAMS(X) nqparams(sizeof(X.q) / sizeof(fp32_t))
290
+
291
+ static inline void ntotal(uint64_t qdim) {
292
+ char buff[16];
293
+ shortScale(buff, global_scale);
294
+ LOG("%12lu (%s)", global_scale, buff);
295
+
296
+ shortScale(buff, qdim);
297
+ LOG("\t%12lu (%s)", qdim, buff);
298
+
299
+ fp32_t factor = (((fp32_t) (global_scale - qdim)) / global_scale) * 100;
300
+ LOG(" < %.2lf%%", factor);
301
+ }
302
+ #define NTOTAL(X) ntotal(sizeof(X) / sizeof(fp32_t))
303
+
304
+ static void MambaLog(Mamba *mamba) {
305
+ MambaConfig *p = &mamba->config;
306
+ MambaWeights *w = mamba->weights;
307
+ MambaState *s = &mamba->state;
308
+
309
+ LOG("Mamba Config:");
310
+ LOG("\n\tvocab_size: %12lu", p->vocab_size);
311
+ LOG("\n\tn_layer: %12lu", p->n_layer);
312
+ LOG("\n\td_model: %12lu", p->d_model);
313
+ LOG("\n\td_inner: %12lu", p->d_inner);
314
+ LOG("\n\tdt_rank: %12lu", p->dt_rank);
315
+ LOG("\n\td_state: %12lu", p->d_state);
316
+ LOG("\n\td_conv: %12lu", p->d_conv);
317
+ printf("\n\n\n");
318
+
319
+ LOG("Parameters Count: Base Quantized Factor");
320
+ LOG("\n\tembed: "); NQPARAMS(w->embed);
321
+ LOG("\n\tin_proj: "); NQPARAMS(w->in_proj);
322
+ LOG("\n\tconv1d_weight: "); NPARAMS(w->conv1d_weight);
323
+ LOG("\n\tconv1d_bias: "); NPARAMS(w->conv1d_bias);
324
+ LOG("\n\tx_proj: "); NQPARAMS(w->x_proj);
325
+ LOG("\n\tdt_proj_weight: "); NPARAMS(w->dt_proj_weight);
326
+ LOG("\n\tdt_proj_bias: "); NPARAMS(w->dt_proj_bias);
327
+ LOG("\n\tA: "); NPARAMS(w->A);
328
+ LOG("\n\tD: "); NPARAMS(w->D);
329
+ LOG("\n\tout_proj: "); NQPARAMS(w->out_proj);
330
+ LOG("\n\tnorm: "); NPARAMS(w->norm);
331
+ LOG("\n\tnorm_f: "); NPARAMS(w->norm_f);
332
+ LOG("\n\tlm_head: "); NQPARAMS(w->lm_head);
333
+ LOG("\n\n\tTotal: "); NTOTAL(MambaWeights);
334
+ printf("\n\n\n");
335
+
336
+ LOG("Recurrent State:");
337
+ LOG("\n\thidden_state: "); NPARAMS(s->hidden_state);
338
+ LOG("\n\tconv_state: "); NPARAMS(s->conv_state);
339
+ LOG("\n\tssm_state: "); NPARAMS(s->ssm_state);
340
+ LOG("\n\n\tTotal: "); NPARAMS(MambaState);
341
+ printf("\n\n\n");
342
+ }
343
+
344
+
345
+ static fp32_t *MambaForwardLayer(Mamba *mamba, uint64_t layer) {
346
+ MambaConfig *p = &mamba->config;
347
+ MambaWeights *w = mamba->weights;
348
+ MambaState *s = &mamba->state;
349
+
350
+ uint64_t d_model = p->d_model,
351
+ d_inner = p->d_inner,
352
+ d_conv = p->d_conv,
353
+ d_state = p->d_state,
354
+ dt_rank = p->dt_rank;
355
+
356
+ fp32_t *hidden_state = s->hidden_state,
357
+ *conv_state = s->conv_state + layer * d_inner * d_conv,
358
+ *ssm_state = s->ssm_state + layer * d_inner * d_state;
359
+
360
+ QTensor( qhidden_state, D_MODEL, 1, 1);
361
+
362
+
363
+ Tensor( xz, 2 * D_INNER, 1, 1);
364
+ QTensor( qx, D_INNER, 1, 1);
365
+ Tensor( x_db, DT_RANK + 2 * D_STATE, 1, 1);
366
+ Tensor( dt, D_INNER, 1, 1);
367
+ Tensor( y, D_INNER, 1, 1);
368
+ QTensor( qy, D_INNER, 1, 1);
369
+
370
+ fp32_t *x = xz,
371
+ *z = xz + d_inner,
372
+ *B = x_db + dt_rank,
373
+ *C = x_db + dt_rank + d_state;
374
+
375
+ QPointer( in_proj, w->in_proj, layer * 2 * d_inner * d_model);
376
+ QPointer( x_proj, w->x_proj, layer * (dt_rank + 2 * d_state) * d_inner);
377
+ QPointer( out_proj, w->out_proj, layer * d_model * d_inner);
378
+
379
+
380
+
381
+ // Proj input
382
+ quantize(qhidden_state.q, qhidden_state.s, hidden_state, d_model);
383
+ qmatmul(xz, qhidden_state.q, qhidden_state.s, in_proj.q, in_proj.s, 2 * d_inner, d_model);
384
+
385
+
386
+ // Conv
387
+ shift_and_update_last_column(conv_state, x, d_inner, d_conv);
388
+ conv1d_silu(x, conv_state, w->conv1d_weight + layer * d_inner * d_conv, w->conv1d_bias + layer * d_inner, d_inner, d_conv);
389
+
390
+
391
+ // SSM
392
+ quantize(qx.q, qx.s, x, d_inner);
393
+ qmatmul(x_db, qx.q, qx.s, x_proj.q, x_proj.s, dt_rank + 2 * d_state, d_inner);
394
+
395
+ dense_softplus(dt, x_db, w->dt_proj_weight + layer * d_inner * dt_rank, w->dt_proj_bias + layer * d_inner, d_inner, dt_rank);
396
+ selective_scan(y, ssm_state, dt, w->A + layer * d_inner * d_state, B, C, w->D + layer * d_inner, x, z, d_inner, d_state);
397
+
398
+
399
+ // Proj output
400
+ quantize(qy.q, qy.s, y, d_inner);
401
+ qmatmul(hidden_state, qy.q, qy.s, out_proj.q, out_proj.s, d_model, d_inner);
402
+
403
+ return hidden_state;
404
+ }
405
+
406
+
407
+ static fp32_t *MambaForward(Mamba *mamba, uint64_t token) {
408
+ MambaConfig *p = &mamba->config;
409
+ MambaWeights *w = mamba->weights;
410
+ MambaState *s = &mamba->state;
411
+
412
+ uint64_t d_model = p->d_model,
413
+ n_layer = p->n_layer,
414
+ vocab_size = p->vocab_size;
415
+
416
+ Tensor( input, D_MODEL, 1, 1);
417
+ static Tensor( logits, ROUNDED_VOCAB_SIZE, 1, 1);
418
+
419
+ QTensor( qhidden_state, D_MODEL, 1, 1);
420
+
421
+ fp32_t *hidden_state = s->hidden_state;
422
+
423
+ QPointer( row, w->embed, token * d_model);
424
+
425
+
426
+ dequantize(input, row.q, row.s, d_model);
427
+
428
+ for (uint64_t layer = 0; layer < n_layer; ++layer) {
429
+ rmsnorm(hidden_state, input, w->norm + layer * d_model, d_model);
430
+ mamba->forward_layer(mamba, layer);
431
+
432
+ for (uint64_t i = 0; i < d_model; ++i) {
433
+ hidden_state[i] += input[i];
434
+ input[i] = hidden_state[i];
435
+ }
436
+ }
437
+
438
+ rmsnorm(hidden_state, hidden_state, w->norm_f, d_model);
439
+
440
+ quantize(qhidden_state.q, qhidden_state.s, hidden_state, d_model);
441
+ qmatmul(logits, qhidden_state.q, qhidden_state.s, w->lm_head.q, w->lm_head.s, vocab_size, d_model);
442
+
443
+ return logits;
444
+ }
c_inferer/sampler.h ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+
4
+ //#include <stdbool.h>
5
+
6
+
7
+
8
+
9
+
10
+
11
+
12
+ #define PRINT(C) fputc((char)C, stdout), fflush(stdout)
13
+
14
+ typedef enum {false, true} bool;
15
+
16
+ typedef struct Sampler Sampler;
17
+ struct Sampler {
18
+ Mamba *model;
19
+ Tokenizer *tokenizer;
20
+
21
+ uint64_t rng_seed;
22
+ fp32_t temperature;
23
+ bool verbose;
24
+
25
+ bool (*generate) (Sampler *, char *, uint64_t);
26
+ uint64_t (*sample) (Sampler *, fp32_t *);
27
+ };
28
+
29
+
30
+
31
+
32
+ static void softmax(fp32_t* x, uint64_t size) {
33
+ fp32_t max_val = x[0];
34
+ for (uint64_t i = 1; i < size; ++i)
35
+ if (x[i] > max_val) max_val = x[i];
36
+
37
+ fp32_t sum = 0.0f;
38
+ for (uint64_t i = 0; i < size; ++i) {
39
+ x[i] = expf(x[i] - max_val);
40
+ sum += x[i];
41
+ }
42
+
43
+ for (uint64_t i = 0; i < size; ++i)
44
+ x[i] /= sum;
45
+ }
46
+
47
+
48
+ static uint64_t random_u32(uint64_t *rng_seed) {
49
+ *rng_seed ^= *rng_seed >> 12;
50
+ *rng_seed ^= *rng_seed << 25;
51
+ *rng_seed ^= *rng_seed >> 27;
52
+ *rng_seed = (*rng_seed * 0x2545F4914F6CDD1Dull) >> 32;
53
+ return *rng_seed;
54
+ }
55
+
56
+
57
+ static inline fp32_t random_f32(uint64_t *rng_seed) { return (random_u32(rng_seed) >> 8) / 16777216.0f; }
58
+
59
+ static uint64_t time_in_ms() {
60
+ struct timeval tv;
61
+ gettimeofday(&tv, NULL);
62
+ return tv.tv_sec * 1000 + tv.tv_usec / 1000;
63
+ }
64
+
65
+ static inline uint64_t sample_argmax(fp32_t* probabilities, uint64_t n) {
66
+ uint64_t max_i = 0;
67
+ fp32_t max_p = probabilities[0];
68
+
69
+ for (uint64_t i = 1; i < n; ++i)
70
+ if (probabilities[i] > max_p)
71
+ max_i = i, max_p = probabilities[i];
72
+
73
+ return max_i;
74
+ }
75
+
76
+ static inline uint64_t sample_mult(fp32_t* probabilities, uint64_t n, fp32_t coin) {
77
+ fp32_t cdf = 0.0f;
78
+
79
+ for (uint64_t i = 0; i < n; ++i) {
80
+ cdf += probabilities[i];
81
+
82
+ if (coin < cdf) return i;
83
+ }
84
+
85
+ return n - 1;
86
+ }
87
+
88
+ static uint64_t SamplerSample(Sampler *sampler, fp32_t* logits) {
89
+ uint64_t next,
90
+ vocab_size = sampler->tokenizer->vocab_size,
91
+ *rng_seed = &sampler->rng_seed;
92
+
93
+ //printf("Vocab size: %llu\n", vocab_size);
94
+
95
+ fp32_t temperature = sampler->temperature;
96
+
97
+ if (temperature == 0.0f) next = sample_argmax(logits, vocab_size);
98
+ else {
99
+ for (uint64_t q = 0; q < vocab_size; ++q)
100
+ logits[q] /= temperature;
101
+
102
+ softmax(logits, vocab_size);
103
+
104
+ fp32_t coin = random_f32(rng_seed);
105
+ next = sample_mult(logits, vocab_size, coin);
106
+ }
107
+
108
+ return next;
109
+ }
110
+
111
+ static bool SamplerGenerate(Sampler *sampler, char *seed_text, uint64_t n_predict) {
112
+ Mamba *model = sampler->model;
113
+ Tokenizer *tokenizer = sampler->tokenizer;
114
+ uint64_t vocab_size = tokenizer->vocab_size;
115
+ fp32_t temperature = sampler->temperature;
116
+ bool verbose = sampler->verbose;
117
+
118
+ uint64_t token;
119
+ fp32_t *logits;
120
+ char *text;
121
+
122
+ if (seed_text == NULL) return EXIT_FAILURE;
123
+
124
+ for (; *seed_text; ) {
125
+
126
+ token = tokenizer->encode(tokenizer, (uint8_t **) &seed_text);
127
+ text = tokenizer->decode(tokenizer, token);
128
+
129
+ fputs(text, stdout);
130
+ fflush(stdout);
131
+
132
+ logits = model->forward(model, token);
133
+ }
134
+
135
+ uint64_t time_start;
136
+ if (verbose) time_start = time_in_ms();
137
+
138
+ for (uint64_t i = 0; i < n_predict; ++i) {
139
+
140
+ token = sampler->sample(sampler, logits);
141
+ text = tokenizer->decode(tokenizer, token);
142
+
143
+ fputs(text, stdout);
144
+ fflush(stdout);
145
+
146
+ logits = model->forward(model, token);
147
+ }
148
+
149
+ CLOG(verbose, "\nachieved tok/s: %f\n", n_predict / (double)(time_in_ms() - time_start) * 1000);
150
+
151
+ return EXIT_SUCCESS;
152
+ }
153
+
154
+
155
+
156
+ Sampler sampler = {
157
+ .model = &mamba,
158
+ .tokenizer = &tokenizer,
159
+
160
+ .rng_seed = 42,
161
+ .temperature = 0.0f,
162
+ .verbose = false,
163
+
164
+ .generate = SamplerGenerate,
165
+ .sample = SamplerSample
166
+ };
c_inferer/tokenizer.h ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+
4
+ #include <stdint.h>
5
+
6
+
7
+
8
+
9
+
10
+ extern char _embedded_binary_tokenizer[];
11
+
12
+
13
+ #define MAX_WORD_LEN 24
14
+
15
+ typedef struct __attribute__((packed)) token_t {
16
+ uint8_t byte;
17
+ uint16_t prev;
18
+ } token_t;
19
+
20
+
21
+ typedef struct Tokenizer Tokenizer;
22
+ struct Tokenizer {
23
+ token_t *vocab;
24
+ uint16_t vocab_size;
25
+
26
+ uint16_t (*find) (Tokenizer *, uint8_t, uint16_t);
27
+ uint16_t (*encode) (Tokenizer *, uint8_t **);
28
+ uint8_t *(*decode) (Tokenizer *, uint16_t);
29
+
30
+
31
+ };
32
+
33
+
34
+
35
+
36
+
37
+ static uint16_t TokenizerFind(Tokenizer *tokenizer, uint8_t byte, uint16_t prev) {
38
+ for (uint16_t i = prev; i < tokenizer->vocab_size; ++i)
39
+ if (tokenizer->vocab[i].byte == byte && tokenizer->vocab[i].prev == prev)
40
+ return i;
41
+
42
+ return 0;
43
+ }
44
+
45
+
46
+ static uint16_t TokenizerEncode(Tokenizer *tokenizer, uint8_t **seed_text) {
47
+
48
+ uint16_t prev = 0;
49
+ for (; **seed_text; ++*seed_text) {
50
+ uint16_t next = tokenizer->find(tokenizer, **seed_text, prev);
51
+ if (next == 0) break;
52
+ prev = next;
53
+ }
54
+
55
+ return prev;
56
+ }
57
+
58
+
59
+ static uint8_t *TokenizerDecode(Tokenizer *tokenizer, uint16_t token) {
60
+
61
+ static uint8_t dest[MAX_WORD_LEN + 1];
62
+ dest[MAX_WORD_LEN] = '\0';
63
+
64
+ uint16_t prev = token;
65
+ uint16_t i = MAX_WORD_LEN - 1;
66
+
67
+ for (; prev && i > 0; prev = tokenizer->vocab[prev].prev, --i)
68
+ dest[i] = tokenizer->vocab[prev].byte;
69
+
70
+ return dest + i + 1;
71
+ }
72
+
73
+
74
+ Tokenizer tokenizer = {
75
+ .vocab = (token_t *) _embedded_binary_tokenizer,
76
+
77
+ .vocab_size = VOCAB_SIZE,
78
+
79
+ .find = TokenizerFind,
80
+ .encode = TokenizerEncode,
81
+ .decode = TokenizerDecode
82
+ };
83
+