Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import time | |
| import threading | |
| from typing import Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field | |
| from huggingface_hub import hf_hub_download | |
| from llama_cpp import Llama | |
| # ---------------- Config (fixed defaults; can be overridden by env) ---------------- | |
| REPO_ID = os.getenv("REPO_ID", "bartowski/Llama-3.2-3B-Instruct-GGUF") | |
| FILENAME = os.getenv("FILENAME", "Llama-3.2-3B-Instruct-Q4_K_M.gguf") | |
| CACHE_DIR = os.getenv("CACHE_DIR", "/app/models") | |
| # Inference knobs (fixed for the Space; override via env only if needed) | |
| N_THREADS = int(os.getenv("N_THREADS", str(min(4, (os.cpu_count() or 2))))) | |
| N_BATCH = int(os.getenv("N_BATCH", "64")) | |
| N_CTX = int(os.getenv("N_CTX", "2048")) | |
| # Fixed sampling | |
| MAX_TOKENS = int(os.getenv("MAX_TOKENS", "256")) | |
| TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7")) | |
| TOP_P = float(os.getenv("TOP_P", "0.9")) | |
| STOP_TOKENS = os.getenv("STOP_TOKENS", "</s>,<|eot_id|>").split(",") | |
| # System prompt (optional). Leave empty for pure user prompt. | |
| SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "").strip() | |
| # Safety margin for context budgeting (prompt + completion + overhead <= N_CTX) | |
| CTX_SAFETY = int(os.getenv("CTX_SAFETY", "128")) | |
| # ---------------- App scaffolding ---------------- | |
| app = FastAPI(title="Llama 3.2 3B Instruct (llama.cpp) API - Prompt Only") | |
| _model: Optional[Llama] = None | |
| _model_lock = threading.Lock() | |
| # ---------------- Model loader ---------------- | |
| def get_model() -> Llama: | |
| global _model | |
| if _model is not None: | |
| return _model | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| local_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=FILENAME, | |
| cache_dir=CACHE_DIR, | |
| local_dir_use_symlinks=False, | |
| ) | |
| _model = Llama( | |
| model_path=local_path, | |
| chat_format="llama-3", # ensures proper Llama-3 prompt templating | |
| n_ctx=N_CTX, | |
| n_threads=N_THREADS, | |
| n_batch=N_BATCH, | |
| verbose=False, | |
| ) | |
| return _model | |
| def _warm(): | |
| # Preload to avoid cold-start on first request | |
| get_model() | |
| # ---------------- Schemas ---------------- | |
| class GenerateRequest(BaseModel): | |
| prompt: str = Field(..., description="User prompt text only.") | |
| # ---------------- Helpers ---------------- | |
| def _fit_prompt_to_context(model: Llama, prompt: str) -> str: | |
| """ | |
| Simple context budgeting: ensures tokens(prompt) + MAX_TOKENS + CTX_SAFETY <= N_CTX. | |
| If over budget, we truncate the prompt from the start (keep the tail). | |
| """ | |
| toks = model.tokenize(prompt.encode("utf-8")) | |
| budget = max(256, N_CTX - MAX_TOKENS - CTX_SAFETY) # keep some minimal room | |
| if len(toks) <= budget: | |
| return prompt | |
| # Truncate from the front (keep the latest part) | |
| kept = model.detokenize(toks[-budget:]) | |
| try: | |
| return kept.decode("utf-8", errors="ignore") | |
| except Exception: | |
| return kept.decode("utf-8", "ignore") | |
| # ---------------- Endpoints ---------------- | |
| def health(): | |
| try: | |
| _ = get_model() | |
| return {"ok": True} | |
| except Exception as e: | |
| return {"ok": False, "error": str(e)} | |
| def config(): | |
| return { | |
| "repo_id": REPO_ID, | |
| "filename": FILENAME, | |
| "cache_dir": CACHE_DIR, | |
| "n_threads": N_THREADS, | |
| "n_batch": N_BATCH, | |
| "n_ctx": N_CTX, | |
| "max_tokens": MAX_TOKENS, | |
| "temperature": TEMPERATURE, | |
| "top_p": TOP_P, | |
| "stop_tokens": STOP_TOKENS, | |
| "ctx_safety": CTX_SAFETY, | |
| "has_system_prompt": bool(SYSTEM_PROMPT), | |
| } | |
| def generate(req: GenerateRequest): | |
| """ | |
| Non-streaming chat completion. | |
| Accepts ONLY a prompt string; all other params are fixed in code. | |
| """ | |
| try: | |
| if not req.prompt or not req.prompt.strip(): | |
| raise HTTPException(status_code=400, detail="prompt must be a non-empty string") | |
| model = get_model() | |
| user_prompt = req.prompt.strip() | |
| fitted_prompt = _fit_prompt_to_context(model, user_prompt) | |
| # Build messages (Llama-3 chat format). System is optional. | |
| messages = [] | |
| if SYSTEM_PROMPT: | |
| messages.append({"role": "system", "content": SYSTEM_PROMPT}) | |
| messages.append({"role": "user", "content": fitted_prompt}) | |
| t0 = time.time() | |
| with _model_lock: | |
| out = model.create_chat_completion( | |
| messages=messages, | |
| max_tokens=MAX_TOKENS, | |
| temperature=TEMPERATURE, | |
| top_p=TOP_P, | |
| stop=STOP_TOKENS, | |
| ) | |
| dt = time.time() - t0 | |
| text = out["choices"][0]["message"]["content"] | |
| usage = out.get("usage", {}) # may include prompt_tokens/completion_tokens | |
| return JSONResponse({ | |
| "ok": True, | |
| "response": text, | |
| "usage": usage, | |
| "timing_sec": round(dt, 3), | |
| "params": { | |
| "max_tokens": MAX_TOKENS, | |
| "temperature": TEMPERATURE, | |
| "top_p": TOP_P, | |
| "stop": STOP_TOKENS, | |
| "n_ctx": N_CTX, | |
| }, | |
| "prompt_truncated": (fitted_prompt != user_prompt), | |
| }) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |