llm_server / app.py
omaryasserhassan's picture
Update app.py
730089e verified
# app.py
import os, time
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
# ---------- Minimal fixed config (fast on CPU) ----------
REPO_ID = "bartowski/Llama-3.2-1B-Instruct-GGUF" # 1B = much faster on CPU
FILENAME = "Llama-3.2-1B-Instruct-Q4_K_M.gguf"
# Build-time prefetch location (Dockerfile step put model here)
BUILD_DIR = "/app/models"
MODEL_PATH = os.path.join(BUILD_DIR, FILENAME)
# Writable runtime cache if the prebuilt file isn't present
RUNTIME_CACHE = "/tmp/hf_cache"
N_THREADS = min(4, os.cpu_count() or 2)
N_BATCH = 8
N_CTX = 2048
MAX_TOKENS = 16
TEMPERATURE = 0.7
TOP_P = 0.9
STOP = ["</s>", "<|eot_id|>"]
# ---------- App ----------
app = FastAPI(title="Simple Llama Server (1B fast)")
model = None
effective_model_path = None
class PromptRequest(BaseModel):
prompt: str
@app.on_event("startup")
def load_model():
global model, effective_model_path
# 1) If the model exists from the Docker build, use it directly (no writes)
if os.path.isfile(MODEL_PATH):
effective_model_path = MODEL_PATH
else:
# 2) Otherwise, download to a writable temp cache (NOT under /app)
os.makedirs(RUNTIME_CACHE, exist_ok=True)
effective_model_path = hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
cache_dir=RUNTIME_CACHE,
local_dir_use_symlinks=False,
)
t0 = time.time()
model = Llama(
model_path=effective_model_path,
chat_format="llama-3",
n_ctx=N_CTX,
n_threads=N_THREADS,
n_batch=N_BATCH,
use_mmap=True, # faster load
n_gpu_layers=0, # CPU only
verbose=False,
)
print(f"[startup] loaded {effective_model_path} in {time.time()-t0:.2f}s")
@app.get("/health")
def health():
return {"ok": model is not None, "model_path": effective_model_path}
@app.post("/generate")
def generate(req: PromptRequest):
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded")
prompt = (req.prompt or "").strip()
if not prompt:
raise HTTPException(status_code=400, detail="prompt must be non-empty")
t0 = time.time()
out = model.create_chat_completion(
messages=[{"role": "user", "content": prompt}],
max_tokens=MAX_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
stop=STOP,
)
text = out["choices"][0]["message"]["content"]
return JSONResponse({"response": text, "timing_sec": round(time.time()-t0, 2)})