File size: 5,447 Bytes
ccef136
0f0af70
33f19bf
ccef136
 
 
0f0af70
ccef136
 
0f0af70
c06971b
0f0af70
ccef136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f0af70
ccef136
 
33f19bf
ccef136
 
 
 
0f0af70
ccef136
33f19bf
0f0af70
 
 
b5dfa0f
33f19bf
c06971b
0f0af70
 
 
 
 
33f19bf
c06971b
 
ccef136
33f19bf
 
 
ccef136
0f0af70
 
 
ccef136
 
 
 
0f0af70
ccef136
 
 
b5dfa0f
ccef136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5dfa0f
 
 
 
 
 
 
33f19bf
ccef136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33f19bf
ccef136
33f19bf
ccef136
 
33f19bf
 
ccef136
 
 
33f19bf
 
ccef136
 
33f19bf
ccef136
 
 
 
 
 
 
 
33f19bf
ccef136
 
 
 
 
33f19bf
ccef136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33f19bf
ccef136
 
33f19bf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# app.py
import os
import time
import threading
from typing import Optional

from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from huggingface_hub import hf_hub_download
from llama_cpp import Llama

# ---------------- Config (fixed defaults; can be overridden by env) ----------------
REPO_ID   = os.getenv("REPO_ID", "bartowski/Llama-3.2-3B-Instruct-GGUF")
FILENAME  = os.getenv("FILENAME", "Llama-3.2-3B-Instruct-Q4_K_M.gguf")
CACHE_DIR = os.getenv("CACHE_DIR", "/app/models")

# Inference knobs (fixed for the Space; override via env only if needed)
N_THREADS = int(os.getenv("N_THREADS", str(min(4, (os.cpu_count() or 2)))))
N_BATCH   = int(os.getenv("N_BATCH", "64"))
N_CTX     = int(os.getenv("N_CTX", "2048"))

# Fixed sampling
MAX_TOKENS   = int(os.getenv("MAX_TOKENS", "256"))
TEMPERATURE  = float(os.getenv("TEMPERATURE", "0.7"))
TOP_P        = float(os.getenv("TOP_P", "0.9"))
STOP_TOKENS  = os.getenv("STOP_TOKENS", "</s>,<|eot_id|>").split(",")

# System prompt (optional). Leave empty for pure user prompt.
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "").strip()

# Safety margin for context budgeting (prompt + completion + overhead <= N_CTX)
CTX_SAFETY = int(os.getenv("CTX_SAFETY", "128"))

# ---------------- App scaffolding ----------------
app = FastAPI(title="Llama 3.2 3B Instruct (llama.cpp) API - Prompt Only")
_model: Optional[Llama] = None
_model_lock = threading.Lock()

# ---------------- Model loader ----------------
def get_model() -> Llama:
    global _model
    if _model is not None:
        return _model

    os.makedirs(CACHE_DIR, exist_ok=True)
    local_path = hf_hub_download(
        repo_id=REPO_ID,
        filename=FILENAME,
        cache_dir=CACHE_DIR,
        local_dir_use_symlinks=False,
    )

    _model = Llama(
        model_path=local_path,
        chat_format="llama-3",   # ensures proper Llama-3 prompt templating
        n_ctx=N_CTX,
        n_threads=N_THREADS,
        n_batch=N_BATCH,
        verbose=False,
    )
    return _model

@app.on_event("startup")
def _warm():
    # Preload to avoid cold-start on first request
    get_model()

# ---------------- Schemas ----------------
class GenerateRequest(BaseModel):
    prompt: str = Field(..., description="User prompt text only.")

# ---------------- Helpers ----------------
def _fit_prompt_to_context(model: Llama, prompt: str) -> str:
    """
    Simple context budgeting: ensures tokens(prompt) + MAX_TOKENS + CTX_SAFETY <= N_CTX.
    If over budget, we truncate the prompt from the start (keep the tail).
    """
    toks = model.tokenize(prompt.encode("utf-8"))
    budget = max(256, N_CTX - MAX_TOKENS - CTX_SAFETY)  # keep some minimal room
    if len(toks) <= budget:
        return prompt
    # Truncate from the front (keep the latest part)
    kept = model.detokenize(toks[-budget:])
    try:
        return kept.decode("utf-8", errors="ignore")
    except Exception:
        return kept.decode("utf-8", "ignore")

# ---------------- Endpoints ----------------
@app.get("/health")
def health():
    try:
        _ = get_model()
        return {"ok": True}
    except Exception as e:
        return {"ok": False, "error": str(e)}

@app.get("/config")
def config():
    return {
        "repo_id": REPO_ID,
        "filename": FILENAME,
        "cache_dir": CACHE_DIR,
        "n_threads": N_THREADS,
        "n_batch": N_BATCH,
        "n_ctx": N_CTX,
        "max_tokens": MAX_TOKENS,
        "temperature": TEMPERATURE,
        "top_p": TOP_P,
        "stop_tokens": STOP_TOKENS,
        "ctx_safety": CTX_SAFETY,
        "has_system_prompt": bool(SYSTEM_PROMPT),
    }

@app.post("/generate")
def generate(req: GenerateRequest):
    """
    Non-streaming chat completion.
    Accepts ONLY a prompt string; all other params are fixed in code.
    """
    try:
        if not req.prompt or not req.prompt.strip():
            raise HTTPException(status_code=400, detail="prompt must be a non-empty string")

        model = get_model()

        user_prompt = req.prompt.strip()
        fitted_prompt = _fit_prompt_to_context(model, user_prompt)

        # Build messages (Llama-3 chat format). System is optional.
        messages = []
        if SYSTEM_PROMPT:
            messages.append({"role": "system", "content": SYSTEM_PROMPT})
        messages.append({"role": "user", "content": fitted_prompt})

        t0 = time.time()
        with _model_lock:
            out = model.create_chat_completion(
                messages=messages,
                max_tokens=MAX_TOKENS,
                temperature=TEMPERATURE,
                top_p=TOP_P,
                stop=STOP_TOKENS,
            )
        dt = time.time() - t0

        text = out["choices"][0]["message"]["content"]
        usage = out.get("usage", {})  # may include prompt_tokens/completion_tokens

        return JSONResponse({
            "ok": True,
            "response": text,
            "usage": usage,
            "timing_sec": round(dt, 3),
            "params": {
                "max_tokens": MAX_TOKENS,
                "temperature": TEMPERATURE,
                "top_p": TOP_P,
                "stop": STOP_TOKENS,
                "n_ctx": N_CTX,
            },
            "prompt_truncated": (fitted_prompt != user_prompt),
        })

    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))