Leonardo0711 commited on
Commit
2d96246
·
verified ·
1 Parent(s): e283af9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -82
app.py CHANGED
@@ -3,41 +3,58 @@
3
 
4
  import os, glob, textwrap
5
  from pathlib import Path
6
- from threading import Lock
7
 
8
- from fastapi import FastAPI, Body
9
  from fastapi.middleware.cors import CORSMiddleware
10
- from fastapi.responses import HTMLResponse, JSONResponse
11
 
12
- from huggingface_hub import snapshot_download
13
- from llama_cpp import Llama
14
  import requests
15
  from bs4 import BeautifulSoup
 
 
16
 
17
- # ===== Carpeta para el modelo (NO usar /app) =====
18
- MODELS_DIR = Path(os.getenv("MODELS_DIR", "/tmp/models"))
 
 
 
 
 
19
  MODELS_DIR.mkdir(parents=True, exist_ok=True)
20
 
21
- # ===== Modelo (GGUF) =====
22
- MODEL_REPO = os.getenv("MODEL_REPO", "Qwen/Qwen2.5-7B-Instruct-GGUF")
23
- # Para CPU basic puedes poner en Variables: MODEL_PATTERN=qwen2.5-7b-instruct-q3_k_m-*.gguf
24
- MODEL_PATTERN = os.getenv("MODEL_PATTERN", "qwen2.5-7b-instruct-q4_k_m-*.gguf")
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
26
  print(f"[Boot] Descargando {MODEL_REPO} patrón {MODEL_PATTERN} en {MODELS_DIR} ...")
27
  snapshot_dir = snapshot_download(
28
  repo_id=MODEL_REPO,
29
  local_dir=str(MODELS_DIR),
30
  allow_patterns=[MODEL_PATTERN],
31
  )
32
- candidates = sorted(glob.glob(str(MODELS_DIR / MODEL_PATTERN)))
33
  if not candidates:
34
  raise FileNotFoundError(f"No hay shards para {MODEL_PATTERN} en {snapshot_dir}")
35
  MODEL_PATH = candidates[0]
36
  print(f"[Boot] Usando shard: {MODEL_PATH}")
37
 
38
- # Hilos seguros para CPU Basic
39
- N_THREADS = max(1, (os.cpu_count() or 2) - 1)
40
-
41
  llm = Llama(
42
  model_path=MODEL_PATH,
43
  n_ctx=4096,
@@ -46,101 +63,81 @@ llm = Llama(
46
  n_gpu_layers=0,
47
  verbose=False,
48
  )
49
- _llm_lock = Lock()
50
-
51
- SYSTEM_DEFAULT = textwrap.dedent("""\
52
- Eres Astrohunters-Guide, un asistente en español.
53
- - Respondes con precisión y sin inventar datos.
54
- - Sabes explicar resultados de exoplanetas (período, duración, profundidad, SNR, radio).
55
- - Si te paso una URL, lees su contenido y lo usas como contexto.
56
- """)
57
 
58
  def fetch_url_text(url: str, max_chars: int = 6000) -> str:
59
  try:
60
  r = requests.get(url, timeout=15)
61
  r.raise_for_status()
62
  soup = BeautifulSoup(r.text, "html.parser")
63
- for t in soup(["script", "style", "noscript"]): t.remove()
 
64
  txt = " ".join(soup.get_text(separator=" ").split())
65
  return txt[:max_chars]
66
  except Exception as e:
67
  return f"[No se pudo cargar {url}: {e}]"
68
 
69
  def run_llm(messages, temperature=0.6, top_p=0.95, max_tokens=768) -> str:
70
- with _llm_lock:
71
- out = llm.create_chat_completion(
72
- messages=messages,
73
- temperature=temperature,
74
- top_p=top_p,
75
- max_tokens=max_tokens,
76
- stream=False,
77
- )
78
  return out["choices"][0]["message"]["content"].strip()
79
 
80
- # ===== FastAPI =====
81
- app = FastAPI(title="Astrohunters LLM API", version="1.0.0")
82
-
83
- # CORS (ajusta ALLOWED_ORIGINS en Settings → Variables si quieres limitar a tu dominio)
84
- ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "*").split(",")
85
- app.add_middleware(
86
- CORSMiddleware,
87
- allow_origins=ALLOWED_ORIGINS,
88
- allow_credentials=True,
89
- allow_methods=["*"],
90
- allow_headers=["*"],
91
- )
 
 
 
 
 
 
 
 
92
 
93
  @app.get("/healthz")
94
  def healthz():
95
- return {"ok": True}
 
 
 
 
 
 
 
96
 
97
  @app.post("/run_predict")
98
- def run_predict(body: dict = Body(...)):
99
- prompt = body.get("prompt", "")
100
- system = body.get("system", "")
101
  messages = [
102
- {"role": "system", "content": system or SYSTEM_DEFAULT},
103
- {"role": "user", "content": prompt},
104
  ]
105
  reply = run_llm(messages, max_tokens=512)
106
  return {"reply": reply}
107
 
108
  @app.post("/run_predict_with_url")
109
- def run_predict_with_url(body: dict = Body(...)):
110
- prompt = body.get("prompt", "")
111
- url = body.get("url", "")
112
- system = body.get("system", "")
113
- web_ctx = fetch_url_text(url) if url else ""
114
- user_msg = prompt if not web_ctx else f"{prompt}\n\n[CONTEXTO_WEB]\n{web_ctx}"
115
  messages = [
116
- {"role": "system", "content": system or SYSTEM_DEFAULT},
117
  {"role": "user", "content": user_msg},
118
  ]
119
  reply = run_llm(messages, max_tokens=700)
120
  return {"reply": reply}
121
 
122
- # Página mínima de prueba
123
- @app.get("/", response_class=HTMLResponse)
124
- def home():
125
- return """
126
- <!doctype html>
127
- <html>
128
- <head><meta charset="utf-8"><title>Astrohunters LLM API</title></head>
129
- <body style="font-family:system-ui;max-width:800px;margin:40px auto">
130
- <h2>🛰️ Astrohunters LLM API</h2>
131
- <p>Endpoints: <code>/healthz</code>, <code>/run_predict</code>, <code>/run_predict_with_url</code>, y <a href="/docs">/docs</a> (Swagger).</p>
132
- <textarea id="q" rows="4" style="width:100%" placeholder="Escribe tu pregunta..."></textarea>
133
- <button id="btn">Preguntar</button>
134
- <pre id="out"></pre>
135
- <script>
136
- document.getElementById('btn').onclick = async () => {
137
- const r = await fetch('/run_predict', {
138
- method:'POST', headers:{'Content-Type':'application/json'},
139
- body: JSON.stringify({prompt: document.getElementById('q').value})
140
- });
141
- const j = await r.json();
142
- document.getElementById('out').textContent = j.reply || JSON.stringify(j,null,2);
143
- };
144
- </script>
145
- </body></html>
146
- """
 
3
 
4
  import os, glob, textwrap
5
  from pathlib import Path
6
+ from typing import Optional
7
 
8
+ from fastapi import FastAPI
9
  from fastapi.middleware.cors import CORSMiddleware
10
+ from pydantic import BaseModel
11
 
 
 
12
  import requests
13
  from bs4 import BeautifulSoup
14
+ from huggingface_hub import snapshot_download
15
+ from llama_cpp import Llama
16
 
17
+ # ------------------ Config ------------------
18
+ MODEL_REPO = os.getenv("MODEL_REPO", "Qwen/Qwen2.5-7B-Instruct-GGUF")
19
+ # Si te falta RAM en CPU Basic: exporta MODEL_PATTERN=qwen2.5-7b-instruct-q3_k_m-*.gguf
20
+ MODEL_PATTERN = os.getenv("MODEL_PATTERN", "qwen2.5-7b-instruct-q4_k_m-*.gguf")
21
+
22
+ # Carpeta de modelos en /data (escribible en Docker Spaces)
23
+ MODELS_DIR = Path(os.getenv("MODELS_DIR", "/data/models"))
24
  MODELS_DIR.mkdir(parents=True, exist_ok=True)
25
 
26
+ N_THREADS = max(1, (os.cpu_count() or 2) - 1)
27
+
28
+ SYSTEM_DEFAULT = textwrap.dedent("""\
29
+ Eres Astrohunters-Guide, un asistente en español.
30
+ - Respondes con precisión y sin inventar datos.
31
+ - Sabes explicar resultados de exoplanetas (período, duración, profundidad, SNR, radio).
32
+ - Si te paso una URL, lees su contenido y lo usas como contexto.
33
+ """)
34
 
35
+ ALLOWED_ORIGINS = [
36
+ # agrega tu dominio(s) aquí
37
+ "https://pruebas.nataliacoronel.com",
38
+ "https://*.nataliacoronel.com",
39
+ # durante pruebas puedes permitir todo, pero es menos seguro:
40
+ os.getenv("ALLOW_ALL_ORIGINS", "") and "*",
41
+ ]
42
+ ALLOWED_ORIGINS = [o for o in ALLOWED_ORIGINS if o]
43
+
44
+ # ------------------ Descarga del modelo ------------------
45
  print(f"[Boot] Descargando {MODEL_REPO} patrón {MODEL_PATTERN} en {MODELS_DIR} ...")
46
  snapshot_dir = snapshot_download(
47
  repo_id=MODEL_REPO,
48
  local_dir=str(MODELS_DIR),
49
  allow_patterns=[MODEL_PATTERN],
50
  )
51
+ candidates = sorted(glob.glob(str(Path(snapshot_dir) / MODEL_PATTERN)))
52
  if not candidates:
53
  raise FileNotFoundError(f"No hay shards para {MODEL_PATTERN} en {snapshot_dir}")
54
  MODEL_PATH = candidates[0]
55
  print(f"[Boot] Usando shard: {MODEL_PATH}")
56
 
57
+ # ------------------ Carga LLaMA.cpp ------------------
 
 
58
  llm = Llama(
59
  model_path=MODEL_PATH,
60
  n_ctx=4096,
 
63
  n_gpu_layers=0,
64
  verbose=False,
65
  )
 
 
 
 
 
 
 
 
66
 
67
  def fetch_url_text(url: str, max_chars: int = 6000) -> str:
68
  try:
69
  r = requests.get(url, timeout=15)
70
  r.raise_for_status()
71
  soup = BeautifulSoup(r.text, "html.parser")
72
+ for t in soup(["script", "style", "noscript"]):
73
+ t.decompose()
74
  txt = " ".join(soup.get_text(separator=" ").split())
75
  return txt[:max_chars]
76
  except Exception as e:
77
  return f"[No se pudo cargar {url}: {e}]"
78
 
79
  def run_llm(messages, temperature=0.6, top_p=0.95, max_tokens=768) -> str:
80
+ out = llm.create_chat_completion(
81
+ messages=messages,
82
+ temperature=temperature,
83
+ top_p=top_p,
84
+ max_tokens=max_tokens,
85
+ stream=False,
86
+ )
 
87
  return out["choices"][0]["message"]["content"].strip()
88
 
89
+ # ------------------ FastAPI ------------------
90
+ app = FastAPI(title="Astrohunters LLM API", docs_url="/docs", redoc_url=None)
91
+
92
+ if ALLOWED_ORIGINS:
93
+ app.add_middleware(
94
+ CORSMiddleware,
95
+ allow_origins=ALLOWED_ORIGINS,
96
+ allow_credentials=True,
97
+ allow_methods=["*"],
98
+ allow_headers=["*"],
99
+ )
100
+
101
+ class PredictIn(BaseModel):
102
+ prompt: str
103
+ system: Optional[str] = None
104
+
105
+ class PredictURLIn(BaseModel):
106
+ prompt: str
107
+ url: Optional[str] = None
108
+ system: Optional[str] = None
109
 
110
  @app.get("/healthz")
111
  def healthz():
112
+ return {"ok": True, "model": os.path.basename(MODEL_PATH), "threads": N_THREADS}
113
+
114
+ @app.get("/")
115
+ def root():
116
+ return {
117
+ "name": "Astrohunters LLM API",
118
+ "endpoints": ["/healthz", "/run_predict", "/run_predict_with_url", "/docs"],
119
+ }
120
 
121
  @app.post("/run_predict")
122
+ def run_predict(body: PredictIn):
 
 
123
  messages = [
124
+ {"role": "system", "content": body.system or SYSTEM_DEFAULT},
125
+ {"role": "user", "content": body.prompt},
126
  ]
127
  reply = run_llm(messages, max_tokens=512)
128
  return {"reply": reply}
129
 
130
  @app.post("/run_predict_with_url")
131
+ def run_predict_with_url(body: PredictURLIn):
132
+ web_ctx = fetch_url_text(body.url) if body.url else ""
133
+ user_msg = body.prompt if not web_ctx else f"{body.prompt}\n\n[CONTEXTO_WEB]\n{web_ctx}"
 
 
 
134
  messages = [
135
+ {"role": "system", "content": body.system or SYSTEM_DEFAULT},
136
  {"role": "user", "content": user_msg},
137
  ]
138
  reply = run_llm(messages, max_tokens=700)
139
  return {"reply": reply}
140
 
141
+ if __name__ == "__main__":
142
+ import uvicorn, os
143
+ uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))