|
|
#include "llama.h" |
|
|
#include <vector> |
|
|
#include <string> |
|
|
#include <cstring> |
|
|
#include <cstdio> |
|
|
#include <cstdlib> |
|
|
|
|
|
|
|
|
static llama_model* g_model = nullptr; |
|
|
static llama_context* g_ctx = nullptr; |
|
|
static llama_sampler* g_smpl = nullptr; |
|
|
|
|
|
extern "C" { |
|
|
|
|
|
|
|
|
bool init_model(const char* model_path) { |
|
|
llama_backend_init(); |
|
|
|
|
|
llama_model_params model_params = llama_model_default_params(); |
|
|
model_params.n_gpu_layers = 0; |
|
|
g_model = llama_model_load_from_file(model_path, model_params); |
|
|
|
|
|
if (!g_model) return false; |
|
|
|
|
|
llama_context_params ctx_params = llama_context_default_params(); |
|
|
ctx_params.n_ctx = 4096; |
|
|
ctx_params.n_batch = 512; |
|
|
ctx_params.n_threads = 8; |
|
|
ctx_params.n_threads_batch = 8; |
|
|
ctx_params.n_seq_max = 16; |
|
|
g_ctx = llama_init_from_model(g_model, ctx_params); |
|
|
|
|
|
if (!g_ctx) return false; |
|
|
|
|
|
auto sparams = llama_sampler_chain_default_params(); |
|
|
g_smpl = llama_sampler_chain_init(sparams); |
|
|
llama_sampler_chain_add(g_smpl, llama_sampler_init_greedy()); |
|
|
|
|
|
return true; |
|
|
} |
|
|
|
|
|
|
|
|
void batch_add(llama_batch & batch, llama_token id, llama_pos pos, const std::vector<llama_seq_id> & seq_ids, bool logits) { |
|
|
batch.token[batch.n_tokens] = id; |
|
|
batch.pos[batch.n_tokens] = pos; |
|
|
batch.n_seq_id[batch.n_tokens] = seq_ids.size(); |
|
|
for (size_t i = 0; i < seq_ids.size(); ++i) { |
|
|
batch.seq_id[batch.n_tokens][i] = seq_ids[i]; |
|
|
} |
|
|
batch.logits[batch.n_tokens] = logits; |
|
|
batch.n_tokens++; |
|
|
} |
|
|
|
|
|
|
|
|
static int g_count = 0; |
|
|
static int g_step = 0; |
|
|
static int g_max_tokens = 0; |
|
|
static std::vector<std::string> g_responses; |
|
|
static std::vector<bool> g_active; |
|
|
static std::vector<int> g_n_pos; |
|
|
static std::vector<int> g_logits_idx; |
|
|
static std::vector<std::vector<llama_token>> g_all_tokens; |
|
|
static llama_batch g_batch; |
|
|
static const llama_vocab* g_vocab = nullptr; |
|
|
|
|
|
|
|
|
void start_batch(const char** prompts, int count, int max_tokens) { |
|
|
if (!g_ctx || count == 0) return; |
|
|
|
|
|
g_vocab = llama_model_get_vocab(g_model); |
|
|
g_count = count; |
|
|
g_max_tokens = max_tokens; |
|
|
g_step = 0; |
|
|
|
|
|
|
|
|
g_responses.assign(count, ""); |
|
|
g_active.assign(count, true); |
|
|
g_n_pos.assign(count, 0); |
|
|
g_logits_idx.assign(count, -1); |
|
|
g_all_tokens.clear(); |
|
|
|
|
|
|
|
|
for (int i = 0; i < count; i++) { |
|
|
int n_prompt = -llama_tokenize(g_vocab, prompts[i], strlen(prompts[i]), NULL, 0, true, true); |
|
|
std::vector<llama_token> tokens(n_prompt); |
|
|
llama_tokenize(g_vocab, prompts[i], strlen(prompts[i]), tokens.data(), tokens.size(), true, true); |
|
|
g_all_tokens.push_back(tokens); |
|
|
} |
|
|
|
|
|
|
|
|
llama_memory_clear(llama_get_memory(g_ctx), true); |
|
|
|
|
|
|
|
|
if (g_batch.token) llama_batch_free(g_batch); |
|
|
g_batch = llama_batch_init(4096, 0, 1); |
|
|
|
|
|
|
|
|
g_batch.n_tokens = 0; |
|
|
for (int i = 0; i < count; i++) { |
|
|
for (size_t j = 0; j < g_all_tokens[i].size(); j++) { |
|
|
bool is_last = (j == g_all_tokens[i].size() - 1); |
|
|
if (is_last) g_logits_idx[i] = g_batch.n_tokens; |
|
|
batch_add(g_batch, g_all_tokens[i][j], g_n_pos[i]++, { (llama_seq_id)i }, is_last); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (llama_decode(g_ctx, g_batch)) { |
|
|
fprintf(stderr, "Failed to decode prefill\n"); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool decode_step(const char** results) { |
|
|
if (g_step >= g_max_tokens) return false; |
|
|
|
|
|
g_batch.n_tokens = 0; |
|
|
bool any_active = false; |
|
|
std::vector<int> next_logits_idx(g_count, -1); |
|
|
int current_batch_pos = 0; |
|
|
|
|
|
for (int i = 0; i < g_count; i++) { |
|
|
results[i] = nullptr; |
|
|
|
|
|
if (!g_active[i]) continue; |
|
|
|
|
|
|
|
|
llama_token id = llama_sampler_sample(g_smpl, g_ctx, g_logits_idx[i]); |
|
|
llama_sampler_accept(g_smpl, id); |
|
|
|
|
|
|
|
|
if (llama_vocab_is_eog(g_vocab, id) || g_n_pos[i] >= 4096) { |
|
|
g_active[i] = false; |
|
|
continue; |
|
|
} |
|
|
|
|
|
|
|
|
static char buf[256]; |
|
|
int n = llama_token_to_piece(g_vocab, id, buf, sizeof(buf), 0, true); |
|
|
if (n < 0) { |
|
|
|
|
|
} else { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
results[i] = strdup(buf); |
|
|
|
|
|
next_logits_idx[i] = current_batch_pos++; |
|
|
batch_add(g_batch, id, g_n_pos[i]++, { (llama_seq_id)i }, true); |
|
|
any_active = true; |
|
|
} |
|
|
|
|
|
if (!any_active) return false; |
|
|
|
|
|
g_logits_idx = next_logits_idx; |
|
|
if (llama_decode(g_ctx, g_batch)) { |
|
|
return false; |
|
|
} |
|
|
|
|
|
g_step++; |
|
|
return true; |
|
|
} |
|
|
|
|
|
|
|
|
void cleanup() { |
|
|
if (g_smpl) llama_sampler_free(g_smpl); |
|
|
if (g_ctx) llama_free(g_ctx); |
|
|
if (g_model) llama_model_free(g_model); |
|
|
} |
|
|
|
|
|
} |
|
|
|