Karan6933 commited on
Commit
c2f1385
·
verified ·
1 Parent(s): 0483cf5

Update api/main.py

Browse files
Files changed (1) hide show
  1. api/main.py +86 -73
api/main.py CHANGED
@@ -1,81 +1,94 @@
1
- from fastapi import FastAPI
2
- from fastapi.responses import StreamingResponse
3
- from batcher import BatchScheduler
4
- from bridge import stream_batch
5
  import asyncio
6
- import time
 
 
 
 
 
7
 
8
- app = FastAPI()
9
- scheduler = BatchScheduler(max_batch=8, max_wait_ms=30)
10
 
11
- # In-memory chat history (per process, for demo)
12
- chat_histories = {}
13
 
14
- @app.post("/chat")
15
- async def chat(prompt: str, session_id: str = "default"):
16
- # Simple history management
17
- if session_id not in chat_histories:
18
- chat_histories[session_id] = []
19
-
20
- # Contextual prompt construction
21
- history = "\n".join(chat_histories[session_id])
22
- if history:
23
- full_prompt = f"{history}\n{prompt}"
24
- else:
25
- full_prompt = prompt
26
-
27
- # Get the queue for this request
28
- token_queue = await scheduler.add(full_prompt)
29
-
30
- # Generator to yield tokens from the queue
31
- async def response_generator():
32
- full_response = []
33
- while True:
34
- token = await token_queue.get()
35
- if token is None:
36
- break
37
- yield token
38
- full_response.append(token)
39
-
40
- # After streaming is done, update history
41
- # Note: This runs after the response closes, might need background task if strict
42
- # But for generator, code continues after yield
43
- response_text = "".join(full_response)
44
- chat_histories[session_id].append(f"User: {prompt}")
45
- chat_histories[session_id].append(f"AI: {response_text}")
46
-
47
- # Keep history concise
48
- if len(chat_histories[session_id]) > 10:
49
- chat_histories[session_id] = chat_histories[session_id][-10:]
50
 
51
- return StreamingResponse(response_generator(), media_type="text/plain")
 
 
 
52
 
53
- async def batch_loop():
54
- print("Batch loop started...")
55
- while True:
56
- # Wait for a batch
57
- batch = await scheduler.get_batch()
58
- if not batch:
59
- await asyncio.sleep(0.01) # Short sleep if empty
60
- continue
61
-
62
- # Process batch
63
- prompts, queues = zip(*batch)
64
- print(f"Processing batch of {len(prompts)} prompts")
65
-
66
- # Stream from C++ engine
67
- # Iterate over the generator which yields step-by-step tokens
68
- for step_tokens in stream_batch(prompts):
69
- for q, token in zip(queues, step_tokens):
70
- if token is not None:
71
- q.put_nowait(token)
72
- # Yield control to event loop to let FastAPI flush tokens
73
- await asyncio.sleep(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # Signal done
76
- for q in queues:
77
- q.put_nowait(None)
 
 
 
 
 
 
 
78
 
79
- @app.on_event("startup")
80
- async def startup_event():
81
- asyncio.create_task(batch_loop())
 
1
+ import os
 
 
 
2
  import asyncio
3
+ from contextlib import asynccontextmanager
4
+ from fastapi import FastAPI, HTTPException
5
+ from fastapi.responses import StreamingResponse
6
+ from pydantic import BaseModel
7
+ from typing import List, Optional
8
+ import logging
9
 
10
+ from engine import init_engine, get_engine
 
11
 
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
 
15
+ # Configuration
16
+ MODEL_PATH = os.getenv("MODEL_PATH", "model/model.gguf")
17
+ MODEL_URL = os.getenv("MODEL_URL", "https://huggingface.co/prithivMLmods/Nanbeige4.1-3B-f32-GGUF/resolve/main/Nanbeige4.1-3B.Q8_0.gguf")
18
+
19
+ class GenerateRequest(BaseModel):
20
+ prompt: str
21
+ max_tokens: int = 256
22
+ temperature: float = 0.7
23
+ stream: bool = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ class BatchRequest(BaseModel):
26
+ prompts: List[str]
27
+ max_tokens: int = 256
28
+ temperature: float = 0.7
29
 
30
+ def download_model():
31
+ """Download model if not exists"""
32
+ if not os.path.exists(MODEL_PATH):
33
+ os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
34
+ logger.info(f"Downloading model from {MODEL_URL}")
35
+ import urllib.request
36
+ urllib.request.urlretrieve(MODEL_URL, MODEL_PATH)
37
+ logger.info("Model downloaded")
38
+
39
+ @asynccontextmanager
40
+ async def lifespan(app: FastAPI):
41
+ # Startup
42
+ logger.info("Starting up...")
43
+ download_model()
44
+ init_engine(MODEL_PATH, n_ctx=4096, n_threads=4)
45
+ logger.info("Ready!")
46
+ yield
47
+ # Shutdown
48
+ logger.info("Shutting down...")
49
+
50
+ app = FastAPI(title="Nanbeige LLM API", lifespan=lifespan)
51
+
52
+ @app.post("/generate")
53
+ async def generate(req: GenerateRequest):
54
+ """Single prompt generation with streaming"""
55
+ engine = get_engine()
56
+
57
+ if req.stream:
58
+ async def stream_generator():
59
+ async for token in engine.generate_stream(
60
+ req.prompt,
61
+ max_tokens=req.max_tokens,
62
+ temperature=req.temperature
63
+ ):
64
+ yield token
65
+
66
+ return StreamingResponse(
67
+ stream_generator(),
68
+ media_type="text/plain"
69
+ )
70
+ else:
71
+ # Non-streaming: collect all tokens
72
+ chunks = []
73
+ async for token in engine.generate_stream(
74
+ req.prompt,
75
+ max_tokens=req.max_tokens,
76
+ temperature=req.temperature
77
+ ):
78
+ chunks.append(token)
79
+ return {"text": "".join(chunks)}
80
 
81
+ @app.post("/generate_batch")
82
+ async def generate_batch(req: BatchRequest):
83
+ """Batch generation (multiple prompts)"""
84
+ engine = get_engine()
85
+ results = await engine.generate_batch(
86
+ req.prompts,
87
+ max_tokens=req.max_tokens,
88
+ temperature=req.temperature
89
+ )
90
+ return {"results": results}
91
 
92
+ @app.get("/health")
93
+ async def health():
94
+ return {"status": "ok", "model_loaded": get_engine()._model is not None}