added c inferer (mamba.h)
Browse files- c_inferer/Makefile +64 -0
- c_inferer/main.c +87 -0
- c_inferer/mamba.h +343 -0
- c_inferer/qmamba.h +444 -0
- c_inferer/sampler.h +166 -0
- c_inferer/tokenizer.h +83 -0
c_inferer/Makefile
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CC := gcc
|
| 2 |
+
CFLAGS := -std=c99 -O3 -Ofast -mtune=native -march=native -fopenmp -static
|
| 3 |
+
CLIBS := -lm -lc
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
C_TOKENIZER := tokenizer.bin
|
| 7 |
+
C_MODEL := model.bin
|
| 8 |
+
C_CONFIG := config.h
|
| 9 |
+
|
| 10 |
+
SRC := $(wildcard *.c)
|
| 11 |
+
TARGET := a.out
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
all: $(TARGET)
|
| 18 |
+
|
| 19 |
+
clean:
|
| 20 |
+
$(RM) $(TARGET) *.o
|
| 21 |
+
|
| 22 |
+
wipe:
|
| 23 |
+
make clean
|
| 24 |
+
$(RM) $(C_TOKENIZER) $(C_MODEL) $(C_CONFIG)
|
| 25 |
+
|
| 26 |
+
run: $(TARGET)
|
| 27 |
+
OMP_NUM_THREADS=4 ./$< -t 0 -n 512 -p "One day, " -v
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
$(C_TOKENIZER):
|
| 34 |
+
awk 'BEGIN {for (i = 0; i <= 255; i++) printf("%c%c%c", i, 0, 0)}' > $@
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
model.o: $(C_MODEL)
|
| 41 |
+
objcopy --input-target binary \
|
| 42 |
+
--output-target elf64-x86-64 \
|
| 43 |
+
--redefine-sym _binary_model_bin_start=_embedded_binary_model \
|
| 44 |
+
$< $@
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
tokenizer.o: $(C_TOKENIZER)
|
| 50 |
+
objcopy --input-target binary \
|
| 51 |
+
--output-target elf64-x86-64 \
|
| 52 |
+
--redefine-sym _binary_tokenizer_bin_start=_embedded_binary_tokenizer \
|
| 53 |
+
$< $@
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
$(TARGET): $(SRC) model.o tokenizer.o
|
| 60 |
+
$(CC) $(CFLAGS) -o $@ $^ $(CLIBS)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
.PHONY: all wipe clean run
|
c_inferer/main.c
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#include <stdlib.h>
|
| 3 |
+
|
| 4 |
+
#include <time.h>
|
| 5 |
+
#include <sys/time.h>
|
| 6 |
+
#include <stdlib.h>
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
#include "qmamba.h"
|
| 11 |
+
#include "tokenizer.h"
|
| 12 |
+
#include "sampler.h"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
static void help(char *name, void *defaults[]) {
|
| 25 |
+
LOG("Usage: %s [-pntsvh]\n\n", name);
|
| 26 |
+
LOG("Infers a Mamba language model.\n\n");
|
| 27 |
+
LOG("Options:\n");
|
| 28 |
+
LOG("\t-p <seed_text> The seed_text to start the generation with. (default NONE)\n");
|
| 29 |
+
LOG("\t-n <n_predict> The number of tokens to predict. (default %lu)\n", *(uint64_t *)defaults[0]);
|
| 30 |
+
LOG("\t-t <temperature> The temperature of the softmax. (default %.1f)\n", *(fp32_t *)defaults[1]);
|
| 31 |
+
LOG("\t-s <seed> The seed for the random number generator. (default %lu)\n", *(uint64_t *)defaults[2]);
|
| 32 |
+
LOG("\t-v Enables verbose mode. (default %s)\n", *(bool *)defaults[3] ? "true" : "false");
|
| 33 |
+
LOG("\t-h Prints this help message.\n\n");
|
| 34 |
+
|
| 35 |
+
exit(EXIT_FAILURE);
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
int main(int argc, char *argv[]) {
|
| 41 |
+
|
| 42 |
+
char *seed_text = NULL;
|
| 43 |
+
uint64_t n_predict = 256;
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
for (int i = 1; i < argc; i++) {
|
| 47 |
+
if (argv[i][0] != '-') {
|
| 48 |
+
LOG("Invalid argument: %s\n", argv[i]);
|
| 49 |
+
return 1;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
switch (argv[i][1]) {
|
| 53 |
+
case 'p':
|
| 54 |
+
seed_text = argv[++i];
|
| 55 |
+
break;
|
| 56 |
+
case 'n':
|
| 57 |
+
n_predict = strtoull(argv[++i], NULL, 10);
|
| 58 |
+
break;
|
| 59 |
+
case 't':
|
| 60 |
+
sampler.temperature = strtod(argv[++i], NULL);
|
| 61 |
+
break;
|
| 62 |
+
case 's':
|
| 63 |
+
sampler.rng_seed = strtoull(argv[++i], NULL, 10);
|
| 64 |
+
break;
|
| 65 |
+
case 'v':
|
| 66 |
+
sampler.verbose = true;
|
| 67 |
+
break;
|
| 68 |
+
case 'h':
|
| 69 |
+
goto help_;
|
| 70 |
+
break;
|
| 71 |
+
default:
|
| 72 |
+
LOG("Invalid argument: %s\n", argv[i]);
|
| 73 |
+
return 1;
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
if (sampler.verbose)
|
| 78 |
+
mamba.log(&mamba);
|
| 79 |
+
|
| 80 |
+
if (sampler.generate(&sampler, seed_text, n_predict) == EXIT_FAILURE)
|
| 81 |
+
goto help_;
|
| 82 |
+
|
| 83 |
+
return 0;
|
| 84 |
+
|
| 85 |
+
help_:
|
| 86 |
+
help(argv[0], (void *[]) {&n_predict, &sampler.temperature, &sampler.rng_seed, &sampler.verbose});
|
| 87 |
+
}
|
c_inferer/mamba.h
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
#include <stdio.h>
|
| 5 |
+
#include <string.h>
|
| 6 |
+
#include <math.h>
|
| 7 |
+
#include <stdint.h>
|
| 8 |
+
|
| 9 |
+
#include "config.h"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
extern char _embedded_binary_model[];
|
| 17 |
+
|
| 18 |
+
typedef float fp32_t;
|
| 19 |
+
|
| 20 |
+
typedef struct MambaWeights MambaWeights;
|
| 21 |
+
typedef struct MambaConfig MambaConfig;
|
| 22 |
+
typedef struct MambaState MambaState;
|
| 23 |
+
typedef struct Mamba Mamba;
|
| 24 |
+
|
| 25 |
+
static void MambaLog(Mamba *);
|
| 26 |
+
static fp32_t *MambaForwardLayer(Mamba *, uint64_t);
|
| 27 |
+
static fp32_t *MambaForward(Mamba *, uint64_t);
|
| 28 |
+
|
| 29 |
+
#define Tensor(NAME, X, Y, Z) fp32_t NAME[(X) * (Y) * (Z)]
|
| 30 |
+
|
| 31 |
+
struct MambaWeights {
|
| 32 |
+
Tensor(embed, ROUNDED_VOCAB_SIZE, D_MODEL, 1);
|
| 33 |
+
Tensor(in_proj, N_LAYER, 2 * D_INNER, D_MODEL);
|
| 34 |
+
Tensor(conv1d_weight, N_LAYER, D_INNER, D_CONV);
|
| 35 |
+
Tensor(conv1d_bias, N_LAYER, D_INNER, 1);
|
| 36 |
+
Tensor(x_proj, N_LAYER, DT_RANK + 2 * D_STATE, D_INNER);
|
| 37 |
+
Tensor(dt_proj_weight, N_LAYER, D_INNER, DT_RANK);
|
| 38 |
+
Tensor(dt_proj_bias, N_LAYER, D_INNER, 1);
|
| 39 |
+
Tensor(A, N_LAYER, D_INNER, D_STATE);
|
| 40 |
+
Tensor(D, N_LAYER, D_INNER, 1);
|
| 41 |
+
Tensor(out_proj, N_LAYER, D_MODEL, D_INNER);
|
| 42 |
+
Tensor(norm, N_LAYER, D_MODEL, 1);
|
| 43 |
+
Tensor(norm_f, D_MODEL, 1, 1);
|
| 44 |
+
Tensor(lm_head, ROUNDED_VOCAB_SIZE, D_MODEL, 1);
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
struct MambaConfig {
|
| 48 |
+
uint64_t vocab_size; // vocabulary size, rounded to nearest multiple of 8
|
| 49 |
+
uint64_t n_layer; // number of layers
|
| 50 |
+
uint64_t d_model; // embedding dim
|
| 51 |
+
uint64_t d_inner;
|
| 52 |
+
uint64_t dt_rank;
|
| 53 |
+
uint64_t d_state;
|
| 54 |
+
uint64_t d_conv;
|
| 55 |
+
};
|
| 56 |
+
|
| 57 |
+
struct MambaState {
|
| 58 |
+
Tensor(hidden_state, D_MODEL, 1, 1);
|
| 59 |
+
Tensor(conv_state, N_LAYER, D_INNER, D_CONV);
|
| 60 |
+
Tensor(ssm_state, N_LAYER, D_INNER, D_STATE);
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
struct Mamba {
|
| 64 |
+
MambaConfig config;
|
| 65 |
+
MambaState state;
|
| 66 |
+
|
| 67 |
+
MambaWeights *weights;
|
| 68 |
+
|
| 69 |
+
void (*log) (Mamba *);
|
| 70 |
+
fp32_t *(*forward_layer) (Mamba *, uint64_t);
|
| 71 |
+
fp32_t *(*forward) (Mamba *, uint64_t);
|
| 72 |
+
};
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
Mamba mamba = {
|
| 76 |
+
.config = {
|
| 77 |
+
.vocab_size = ROUNDED_VOCAB_SIZE,
|
| 78 |
+
.n_layer = N_LAYER,
|
| 79 |
+
.d_model = D_MODEL,
|
| 80 |
+
.d_inner = D_INNER,
|
| 81 |
+
.dt_rank = DT_RANK,
|
| 82 |
+
.d_state = D_STATE,
|
| 83 |
+
.d_conv = D_CONV,
|
| 84 |
+
},
|
| 85 |
+
|
| 86 |
+
.state = {},
|
| 87 |
+
|
| 88 |
+
.weights = (MambaWeights *) _embedded_binary_model,
|
| 89 |
+
|
| 90 |
+
.log = MambaLog,
|
| 91 |
+
.forward_layer = MambaForwardLayer,
|
| 92 |
+
.forward = MambaForward
|
| 93 |
+
};
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
static void rmsnorm(fp32_t* y, fp32_t* x, fp32_t* weight, uint64_t size) {
|
| 99 |
+
fp32_t ss = 0.0f;
|
| 100 |
+
for (uint64_t j = 0; j < size; ++j)
|
| 101 |
+
ss += x[j] * x[j];
|
| 102 |
+
|
| 103 |
+
ss /= size;
|
| 104 |
+
ss += 1e-5f;
|
| 105 |
+
ss = 1.0f / sqrtf(ss);
|
| 106 |
+
|
| 107 |
+
for (uint64_t j = 0; j < size; ++j)
|
| 108 |
+
y[j] = x[j] * weight[j] * ss;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
static fp32_t softplus(fp32_t x) { return logf(1.0f + expf(x)); }
|
| 112 |
+
static fp32_t sigmoid(fp32_t x) { return 1.0f / (1.0f + expf(-x)); }
|
| 113 |
+
static fp32_t silu(fp32_t x) { return x * sigmoid(x); }
|
| 114 |
+
|
| 115 |
+
static void shift_and_update_last_column(fp32_t* matrix, fp32_t* x, uint64_t rows, uint64_t cols) {
|
| 116 |
+
#pragma omp parallel for
|
| 117 |
+
for (uint64_t i = 0; i < rows; ++i) {
|
| 118 |
+
fp32_t* row = matrix + i * cols;
|
| 119 |
+
|
| 120 |
+
for (uint64_t j = 0; j < cols - 1; ++j)
|
| 121 |
+
row[j] = row[j + 1];
|
| 122 |
+
|
| 123 |
+
row[cols - 1] = x[i];
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
static void conv1d_silu(fp32_t* x, fp32_t* conv_state, fp32_t* conv1d_weight, fp32_t* conv1d_bias, uint64_t d_inner, uint64_t d_conv) {
|
| 128 |
+
#pragma omp parallel for
|
| 129 |
+
for (uint64_t i = 0; i < d_inner; ++i) {
|
| 130 |
+
fp32_t val = 0.0f;
|
| 131 |
+
|
| 132 |
+
for (uintconv1d_silu64_t j = 0; j < d_conv; ++j) {
|
| 133 |
+
uint64_t index = i * d_conv + j;
|
| 134 |
+
val += conv_state[index] * conv1d_weight[index];
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
x[i] = silu(val + conv1d_bias[i]);
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
static void dense_softplus(fp32_t* y, fp32_t* x, fp32_t* w, fp32_t* b, uint64_t d, uint64_t n) {
|
| 142 |
+
|
| 143 |
+
#pragma omp parallel for
|
| 144 |
+
for (uint64_t i = 0; i < d; ++i) {
|
| 145 |
+
fp32_t val = 0.0f;
|
| 146 |
+
|
| 147 |
+
for (uint64_t j = 0; j < n; ++j)
|
| 148 |
+
val += w[i * n + j] * x[j];
|
| 149 |
+
|
| 150 |
+
y[i] = softplus(val + b[i]);
|
| 151 |
+
}
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
static void selective_scan(fp32_t* y, fp32_t* ssm_state, fp32_t* dt, fp32_t* A, fp32_t* B, fp32_t* C, fp32_t* D, fp32_t* x, fp32_t* z, uint64_t d_inner, uint64_t d_state) {
|
| 155 |
+
#pragma omp parallel for
|
| 156 |
+
for (uint64_t i = 0; i < d_inner; ++i) {
|
| 157 |
+
fp32_t val = 0.0f;
|
| 158 |
+
|
| 159 |
+
for (uint64_t j = 0; j < d_state; ++j) {
|
| 160 |
+
uint64_t index = i * d_state + j;
|
| 161 |
+
|
| 162 |
+
fp32_t dA = expf(dt[i] * A[index]);
|
| 163 |
+
fp32_t dB = dt[i] * B[j];
|
| 164 |
+
|
| 165 |
+
ssm_state[index] = ssm_state[index] * dA + x[i] * dB;
|
| 166 |
+
|
| 167 |
+
val += ssm_state[index] * C[j];
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
val += D[i] * x[i];
|
| 171 |
+
y[i] = val * silu(z[i]);
|
| 172 |
+
}
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
static void matmul(fp32_t* y, fp32_t* x, fp32_t* w, uint64_t d, uint64_t n) {
|
| 177 |
+
#pragma omp parallel for
|
| 178 |
+
for (uint64_t i = 0; i < d; ++i) {
|
| 179 |
+
fp32_t val = 0.0f;
|
| 180 |
+
|
| 181 |
+
for (uint64_t j = 0; j < n; ++j)
|
| 182 |
+
val += w[i * n + j] * x[j];
|
| 183 |
+
|
| 184 |
+
y[i] = val;
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
#define LOG(...) fprintf(stderr, __VA_ARGS__)
|
| 191 |
+
#define CLOG(X, ...) if(X) LOG(__VA_ARGS__)
|
| 192 |
+
|
| 193 |
+
static char *shortScale(char *result, int64_t number) {
|
| 194 |
+
char *suffixes[] = {"", "k", "m", "b"};
|
| 195 |
+
uint64_t magnitude = 0;
|
| 196 |
+
fp32_t num = (fp32_t)number;
|
| 197 |
+
|
| 198 |
+
if (number < 1000) {
|
| 199 |
+
sprintf(result, "%lu", number);
|
| 200 |
+
return result;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
while (number >= 1000 || number <= -1000) {
|
| 204 |
+
magnitude++;
|
| 205 |
+
number /= 1000;
|
| 206 |
+
num /= 1000.0;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
sprintf(result, "%.0f%s", num, suffixes[magnitude]);
|
| 210 |
+
return result;
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
static inline void nparams(uint64_t dim) {
|
| 216 |
+
char buff[16];
|
| 217 |
+
shortScale(buff, dim);
|
| 218 |
+
|
| 219 |
+
LOG("%12lu (%s)", dim, buff);
|
| 220 |
+
}
|
| 221 |
+
#define NPARAMS(X) nparams(sizeof(X) / sizeof(fp32_t))
|
| 222 |
+
|
| 223 |
+
static void MambaLog(Mamba *mamba) {
|
| 224 |
+
MambaConfig *p = &mamba->config;
|
| 225 |
+
MambaWeights *w = mamba->weights;
|
| 226 |
+
MambaState *s = &mamba->state;
|
| 227 |
+
|
| 228 |
+
LOG("Mamba Config:");
|
| 229 |
+
LOG("\n\tvocab_size: %12lu", p->vocab_size);
|
| 230 |
+
LOG("\n\tn_layer: %12lu", p->n_layer);
|
| 231 |
+
LOG("\n\td_model: %12lu", p->d_model);
|
| 232 |
+
LOG("\n\td_inner: %12lu", p->d_inner);
|
| 233 |
+
LOG("\n\tdt_rank: %12lu", p->dt_rank);
|
| 234 |
+
LOG("\n\td_state: %12lu", p->d_state);
|
| 235 |
+
LOG("\n\td_conv: %12lu", p->d_conv);
|
| 236 |
+
printf("\n\n\n");
|
| 237 |
+
|
| 238 |
+
LOG("Parameters Count:");
|
| 239 |
+
LOG("\n\tembed: "); NPARAMS(w->embed);
|
| 240 |
+
LOG("\n\tin_proj: "); NPARAMS(w->in_proj);
|
| 241 |
+
LOG("\n\tconv1d_weight: "); NPARAMS(w->conv1d_weight);
|
| 242 |
+
LOG("\n\tconv1d_bias: "); NPARAMS(w->conv1d_bias);
|
| 243 |
+
LOG("\n\tx_proj: "); NPARAMS(w->x_proj);
|
| 244 |
+
LOG("\n\tdt_proj_weight: "); NPARAMS(w->dt_proj_weight);
|
| 245 |
+
LOG("\n\tdt_proj_bias: "); NPARAMS(w->dt_proj_bias);
|
| 246 |
+
LOG("\n\tA: "); NPARAMS(w->A);
|
| 247 |
+
LOG("\n\tD: "); NPARAMS(w->D);
|
| 248 |
+
LOG("\n\tout_proj: "); NPARAMS(w->out_proj);
|
| 249 |
+
LOG("\n\tnorm: "); NPARAMS(w->norm);
|
| 250 |
+
LOG("\n\tnorm_f: "); NPARAMS(w->norm_f);
|
| 251 |
+
LOG("\n\tlm_head: "); NPARAMS(w->lm_head);
|
| 252 |
+
LOG("\n\n\tTotal: "); NPARAMS(MambaWeights); //NTotal();
|
| 253 |
+
printf("\n\n\n");
|
| 254 |
+
|
| 255 |
+
LOG("Recurrent State:");
|
| 256 |
+
LOG("\n\thidden_state: "); NPARAMS(s->hidden_state);
|
| 257 |
+
LOG("\n\tconv_state: "); NPARAMS(s->conv_state);
|
| 258 |
+
LOG("\n\tssm_state: "); NPARAMS(s->ssm_state);
|
| 259 |
+
LOG("\n\n\tTotal: "); NPARAMS(MambaState);
|
| 260 |
+
printf("\n\n\n");
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
static fp32_t *MambaForwardLayer(Mamba *mamba, uint64_t layer) {
|
| 264 |
+
MambaConfig *p = &mamba->config;
|
| 265 |
+
MambaWeights *w = mamba->weights;
|
| 266 |
+
MambaState *s = &mamba->state;
|
| 267 |
+
|
| 268 |
+
uint64_t d_model = p->d_model,
|
| 269 |
+
d_inner = p->d_inner,
|
| 270 |
+
d_conv = p->d_conv,
|
| 271 |
+
d_state = p->d_state,
|
| 272 |
+
dt_rank = p->dt_rank;
|
| 273 |
+
|
| 274 |
+
fp32_t *hidden_state = s->hidden_state,
|
| 275 |
+
*conv_state = s->conv_state + layer * d_inner * d_conv,
|
| 276 |
+
*ssm_state = s->ssm_state + layer * d_inner * d_state;
|
| 277 |
+
|
| 278 |
+
Tensor( xz, 2 * D_INNER, 1, 1);
|
| 279 |
+
Tensor( x_db, DT_RANK + 2 * D_STATE, 1, 1);
|
| 280 |
+
Tensor( dt, D_INNER, 1, 1);
|
| 281 |
+
Tensor( y, D_INNER, 1, 1);
|
| 282 |
+
|
| 283 |
+
fp32_t *x = xz,
|
| 284 |
+
*z = xz + d_inner,
|
| 285 |
+
*B = x_db + dt_rank,
|
| 286 |
+
*C = x_db + dt_rank + d_state;
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
// Proj input
|
| 290 |
+
matmul(xz, hidden_state, w->in_proj + layer * 2 * d_inner * d_model, 2 * d_inner, d_model);
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
// Conv
|
| 294 |
+
shift_and_update_last_column(conv_state, x, d_inner, d_conv);
|
| 295 |
+
conv1d_silu(x, conv_state, w->conv1d_weight + layer * d_inner * d_conv, w->conv1d_bias + layer * d_inner, d_inner, d_conv);
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
// SSM
|
| 299 |
+
matmul(x_db, x, w->x_proj + layer * (dt_rank + 2 * d_state) * d_inner, dt_rank + 2 * d_state, d_inner);
|
| 300 |
+
dense_softplus(dt, x_db, w->dt_proj_weight + layer * d_inner * dt_rank, w->dt_proj_bias + layer * d_inner, d_inner, dt_rank);
|
| 301 |
+
selective_scan(y, ssm_state, dt, w->A + layer * d_inner * d_state, B, C, w->D + layer * d_inner, x, z, d_inner, d_state);
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
// Proj output
|
| 305 |
+
matmul(hidden_state, y, w->out_proj + layer * d_model * d_inner, d_model, d_inner);
|
| 306 |
+
|
| 307 |
+
return hidden_state;
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
static fp32_t *MambaForward(Mamba *mamba, uint64_t token) {
|
| 311 |
+
MambaConfig *p = &mamba->config;
|
| 312 |
+
MambaWeights *w = mamba->weights;
|
| 313 |
+
MambaState *s = &mamba->state;
|
| 314 |
+
|
| 315 |
+
uint64_t d_model = p->d_model,
|
| 316 |
+
n_layer = p->n_layer,
|
| 317 |
+
vocab_size = p->vocab_size;
|
| 318 |
+
|
| 319 |
+
Tensor( input, D_MODEL, 1, 1);
|
| 320 |
+
static Tensor( logits, ROUNDED_VOCAB_SIZE, 1, 1);
|
| 321 |
+
|
| 322 |
+
fp32_t *hidden_state = s->hidden_state,
|
| 323 |
+
*content_row = w->embed + token * d_model;
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
memcpy(input, content_row, d_model * sizeof(fp32_t));
|
| 327 |
+
|
| 328 |
+
for (uint64_t layer = 0; layer < n_layer; ++layer) {
|
| 329 |
+
rmsnorm(hidden_state, input, w->norm + layer * d_model, d_model);
|
| 330 |
+
mamba->forward_layer(mamba, layer);
|
| 331 |
+
|
| 332 |
+
for (uint64_t i = 0; i < d_model; ++i) {
|
| 333 |
+
hidden_state[i] += input[i];
|
| 334 |
+
input[i] = hidden_state[i];
|
| 335 |
+
}
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
rmsnorm(hidden_state, hidden_state, w->norm_f, d_model);
|
| 339 |
+
matmul(logits, hidden_state, w->lm_head, vocab_size, d_model);
|
| 340 |
+
|
| 341 |
+
return logits;
|
| 342 |
+
}
|
| 343 |
+
|
c_inferer/qmamba.h
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
#include <stdio.h>
|
| 5 |
+
#include <string.h>
|
| 6 |
+
#include <math.h>
|
| 7 |
+
#include <stdint.h>
|
| 8 |
+
|
| 9 |
+
#include "config.h"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
extern char _embedded_binary_model[];
|
| 17 |
+
|
| 18 |
+
typedef float fp32_t;
|
| 19 |
+
|
| 20 |
+
typedef struct MambaWeights MambaWeights;
|
| 21 |
+
typedef struct MambaConfig MambaConfig;
|
| 22 |
+
typedef struct MambaState MambaState;
|
| 23 |
+
typedef struct Mamba Mamba;
|
| 24 |
+
|
| 25 |
+
static void MambaLog(Mamba *);
|
| 26 |
+
static fp32_t *MambaForwardLayer(Mamba *, uint64_t);
|
| 27 |
+
static fp32_t *MambaForward(Mamba *, uint64_t);
|
| 28 |
+
|
| 29 |
+
#define GS 64
|
| 30 |
+
|
| 31 |
+
#define Tensor(NAME, X, Y, Z) fp32_t NAME[(X) * (Y) * (Z)]
|
| 32 |
+
#define QTensor(NAME, X, Y, Z) struct { int8_t q[(X) * (Y) * (Z)]; fp32_t s[((X) * (Y) * (Z)) / GS]; } NAME
|
| 33 |
+
#define QPointer(NAME, QT, P) struct { int8_t *q; fp32_t *s; } NAME = { .q = QT.q + (P), .s = QT.s + ((P) / GS) }
|
| 34 |
+
|
| 35 |
+
struct MambaWeights {
|
| 36 |
+
QTensor(embed, ROUNDED_VOCAB_SIZE, D_MODEL, 1);
|
| 37 |
+
QTensor(in_proj, N_LAYER, 2 * D_INNER, D_MODEL);
|
| 38 |
+
|
| 39 |
+
Tensor(conv1d_weight, N_LAYER, D_INNER, D_CONV);
|
| 40 |
+
Tensor(conv1d_bias, N_LAYER, D_INNER, 1);
|
| 41 |
+
|
| 42 |
+
QTensor(x_proj, N_LAYER, DT_RANK + 2 * D_STATE, D_INNER);
|
| 43 |
+
|
| 44 |
+
Tensor(dt_proj_weight, N_LAYER, D_INNER, DT_RANK);
|
| 45 |
+
Tensor(dt_proj_bias, N_LAYER, D_INNER, 1);
|
| 46 |
+
Tensor(A, N_LAYER, D_INNER, D_STATE);
|
| 47 |
+
Tensor(D, N_LAYER, D_INNER, 1);
|
| 48 |
+
|
| 49 |
+
QTensor(out_proj, N_LAYER, D_MODEL, D_INNER);
|
| 50 |
+
|
| 51 |
+
Tensor(norm, N_LAYER, D_MODEL, 1);
|
| 52 |
+
Tensor(norm_f, D_MODEL, 1, 1);
|
| 53 |
+
|
| 54 |
+
QTensor(lm_head, ROUNDED_VOCAB_SIZE, D_MODEL, 1);
|
| 55 |
+
};
|
| 56 |
+
|
| 57 |
+
struct MambaConfig {
|
| 58 |
+
uint64_t vocab_size; // vocabulary size, rounded to nearest multiple of 8
|
| 59 |
+
uint64_t n_layer; // number of layers
|
| 60 |
+
uint64_t d_model; // embedding dim
|
| 61 |
+
uint64_t d_inner;
|
| 62 |
+
uint64_t dt_rank;
|
| 63 |
+
uint64_t d_state;
|
| 64 |
+
uint64_t d_conv;
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
struct MambaState {
|
| 68 |
+
Tensor(hidden_state, D_MODEL, 1, 1);
|
| 69 |
+
Tensor(conv_state, N_LAYER, D_INNER, D_CONV);
|
| 70 |
+
Tensor(ssm_state, N_LAYER, D_INNER, D_STATE);
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
struct Mamba {
|
| 74 |
+
MambaConfig config;
|
| 75 |
+
MambaState state;
|
| 76 |
+
|
| 77 |
+
MambaWeights *weights;
|
| 78 |
+
|
| 79 |
+
void (*log) (Mamba *);
|
| 80 |
+
fp32_t *(*forward_layer) (Mamba *, uint64_t);
|
| 81 |
+
fp32_t *(*forward) (Mamba *, uint64_t);
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
Mamba mamba = {
|
| 86 |
+
.config = {
|
| 87 |
+
.vocab_size = ROUNDED_VOCAB_SIZE,
|
| 88 |
+
.n_layer = N_LAYER,
|
| 89 |
+
.d_model = D_MODEL,
|
| 90 |
+
.d_inner = D_INNER,
|
| 91 |
+
.dt_rank = DT_RANK,
|
| 92 |
+
.d_state = D_STATE,
|
| 93 |
+
.d_conv = D_CONV,
|
| 94 |
+
},
|
| 95 |
+
|
| 96 |
+
.state = {},
|
| 97 |
+
|
| 98 |
+
.weights = (MambaWeights *) _embedded_binary_model,
|
| 99 |
+
|
| 100 |
+
.log = MambaLog,
|
| 101 |
+
.forward_layer = MambaForwardLayer,
|
| 102 |
+
.forward = MambaForward
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
static void rmsnorm(fp32_t* y, fp32_t* x, fp32_t* weight, uint64_t size) {
|
| 109 |
+
fp32_t ss = 0.0f;
|
| 110 |
+
for (uint64_t j = 0; j < size; ++j)
|
| 111 |
+
ss += x[j] * x[j];
|
| 112 |
+
|
| 113 |
+
ss /= size;
|
| 114 |
+
ss += 1e-5f;
|
| 115 |
+
ss = 1.0f / sqrtf(ss);
|
| 116 |
+
|
| 117 |
+
for (uint64_t j = 0; j < size; ++j)
|
| 118 |
+
y[j] = x[j] * weight[j] * ss;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
static fp32_t softplus(fp32_t x) { return logf(1.0f + expf(x)); }
|
| 122 |
+
static fp32_t sigmoid(fp32_t x) { return 1.0f / (1.0f + expf(-x)); }
|
| 123 |
+
static fp32_t silu(fp32_t x) { return x * sigmoid(x); }
|
| 124 |
+
|
| 125 |
+
static void shift_and_update_last_column(fp32_t* matrix, fp32_t* x, uint64_t rows, uint64_t cols) {
|
| 126 |
+
#pragma omp parallel for
|
| 127 |
+
for (uint64_t i = 0; i < rows; ++i) {
|
| 128 |
+
fp32_t* row = matrix + i * cols;
|
| 129 |
+
|
| 130 |
+
for (uint64_t j = 0; j < cols - 1; ++j)
|
| 131 |
+
row[j] = row[j + 1];
|
| 132 |
+
|
| 133 |
+
row[cols - 1] = x[i];
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
static void conv1d_silu(fp32_t* x, fp32_t* conv_state, fp32_t* conv1d_weight, fp32_t* conv1d_bias, uint64_t d_inner, uint64_t d_conv) {
|
| 138 |
+
#pragma omp parallel for
|
| 139 |
+
for (uint64_t i = 0; i < d_inner; ++i) {
|
| 140 |
+
fp32_t val = 0.0f;
|
| 141 |
+
|
| 142 |
+
for (uint64_t j = 0; j < d_conv; ++j) {
|
| 143 |
+
uint64_t index = i * d_conv + j;
|
| 144 |
+
val += conv_state[index] * conv1d_weight[index];
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
x[i] = silu(val + conv1d_bias[i]);
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
static void dense_softplus(fp32_t* y, fp32_t* x, fp32_t* w, fp32_t* b, uint64_t d, uint64_t n) {
|
| 152 |
+
|
| 153 |
+
#pragma omp parallel for
|
| 154 |
+
for (uint64_t i = 0; i < d; ++i) {
|
| 155 |
+
fp32_t val = 0.0f;
|
| 156 |
+
|
| 157 |
+
for (uint64_t j = 0; j < n; ++j)
|
| 158 |
+
val += w[i * n + j] * x[j];
|
| 159 |
+
|
| 160 |
+
y[i] = softplus(val + b[i]);
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
static void selective_scan(fp32_t* y, fp32_t* ssm_state, fp32_t* dt, fp32_t* A, fp32_t* B, fp32_t* C, fp32_t* D, fp32_t* x, fp32_t* z, uint64_t d_inner, uint64_t d_state) {
|
| 165 |
+
#pragma omp parallel for
|
| 166 |
+
for (uint64_t i = 0; i < d_inner; ++i) {
|
| 167 |
+
fp32_t val = 0.0f;
|
| 168 |
+
|
| 169 |
+
for (uint64_t j = 0; j < d_state; ++j) {
|
| 170 |
+
uint64_t index = i * d_state + j;
|
| 171 |
+
|
| 172 |
+
fp32_t dA = expf(dt[i] * A[index]);
|
| 173 |
+
fp32_t dB = dt[i] * B[j];
|
| 174 |
+
|
| 175 |
+
ssm_state[index] = ssm_state[index] * dA + x[i] * dB;
|
| 176 |
+
|
| 177 |
+
val += ssm_state[index] * C[j];
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
val += D[i] * x[i];
|
| 181 |
+
y[i] = val * silu(z[i]);
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
static void qmatmul(fp32_t* y, int8_t *xq, fp32_t *xs, int8_t *wq, fp32_t *ws, uint64_t d, uint64_t n) {
|
| 188 |
+
#pragma omp parallel for
|
| 189 |
+
for (uint64_t i = 0; i < d; i++) {
|
| 190 |
+
|
| 191 |
+
float val = 0.0f;
|
| 192 |
+
int64_t in = i * n;
|
| 193 |
+
|
| 194 |
+
// do the matmul in groups of GS
|
| 195 |
+
for (int j = 0; j <= n - GS; j += GS) {
|
| 196 |
+
int32_t ival = 0;
|
| 197 |
+
|
| 198 |
+
for (uint64_t k = 0; k < GS; k++)
|
| 199 |
+
ival += ((uint64_t) xq[j + k]) * ((uint64_t) wq[in + j + k]);
|
| 200 |
+
|
| 201 |
+
val += ((float) ival) * ws[(in + j) / GS] * xs[j / GS];
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
y[i] = val;
|
| 205 |
+
}
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
void dequantize(fp32_t* x, int8_t *xq, fp32_t *xs, uint64_t n) {
|
| 209 |
+
for (uint64_t i = 0; i < n; i++)
|
| 210 |
+
x[i] = xq[i] * xs[i / GS];
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
void quantize(int8_t *xq, fp32_t *xs, fp32_t* x, uint64_t n) {
|
| 215 |
+
uint64_t num_groups = n / GS;
|
| 216 |
+
fp32_t Q_MAX = 127.0f;
|
| 217 |
+
|
| 218 |
+
for (uint64_t group = 0; group < num_groups; group++) {
|
| 219 |
+
|
| 220 |
+
// find the max absolute value in the current group
|
| 221 |
+
fp32_t wmax = 0.0;
|
| 222 |
+
for (uint64_t i = 0; i < GS; i++) {
|
| 223 |
+
fp32_t val = fabs(x[group * GS + i]);
|
| 224 |
+
if (val > wmax) {
|
| 225 |
+
wmax = val;
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
// calculate and write the scaling factor
|
| 230 |
+
fp32_t scale = wmax / Q_MAX;
|
| 231 |
+
xs[group] = scale;
|
| 232 |
+
|
| 233 |
+
// calculate and write the quantized values
|
| 234 |
+
for (uint64_t i = 0; i < GS; i++) {
|
| 235 |
+
fp32_t quant_value = x[group * GS + i] / scale; // scale
|
| 236 |
+
int8_t quantized = (int8_t) round(quant_value); // round and clamp
|
| 237 |
+
xq[group * GS + i] = quantized;
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
#define LOG(...) fprintf(stderr, __VA_ARGS__)
|
| 245 |
+
#define CLOG(X, ...) if(X) LOG(__VA_ARGS__)
|
| 246 |
+
|
| 247 |
+
static char *shortScale(char *result, int64_t number) {
|
| 248 |
+
char *suffixes[] = {"", "k", "m", "b"};
|
| 249 |
+
uint64_t magnitude = 0;
|
| 250 |
+
fp32_t num = (fp32_t)number;
|
| 251 |
+
|
| 252 |
+
if (number < 1000) {
|
| 253 |
+
sprintf(result, "%lu", number);
|
| 254 |
+
return result;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
while (number >= 1000 || number <= -1000) {
|
| 258 |
+
magnitude++;
|
| 259 |
+
number /= 1000;
|
| 260 |
+
num /= 1000.0;
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
sprintf(result, "%.0f%s", num, suffixes[magnitude]);
|
| 264 |
+
return result;
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
uint64_t global_scale = 0;
|
| 270 |
+
static inline void nparams(uint64_t dim) {
|
| 271 |
+
char buff[16];
|
| 272 |
+
shortScale(buff, dim);
|
| 273 |
+
|
| 274 |
+
global_scale += dim;
|
| 275 |
+
|
| 276 |
+
LOG("%12lu (%s)", dim, buff);
|
| 277 |
+
}
|
| 278 |
+
#define NPARAMS(X) nparams(sizeof(X) / sizeof(fp32_t))
|
| 279 |
+
|
| 280 |
+
static inline void nqparams(uint64_t qdim) {
|
| 281 |
+
uint64_t dim = qdim * 4;
|
| 282 |
+
|
| 283 |
+
char buff[16];
|
| 284 |
+
shortScale(buff, qdim);
|
| 285 |
+
|
| 286 |
+
nparams(dim);
|
| 287 |
+
LOG("\t%12lu (%s)", qdim, buff);
|
| 288 |
+
}
|
| 289 |
+
#define NQPARAMS(X) nqparams(sizeof(X.q) / sizeof(fp32_t))
|
| 290 |
+
|
| 291 |
+
static inline void ntotal(uint64_t qdim) {
|
| 292 |
+
char buff[16];
|
| 293 |
+
shortScale(buff, global_scale);
|
| 294 |
+
LOG("%12lu (%s)", global_scale, buff);
|
| 295 |
+
|
| 296 |
+
shortScale(buff, qdim);
|
| 297 |
+
LOG("\t%12lu (%s)", qdim, buff);
|
| 298 |
+
|
| 299 |
+
fp32_t factor = (((fp32_t) (global_scale - qdim)) / global_scale) * 100;
|
| 300 |
+
LOG(" < %.2lf%%", factor);
|
| 301 |
+
}
|
| 302 |
+
#define NTOTAL(X) ntotal(sizeof(X) / sizeof(fp32_t))
|
| 303 |
+
|
| 304 |
+
static void MambaLog(Mamba *mamba) {
|
| 305 |
+
MambaConfig *p = &mamba->config;
|
| 306 |
+
MambaWeights *w = mamba->weights;
|
| 307 |
+
MambaState *s = &mamba->state;
|
| 308 |
+
|
| 309 |
+
LOG("Mamba Config:");
|
| 310 |
+
LOG("\n\tvocab_size: %12lu", p->vocab_size);
|
| 311 |
+
LOG("\n\tn_layer: %12lu", p->n_layer);
|
| 312 |
+
LOG("\n\td_model: %12lu", p->d_model);
|
| 313 |
+
LOG("\n\td_inner: %12lu", p->d_inner);
|
| 314 |
+
LOG("\n\tdt_rank: %12lu", p->dt_rank);
|
| 315 |
+
LOG("\n\td_state: %12lu", p->d_state);
|
| 316 |
+
LOG("\n\td_conv: %12lu", p->d_conv);
|
| 317 |
+
printf("\n\n\n");
|
| 318 |
+
|
| 319 |
+
LOG("Parameters Count: Base Quantized Factor");
|
| 320 |
+
LOG("\n\tembed: "); NQPARAMS(w->embed);
|
| 321 |
+
LOG("\n\tin_proj: "); NQPARAMS(w->in_proj);
|
| 322 |
+
LOG("\n\tconv1d_weight: "); NPARAMS(w->conv1d_weight);
|
| 323 |
+
LOG("\n\tconv1d_bias: "); NPARAMS(w->conv1d_bias);
|
| 324 |
+
LOG("\n\tx_proj: "); NQPARAMS(w->x_proj);
|
| 325 |
+
LOG("\n\tdt_proj_weight: "); NPARAMS(w->dt_proj_weight);
|
| 326 |
+
LOG("\n\tdt_proj_bias: "); NPARAMS(w->dt_proj_bias);
|
| 327 |
+
LOG("\n\tA: "); NPARAMS(w->A);
|
| 328 |
+
LOG("\n\tD: "); NPARAMS(w->D);
|
| 329 |
+
LOG("\n\tout_proj: "); NQPARAMS(w->out_proj);
|
| 330 |
+
LOG("\n\tnorm: "); NPARAMS(w->norm);
|
| 331 |
+
LOG("\n\tnorm_f: "); NPARAMS(w->norm_f);
|
| 332 |
+
LOG("\n\tlm_head: "); NQPARAMS(w->lm_head);
|
| 333 |
+
LOG("\n\n\tTotal: "); NTOTAL(MambaWeights);
|
| 334 |
+
printf("\n\n\n");
|
| 335 |
+
|
| 336 |
+
LOG("Recurrent State:");
|
| 337 |
+
LOG("\n\thidden_state: "); NPARAMS(s->hidden_state);
|
| 338 |
+
LOG("\n\tconv_state: "); NPARAMS(s->conv_state);
|
| 339 |
+
LOG("\n\tssm_state: "); NPARAMS(s->ssm_state);
|
| 340 |
+
LOG("\n\n\tTotal: "); NPARAMS(MambaState);
|
| 341 |
+
printf("\n\n\n");
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
static fp32_t *MambaForwardLayer(Mamba *mamba, uint64_t layer) {
|
| 346 |
+
MambaConfig *p = &mamba->config;
|
| 347 |
+
MambaWeights *w = mamba->weights;
|
| 348 |
+
MambaState *s = &mamba->state;
|
| 349 |
+
|
| 350 |
+
uint64_t d_model = p->d_model,
|
| 351 |
+
d_inner = p->d_inner,
|
| 352 |
+
d_conv = p->d_conv,
|
| 353 |
+
d_state = p->d_state,
|
| 354 |
+
dt_rank = p->dt_rank;
|
| 355 |
+
|
| 356 |
+
fp32_t *hidden_state = s->hidden_state,
|
| 357 |
+
*conv_state = s->conv_state + layer * d_inner * d_conv,
|
| 358 |
+
*ssm_state = s->ssm_state + layer * d_inner * d_state;
|
| 359 |
+
|
| 360 |
+
QTensor( qhidden_state, D_MODEL, 1, 1);
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
Tensor( xz, 2 * D_INNER, 1, 1);
|
| 364 |
+
QTensor( qx, D_INNER, 1, 1);
|
| 365 |
+
Tensor( x_db, DT_RANK + 2 * D_STATE, 1, 1);
|
| 366 |
+
Tensor( dt, D_INNER, 1, 1);
|
| 367 |
+
Tensor( y, D_INNER, 1, 1);
|
| 368 |
+
QTensor( qy, D_INNER, 1, 1);
|
| 369 |
+
|
| 370 |
+
fp32_t *x = xz,
|
| 371 |
+
*z = xz + d_inner,
|
| 372 |
+
*B = x_db + dt_rank,
|
| 373 |
+
*C = x_db + dt_rank + d_state;
|
| 374 |
+
|
| 375 |
+
QPointer( in_proj, w->in_proj, layer * 2 * d_inner * d_model);
|
| 376 |
+
QPointer( x_proj, w->x_proj, layer * (dt_rank + 2 * d_state) * d_inner);
|
| 377 |
+
QPointer( out_proj, w->out_proj, layer * d_model * d_inner);
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
// Proj input
|
| 382 |
+
quantize(qhidden_state.q, qhidden_state.s, hidden_state, d_model);
|
| 383 |
+
qmatmul(xz, qhidden_state.q, qhidden_state.s, in_proj.q, in_proj.s, 2 * d_inner, d_model);
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
// Conv
|
| 387 |
+
shift_and_update_last_column(conv_state, x, d_inner, d_conv);
|
| 388 |
+
conv1d_silu(x, conv_state, w->conv1d_weight + layer * d_inner * d_conv, w->conv1d_bias + layer * d_inner, d_inner, d_conv);
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
// SSM
|
| 392 |
+
quantize(qx.q, qx.s, x, d_inner);
|
| 393 |
+
qmatmul(x_db, qx.q, qx.s, x_proj.q, x_proj.s, dt_rank + 2 * d_state, d_inner);
|
| 394 |
+
|
| 395 |
+
dense_softplus(dt, x_db, w->dt_proj_weight + layer * d_inner * dt_rank, w->dt_proj_bias + layer * d_inner, d_inner, dt_rank);
|
| 396 |
+
selective_scan(y, ssm_state, dt, w->A + layer * d_inner * d_state, B, C, w->D + layer * d_inner, x, z, d_inner, d_state);
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
// Proj output
|
| 400 |
+
quantize(qy.q, qy.s, y, d_inner);
|
| 401 |
+
qmatmul(hidden_state, qy.q, qy.s, out_proj.q, out_proj.s, d_model, d_inner);
|
| 402 |
+
|
| 403 |
+
return hidden_state;
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
static fp32_t *MambaForward(Mamba *mamba, uint64_t token) {
|
| 408 |
+
MambaConfig *p = &mamba->config;
|
| 409 |
+
MambaWeights *w = mamba->weights;
|
| 410 |
+
MambaState *s = &mamba->state;
|
| 411 |
+
|
| 412 |
+
uint64_t d_model = p->d_model,
|
| 413 |
+
n_layer = p->n_layer,
|
| 414 |
+
vocab_size = p->vocab_size;
|
| 415 |
+
|
| 416 |
+
Tensor( input, D_MODEL, 1, 1);
|
| 417 |
+
static Tensor( logits, ROUNDED_VOCAB_SIZE, 1, 1);
|
| 418 |
+
|
| 419 |
+
QTensor( qhidden_state, D_MODEL, 1, 1);
|
| 420 |
+
|
| 421 |
+
fp32_t *hidden_state = s->hidden_state;
|
| 422 |
+
|
| 423 |
+
QPointer( row, w->embed, token * d_model);
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
dequantize(input, row.q, row.s, d_model);
|
| 427 |
+
|
| 428 |
+
for (uint64_t layer = 0; layer < n_layer; ++layer) {
|
| 429 |
+
rmsnorm(hidden_state, input, w->norm + layer * d_model, d_model);
|
| 430 |
+
mamba->forward_layer(mamba, layer);
|
| 431 |
+
|
| 432 |
+
for (uint64_t i = 0; i < d_model; ++i) {
|
| 433 |
+
hidden_state[i] += input[i];
|
| 434 |
+
input[i] = hidden_state[i];
|
| 435 |
+
}
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
rmsnorm(hidden_state, hidden_state, w->norm_f, d_model);
|
| 439 |
+
|
| 440 |
+
quantize(qhidden_state.q, qhidden_state.s, hidden_state, d_model);
|
| 441 |
+
qmatmul(logits, qhidden_state.q, qhidden_state.s, w->lm_head.q, w->lm_head.s, vocab_size, d_model);
|
| 442 |
+
|
| 443 |
+
return logits;
|
| 444 |
+
}
|
c_inferer/sampler.h
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
//#include <stdbool.h>
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
#define PRINT(C) fputc((char)C, stdout), fflush(stdout)
|
| 13 |
+
|
| 14 |
+
typedef enum {false, true} bool;
|
| 15 |
+
|
| 16 |
+
typedef struct Sampler Sampler;
|
| 17 |
+
struct Sampler {
|
| 18 |
+
Mamba *model;
|
| 19 |
+
Tokenizer *tokenizer;
|
| 20 |
+
|
| 21 |
+
uint64_t rng_seed;
|
| 22 |
+
fp32_t temperature;
|
| 23 |
+
bool verbose;
|
| 24 |
+
|
| 25 |
+
bool (*generate) (Sampler *, char *, uint64_t);
|
| 26 |
+
uint64_t (*sample) (Sampler *, fp32_t *);
|
| 27 |
+
};
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
static void softmax(fp32_t* x, uint64_t size) {
|
| 33 |
+
fp32_t max_val = x[0];
|
| 34 |
+
for (uint64_t i = 1; i < size; ++i)
|
| 35 |
+
if (x[i] > max_val) max_val = x[i];
|
| 36 |
+
|
| 37 |
+
fp32_t sum = 0.0f;
|
| 38 |
+
for (uint64_t i = 0; i < size; ++i) {
|
| 39 |
+
x[i] = expf(x[i] - max_val);
|
| 40 |
+
sum += x[i];
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
for (uint64_t i = 0; i < size; ++i)
|
| 44 |
+
x[i] /= sum;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
static uint64_t random_u32(uint64_t *rng_seed) {
|
| 49 |
+
*rng_seed ^= *rng_seed >> 12;
|
| 50 |
+
*rng_seed ^= *rng_seed << 25;
|
| 51 |
+
*rng_seed ^= *rng_seed >> 27;
|
| 52 |
+
*rng_seed = (*rng_seed * 0x2545F4914F6CDD1Dull) >> 32;
|
| 53 |
+
return *rng_seed;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
static inline fp32_t random_f32(uint64_t *rng_seed) { return (random_u32(rng_seed) >> 8) / 16777216.0f; }
|
| 58 |
+
|
| 59 |
+
static uint64_t time_in_ms() {
|
| 60 |
+
struct timeval tv;
|
| 61 |
+
gettimeofday(&tv, NULL);
|
| 62 |
+
return tv.tv_sec * 1000 + tv.tv_usec / 1000;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
static inline uint64_t sample_argmax(fp32_t* probabilities, uint64_t n) {
|
| 66 |
+
uint64_t max_i = 0;
|
| 67 |
+
fp32_t max_p = probabilities[0];
|
| 68 |
+
|
| 69 |
+
for (uint64_t i = 1; i < n; ++i)
|
| 70 |
+
if (probabilities[i] > max_p)
|
| 71 |
+
max_i = i, max_p = probabilities[i];
|
| 72 |
+
|
| 73 |
+
return max_i;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
static inline uint64_t sample_mult(fp32_t* probabilities, uint64_t n, fp32_t coin) {
|
| 77 |
+
fp32_t cdf = 0.0f;
|
| 78 |
+
|
| 79 |
+
for (uint64_t i = 0; i < n; ++i) {
|
| 80 |
+
cdf += probabilities[i];
|
| 81 |
+
|
| 82 |
+
if (coin < cdf) return i;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
return n - 1;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
static uint64_t SamplerSample(Sampler *sampler, fp32_t* logits) {
|
| 89 |
+
uint64_t next,
|
| 90 |
+
vocab_size = sampler->tokenizer->vocab_size,
|
| 91 |
+
*rng_seed = &sampler->rng_seed;
|
| 92 |
+
|
| 93 |
+
//printf("Vocab size: %llu\n", vocab_size);
|
| 94 |
+
|
| 95 |
+
fp32_t temperature = sampler->temperature;
|
| 96 |
+
|
| 97 |
+
if (temperature == 0.0f) next = sample_argmax(logits, vocab_size);
|
| 98 |
+
else {
|
| 99 |
+
for (uint64_t q = 0; q < vocab_size; ++q)
|
| 100 |
+
logits[q] /= temperature;
|
| 101 |
+
|
| 102 |
+
softmax(logits, vocab_size);
|
| 103 |
+
|
| 104 |
+
fp32_t coin = random_f32(rng_seed);
|
| 105 |
+
next = sample_mult(logits, vocab_size, coin);
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
return next;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
static bool SamplerGenerate(Sampler *sampler, char *seed_text, uint64_t n_predict) {
|
| 112 |
+
Mamba *model = sampler->model;
|
| 113 |
+
Tokenizer *tokenizer = sampler->tokenizer;
|
| 114 |
+
uint64_t vocab_size = tokenizer->vocab_size;
|
| 115 |
+
fp32_t temperature = sampler->temperature;
|
| 116 |
+
bool verbose = sampler->verbose;
|
| 117 |
+
|
| 118 |
+
uint64_t token;
|
| 119 |
+
fp32_t *logits;
|
| 120 |
+
char *text;
|
| 121 |
+
|
| 122 |
+
if (seed_text == NULL) return EXIT_FAILURE;
|
| 123 |
+
|
| 124 |
+
for (; *seed_text; ) {
|
| 125 |
+
|
| 126 |
+
token = tokenizer->encode(tokenizer, (uint8_t **) &seed_text);
|
| 127 |
+
text = tokenizer->decode(tokenizer, token);
|
| 128 |
+
|
| 129 |
+
fputs(text, stdout);
|
| 130 |
+
fflush(stdout);
|
| 131 |
+
|
| 132 |
+
logits = model->forward(model, token);
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
uint64_t time_start;
|
| 136 |
+
if (verbose) time_start = time_in_ms();
|
| 137 |
+
|
| 138 |
+
for (uint64_t i = 0; i < n_predict; ++i) {
|
| 139 |
+
|
| 140 |
+
token = sampler->sample(sampler, logits);
|
| 141 |
+
text = tokenizer->decode(tokenizer, token);
|
| 142 |
+
|
| 143 |
+
fputs(text, stdout);
|
| 144 |
+
fflush(stdout);
|
| 145 |
+
|
| 146 |
+
logits = model->forward(model, token);
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
CLOG(verbose, "\nachieved tok/s: %f\n", n_predict / (double)(time_in_ms() - time_start) * 1000);
|
| 150 |
+
|
| 151 |
+
return EXIT_SUCCESS;
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
Sampler sampler = {
|
| 157 |
+
.model = &mamba,
|
| 158 |
+
.tokenizer = &tokenizer,
|
| 159 |
+
|
| 160 |
+
.rng_seed = 42,
|
| 161 |
+
.temperature = 0.0f,
|
| 162 |
+
.verbose = false,
|
| 163 |
+
|
| 164 |
+
.generate = SamplerGenerate,
|
| 165 |
+
.sample = SamplerSample
|
| 166 |
+
};
|
c_inferer/tokenizer.h
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
#include <stdint.h>
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
extern char _embedded_binary_tokenizer[];
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
#define MAX_WORD_LEN 24
|
| 14 |
+
|
| 15 |
+
typedef struct __attribute__((packed)) token_t {
|
| 16 |
+
uint8_t byte;
|
| 17 |
+
uint16_t prev;
|
| 18 |
+
} token_t;
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
typedef struct Tokenizer Tokenizer;
|
| 22 |
+
struct Tokenizer {
|
| 23 |
+
token_t *vocab;
|
| 24 |
+
uint16_t vocab_size;
|
| 25 |
+
|
| 26 |
+
uint16_t (*find) (Tokenizer *, uint8_t, uint16_t);
|
| 27 |
+
uint16_t (*encode) (Tokenizer *, uint8_t **);
|
| 28 |
+
uint8_t *(*decode) (Tokenizer *, uint16_t);
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
};
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
static uint16_t TokenizerFind(Tokenizer *tokenizer, uint8_t byte, uint16_t prev) {
|
| 38 |
+
for (uint16_t i = prev; i < tokenizer->vocab_size; ++i)
|
| 39 |
+
if (tokenizer->vocab[i].byte == byte && tokenizer->vocab[i].prev == prev)
|
| 40 |
+
return i;
|
| 41 |
+
|
| 42 |
+
return 0;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
static uint16_t TokenizerEncode(Tokenizer *tokenizer, uint8_t **seed_text) {
|
| 47 |
+
|
| 48 |
+
uint16_t prev = 0;
|
| 49 |
+
for (; **seed_text; ++*seed_text) {
|
| 50 |
+
uint16_t next = tokenizer->find(tokenizer, **seed_text, prev);
|
| 51 |
+
if (next == 0) break;
|
| 52 |
+
prev = next;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
return prev;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
static uint8_t *TokenizerDecode(Tokenizer *tokenizer, uint16_t token) {
|
| 60 |
+
|
| 61 |
+
static uint8_t dest[MAX_WORD_LEN + 1];
|
| 62 |
+
dest[MAX_WORD_LEN] = '\0';
|
| 63 |
+
|
| 64 |
+
uint16_t prev = token;
|
| 65 |
+
uint16_t i = MAX_WORD_LEN - 1;
|
| 66 |
+
|
| 67 |
+
for (; prev && i > 0; prev = tokenizer->vocab[prev].prev, --i)
|
| 68 |
+
dest[i] = tokenizer->vocab[prev].byte;
|
| 69 |
+
|
| 70 |
+
return dest + i + 1;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
Tokenizer tokenizer = {
|
| 75 |
+
.vocab = (token_t *) _embedded_binary_tokenizer,
|
| 76 |
+
|
| 77 |
+
.vocab_size = VOCAB_SIZE,
|
| 78 |
+
|
| 79 |
+
.find = TokenizerFind,
|
| 80 |
+
.encode = TokenizerEncode,
|
| 81 |
+
.decode = TokenizerDecode
|
| 82 |
+
};
|
| 83 |
+
|