omaryasserhassan commited on
Commit
490925d
·
verified ·
1 Parent(s): 77fd55f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -191
app.py CHANGED
@@ -1,217 +1,61 @@
1
  # app.py
2
  import os
3
- import time
4
- import threading
5
- from typing import Optional
6
-
7
  from fastapi import FastAPI, HTTPException
8
  from fastapi.responses import JSONResponse
9
- from pydantic import BaseModel, Field
10
  from huggingface_hub import hf_hub_download
11
  from llama_cpp import Llama
12
 
13
  # ---------------- Config ----------------
14
- REPO_ID = os.getenv("REPO_ID", "bartowski/Llama-3.2-3B-Instruct-GGUF")
15
- FILENAME = os.getenv("FILENAME", "Llama-3.2-3B-Instruct-Q4_K_M.gguf")
16
-
17
- # BUILD-TIME PREFETCH LOCATION (your Dockerfile downloads here)
18
- BUILD_CACHE_DIR = "/app/models"
19
- BUILD_MODEL_PATH = os.path.join(BUILD_CACHE_DIR, FILENAME)
20
-
21
- # Preferred runtime cache (only used if model not found above)
22
- PREFERRED_CACHE_DIR = os.getenv("CACHE_DIR", "/app/models")
23
-
24
- # Inference knobs (conservative for small CPU Spaces)
25
- N_THREADS = min(4, (os.cpu_count() or 2))
26
- N_BATCH = int(os.getenv("N_BATCH", "16")) # safer than 32/64 on tiny CPUs
27
- N_CTX = int(os.getenv("N_CTX", "2048"))
28
 
29
- # Sampling (keep short for latency)
30
- MAX_TOKENS = int(os.getenv("MAX_TOKENS", "48")) # tighter → faster
31
- TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
32
- TOP_P = float(os.getenv("TOP_P", "0.9"))
33
- STOP_TOKENS = os.getenv("STOP_TOKENS", "</s>,<|eot_id|>").split(",")
34
-
35
- SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "").strip()
36
- CTX_SAFETY = int(os.getenv("CTX_SAFETY", "128"))
37
 
38
  # ---------------- App ----------------
39
- app = FastAPI(title="Llama 3.2 3B Instruct (llama.cpp) API - Prompt Only")
40
- _model: Optional[Llama] = None
41
- _model_lock = threading.Lock()
42
- _effective_model_path: Optional[str] = None
43
- _effective_cache_dir: Optional[str] = None
44
-
45
- def _select_writable_cache_dir(preferred: str) -> str:
46
- candidates = [
47
- preferred,
48
- os.path.join(os.path.expanduser("~"), ".cache", "hf_models"),
49
- "/tmp/hf_models",
50
- ]
51
- for d in candidates:
52
- try:
53
- os.makedirs(d, exist_ok=True)
54
- test_file = os.path.join(d, ".w")
55
- with open(test_file, "w") as f:
56
- f.write("ok")
57
- os.remove(test_file)
58
- return d
59
- except Exception:
60
- continue
61
- raise RuntimeError("No writable cache directory found")
62
-
63
- def _resolve_model_path() -> str:
64
- """
65
- 1) If the model file exists at build path (/app/models/...), use it (fast path).
66
- 2) Else, download into first writable cache dir and return that path.
67
- """
68
- global _effective_cache_dir
69
- if os.path.isfile(BUILD_MODEL_PATH):
70
- return BUILD_MODEL_PATH
71
 
72
- if _effective_cache_dir is None:
73
- _effective_cache_dir = _select_writable_cache_dir(PREFERRED_CACHE_DIR)
74
 
 
 
 
 
 
75
  local_path = hf_hub_download(
76
  repo_id=REPO_ID,
77
  filename=FILENAME,
78
- cache_dir=_effective_cache_dir,
79
  local_dir_use_symlinks=False,
80
  )
81
- return local_path
82
-
83
- # ---------------- Model loader ----------------
84
- def get_model() -> Llama:
85
- global _model, _effective_model_path
86
- if _model is not None:
87
- return _model
88
-
89
- # Resolve path without failing on /data permission
90
- _effective_model_path = _resolve_model_path()
91
-
92
- # llama.cpp init (CPU-friendly)
93
- _model = Llama(
94
- model_path=_effective_model_path,
95
  chat_format="llama-3",
96
  n_ctx=N_CTX,
97
  n_threads=N_THREADS,
98
  n_batch=N_BATCH,
99
- use_mmap=True, # faster load on CPU
100
- n_gpu_layers=0, # ensure pure CPU
101
  verbose=False,
102
  )
103
- return _model
104
-
105
- @app.on_event("startup")
106
- def _warm_start():
107
- get_model() # force load at startup so first request is predictable
108
-
109
- # ---------------- Schemas ----------------
110
- class GenerateRequest(BaseModel):
111
- prompt: str = Field(..., description="User prompt text only.")
112
-
113
- # ---------------- Helpers ----------------
114
- def _fit_prompt_to_context(model: Llama, prompt: str) -> str:
115
- """
116
- Ensure tokens(prompt) + MAX_TOKENS + CTX_SAFETY <= N_CTX.
117
- If over budget, truncate from the front (keep the tail).
118
- """
119
- toks = model.tokenize(prompt.encode("utf-8"))
120
- budget = max(256, N_CTX - MAX_TOKENS - CTX_SAFETY)
121
- if len(toks) <= budget:
122
- return prompt
123
- kept = model.detokenize(toks[-budget:])
124
- try:
125
- return kept.decode("utf-8", errors="ignore")
126
- except Exception:
127
- return kept.decode("utf-8", "ignore")
128
-
129
- # ---------------- Endpoints ----------------
130
- @app.get("/health")
131
- def health():
132
- try:
133
- _ = get_model()
134
- return {
135
- "ok": True,
136
- "model_path": _effective_model_path,
137
- "cache_dir": _effective_cache_dir,
138
- "n_threads": N_THREADS,
139
- "n_batch": N_BATCH,
140
- "n_ctx": N_CTX
141
- }
142
- except Exception as e:
143
- return {"ok": False, "error": str(e)}
144
-
145
- @app.get("/warmup")
146
- def warmup():
147
- model = get_model()
148
- messages = [{"role": "user", "content": "Say OK."}]
149
- t0 = time.time()
150
- with _model_lock:
151
- out = model.create_chat_completion(
152
- messages=messages,
153
- max_tokens=8,
154
- temperature=0.0,
155
- top_p=1.0,
156
- stop=STOP_TOKENS,
157
- )
158
- dt = time.time() - t0
159
- text = out["choices"][0]["message"]["content"]
160
- return {"ok": True, "ms": int(dt * 1000), "resp": text.strip()}
161
 
 
162
  @app.post("/generate")
163
- def generate(req: GenerateRequest):
164
- """
165
- Non-streaming chat completion.
166
- Accepts ONLY a prompt string; all other params are fixed here.
167
- """
168
- try:
169
- if not req.prompt or not req.prompt.strip():
170
- raise HTTPException(status_code=400, detail="prompt must be a non-empty string")
171
-
172
- model = get_model()
173
- user_prompt = req.prompt.strip()
174
- fitted_prompt = _fit_prompt_to_context(model, user_prompt)
175
-
176
- messages = []
177
- if SYSTEM_PROMPT:
178
- messages.append({"role": "system", "content": SYSTEM_PROMPT})
179
- messages.append({"role": "user", "content": fitted_prompt})
180
-
181
- t0 = time.time()
182
- with _model_lock:
183
- out = model.create_chat_completion(
184
- messages=messages,
185
- max_tokens=MAX_TOKENS,
186
- temperature=TEMPERATURE,
187
- top_p=TOP_P,
188
- stop=STOP_TOKENS,
189
- )
190
- dt = time.time() - t0
191
-
192
- text = out["choices"][0]["message"]["content"]
193
- usage = out.get("usage", {})
194
-
195
- return JSONResponse({
196
- "ok": True,
197
- "response": text,
198
- "usage": usage,
199
- "timing_sec": round(dt, 3),
200
- "params": {
201
- "max_tokens": MAX_TOKENS,
202
- "temperature": TEMPERATURE,
203
- "top_p": TOP_P,
204
- "stop": STOP_TOKENS,
205
- "n_ctx": N_CTX,
206
- "n_batch": N_BATCH,
207
- "n_threads": N_THREADS,
208
- },
209
- "prompt_truncated": (fitted_prompt != user_prompt),
210
- "effective_model_path": _effective_model_path,
211
- "effective_cache_dir": _effective_cache_dir,
212
- })
213
-
214
- except HTTPException:
215
- raise
216
- except Exception as e:
217
- raise HTTPException(status_code=500, detail=str(e))
 
1
  # app.py
2
  import os
 
 
 
 
3
  from fastapi import FastAPI, HTTPException
4
  from fastapi.responses import JSONResponse
5
+ from pydantic import BaseModel
6
  from huggingface_hub import hf_hub_download
7
  from llama_cpp import Llama
8
 
9
  # ---------------- Config ----------------
10
+ REPO_ID = "bartowski/Llama-3.2-3B-Instruct-GGUF"
11
+ FILENAME = "Llama-3.2-3B-Instruct-Q4_K_M.gguf"
12
+ CACHE_DIR = "/app/models"
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ N_THREADS = min(4, os.cpu_count() or 2)
15
+ N_BATCH = 16
16
+ N_CTX = 2048
17
+ MAX_TOKENS = 64
 
 
 
 
18
 
19
  # ---------------- App ----------------
20
+ app = FastAPI(title="Simple Llama Server")
21
+ model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ class PromptRequest(BaseModel):
24
+ prompt: str
25
 
26
+ # ---------------- Startup ----------------
27
+ @app.on_event("startup")
28
+ def load_model():
29
+ global model
30
+ os.makedirs(CACHE_DIR, exist_ok=True)
31
  local_path = hf_hub_download(
32
  repo_id=REPO_ID,
33
  filename=FILENAME,
34
+ cache_dir=CACHE_DIR,
35
  local_dir_use_symlinks=False,
36
  )
37
+ model = Llama(
38
+ model_path=local_path,
 
 
 
 
 
 
 
 
 
 
 
 
39
  chat_format="llama-3",
40
  n_ctx=N_CTX,
41
  n_threads=N_THREADS,
42
  n_batch=N_BATCH,
 
 
43
  verbose=False,
44
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # ---------------- Endpoint ----------------
47
  @app.post("/generate")
48
+ def generate(req: PromptRequest):
49
+ global model
50
+ if model is None:
51
+ raise HTTPException(status_code=500, detail="Model not loaded")
52
+
53
+ out = model.create_chat_completion(
54
+ messages=[{"role": "user", "content": req.prompt}],
55
+ max_tokens=MAX_TOKENS,
56
+ temperature=0.7,
57
+ top_p=0.9,
58
+ stop=["</s>", "<|eot_id|>"]
59
+ )
60
+ text = out["choices"][0]["message"]["content"]
61
+ return JSONResponse({"response": text})