import ctypes import os # Load the shared library LIB_PATH = os.path.abspath("../engine/libbatch.so") if not os.path.exists(LIB_PATH): raise FileNotFoundError(f"Shared library not found at: {LIB_PATH}. Did you compile the engine?") lib = ctypes.CDLL(LIB_PATH) # Define function signatures lib.init_model.argtypes = [ctypes.c_char_p] lib.init_model.restype = ctypes.c_bool # Define function signatures for streaming lib.start_batch.argtypes = [ ctypes.POINTER(ctypes.c_char_p), # prompts ctypes.c_int, # count ctypes.c_int # max_tokens ] lib.start_batch.restype = None lib.decode_step.argtypes = [ ctypes.POINTER(ctypes.c_char_p) # results ] lib.decode_step.restype = ctypes.c_bool # Load template with open("../model/template.txt", "r") as f: TEMPLATE = f.read() def format_prompt(prompt: str) -> str: return TEMPLATE.replace("{{prompt}}", prompt) # Initialize the model MODEL_PATH = os.path.abspath("../model/model.gguf").encode('utf-8') if not lib.init_model(MODEL_PATH): print(f"Failed to initialize model at {MODEL_PATH}") def stream_batch(prompts): count = len(prompts) # Apply Ollama-style templates formatted_prompts = [format_prompt(p) for p in prompts] c_prompts = (ctypes.c_char_p * count)(*[p.encode('utf-8') for p in formatted_prompts]) c_results = (ctypes.c_char_p * count)() # 1. Start Batch (Prefill) lib.start_batch(c_prompts, count, 256) # 2. Decode Loop while True: # Run one step active = lib.decode_step(c_results) # Collect results for this step step_output = [] for i in range(count): res = c_results[i] if res: text = res.decode('utf-8') step_output.append(text) # libc.free(res) # Ideally free, but for now we rely on OS cleanup or leak small amount in this demo else: step_output.append(None) yield step_output if not active: break