abcd / engine /batch_server.cpp
Karan6933's picture
Upload 17 files
8d11530 verified
#include "llama.h"
#include <vector>
#include <string>
#include <cstring>
#include <cstdio>
#include <cstdlib>
// Global context for the loaded model
static llama_model* g_model = nullptr;
static llama_context* g_ctx = nullptr;
static llama_sampler* g_smpl = nullptr;
extern "C" {
// Initialize the model
bool init_model(const char* model_path) {
llama_backend_init();
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = 0; // CPU only for now
g_model = llama_model_load_from_file(model_path, model_params);
if (!g_model) return false;
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = 4096; // 256 per sequence for 16 users
ctx_params.n_batch = 512;
ctx_params.n_threads = 8;
ctx_params.n_threads_batch = 8;
ctx_params.n_seq_max = 16;
g_ctx = llama_init_from_model(g_model, ctx_params);
if (!g_ctx) return false;
auto sparams = llama_sampler_chain_default_params();
g_smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(g_smpl, llama_sampler_init_greedy());
return true;
}
// Helper to add a token to a batch manually
void batch_add(llama_batch & batch, llama_token id, llama_pos pos, const std::vector<llama_seq_id> & seq_ids, bool logits) {
batch.token[batch.n_tokens] = id;
batch.pos[batch.n_tokens] = pos;
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits[batch.n_tokens] = logits;
batch.n_tokens++;
}
// Global state for streaming
static int g_count = 0;
static int g_step = 0;
static int g_max_tokens = 0;
static std::vector<std::string> g_responses;
static std::vector<bool> g_active;
static std::vector<int> g_n_pos;
static std::vector<int> g_logits_idx;
static std::vector<std::vector<llama_token>> g_all_tokens;
static llama_batch g_batch;
static const llama_vocab* g_vocab = nullptr;
// 1. Start a new batch (Prefill)
void start_batch(const char** prompts, int count, int max_tokens) {
if (!g_ctx || count == 0) return;
g_vocab = llama_model_get_vocab(g_model);
g_count = count;
g_max_tokens = max_tokens;
g_step = 0;
// Reset state
g_responses.assign(count, "");
g_active.assign(count, true);
g_n_pos.assign(count, 0);
g_logits_idx.assign(count, -1);
g_all_tokens.clear();
// Tokenize
for (int i = 0; i < count; i++) {
int n_prompt = -llama_tokenize(g_vocab, prompts[i], strlen(prompts[i]), NULL, 0, true, true);
std::vector<llama_token> tokens(n_prompt);
llama_tokenize(g_vocab, prompts[i], strlen(prompts[i]), tokens.data(), tokens.size(), true, true);
g_all_tokens.push_back(tokens);
}
// Clear KV cache
llama_memory_clear(llama_get_memory(g_ctx), true);
// Init batch
if (g_batch.token) llama_batch_free(g_batch); // Free if exists
g_batch = llama_batch_init(4096, 0, 1); // Larger batch for safety
// Prefill
g_batch.n_tokens = 0;
for (int i = 0; i < count; i++) {
for (size_t j = 0; j < g_all_tokens[i].size(); j++) {
bool is_last = (j == g_all_tokens[i].size() - 1);
if (is_last) g_logits_idx[i] = g_batch.n_tokens;
batch_add(g_batch, g_all_tokens[i][j], g_n_pos[i]++, { (llama_seq_id)i }, is_last);
}
}
// Decode Prefill
if (llama_decode(g_ctx, g_batch)) {
fprintf(stderr, "Failed to decode prefill\n");
}
}
// 2. Decode one step (Generate next token for all active sequences)
// Returns true if any sequence is still active, false if all done
bool decode_step(const char** results) {
if (g_step >= g_max_tokens) return false;
g_batch.n_tokens = 0;
bool any_active = false;
std::vector<int> next_logits_idx(g_count, -1);
int current_batch_pos = 0;
for (int i = 0; i < g_count; i++) {
results[i] = nullptr; // Default to null (no new token this step if inactive)
if (!g_active[i]) continue;
// Sample
llama_token id = llama_sampler_sample(g_smpl, g_ctx, g_logits_idx[i]);
llama_sampler_accept(g_smpl, id);
// Check EOG/limit
if (llama_vocab_is_eog(g_vocab, id) || g_n_pos[i] >= 4096) { // Hard limit matches n_ctx
g_active[i] = false;
continue;
}
// Decode token to string
static char buf[256]; // Static buffer for simplicity (not thread safe across batches, but fine here)
int n = llama_token_to_piece(g_vocab, id, buf, sizeof(buf), 0, true);
if (n < 0) {
// error or empty
} else {
// Allocate new string for result to pass back to Python
// Python side must handle this appropriately or we just copy
// For ctypes `const char**`, we need a pointer that persists until next call
// We can use a thread_local buffer or just strdup and let Python manage?
// Ideally Python frees, but to keep it simple we use a persistent buffer in C++?
// No, let's just return a pointer to a static buffer is risky if multiple threads.
// But we are single threaded engine key.
// Safer: return strdup, Python bridge won't free it effortlessly though.
// Let's use a global vector of string buffers for the current step.
}
// Actually, we just need to return the piece.
// Let's rely on a persistent buffer for the current step.
// We can just strdup and leak if we don't free? No.
// Let's assume the Python side consumes it immediately.
results[i] = strdup(buf);
next_logits_idx[i] = current_batch_pos++;
batch_add(g_batch, id, g_n_pos[i]++, { (llama_seq_id)i }, true);
any_active = true;
}
if (!any_active) return false;
g_logits_idx = next_logits_idx;
if (llama_decode(g_ctx, g_batch)) {
return false;
}
g_step++;
return true;
}
// Cleanup
void cleanup() {
if (g_smpl) llama_sampler_free(g_smpl);
if (g_ctx) llama_free(g_ctx);
if (g_model) llama_model_free(g_model);
}
}