abcd / api /bridge.py
Karan6933's picture
Upload 5 files
a17c086 verified
import ctypes
import os
# Load the shared library
LIB_PATH = os.path.abspath("../engine/libbatch.so")
if not os.path.exists(LIB_PATH):
raise FileNotFoundError(f"Shared library not found at: {LIB_PATH}. Did you compile the engine?")
lib = ctypes.CDLL(LIB_PATH)
# Define function signatures
lib.init_model.argtypes = [ctypes.c_char_p]
lib.init_model.restype = ctypes.c_bool
# Define function signatures for streaming
lib.start_batch.argtypes = [
ctypes.POINTER(ctypes.c_char_p), # prompts
ctypes.c_int, # count
ctypes.c_int # max_tokens
]
lib.start_batch.restype = None
lib.decode_step.argtypes = [
ctypes.POINTER(ctypes.c_char_p) # results
]
lib.decode_step.restype = ctypes.c_bool
# Load template
with open("../model/template.txt", "r") as f:
TEMPLATE = f.read()
def format_prompt(prompt: str) -> str:
return TEMPLATE.replace("{{prompt}}", prompt)
# Initialize the model
MODEL_PATH = os.path.abspath("../model/model.gguf").encode('utf-8')
if not lib.init_model(MODEL_PATH):
print(f"Failed to initialize model at {MODEL_PATH}")
def stream_batch(prompts):
count = len(prompts)
# Apply Ollama-style templates
formatted_prompts = [format_prompt(p) for p in prompts]
c_prompts = (ctypes.c_char_p * count)(*[p.encode('utf-8') for p in formatted_prompts])
c_results = (ctypes.c_char_p * count)()
# 1. Start Batch (Prefill)
lib.start_batch(c_prompts, count, 256)
# 2. Decode Loop
while True:
# Run one step
active = lib.decode_step(c_results)
# Collect results for this step
step_output = []
for i in range(count):
res = c_results[i]
if res:
text = res.decode('utf-8')
step_output.append(text)
# libc.free(res) # Ideally free, but for now we rely on OS cleanup or leak small amount in this demo
else:
step_output.append(None)
yield step_output
if not active:
break