|
|
import asyncio |
|
|
from typing import List, AsyncGenerator, Dict |
|
|
from llama_cpp import Llama, LlamaGrammar |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class BatchInferenceEngine: |
|
|
""" |
|
|
Pure Python batch inference engine using llama-cpp-python. |
|
|
Loads model once, handles multiple concurrent requests efficiently. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_path: str, n_ctx: int = 4096, n_threads: int = 4): |
|
|
self.model_path = model_path |
|
|
self.n_ctx = n_ctx |
|
|
self.n_threads = n_threads |
|
|
self._model: Llama = None |
|
|
self._lock = asyncio.Lock() |
|
|
|
|
|
def load(self): |
|
|
"""Load model once at startup""" |
|
|
logger.info(f"Loading model from {self.model_path}") |
|
|
self._model = Llama( |
|
|
model_path=self.model_path, |
|
|
n_ctx=self.n_ctx, |
|
|
n_threads=self.n_threads, |
|
|
n_batch=512, |
|
|
verbose=False |
|
|
) |
|
|
logger.info("Model loaded successfully") |
|
|
|
|
|
async def generate_stream( |
|
|
self, |
|
|
prompt: str, |
|
|
max_tokens: int = 256, |
|
|
temperature: float = 0.7, |
|
|
stop: List[str] = None |
|
|
) -> AsyncGenerator[str, None]: |
|
|
""" |
|
|
Async streaming generator for single request. |
|
|
Uses thread pool to run sync llama-cpp in background. |
|
|
""" |
|
|
if self._model is None: |
|
|
raise RuntimeError("Model not loaded") |
|
|
|
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
|
|
def _generate(): |
|
|
return self._model.create_completion( |
|
|
prompt=prompt, |
|
|
max_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
stop=stop or [], |
|
|
stream=True |
|
|
) |
|
|
|
|
|
|
|
|
stream = await loop.run_in_executor(None, _generate) |
|
|
|
|
|
|
|
|
for chunk in stream: |
|
|
if "choices" in chunk and len(chunk["choices"]) > 0: |
|
|
delta = chunk["choices"][0].get("text", "") |
|
|
if delta: |
|
|
yield delta |
|
|
|
|
|
async def generate_batch( |
|
|
self, |
|
|
prompts: List[str], |
|
|
max_tokens: int = 256, |
|
|
temperature: float = 0.7 |
|
|
) -> List[str]: |
|
|
""" |
|
|
Process multiple prompts efficiently. |
|
|
On CPU, we process sequentially to avoid contention. |
|
|
""" |
|
|
results = [] |
|
|
for prompt in prompts: |
|
|
chunks = [] |
|
|
async for token in self.generate_stream(prompt, max_tokens, temperature): |
|
|
chunks.append(token) |
|
|
results.append("".join(chunks)) |
|
|
return results |
|
|
|
|
|
|
|
|
_engine: BatchInferenceEngine = None |
|
|
|
|
|
def get_engine() -> BatchInferenceEngine: |
|
|
global _engine |
|
|
if _engine is None: |
|
|
raise RuntimeError("Engine not initialized") |
|
|
return _engine |
|
|
|
|
|
def init_engine(model_path: str, **kwargs): |
|
|
global _engine |
|
|
_engine = BatchInferenceEngine(model_path, **kwargs) |
|
|
_engine.load() |
|
|
return _engine |