Spaces:
Sleeping
Sleeping
File size: 5,447 Bytes
ccef136 0f0af70 33f19bf ccef136 0f0af70 ccef136 0f0af70 c06971b 0f0af70 ccef136 0f0af70 ccef136 33f19bf ccef136 0f0af70 ccef136 33f19bf 0f0af70 b5dfa0f 33f19bf c06971b 0f0af70 33f19bf c06971b ccef136 33f19bf ccef136 0f0af70 ccef136 0f0af70 ccef136 b5dfa0f ccef136 b5dfa0f 33f19bf ccef136 33f19bf ccef136 33f19bf ccef136 33f19bf ccef136 33f19bf ccef136 33f19bf ccef136 33f19bf ccef136 33f19bf ccef136 33f19bf ccef136 33f19bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# app.py
import os
import time
import threading
from typing import Optional
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
# ---------------- Config (fixed defaults; can be overridden by env) ----------------
REPO_ID = os.getenv("REPO_ID", "bartowski/Llama-3.2-3B-Instruct-GGUF")
FILENAME = os.getenv("FILENAME", "Llama-3.2-3B-Instruct-Q4_K_M.gguf")
CACHE_DIR = os.getenv("CACHE_DIR", "/app/models")
# Inference knobs (fixed for the Space; override via env only if needed)
N_THREADS = int(os.getenv("N_THREADS", str(min(4, (os.cpu_count() or 2)))))
N_BATCH = int(os.getenv("N_BATCH", "64"))
N_CTX = int(os.getenv("N_CTX", "2048"))
# Fixed sampling
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "256"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
TOP_P = float(os.getenv("TOP_P", "0.9"))
STOP_TOKENS = os.getenv("STOP_TOKENS", "</s>,<|eot_id|>").split(",")
# System prompt (optional). Leave empty for pure user prompt.
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "").strip()
# Safety margin for context budgeting (prompt + completion + overhead <= N_CTX)
CTX_SAFETY = int(os.getenv("CTX_SAFETY", "128"))
# ---------------- App scaffolding ----------------
app = FastAPI(title="Llama 3.2 3B Instruct (llama.cpp) API - Prompt Only")
_model: Optional[Llama] = None
_model_lock = threading.Lock()
# ---------------- Model loader ----------------
def get_model() -> Llama:
global _model
if _model is not None:
return _model
os.makedirs(CACHE_DIR, exist_ok=True)
local_path = hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
cache_dir=CACHE_DIR,
local_dir_use_symlinks=False,
)
_model = Llama(
model_path=local_path,
chat_format="llama-3", # ensures proper Llama-3 prompt templating
n_ctx=N_CTX,
n_threads=N_THREADS,
n_batch=N_BATCH,
verbose=False,
)
return _model
@app.on_event("startup")
def _warm():
# Preload to avoid cold-start on first request
get_model()
# ---------------- Schemas ----------------
class GenerateRequest(BaseModel):
prompt: str = Field(..., description="User prompt text only.")
# ---------------- Helpers ----------------
def _fit_prompt_to_context(model: Llama, prompt: str) -> str:
"""
Simple context budgeting: ensures tokens(prompt) + MAX_TOKENS + CTX_SAFETY <= N_CTX.
If over budget, we truncate the prompt from the start (keep the tail).
"""
toks = model.tokenize(prompt.encode("utf-8"))
budget = max(256, N_CTX - MAX_TOKENS - CTX_SAFETY) # keep some minimal room
if len(toks) <= budget:
return prompt
# Truncate from the front (keep the latest part)
kept = model.detokenize(toks[-budget:])
try:
return kept.decode("utf-8", errors="ignore")
except Exception:
return kept.decode("utf-8", "ignore")
# ---------------- Endpoints ----------------
@app.get("/health")
def health():
try:
_ = get_model()
return {"ok": True}
except Exception as e:
return {"ok": False, "error": str(e)}
@app.get("/config")
def config():
return {
"repo_id": REPO_ID,
"filename": FILENAME,
"cache_dir": CACHE_DIR,
"n_threads": N_THREADS,
"n_batch": N_BATCH,
"n_ctx": N_CTX,
"max_tokens": MAX_TOKENS,
"temperature": TEMPERATURE,
"top_p": TOP_P,
"stop_tokens": STOP_TOKENS,
"ctx_safety": CTX_SAFETY,
"has_system_prompt": bool(SYSTEM_PROMPT),
}
@app.post("/generate")
def generate(req: GenerateRequest):
"""
Non-streaming chat completion.
Accepts ONLY a prompt string; all other params are fixed in code.
"""
try:
if not req.prompt or not req.prompt.strip():
raise HTTPException(status_code=400, detail="prompt must be a non-empty string")
model = get_model()
user_prompt = req.prompt.strip()
fitted_prompt = _fit_prompt_to_context(model, user_prompt)
# Build messages (Llama-3 chat format). System is optional.
messages = []
if SYSTEM_PROMPT:
messages.append({"role": "system", "content": SYSTEM_PROMPT})
messages.append({"role": "user", "content": fitted_prompt})
t0 = time.time()
with _model_lock:
out = model.create_chat_completion(
messages=messages,
max_tokens=MAX_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
stop=STOP_TOKENS,
)
dt = time.time() - t0
text = out["choices"][0]["message"]["content"]
usage = out.get("usage", {}) # may include prompt_tokens/completion_tokens
return JSONResponse({
"ok": True,
"response": text,
"usage": usage,
"timing_sec": round(dt, 3),
"params": {
"max_tokens": MAX_TOKENS,
"temperature": TEMPERATURE,
"top_p": TOP_P,
"stop": STOP_TOKENS,
"n_ctx": N_CTX,
},
"prompt_truncated": (fitted_prompt != user_prompt),
})
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|