llm_server / app.py
omaryasserhassan's picture
Update app.py
ccef136 verified
raw
history blame
5.45 kB
# app.py
import os
import time
import threading
from typing import Optional
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
# ---------------- Config (fixed defaults; can be overridden by env) ----------------
REPO_ID = os.getenv("REPO_ID", "bartowski/Llama-3.2-3B-Instruct-GGUF")
FILENAME = os.getenv("FILENAME", "Llama-3.2-3B-Instruct-Q4_K_M.gguf")
CACHE_DIR = os.getenv("CACHE_DIR", "/app/models")
# Inference knobs (fixed for the Space; override via env only if needed)
N_THREADS = int(os.getenv("N_THREADS", str(min(4, (os.cpu_count() or 2)))))
N_BATCH = int(os.getenv("N_BATCH", "64"))
N_CTX = int(os.getenv("N_CTX", "2048"))
# Fixed sampling
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "256"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
TOP_P = float(os.getenv("TOP_P", "0.9"))
STOP_TOKENS = os.getenv("STOP_TOKENS", "</s>,<|eot_id|>").split(",")
# System prompt (optional). Leave empty for pure user prompt.
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "").strip()
# Safety margin for context budgeting (prompt + completion + overhead <= N_CTX)
CTX_SAFETY = int(os.getenv("CTX_SAFETY", "128"))
# ---------------- App scaffolding ----------------
app = FastAPI(title="Llama 3.2 3B Instruct (llama.cpp) API - Prompt Only")
_model: Optional[Llama] = None
_model_lock = threading.Lock()
# ---------------- Model loader ----------------
def get_model() -> Llama:
global _model
if _model is not None:
return _model
os.makedirs(CACHE_DIR, exist_ok=True)
local_path = hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
cache_dir=CACHE_DIR,
local_dir_use_symlinks=False,
)
_model = Llama(
model_path=local_path,
chat_format="llama-3", # ensures proper Llama-3 prompt templating
n_ctx=N_CTX,
n_threads=N_THREADS,
n_batch=N_BATCH,
verbose=False,
)
return _model
@app.on_event("startup")
def _warm():
# Preload to avoid cold-start on first request
get_model()
# ---------------- Schemas ----------------
class GenerateRequest(BaseModel):
prompt: str = Field(..., description="User prompt text only.")
# ---------------- Helpers ----------------
def _fit_prompt_to_context(model: Llama, prompt: str) -> str:
"""
Simple context budgeting: ensures tokens(prompt) + MAX_TOKENS + CTX_SAFETY <= N_CTX.
If over budget, we truncate the prompt from the start (keep the tail).
"""
toks = model.tokenize(prompt.encode("utf-8"))
budget = max(256, N_CTX - MAX_TOKENS - CTX_SAFETY) # keep some minimal room
if len(toks) <= budget:
return prompt
# Truncate from the front (keep the latest part)
kept = model.detokenize(toks[-budget:])
try:
return kept.decode("utf-8", errors="ignore")
except Exception:
return kept.decode("utf-8", "ignore")
# ---------------- Endpoints ----------------
@app.get("/health")
def health():
try:
_ = get_model()
return {"ok": True}
except Exception as e:
return {"ok": False, "error": str(e)}
@app.get("/config")
def config():
return {
"repo_id": REPO_ID,
"filename": FILENAME,
"cache_dir": CACHE_DIR,
"n_threads": N_THREADS,
"n_batch": N_BATCH,
"n_ctx": N_CTX,
"max_tokens": MAX_TOKENS,
"temperature": TEMPERATURE,
"top_p": TOP_P,
"stop_tokens": STOP_TOKENS,
"ctx_safety": CTX_SAFETY,
"has_system_prompt": bool(SYSTEM_PROMPT),
}
@app.post("/generate")
def generate(req: GenerateRequest):
"""
Non-streaming chat completion.
Accepts ONLY a prompt string; all other params are fixed in code.
"""
try:
if not req.prompt or not req.prompt.strip():
raise HTTPException(status_code=400, detail="prompt must be a non-empty string")
model = get_model()
user_prompt = req.prompt.strip()
fitted_prompt = _fit_prompt_to_context(model, user_prompt)
# Build messages (Llama-3 chat format). System is optional.
messages = []
if SYSTEM_PROMPT:
messages.append({"role": "system", "content": SYSTEM_PROMPT})
messages.append({"role": "user", "content": fitted_prompt})
t0 = time.time()
with _model_lock:
out = model.create_chat_completion(
messages=messages,
max_tokens=MAX_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
stop=STOP_TOKENS,
)
dt = time.time() - t0
text = out["choices"][0]["message"]["content"]
usage = out.get("usage", {}) # may include prompt_tokens/completion_tokens
return JSONResponse({
"ok": True,
"response": text,
"usage": usage,
"timing_sec": round(dt, 3),
"params": {
"max_tokens": MAX_TOKENS,
"temperature": TEMPERATURE,
"top_p": TOP_P,
"stop": STOP_TOKENS,
"n_ctx": N_CTX,
},
"prompt_truncated": (fitted_prompt != user_prompt),
})
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))