|
|
import ctypes |
|
|
import os |
|
|
|
|
|
|
|
|
LIB_PATH = os.path.abspath("../engine/libbatch.so") |
|
|
if not os.path.exists(LIB_PATH): |
|
|
raise FileNotFoundError(f"Shared library not found at: {LIB_PATH}. Did you compile the engine?") |
|
|
lib = ctypes.CDLL(LIB_PATH) |
|
|
|
|
|
|
|
|
lib.init_model.argtypes = [ctypes.c_char_p] |
|
|
lib.init_model.restype = ctypes.c_bool |
|
|
|
|
|
|
|
|
lib.start_batch.argtypes = [ |
|
|
ctypes.POINTER(ctypes.c_char_p), |
|
|
ctypes.c_int, |
|
|
ctypes.c_int |
|
|
] |
|
|
lib.start_batch.restype = None |
|
|
|
|
|
lib.decode_step.argtypes = [ |
|
|
ctypes.POINTER(ctypes.c_char_p) |
|
|
] |
|
|
lib.decode_step.restype = ctypes.c_bool |
|
|
|
|
|
|
|
|
with open("../model/template.txt", "r") as f: |
|
|
TEMPLATE = f.read() |
|
|
|
|
|
def format_prompt(prompt: str) -> str: |
|
|
return TEMPLATE.replace("{{prompt}}", prompt) |
|
|
|
|
|
|
|
|
MODEL_PATH = os.path.abspath("../model/model.gguf").encode('utf-8') |
|
|
if not lib.init_model(MODEL_PATH): |
|
|
print(f"Failed to initialize model at {MODEL_PATH}") |
|
|
|
|
|
def stream_batch(prompts): |
|
|
count = len(prompts) |
|
|
|
|
|
|
|
|
formatted_prompts = [format_prompt(p) for p in prompts] |
|
|
|
|
|
c_prompts = (ctypes.c_char_p * count)(*[p.encode('utf-8') for p in formatted_prompts]) |
|
|
c_results = (ctypes.c_char_p * count)() |
|
|
|
|
|
|
|
|
lib.start_batch(c_prompts, count, 256) |
|
|
|
|
|
|
|
|
while True: |
|
|
|
|
|
active = lib.decode_step(c_results) |
|
|
|
|
|
|
|
|
step_output = [] |
|
|
for i in range(count): |
|
|
res = c_results[i] |
|
|
if res: |
|
|
text = res.decode('utf-8') |
|
|
step_output.append(text) |
|
|
|
|
|
else: |
|
|
step_output.append(None) |
|
|
|
|
|
yield step_output |
|
|
|
|
|
if not active: |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|