File size: 3,114 Bytes
a17c086
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import asyncio
from typing import List, AsyncGenerator, Dict
from llama_cpp import Llama, LlamaGrammar
import logging

logger = logging.getLogger(__name__)

class BatchInferenceEngine:
    """
    Pure Python batch inference engine using llama-cpp-python.
    Loads model once, handles multiple concurrent requests efficiently.
    """
    
    def __init__(self, model_path: str, n_ctx: int = 4096, n_threads: int = 4):
        self.model_path = model_path
        self.n_ctx = n_ctx
        self.n_threads = n_threads
        self._model: Llama = None
        self._lock = asyncio.Lock()
        
    def load(self):
        """Load model once at startup"""
        logger.info(f"Loading model from {self.model_path}")
        self._model = Llama(
            model_path=self.model_path,
            n_ctx=self.n_ctx,
            n_threads=self.n_threads,
            n_batch=512,
            verbose=False
        )
        logger.info("Model loaded successfully")
        
    async def generate_stream(
        self, 
        prompt: str, 
        max_tokens: int = 256,
        temperature: float = 0.7,
        stop: List[str] = None
    ) -> AsyncGenerator[str, None]:
        """
        Async streaming generator for single request.
        Uses thread pool to run sync llama-cpp in background.
        """
        if self._model is None:
            raise RuntimeError("Model not loaded")
            
        # Run blocking llama-cpp call in thread pool
        loop = asyncio.get_event_loop()
        
        def _generate():
            return self._model.create_completion(
                prompt=prompt,
                max_tokens=max_tokens,
                temperature=temperature,
                stop=stop or [],
                stream=True  # Enable streaming
            )
        
        # Get streaming iterator
        stream = await loop.run_in_executor(None, _generate)
        
        # Yield tokens as they arrive
        for chunk in stream:
            if "choices" in chunk and len(chunk["choices"]) > 0:
                delta = chunk["choices"][0].get("text", "")
                if delta:
                    yield delta
                    
    async def generate_batch(
        self,
        prompts: List[str],
        max_tokens: int = 256,
        temperature: float = 0.7
    ) -> List[str]:
        """
        Process multiple prompts efficiently.
        On CPU, we process sequentially to avoid contention.
        """
        results = []
        for prompt in prompts:
            chunks = []
            async for token in self.generate_stream(prompt, max_tokens, temperature):
                chunks.append(token)
            results.append("".join(chunks))
        return results

# Global singleton instance
_engine: BatchInferenceEngine = None

def get_engine() -> BatchInferenceEngine:
    global _engine
    if _engine is None:
        raise RuntimeError("Engine not initialized")
    return _engine

def init_engine(model_path: str, **kwargs):
    global _engine
    _engine = BatchInferenceEngine(model_path, **kwargs)
    _engine.load()
    return _engine