omaryasserhassan commited on
Commit
33f19bf
·
verified ·
1 Parent(s): b5dfa0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -26
app.py CHANGED
@@ -1,56 +1,63 @@
1
  import os
 
2
  from fastapi import FastAPI, HTTPException
 
3
  from pydantic import BaseModel
4
  from huggingface_hub import hf_hub_download
5
  from llama_cpp import Llama
6
 
7
- REPO_ID = "bartowski/Llama-3.2-3B-Instruct-GGUF"
8
- FILENAME = "Llama-3.2-3B-Instruct-Q4_K_M.gguf"
9
- CACHE_DIR = "/app/models" # matches Dockerfile pre-download
10
- os.makedirs(CACHE_DIR, exist_ok=True)
11
 
12
- app = FastAPI()
 
 
 
 
 
 
13
  _model = None
14
 
15
- def get_model():
 
16
  global _model
17
  if _model is not None:
18
  return _model
19
 
 
20
  local_path = hf_hub_download(
21
  repo_id=REPO_ID,
22
  filename=FILENAME,
23
  cache_dir=CACHE_DIR,
24
  local_dir_use_symlinks=False,
25
  )
 
 
26
  _model = Llama(
27
  model_path=local_path,
28
- n_ctx=2048,
29
- n_threads=os.cpu_count() or 2,
30
- n_batch=256,
 
31
  verbose=False
32
  )
33
  return _model
34
 
35
- class PromptRequest(BaseModel):
36
- prompt: str
37
- max_tokens: int = 256
38
- temperature: float = 0.7
39
 
40
- @app.post("/generate")
41
- def generate_text(req: PromptRequest):
42
- try:
43
- model = get_model()
44
- output = model(
45
- req.prompt,
46
- max_tokens=req.max_tokens,
47
- temperature=req.temperature,
48
- stop=["</s>"]
49
- )
50
- return {"ok": True, "response": output["choices"][0]["text"]}
51
- except Exception as e:
52
- raise HTTPException(status_code=500, detail=str(e))
53
 
 
54
  @app.get("/health")
55
  def health():
56
  try:
@@ -58,3 +65,47 @@ def health():
58
  return {"ok": True}
59
  except Exception as e:
60
  return {"ok": False, "error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import time
3
  from fastapi import FastAPI, HTTPException
4
+ from fastapi.responses import StreamingResponse, JSONResponse
5
  from pydantic import BaseModel
6
  from huggingface_hub import hf_hub_download
7
  from llama_cpp import Llama
8
 
9
+ # ---------------- Config ----------------
10
+ REPO_ID = "bartowski/Llama-3.2-3B-Instruct-GGUF"
11
+ FILENAME = "Llama-3.2-3B-Instruct-Q4_K_M.gguf"
12
+ CACHE_DIR = "/app/models" # match your Dockerfile prefetch if you use it
13
 
14
+ # Conservative CPU settings for Spaces (prevents stalls)
15
+ N_THREADS = min(4, (os.cpu_count() or 2)) # don't over-thread on tiny CPUs
16
+ N_BATCH = 64 # modest batch to avoid RAM thrash
17
+ N_CTX = 2048 # enough for short prompts
18
+
19
+ # --------------- FastAPI App ---------------
20
+ app = FastAPI(title="Llama 3.2 3B Instruct (llama.cpp) API")
21
  _model = None
22
 
23
+ # --------------- Load Model ---------------
24
+ def get_model() -> Llama:
25
  global _model
26
  if _model is not None:
27
  return _model
28
 
29
+ os.makedirs(CACHE_DIR, exist_ok=True)
30
  local_path = hf_hub_download(
31
  repo_id=REPO_ID,
32
  filename=FILENAME,
33
  cache_dir=CACHE_DIR,
34
  local_dir_use_symlinks=False,
35
  )
36
+
37
+ # IMPORTANT: use Llama-3 chat template
38
  _model = Llama(
39
  model_path=local_path,
40
+ chat_format="llama-3", # <- ensures proper prompt templating
41
+ n_ctx=N_CTX,
42
+ n_threads=N_THREADS,
43
+ n_batch=N_BATCH,
44
  verbose=False
45
  )
46
  return _model
47
 
48
+ # --------------- Schemas ----------------
49
+ class ChatMessage(BaseModel):
50
+ role: str # "system" | "user" | "assistant"
51
+ content: str
52
 
53
+ class ChatRequest(BaseModel):
54
+ messages: list[ChatMessage]
55
+ max_tokens: int = 128
56
+ temperature: float = 0.7
57
+ top_p: float = 0.9
58
+ stream: bool = False
 
 
 
 
 
 
 
59
 
60
+ # --------------- Endpoints ---------------
61
  @app.get("/health")
62
  def health():
63
  try:
 
65
  return {"ok": True}
66
  except Exception as e:
67
  return {"ok": False, "error": str(e)}
68
+
69
+ @app.post("/generate")
70
+ def generate(req: ChatRequest):
71
+ """
72
+ Chat-completion endpoint with optional server-side streaming.
73
+ Uses Llama-3 chat template via chat_format="llama-3".
74
+ """
75
+ try:
76
+ model = get_model()
77
+
78
+ # Convert to llama.cpp message format
79
+ msgs = [{"role": m.role, "content": m.content} for m in req.messages]
80
+
81
+ if not req.stream:
82
+ out = model.create_chat_completion(
83
+ messages=msgs,
84
+ max_tokens=req.max_tokens,
85
+ temperature=req.temperature,
86
+ top_p=req.top_p,
87
+ )
88
+ text = out["choices"][0]["message"]["content"]
89
+ return JSONResponse({"ok": True, "response": text})
90
+
91
+ # --- Streaming mode ---
92
+ def token_stream():
93
+ start = time.time()
94
+ for chunk in model.create_chat_completion(
95
+ messages=msgs,
96
+ max_tokens=req.max_tokens,
97
+ temperature=req.temperature,
98
+ top_p=req.top_p,
99
+ stream=True,
100
+ ):
101
+ if "choices" in chunk and chunk["choices"]:
102
+ delta = chunk["choices"][0]["delta"].get("content", "")
103
+ if delta:
104
+ yield delta
105
+ # small trailer to mark end (optional)
106
+ yield f"\n\n[done in {time.time()-start:.2f}s]"
107
+
108
+ return StreamingResponse(token_stream(), media_type="text/plain")
109
+
110
+ except Exception as e:
111
+ raise HTTPException(status_code=500, detail=str(e))