import asyncio from typing import List, AsyncGenerator, Dict from llama_cpp import Llama, LlamaGrammar import logging logger = logging.getLogger(__name__) class BatchInferenceEngine: """ Pure Python batch inference engine using llama-cpp-python. Loads model once, handles multiple concurrent requests efficiently. """ def __init__(self, model_path: str, n_ctx: int = 4096, n_threads: int = 4): self.model_path = model_path self.n_ctx = n_ctx self.n_threads = n_threads self._model: Llama = None self._lock = asyncio.Lock() def load(self): """Load model once at startup""" logger.info(f"Loading model from {self.model_path}") self._model = Llama( model_path=self.model_path, n_ctx=self.n_ctx, n_threads=self.n_threads, n_batch=512, verbose=False ) logger.info("Model loaded successfully") async def generate_stream( self, prompt: str, max_tokens: int = 256, temperature: float = 0.7, stop: List[str] = None ) -> AsyncGenerator[str, None]: """ Async streaming generator for single request. Uses thread pool to run sync llama-cpp in background. """ if self._model is None: raise RuntimeError("Model not loaded") # Run blocking llama-cpp call in thread pool loop = asyncio.get_event_loop() def _generate(): return self._model.create_completion( prompt=prompt, max_tokens=max_tokens, temperature=temperature, stop=stop or [], stream=True # Enable streaming ) # Get streaming iterator stream = await loop.run_in_executor(None, _generate) # Yield tokens as they arrive for chunk in stream: if "choices" in chunk and len(chunk["choices"]) > 0: delta = chunk["choices"][0].get("text", "") if delta: yield delta async def generate_batch( self, prompts: List[str], max_tokens: int = 256, temperature: float = 0.7 ) -> List[str]: """ Process multiple prompts efficiently. On CPU, we process sequentially to avoid contention. """ results = [] for prompt in prompts: chunks = [] async for token in self.generate_stream(prompt, max_tokens, temperature): chunks.append(token) results.append("".join(chunks)) return results # Global singleton instance _engine: BatchInferenceEngine = None def get_engine() -> BatchInferenceEngine: global _engine if _engine is None: raise RuntimeError("Engine not initialized") return _engine def init_engine(model_path: str, **kwargs): global _engine _engine = BatchInferenceEngine(model_path, **kwargs) _engine.load() return _engine