omaryasserhassan commited on
Commit
730089e
·
verified ·
1 Parent(s): 5f65aa8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -20
app.py CHANGED
@@ -7,14 +7,20 @@ from huggingface_hub import hf_hub_download
7
  from llama_cpp import Llama
8
 
9
  # ---------- Minimal fixed config (fast on CPU) ----------
10
- REPO_ID = "bartowski/Llama-3.2-1B-Instruct-GGUF" # <-- 1B model (faster)
11
  FILENAME = "Llama-3.2-1B-Instruct-Q4_K_M.gguf"
12
- CACHE_DIR = "/app/models" # already prefetched in your build
 
 
 
 
 
 
13
 
14
  N_THREADS = min(4, os.cpu_count() or 2)
15
- N_BATCH = 8 # smaller = less thrash
16
  N_CTX = 2048
17
- MAX_TOKENS = 16 # short, fast replies
18
 
19
  TEMPERATURE = 0.7
20
  TOP_P = 0.9
@@ -23,36 +29,44 @@ STOP = ["</s>", "<|eot_id|>"]
23
  # ---------- App ----------
24
  app = FastAPI(title="Simple Llama Server (1B fast)")
25
  model = None
 
26
 
27
  class PromptRequest(BaseModel):
28
  prompt: str
29
 
30
  @app.on_event("startup")
31
  def load_model():
32
- global model
33
- os.makedirs(CACHE_DIR, exist_ok=True)
34
- local_path = hf_hub_download(
35
- repo_id=REPO_ID,
36
- filename=FILENAME,
37
- cache_dir=CACHE_DIR,
38
- local_dir_use_symlinks=False,
39
- )
 
 
 
 
 
 
 
40
  t0 = time.time()
41
  model = Llama(
42
- model_path=local_path,
43
  chat_format="llama-3",
44
  n_ctx=N_CTX,
45
  n_threads=N_THREADS,
46
  n_batch=N_BATCH,
47
- use_mmap=True,
48
- n_gpu_layers=0,
49
  verbose=False,
50
  )
51
- print(f"[startup] model loaded in {time.time()-t0:.2f}s from {local_path}")
52
 
53
  @app.get("/health")
54
  def health():
55
- return {"ok": model is not None}
56
 
57
  @app.post("/generate")
58
  def generate(req: PromptRequest):
@@ -71,6 +85,4 @@ def generate(req: PromptRequest):
71
  stop=STOP,
72
  )
73
  text = out["choices"][0]["message"]["content"]
74
- dt = time.time() - t0
75
- print(f"[infer] tokens={MAX_TOKENS} took {dt:.2f}s, prompt_len_chars={len(prompt)}")
76
- return JSONResponse({"response": text, "timing_sec": round(dt, 2)})
 
7
  from llama_cpp import Llama
8
 
9
  # ---------- Minimal fixed config (fast on CPU) ----------
10
+ REPO_ID = "bartowski/Llama-3.2-1B-Instruct-GGUF" # 1B = much faster on CPU
11
  FILENAME = "Llama-3.2-1B-Instruct-Q4_K_M.gguf"
12
+
13
+ # Build-time prefetch location (Dockerfile step put model here)
14
+ BUILD_DIR = "/app/models"
15
+ MODEL_PATH = os.path.join(BUILD_DIR, FILENAME)
16
+
17
+ # Writable runtime cache if the prebuilt file isn't present
18
+ RUNTIME_CACHE = "/tmp/hf_cache"
19
 
20
  N_THREADS = min(4, os.cpu_count() or 2)
21
+ N_BATCH = 8
22
  N_CTX = 2048
23
+ MAX_TOKENS = 16
24
 
25
  TEMPERATURE = 0.7
26
  TOP_P = 0.9
 
29
  # ---------- App ----------
30
  app = FastAPI(title="Simple Llama Server (1B fast)")
31
  model = None
32
+ effective_model_path = None
33
 
34
  class PromptRequest(BaseModel):
35
  prompt: str
36
 
37
  @app.on_event("startup")
38
  def load_model():
39
+ global model, effective_model_path
40
+
41
+ # 1) If the model exists from the Docker build, use it directly (no writes)
42
+ if os.path.isfile(MODEL_PATH):
43
+ effective_model_path = MODEL_PATH
44
+ else:
45
+ # 2) Otherwise, download to a writable temp cache (NOT under /app)
46
+ os.makedirs(RUNTIME_CACHE, exist_ok=True)
47
+ effective_model_path = hf_hub_download(
48
+ repo_id=REPO_ID,
49
+ filename=FILENAME,
50
+ cache_dir=RUNTIME_CACHE,
51
+ local_dir_use_symlinks=False,
52
+ )
53
+
54
  t0 = time.time()
55
  model = Llama(
56
+ model_path=effective_model_path,
57
  chat_format="llama-3",
58
  n_ctx=N_CTX,
59
  n_threads=N_THREADS,
60
  n_batch=N_BATCH,
61
+ use_mmap=True, # faster load
62
+ n_gpu_layers=0, # CPU only
63
  verbose=False,
64
  )
65
+ print(f"[startup] loaded {effective_model_path} in {time.time()-t0:.2f}s")
66
 
67
  @app.get("/health")
68
  def health():
69
+ return {"ok": model is not None, "model_path": effective_model_path}
70
 
71
  @app.post("/generate")
72
  def generate(req: PromptRequest):
 
85
  stop=STOP,
86
  )
87
  text = out["choices"][0]["message"]["content"]
88
+ return JSONResponse({"response": text, "timing_sec": round(time.time()-t0, 2)})