File size: 6,249 Bytes
8d11530 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
#include "llama.h"
#include <vector>
#include <string>
#include <cstring>
#include <cstdio>
#include <cstdlib>
// Global context for the loaded model
static llama_model* g_model = nullptr;
static llama_context* g_ctx = nullptr;
static llama_sampler* g_smpl = nullptr;
extern "C" {
// Initialize the model
bool init_model(const char* model_path) {
llama_backend_init();
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = 0; // CPU only for now
g_model = llama_model_load_from_file(model_path, model_params);
if (!g_model) return false;
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = 4096; // 256 per sequence for 16 users
ctx_params.n_batch = 512;
ctx_params.n_threads = 8;
ctx_params.n_threads_batch = 8;
ctx_params.n_seq_max = 16;
g_ctx = llama_init_from_model(g_model, ctx_params);
if (!g_ctx) return false;
auto sparams = llama_sampler_chain_default_params();
g_smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(g_smpl, llama_sampler_init_greedy());
return true;
}
// Helper to add a token to a batch manually
void batch_add(llama_batch & batch, llama_token id, llama_pos pos, const std::vector<llama_seq_id> & seq_ids, bool logits) {
batch.token[batch.n_tokens] = id;
batch.pos[batch.n_tokens] = pos;
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits[batch.n_tokens] = logits;
batch.n_tokens++;
}
// Global state for streaming
static int g_count = 0;
static int g_step = 0;
static int g_max_tokens = 0;
static std::vector<std::string> g_responses;
static std::vector<bool> g_active;
static std::vector<int> g_n_pos;
static std::vector<int> g_logits_idx;
static std::vector<std::vector<llama_token>> g_all_tokens;
static llama_batch g_batch;
static const llama_vocab* g_vocab = nullptr;
// 1. Start a new batch (Prefill)
void start_batch(const char** prompts, int count, int max_tokens) {
if (!g_ctx || count == 0) return;
g_vocab = llama_model_get_vocab(g_model);
g_count = count;
g_max_tokens = max_tokens;
g_step = 0;
// Reset state
g_responses.assign(count, "");
g_active.assign(count, true);
g_n_pos.assign(count, 0);
g_logits_idx.assign(count, -1);
g_all_tokens.clear();
// Tokenize
for (int i = 0; i < count; i++) {
int n_prompt = -llama_tokenize(g_vocab, prompts[i], strlen(prompts[i]), NULL, 0, true, true);
std::vector<llama_token> tokens(n_prompt);
llama_tokenize(g_vocab, prompts[i], strlen(prompts[i]), tokens.data(), tokens.size(), true, true);
g_all_tokens.push_back(tokens);
}
// Clear KV cache
llama_memory_clear(llama_get_memory(g_ctx), true);
// Init batch
if (g_batch.token) llama_batch_free(g_batch); // Free if exists
g_batch = llama_batch_init(4096, 0, 1); // Larger batch for safety
// Prefill
g_batch.n_tokens = 0;
for (int i = 0; i < count; i++) {
for (size_t j = 0; j < g_all_tokens[i].size(); j++) {
bool is_last = (j == g_all_tokens[i].size() - 1);
if (is_last) g_logits_idx[i] = g_batch.n_tokens;
batch_add(g_batch, g_all_tokens[i][j], g_n_pos[i]++, { (llama_seq_id)i }, is_last);
}
}
// Decode Prefill
if (llama_decode(g_ctx, g_batch)) {
fprintf(stderr, "Failed to decode prefill\n");
}
}
// 2. Decode one step (Generate next token for all active sequences)
// Returns true if any sequence is still active, false if all done
bool decode_step(const char** results) {
if (g_step >= g_max_tokens) return false;
g_batch.n_tokens = 0;
bool any_active = false;
std::vector<int> next_logits_idx(g_count, -1);
int current_batch_pos = 0;
for (int i = 0; i < g_count; i++) {
results[i] = nullptr; // Default to null (no new token this step if inactive)
if (!g_active[i]) continue;
// Sample
llama_token id = llama_sampler_sample(g_smpl, g_ctx, g_logits_idx[i]);
llama_sampler_accept(g_smpl, id);
// Check EOG/limit
if (llama_vocab_is_eog(g_vocab, id) || g_n_pos[i] >= 4096) { // Hard limit matches n_ctx
g_active[i] = false;
continue;
}
// Decode token to string
static char buf[256]; // Static buffer for simplicity (not thread safe across batches, but fine here)
int n = llama_token_to_piece(g_vocab, id, buf, sizeof(buf), 0, true);
if (n < 0) {
// error or empty
} else {
// Allocate new string for result to pass back to Python
// Python side must handle this appropriately or we just copy
// For ctypes `const char**`, we need a pointer that persists until next call
// We can use a thread_local buffer or just strdup and let Python manage?
// Ideally Python frees, but to keep it simple we use a persistent buffer in C++?
// No, let's just return a pointer to a static buffer is risky if multiple threads.
// But we are single threaded engine key.
// Safer: return strdup, Python bridge won't free it effortlessly though.
// Let's use a global vector of string buffers for the current step.
}
// Actually, we just need to return the piece.
// Let's rely on a persistent buffer for the current step.
// We can just strdup and leak if we don't free? No.
// Let's assume the Python side consumes it immediately.
results[i] = strdup(buf);
next_logits_idx[i] = current_batch_pos++;
batch_add(g_batch, id, g_n_pos[i]++, { (llama_seq_id)i }, true);
any_active = true;
}
if (!any_active) return false;
g_logits_idx = next_logits_idx;
if (llama_decode(g_ctx, g_batch)) {
return false;
}
g_step++;
return true;
}
// Cleanup
void cleanup() {
if (g_smpl) llama_sampler_free(g_smpl);
if (g_ctx) llama_free(g_ctx);
if (g_model) llama_model_free(g_model);
}
}
|