#include "llama.h" #include #include #include #include #include // Global context for the loaded model static llama_model* g_model = nullptr; static llama_context* g_ctx = nullptr; static llama_sampler* g_smpl = nullptr; extern "C" { // Initialize the model bool init_model(const char* model_path) { llama_backend_init(); llama_model_params model_params = llama_model_default_params(); model_params.n_gpu_layers = 0; // CPU only for now g_model = llama_model_load_from_file(model_path, model_params); if (!g_model) return false; llama_context_params ctx_params = llama_context_default_params(); ctx_params.n_ctx = 4096; // 256 per sequence for 16 users ctx_params.n_batch = 512; ctx_params.n_threads = 8; ctx_params.n_threads_batch = 8; ctx_params.n_seq_max = 16; g_ctx = llama_init_from_model(g_model, ctx_params); if (!g_ctx) return false; auto sparams = llama_sampler_chain_default_params(); g_smpl = llama_sampler_chain_init(sparams); llama_sampler_chain_add(g_smpl, llama_sampler_init_greedy()); return true; } // Helper to add a token to a batch manually void batch_add(llama_batch & batch, llama_token id, llama_pos pos, const std::vector & seq_ids, bool logits) { batch.token[batch.n_tokens] = id; batch.pos[batch.n_tokens] = pos; batch.n_seq_id[batch.n_tokens] = seq_ids.size(); for (size_t i = 0; i < seq_ids.size(); ++i) { batch.seq_id[batch.n_tokens][i] = seq_ids[i]; } batch.logits[batch.n_tokens] = logits; batch.n_tokens++; } // Global state for streaming static int g_count = 0; static int g_step = 0; static int g_max_tokens = 0; static std::vector g_responses; static std::vector g_active; static std::vector g_n_pos; static std::vector g_logits_idx; static std::vector> g_all_tokens; static llama_batch g_batch; static const llama_vocab* g_vocab = nullptr; // 1. Start a new batch (Prefill) void start_batch(const char** prompts, int count, int max_tokens) { if (!g_ctx || count == 0) return; g_vocab = llama_model_get_vocab(g_model); g_count = count; g_max_tokens = max_tokens; g_step = 0; // Reset state g_responses.assign(count, ""); g_active.assign(count, true); g_n_pos.assign(count, 0); g_logits_idx.assign(count, -1); g_all_tokens.clear(); // Tokenize for (int i = 0; i < count; i++) { int n_prompt = -llama_tokenize(g_vocab, prompts[i], strlen(prompts[i]), NULL, 0, true, true); std::vector tokens(n_prompt); llama_tokenize(g_vocab, prompts[i], strlen(prompts[i]), tokens.data(), tokens.size(), true, true); g_all_tokens.push_back(tokens); } // Clear KV cache llama_memory_clear(llama_get_memory(g_ctx), true); // Init batch if (g_batch.token) llama_batch_free(g_batch); // Free if exists g_batch = llama_batch_init(4096, 0, 1); // Larger batch for safety // Prefill g_batch.n_tokens = 0; for (int i = 0; i < count; i++) { for (size_t j = 0; j < g_all_tokens[i].size(); j++) { bool is_last = (j == g_all_tokens[i].size() - 1); if (is_last) g_logits_idx[i] = g_batch.n_tokens; batch_add(g_batch, g_all_tokens[i][j], g_n_pos[i]++, { (llama_seq_id)i }, is_last); } } // Decode Prefill if (llama_decode(g_ctx, g_batch)) { fprintf(stderr, "Failed to decode prefill\n"); } } // 2. Decode one step (Generate next token for all active sequences) // Returns true if any sequence is still active, false if all done bool decode_step(const char** results) { if (g_step >= g_max_tokens) return false; g_batch.n_tokens = 0; bool any_active = false; std::vector next_logits_idx(g_count, -1); int current_batch_pos = 0; for (int i = 0; i < g_count; i++) { results[i] = nullptr; // Default to null (no new token this step if inactive) if (!g_active[i]) continue; // Sample llama_token id = llama_sampler_sample(g_smpl, g_ctx, g_logits_idx[i]); llama_sampler_accept(g_smpl, id); // Check EOG/limit if (llama_vocab_is_eog(g_vocab, id) || g_n_pos[i] >= 4096) { // Hard limit matches n_ctx g_active[i] = false; continue; } // Decode token to string static char buf[256]; // Static buffer for simplicity (not thread safe across batches, but fine here) int n = llama_token_to_piece(g_vocab, id, buf, sizeof(buf), 0, true); if (n < 0) { // error or empty } else { // Allocate new string for result to pass back to Python // Python side must handle this appropriately or we just copy // For ctypes `const char**`, we need a pointer that persists until next call // We can use a thread_local buffer or just strdup and let Python manage? // Ideally Python frees, but to keep it simple we use a persistent buffer in C++? // No, let's just return a pointer to a static buffer is risky if multiple threads. // But we are single threaded engine key. // Safer: return strdup, Python bridge won't free it effortlessly though. // Let's use a global vector of string buffers for the current step. } // Actually, we just need to return the piece. // Let's rely on a persistent buffer for the current step. // We can just strdup and leak if we don't free? No. // Let's assume the Python side consumes it immediately. results[i] = strdup(buf); next_logits_idx[i] = current_batch_pos++; batch_add(g_batch, id, g_n_pos[i]++, { (llama_seq_id)i }, true); any_active = true; } if (!any_active) return false; g_logits_idx = next_logits_idx; if (llama_decode(g_ctx, g_batch)) { return false; } g_step++; return true; } // Cleanup void cleanup() { if (g_smpl) llama_sampler_free(g_smpl); if (g_ctx) llama_free(g_ctx); if (g_model) llama_model_free(g_model); } }