omaryasserhassan commited on
Commit
77fd55f
·
verified ·
1 Parent(s): a510b92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -47
app.py CHANGED
@@ -10,20 +10,24 @@ from pydantic import BaseModel, Field
10
  from huggingface_hub import hf_hub_download
11
  from llama_cpp import Llama
12
 
13
- # ---------------- Config (still overridable via env) ----------------
14
  REPO_ID = os.getenv("REPO_ID", "bartowski/Llama-3.2-3B-Instruct-GGUF")
15
  FILENAME = os.getenv("FILENAME", "Llama-3.2-3B-Instruct-Q4_K_M.gguf")
16
 
17
- # Preferred cache dir (may not be writable)
18
- CACHE_DIR = os.getenv("CACHE_DIR", "/data/models")
 
19
 
20
- # Inference knobs (safer on small CPU Spaces)
21
- N_THREADS = int(os.getenv("N_THREADS", str(min(4, (os.cpu_count() or 2)))))
22
- N_BATCH = int(os.getenv("N_BATCH", "32"))
 
 
 
23
  N_CTX = int(os.getenv("N_CTX", "2048"))
24
 
25
- # Fixed sampling (fast-ish defaults)
26
- MAX_TOKENS = int(os.getenv("MAX_TOKENS", "96"))
27
  TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
28
  TOP_P = float(os.getenv("TOP_P", "0.9"))
29
  STOP_TOKENS = os.getenv("STOP_TOKENS", "</s>,<|eot_id|>").split(",")
@@ -31,17 +35,14 @@ STOP_TOKENS = os.getenv("STOP_TOKENS", "</s>,<|eot_id|>").split(",")
31
  SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "").strip()
32
  CTX_SAFETY = int(os.getenv("CTX_SAFETY", "128"))
33
 
34
- # ---------------- App scaffolding ----------------
35
  app = FastAPI(title="Llama 3.2 3B Instruct (llama.cpp) API - Prompt Only")
36
  _model: Optional[Llama] = None
37
  _model_lock = threading.Lock()
 
38
  _effective_cache_dir: Optional[str] = None
39
 
40
  def _select_writable_cache_dir(preferred: str) -> str:
41
- """
42
- Pick the first writable directory from a list of candidates.
43
- Tries to mkdir and write a tiny file to confirm writability.
44
- """
45
  candidates = [
46
  preferred,
47
  os.path.join(os.path.expanduser("~"), ".cache", "hf_models"),
@@ -50,24 +51,26 @@ def _select_writable_cache_dir(preferred: str) -> str:
50
  for d in candidates:
51
  try:
52
  os.makedirs(d, exist_ok=True)
53
- test_path = os.path.join(d, ".write_test")
54
- with open(test_path, "w") as f:
55
  f.write("ok")
56
- os.remove(test_path)
57
  return d
58
  except Exception:
59
  continue
60
- raise RuntimeError("No writable cache directory found among: " + ", ".join(candidates))
61
 
62
- # ---------------- Model loader ----------------
63
- def get_model() -> Llama:
64
- global _model, _effective_cache_dir
65
- if _model is not None:
66
- return _model
 
 
 
67
 
68
- # pick a writable cache dir (handles /data permission issues)
69
  if _effective_cache_dir is None:
70
- _effective_cache_dir = _select_writable_cache_dir(CACHE_DIR)
71
 
72
  local_path = hf_hub_download(
73
  repo_id=REPO_ID,
@@ -75,21 +78,33 @@ def get_model() -> Llama:
75
  cache_dir=_effective_cache_dir,
76
  local_dir_use_symlinks=False,
77
  )
 
78
 
 
 
 
 
 
 
 
 
 
 
79
  _model = Llama(
80
- model_path=local_path,
81
  chat_format="llama-3",
82
  n_ctx=N_CTX,
83
  n_threads=N_THREADS,
84
  n_batch=N_BATCH,
 
 
85
  verbose=False,
86
  )
87
  return _model
88
 
89
  @app.on_event("startup")
90
  def _warm_start():
91
- # Preload to avoid cold-start cost on first request
92
- get_model()
93
 
94
  # ---------------- Schemas ----------------
95
  class GenerateRequest(BaseModel):
@@ -116,28 +131,17 @@ def _fit_prompt_to_context(model: Llama, prompt: str) -> str:
116
  def health():
117
  try:
118
  _ = get_model()
119
- return {"ok": True, "cache_dir": _effective_cache_dir}
 
 
 
 
 
 
 
120
  except Exception as e:
121
  return {"ok": False, "error": str(e)}
122
 
123
- @app.get("/config")
124
- def config():
125
- return {
126
- "repo_id": REPO_ID,
127
- "filename": FILENAME,
128
- "preferred_cache_dir": CACHE_DIR,
129
- "effective_cache_dir": _effective_cache_dir,
130
- "n_threads": N_THREADS,
131
- "n_batch": N_BATCH,
132
- "n_ctx": N_CTX,
133
- "max_tokens": MAX_TOKENS,
134
- "temperature": TEMPERATURE,
135
- "top_p": TOP_P,
136
- "stop_tokens": STOP_TOKENS,
137
- "ctx_safety": CTX_SAFETY,
138
- "has_system_prompt": bool(SYSTEM_PROMPT),
139
- }
140
-
141
  @app.get("/warmup")
142
  def warmup():
143
  model = get_model()
@@ -159,7 +163,7 @@ def warmup():
159
  def generate(req: GenerateRequest):
160
  """
161
  Non-streaming chat completion.
162
- Accepts ONLY a prompt string; all other params are fixed in code/env.
163
  """
164
  try:
165
  if not req.prompt or not req.prompt.strip():
@@ -203,6 +207,7 @@ def generate(req: GenerateRequest):
203
  "n_threads": N_THREADS,
204
  },
205
  "prompt_truncated": (fitted_prompt != user_prompt),
 
206
  "effective_cache_dir": _effective_cache_dir,
207
  })
208
 
 
10
  from huggingface_hub import hf_hub_download
11
  from llama_cpp import Llama
12
 
13
+ # ---------------- Config ----------------
14
  REPO_ID = os.getenv("REPO_ID", "bartowski/Llama-3.2-3B-Instruct-GGUF")
15
  FILENAME = os.getenv("FILENAME", "Llama-3.2-3B-Instruct-Q4_K_M.gguf")
16
 
17
+ # BUILD-TIME PREFETCH LOCATION (your Dockerfile downloads here)
18
+ BUILD_CACHE_DIR = "/app/models"
19
+ BUILD_MODEL_PATH = os.path.join(BUILD_CACHE_DIR, FILENAME)
20
 
21
+ # Preferred runtime cache (only used if model not found above)
22
+ PREFERRED_CACHE_DIR = os.getenv("CACHE_DIR", "/app/models")
23
+
24
+ # Inference knobs (conservative for small CPU Spaces)
25
+ N_THREADS = min(4, (os.cpu_count() or 2))
26
+ N_BATCH = int(os.getenv("N_BATCH", "16")) # safer than 32/64 on tiny CPUs
27
  N_CTX = int(os.getenv("N_CTX", "2048"))
28
 
29
+ # Sampling (keep short for latency)
30
+ MAX_TOKENS = int(os.getenv("MAX_TOKENS", "48")) # tighter → faster
31
  TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
32
  TOP_P = float(os.getenv("TOP_P", "0.9"))
33
  STOP_TOKENS = os.getenv("STOP_TOKENS", "</s>,<|eot_id|>").split(",")
 
35
  SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "").strip()
36
  CTX_SAFETY = int(os.getenv("CTX_SAFETY", "128"))
37
 
38
+ # ---------------- App ----------------
39
  app = FastAPI(title="Llama 3.2 3B Instruct (llama.cpp) API - Prompt Only")
40
  _model: Optional[Llama] = None
41
  _model_lock = threading.Lock()
42
+ _effective_model_path: Optional[str] = None
43
  _effective_cache_dir: Optional[str] = None
44
 
45
  def _select_writable_cache_dir(preferred: str) -> str:
 
 
 
 
46
  candidates = [
47
  preferred,
48
  os.path.join(os.path.expanduser("~"), ".cache", "hf_models"),
 
51
  for d in candidates:
52
  try:
53
  os.makedirs(d, exist_ok=True)
54
+ test_file = os.path.join(d, ".w")
55
+ with open(test_file, "w") as f:
56
  f.write("ok")
57
+ os.remove(test_file)
58
  return d
59
  except Exception:
60
  continue
61
+ raise RuntimeError("No writable cache directory found")
62
 
63
+ def _resolve_model_path() -> str:
64
+ """
65
+ 1) If the model file exists at build path (/app/models/...), use it (fast path).
66
+ 2) Else, download into first writable cache dir and return that path.
67
+ """
68
+ global _effective_cache_dir
69
+ if os.path.isfile(BUILD_MODEL_PATH):
70
+ return BUILD_MODEL_PATH
71
 
 
72
  if _effective_cache_dir is None:
73
+ _effective_cache_dir = _select_writable_cache_dir(PREFERRED_CACHE_DIR)
74
 
75
  local_path = hf_hub_download(
76
  repo_id=REPO_ID,
 
78
  cache_dir=_effective_cache_dir,
79
  local_dir_use_symlinks=False,
80
  )
81
+ return local_path
82
 
83
+ # ---------------- Model loader ----------------
84
+ def get_model() -> Llama:
85
+ global _model, _effective_model_path
86
+ if _model is not None:
87
+ return _model
88
+
89
+ # Resolve path without failing on /data permission
90
+ _effective_model_path = _resolve_model_path()
91
+
92
+ # llama.cpp init (CPU-friendly)
93
  _model = Llama(
94
+ model_path=_effective_model_path,
95
  chat_format="llama-3",
96
  n_ctx=N_CTX,
97
  n_threads=N_THREADS,
98
  n_batch=N_BATCH,
99
+ use_mmap=True, # faster load on CPU
100
+ n_gpu_layers=0, # ensure pure CPU
101
  verbose=False,
102
  )
103
  return _model
104
 
105
  @app.on_event("startup")
106
  def _warm_start():
107
+ get_model() # force load at startup so first request is predictable
 
108
 
109
  # ---------------- Schemas ----------------
110
  class GenerateRequest(BaseModel):
 
131
  def health():
132
  try:
133
  _ = get_model()
134
+ return {
135
+ "ok": True,
136
+ "model_path": _effective_model_path,
137
+ "cache_dir": _effective_cache_dir,
138
+ "n_threads": N_THREADS,
139
+ "n_batch": N_BATCH,
140
+ "n_ctx": N_CTX
141
+ }
142
  except Exception as e:
143
  return {"ok": False, "error": str(e)}
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  @app.get("/warmup")
146
  def warmup():
147
  model = get_model()
 
163
  def generate(req: GenerateRequest):
164
  """
165
  Non-streaming chat completion.
166
+ Accepts ONLY a prompt string; all other params are fixed here.
167
  """
168
  try:
169
  if not req.prompt or not req.prompt.strip():
 
207
  "n_threads": N_THREADS,
208
  },
209
  "prompt_truncated": (fitted_prompt != user_prompt),
210
+ "effective_model_path": _effective_model_path,
211
  "effective_cache_dir": _effective_cache_dir,
212
  })
213