Leonardo0711 commited on
Commit
b97fe5e
·
verified ·
1 Parent(s): 4270a26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -23
app.py CHANGED
@@ -14,16 +14,48 @@ from bs4 import BeautifulSoup
14
  from huggingface_hub import snapshot_download
15
  from llama_cpp import Llama
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # ------------------ Config ------------------
18
  MODEL_REPO = os.getenv("MODEL_REPO", "Qwen/Qwen2.5-7B-Instruct-GGUF")
19
- # Si te falta RAM en CPU Basic: exporta MODEL_PATTERN=qwen2.5-7b-instruct-q3_k_m-*.gguf
 
20
  MODEL_PATTERN = os.getenv("MODEL_PATTERN", "qwen2.5-7b-instruct-q4_k_m-*.gguf")
21
 
22
  # Carpeta de modelos en /data (escribible en Docker Spaces)
23
  MODELS_DIR = Path(os.getenv("MODELS_DIR", "/data/models"))
24
  MODELS_DIR.mkdir(parents=True, exist_ok=True)
25
 
26
- N_THREADS = max(1, (os.cpu_count() or 2) - 1)
 
 
 
 
 
 
 
 
 
 
27
 
28
  SYSTEM_DEFAULT = textwrap.dedent("""\
29
  Eres Astrohunters-Guide, un asistente en español.
@@ -32,14 +64,19 @@ Eres Astrohunters-Guide, un asistente en español.
32
  - Si te paso una URL, lees su contenido y lo usas como contexto.
33
  """)
34
 
35
- ALLOWED_ORIGINS = [
36
- # agrega tu dominio(s) aquí
37
- "https://pruebas.nataliacoronel.com",
38
- "https://*.nataliacoronel.com",
39
- # durante pruebas puedes permitir todo, pero es menos seguro:
40
- os.getenv("ALLOW_ALL_ORIGINS", "") and "*",
41
- ]
42
- ALLOWED_ORIGINS = [o for o in ALLOWED_ORIGINS if o]
 
 
 
 
 
43
 
44
  # ------------------ Descarga del modelo ------------------
45
  print(f"[Boot] Descargando {MODEL_REPO} patrón {MODEL_PATTERN} en {MODELS_DIR} ...")
@@ -54,17 +91,22 @@ if not candidates:
54
  MODEL_PATH = candidates[0]
55
  print(f"[Boot] Usando shard: {MODEL_PATH}")
56
 
 
57
  # ------------------ Carga LLaMA.cpp ------------------
58
  llm = Llama(
59
  model_path=MODEL_PATH,
60
- n_ctx=4096,
61
  n_threads=N_THREADS,
62
- n_batch=256,
63
- n_gpu_layers=0,
64
  verbose=False,
65
  )
66
 
 
 
67
  def fetch_url_text(url: str, max_chars: int = 6000) -> str:
 
 
68
  try:
69
  r = requests.get(url, timeout=15)
70
  r.raise_for_status()
@@ -76,40 +118,85 @@ def fetch_url_text(url: str, max_chars: int = 6000) -> str:
76
  except Exception as e:
77
  return f"[No se pudo cargar {url}: {e}]"
78
 
79
- def run_llm(messages, temperature=0.6, top_p=0.95, max_tokens=768) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
80
  out = llm.create_chat_completion(
81
  messages=messages,
82
- temperature=temperature,
83
- top_p=top_p,
84
- max_tokens=max_tokens,
85
  stream=False,
86
  )
87
  return out["choices"][0]["message"]["content"].strip()
88
 
 
89
  # ------------------ FastAPI ------------------
90
  app = FastAPI(title="Astrohunters LLM API", docs_url="/docs", redoc_url=None)
91
 
92
- if ALLOWED_ORIGINS:
93
  app.add_middleware(
94
  CORSMiddleware,
95
- allow_origins=ALLOWED_ORIGINS,
 
 
 
 
 
 
 
 
96
  allow_credentials=True,
97
  allow_methods=["*"],
98
  allow_headers=["*"],
99
  )
100
 
 
 
101
  class PredictIn(BaseModel):
102
  prompt: str
103
  system: Optional[str] = None
 
 
 
 
104
 
105
  class PredictURLIn(BaseModel):
106
  prompt: str
107
  url: Optional[str] = None
108
  system: Optional[str] = None
 
 
 
 
109
 
 
110
  @app.get("/healthz")
111
  def healthz():
112
- return {"ok": True, "model": os.path.basename(MODEL_PATH), "threads": N_THREADS}
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  @app.get("/")
115
  def root():
@@ -118,15 +205,22 @@ def root():
118
  "endpoints": ["/healthz", "/run_predict", "/run_predict_with_url", "/docs"],
119
  }
120
 
 
121
  @app.post("/run_predict")
122
  def run_predict(body: PredictIn):
123
  messages = [
124
  {"role": "system", "content": body.system or SYSTEM_DEFAULT},
125
  {"role": "user", "content": body.prompt},
126
  ]
127
- reply = run_llm(messages, max_tokens=512)
 
 
 
 
 
128
  return {"reply": reply}
129
 
 
130
  @app.post("/run_predict_with_url")
131
  def run_predict_with_url(body: PredictURLIn):
132
  web_ctx = fetch_url_text(body.url) if body.url else ""
@@ -135,9 +229,15 @@ def run_predict_with_url(body: PredictURLIn):
135
  {"role": "system", "content": body.system or SYSTEM_DEFAULT},
136
  {"role": "user", "content": user_msg},
137
  ]
138
- reply = run_llm(messages, max_tokens=700)
 
 
 
 
 
139
  return {"reply": reply}
140
 
 
141
  if __name__ == "__main__":
142
- import uvicorn, os
143
  uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))
 
14
  from huggingface_hub import snapshot_download
15
  from llama_cpp import Llama
16
 
17
+
18
+ # ------------------ Helpers ------------------
19
+ def env_int(name: str, default: int) -> int:
20
+ try:
21
+ return int(os.getenv(name, "").strip() or default)
22
+ except Exception:
23
+ return default
24
+
25
+
26
+ def env_float(name: str, default: float) -> float:
27
+ try:
28
+ return float(os.getenv(name, "").strip() or default)
29
+ except Exception:
30
+ return default
31
+
32
+
33
+ def env_list(name: str) -> list[str]:
34
+ raw = os.getenv(name, "").strip()
35
+ return [x.strip() for x in raw.split(",") if x.strip()]
36
+
37
+
38
  # ------------------ Config ------------------
39
  MODEL_REPO = os.getenv("MODEL_REPO", "Qwen/Qwen2.5-7B-Instruct-GGUF")
40
+ # Si te falta RAM en CPU Basic, usa q3_k_m:
41
+ # export MODEL_PATTERN=qwen2.5-7b-instruct-q3_k_m-*.gguf
42
  MODEL_PATTERN = os.getenv("MODEL_PATTERN", "qwen2.5-7b-instruct-q4_k_m-*.gguf")
43
 
44
  # Carpeta de modelos en /data (escribible en Docker Spaces)
45
  MODELS_DIR = Path(os.getenv("MODELS_DIR", "/data/models"))
46
  MODELS_DIR.mkdir(parents=True, exist_ok=True)
47
 
48
+ # Rendimiento (overrides por variables de entorno)
49
+ CPU_COUNT = os.cpu_count() or 2
50
+ N_THREADS = env_int("N_THREADS", max(1, CPU_COUNT - 1))
51
+ N_BATCH = env_int("N_BATCH", 64) # bajar si vas muy justo de RAM
52
+ N_CTX = env_int("N_CTX", 2048) # 2048 va bien en CPU basic
53
+
54
+ # Decodificación / longitud por defecto
55
+ DEF_TEMPERATURE = env_float("LLM_TEMPERATURE", 0.6)
56
+ DEF_TOP_P = env_float("LLM_TOP_P", 0.95)
57
+ DEF_MAX_TOKENS = env_int("LLM_MAX_TOKENS", 160) # longitud típica
58
+ MAX_TOKENS_CAP = env_int("LLM_MAX_TOKENS_CAP", 320) # tope duro
59
 
60
  SYSTEM_DEFAULT = textwrap.dedent("""\
61
  Eres Astrohunters-Guide, un asistente en español.
 
64
  - Si te paso una URL, lees su contenido y lo usas como contexto.
65
  """)
66
 
67
+ # CORS
68
+ # Opciones:
69
+ # - ALLOW_ALL_ORIGINS=1 (menos seguro, útil en pruebas)
70
+ # - CORS_ORIGINS="https://dominio1,https://dominio2"
71
+ allow_all = os.getenv("ALLOW_ALL_ORIGINS", "").strip() in ("1", "true", "yes")
72
+ CORS_ORIGINS = env_list("CORS_ORIGINS")
73
+ if not CORS_ORIGINS:
74
+ # defaults cómodos para tu caso
75
+ CORS_ORIGINS = [
76
+ "https://pruebas.nataliacoronel.com",
77
+ "https://*.nataliacoronel.com",
78
+ ]
79
+
80
 
81
  # ------------------ Descarga del modelo ------------------
82
  print(f"[Boot] Descargando {MODEL_REPO} patrón {MODEL_PATTERN} en {MODELS_DIR} ...")
 
91
  MODEL_PATH = candidates[0]
92
  print(f"[Boot] Usando shard: {MODEL_PATH}")
93
 
94
+
95
  # ------------------ Carga LLaMA.cpp ------------------
96
  llm = Llama(
97
  model_path=MODEL_PATH,
98
+ n_ctx=N_CTX,
99
  n_threads=N_THREADS,
100
+ n_batch=N_BATCH,
101
+ n_gpu_layers=env_int("N_GPU_LAYERS", 0),
102
  verbose=False,
103
  )
104
 
105
+
106
+ # ------------------ Utilidades ------------------
107
  def fetch_url_text(url: str, max_chars: int = 6000) -> str:
108
+ if not url:
109
+ return ""
110
  try:
111
  r = requests.get(url, timeout=15)
112
  r.raise_for_status()
 
118
  except Exception as e:
119
  return f"[No se pudo cargar {url}: {e}]"
120
 
121
+
122
+ def clamp_tokens(requested: Optional[int]) -> int:
123
+ if requested is None or requested <= 0:
124
+ return DEF_MAX_TOKENS
125
+ return max(1, min(requested, MAX_TOKENS_CAP))
126
+
127
+
128
+ def run_llm(
129
+ messages,
130
+ temperature: Optional[float] = None,
131
+ top_p: Optional[float] = None,
132
+ max_tokens: Optional[int] = None,
133
+ ) -> str:
134
  out = llm.create_chat_completion(
135
  messages=messages,
136
+ temperature=DEF_TEMPERATURE if temperature is None else float(temperature),
137
+ top_p=DEF_TOP_P if top_p is None else float(top_p),
138
+ max_tokens=clamp_tokens(max_tokens),
139
  stream=False,
140
  )
141
  return out["choices"][0]["message"]["content"].strip()
142
 
143
+
144
  # ------------------ FastAPI ------------------
145
  app = FastAPI(title="Astrohunters LLM API", docs_url="/docs", redoc_url=None)
146
 
147
+ if allow_all:
148
  app.add_middleware(
149
  CORSMiddleware,
150
+ allow_origins=["*"],
151
+ allow_credentials=False,
152
+ allow_methods=["*"],
153
+ allow_headers=["*"],
154
+ )
155
+ else:
156
+ app.add_middleware(
157
+ CORSMiddleware,
158
+ allow_origins=CORS_ORIGINS,
159
  allow_credentials=True,
160
  allow_methods=["*"],
161
  allow_headers=["*"],
162
  )
163
 
164
+
165
+ # --------- Esquemas de entrada ---------
166
  class PredictIn(BaseModel):
167
  prompt: str
168
  system: Optional[str] = None
169
+ max_tokens: Optional[int] = None
170
+ temperature: Optional[float] = None
171
+ top_p: Optional[float] = None
172
+
173
 
174
  class PredictURLIn(BaseModel):
175
  prompt: str
176
  url: Optional[str] = None
177
  system: Optional[str] = None
178
+ max_tokens: Optional[int] = None
179
+ temperature: Optional[float] = None
180
+ top_p: Optional[float] = None
181
+
182
 
183
+ # --------- Endpoints ---------
184
  @app.get("/healthz")
185
  def healthz():
186
+ return {
187
+ "ok": True,
188
+ "model": os.path.basename(MODEL_PATH),
189
+ "threads": N_THREADS,
190
+ "n_batch": N_BATCH,
191
+ "n_ctx": N_CTX,
192
+ "defaults": {
193
+ "temperature": DEF_TEMPERATURE,
194
+ "top_p": DEF_TOP_P,
195
+ "max_tokens": DEF_MAX_TOKENS,
196
+ "max_tokens_cap": MAX_TOKENS_CAP,
197
+ },
198
+ }
199
+
200
 
201
  @app.get("/")
202
  def root():
 
205
  "endpoints": ["/healthz", "/run_predict", "/run_predict_with_url", "/docs"],
206
  }
207
 
208
+
209
  @app.post("/run_predict")
210
  def run_predict(body: PredictIn):
211
  messages = [
212
  {"role": "system", "content": body.system or SYSTEM_DEFAULT},
213
  {"role": "user", "content": body.prompt},
214
  ]
215
+ reply = run_llm(
216
+ messages,
217
+ temperature=body.temperature,
218
+ top_p=body.top_p,
219
+ max_tokens=body.max_tokens,
220
+ )
221
  return {"reply": reply}
222
 
223
+
224
  @app.post("/run_predict_with_url")
225
  def run_predict_with_url(body: PredictURLIn):
226
  web_ctx = fetch_url_text(body.url) if body.url else ""
 
229
  {"role": "system", "content": body.system or SYSTEM_DEFAULT},
230
  {"role": "user", "content": user_msg},
231
  ]
232
+ reply = run_llm(
233
+ messages,
234
+ temperature=body.temperature,
235
+ top_p=body.top_p,
236
+ max_tokens=body.max_tokens,
237
+ )
238
  return {"reply": reply}
239
 
240
+
241
  if __name__ == "__main__":
242
+ import uvicorn
243
  uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))