Leonardo0711 commited on
Commit
950bb84
·
verified ·
1 Parent(s): ef31123

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -40
app.py CHANGED
@@ -3,7 +3,7 @@
3
 
4
  import os, glob, textwrap
5
  from pathlib import Path
6
- from typing import Optional
7
 
8
  from fastapi import FastAPI
9
  from fastapi.middleware.cors import CORSMiddleware
@@ -22,76 +22,128 @@ def env_int(name: str, default: int) -> int:
22
  except Exception:
23
  return default
24
 
25
-
26
  def env_float(name: str, default: float) -> float:
27
  try:
28
  return float(os.getenv(name, "").strip() or default)
29
  except Exception:
30
  return default
31
 
32
-
33
- def env_list(name: str) -> list[str]:
34
  raw = os.getenv(name, "").strip()
35
  return [x.strip() for x in raw.split(",") if x.strip()]
36
 
37
 
38
  # ------------------ Config ------------------
39
- # ⇩⇩ Cambiado a Qwen 2.5 3B (mejor para CPU Basic gratuito)
40
  MODEL_REPO = os.getenv("MODEL_REPO", "Qwen/Qwen2.5-3B-Instruct-GGUF")
41
- # Opción rápida/ligera: q3_k_m si vas muy justo de RAM:
42
- # MODEL_PATTERN=qwen2.5-3b-instruct-q3_k_m-*.gguf
43
- MODEL_PATTERN = os.getenv("MODEL_PATTERN", "qwen2.5-3b-instruct-q4_k_m-*.gguf")
44
 
45
  # Carpeta de modelos en /data (escribible en Docker Spaces)
46
  MODELS_DIR = Path(os.getenv("MODELS_DIR", "/data/models"))
47
  MODELS_DIR.mkdir(parents=True, exist_ok=True)
48
 
49
- # Rendimiento (overrides por variables de entorno)
50
  CPU_COUNT = os.cpu_count() or 2
51
  N_THREADS = env_int("N_THREADS", max(1, CPU_COUNT - 1))
52
- N_BATCH = env_int("N_BATCH", 64) # bajar si vas justo de RAM
53
- N_CTX = env_int("N_CTX", 1536) # 1536-2048 ok; menos = más rápido
54
 
55
- # Decodificación / longitud por defecto
56
- DEF_TEMPERATURE = env_float("LLM_TEMPERATURE", 0.4) # un poco más bajo para menos alucinación
57
  DEF_TOP_P = env_float("LLM_TOP_P", 0.9)
58
- DEF_MAX_TOKENS = env_int("LLM_MAX_TOKENS", 160) # longitud típica
59
- MAX_TOKENS_CAP = env_int("LLM_MAX_TOKENS_CAP", 320) # tope duro
60
 
61
  SYSTEM_DEFAULT = textwrap.dedent("""\
62
  Eres Astrohunters-Guide, un asistente en español.
63
  - Respondes con precisión y sin inventar datos.
64
  - Sabes explicar resultados de exoplanetas (período, duración, profundidad, SNR, radio).
65
- - Si te paso una URL, lees su contenido y lo usas como contexto.
66
  """)
67
 
68
 
69
  # CORS
70
- # Opciones:
71
- # - ALLOW_ALL_ORIGINS=1 (menos seguro, útil en pruebas)
72
- # - CORS_ORIGINS="https://dominio1,https://dominio2"
73
  allow_all = os.getenv("ALLOW_ALL_ORIGINS", "").strip() in ("1", "true", "yes")
74
  CORS_ORIGINS = env_list("CORS_ORIGINS")
75
  if not CORS_ORIGINS:
76
- # defaults cómodos para tu caso
77
  CORS_ORIGINS = [
78
  "https://pruebas.nataliacoronel.com",
79
  "https://*.nataliacoronel.com",
80
  ]
81
 
82
 
83
- # ------------------ Descarga del modelo ------------------
84
- print(f"[Boot] Descargando {MODEL_REPO} patrón {MODEL_PATTERN} en {MODELS_DIR} ...")
85
- snapshot_dir = snapshot_download(
86
- repo_id=MODEL_REPO,
87
- local_dir=str(MODELS_DIR),
88
- allow_patterns=[MODEL_PATTERN],
89
- )
90
- candidates = sorted(glob.glob(str(Path(snapshot_dir) / MODEL_PATTERN)))
91
- if not candidates:
92
- raise FileNotFoundError(f"No hay shards para {MODEL_PATTERN} en {snapshot_dir}")
93
- MODEL_PATH = candidates[0]
94
- print(f"[Boot] Usando shard: {MODEL_PATH}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
97
  # ------------------ Carga LLaMA.cpp ------------------
@@ -120,13 +172,11 @@ def fetch_url_text(url: str, max_chars: int = 6000) -> str:
120
  except Exception as e:
121
  return f"[No se pudo cargar {url}: {e}]"
122
 
123
-
124
  def clamp_tokens(requested: Optional[int]) -> int:
125
  if requested is None or requested <= 0:
126
  return DEF_MAX_TOKENS
127
  return max(1, min(requested, MAX_TOKENS_CAP))
128
 
129
-
130
  def run_llm(
131
  messages,
132
  temperature: Optional[float] = None,
@@ -143,7 +193,6 @@ def run_llm(
143
  try:
144
  return out["choices"][0]["message"]["content"].strip()
145
  except Exception:
146
- # fallback defensivo
147
  return str(out)[:1000]
148
 
149
 
@@ -176,7 +225,6 @@ class PredictIn(BaseModel):
176
  temperature: Optional[float] = None
177
  top_p: Optional[float] = None
178
 
179
-
180
  class PredictURLIn(BaseModel):
181
  prompt: str
182
  url: Optional[str] = None
@@ -203,7 +251,6 @@ def healthz():
203
  },
204
  }
205
 
206
-
207
  @app.get("/")
208
  def root():
209
  return {
@@ -211,7 +258,6 @@ def root():
211
  "endpoints": ["/healthz", "/run_predict", "/run_predict_with_url", "/docs"],
212
  }
213
 
214
-
215
  @app.post("/run_predict")
216
  def run_predict(body: PredictIn):
217
  messages = [
@@ -226,7 +272,6 @@ def run_predict(body: PredictIn):
226
  )
227
  return {"reply": reply}
228
 
229
-
230
  @app.post("/run_predict_with_url")
231
  def run_predict_with_url(body: PredictURLIn):
232
  web_ctx = fetch_url_text(body.url) if body.url else ""
@@ -243,7 +288,6 @@ def run_predict_with_url(body: PredictURLIn):
243
  )
244
  return {"reply": reply}
245
 
246
-
247
  if __name__ == "__main__":
248
  import uvicorn
249
  uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))
 
3
 
4
  import os, glob, textwrap
5
  from pathlib import Path
6
+ from typing import Optional, List
7
 
8
  from fastapi import FastAPI
9
  from fastapi.middleware.cors import CORSMiddleware
 
22
  except Exception:
23
  return default
24
 
 
25
  def env_float(name: str, default: float) -> float:
26
  try:
27
  return float(os.getenv(name, "").strip() or default)
28
  except Exception:
29
  return default
30
 
31
+ def env_list(name: str) -> List[str]:
 
32
  raw = os.getenv(name, "").strip()
33
  return [x.strip() for x in raw.split(",") if x.strip()]
34
 
35
 
36
  # ------------------ Config ------------------
37
+ # ⇩⇩ modelo 3B (mejor para Spaces CPU gratis)
38
  MODEL_REPO = os.getenv("MODEL_REPO", "Qwen/Qwen2.5-3B-Instruct-GGUF")
39
+
40
+ # Si defines MODEL_PATTERN lo respetamos; si no, probamos varios patrones típicos.
41
+ PRIMARY_PATTERN = os.getenv("MODEL_PATTERN", "").strip()
42
 
43
  # Carpeta de modelos en /data (escribible en Docker Spaces)
44
  MODELS_DIR = Path(os.getenv("MODELS_DIR", "/data/models"))
45
  MODELS_DIR.mkdir(parents=True, exist_ok=True)
46
 
47
+ # Rendimiento
48
  CPU_COUNT = os.cpu_count() or 2
49
  N_THREADS = env_int("N_THREADS", max(1, CPU_COUNT - 1))
50
+ N_BATCH = env_int("N_BATCH", 64)
51
+ N_CTX = env_int("N_CTX", 1536) # 15362048 ok en CPU basic
52
 
53
+ # Decodificación / longitudes
54
+ DEF_TEMPERATURE = env_float("LLM_TEMPERATURE", 0.4) # un poco más bajo menos alucinación
55
  DEF_TOP_P = env_float("LLM_TOP_P", 0.9)
56
+ DEF_MAX_TOKENS = env_int("LLM_MAX_TOKENS", 160)
57
+ MAX_TOKENS_CAP = env_int("LLM_MAX_TOKENS_CAP", 320)
58
 
59
  SYSTEM_DEFAULT = textwrap.dedent("""\
60
  Eres Astrohunters-Guide, un asistente en español.
61
  - Respondes con precisión y sin inventar datos.
62
  - Sabes explicar resultados de exoplanetas (período, duración, profundidad, SNR, radio).
63
+ - Si te paso una URL, lees su contenido y la usas como contexto.
64
  """)
65
 
66
 
67
  # CORS
 
 
 
68
  allow_all = os.getenv("ALLOW_ALL_ORIGINS", "").strip() in ("1", "true", "yes")
69
  CORS_ORIGINS = env_list("CORS_ORIGINS")
70
  if not CORS_ORIGINS:
 
71
  CORS_ORIGINS = [
72
  "https://pruebas.nataliacoronel.com",
73
  "https://*.nataliacoronel.com",
74
  ]
75
 
76
 
77
+ # ------------------ Resolución robusta del archivo GGUF ------------------
78
+ def resolve_model_path(repo: str, models_dir: Path, primary_pattern: str) -> str:
79
+ """
80
+ Descarga sólo los archivos que necesitamos probando varios patrones comunes de Qwen 3B.
81
+ Devuelve la ruta al GGUF elegido o lanza FileNotFoundError.
82
+ """
83
+ # Patrones preferidos (ordenados por calidad/viabilidad en CPU gratuita)
84
+ patterns = []
85
+ if primary_pattern:
86
+ patterns.append(primary_pattern)
87
+
88
+ # 3B suele venir sin sufijo -00001-of-00001
89
+ patterns += [
90
+ "qwen2.5-3b-instruct-q4_k_m-*.gguf",
91
+ "qwen2.5-3b-instruct-q4_k_m.gguf",
92
+ "qwen2.5-3b-instruct-q4_0-*.gguf",
93
+ "qwen2.5-3b-instruct-q4_0.gguf",
94
+ "qwen2.5-3b-instruct-q3_k_m-*.gguf",
95
+ "qwen2.5-3b-instruct-q3_k_m.gguf",
96
+ ]
97
+ # Como último recurso (no deseable porque puede bajar más de un archivo):
98
+ # patterns.append("*.gguf")
99
+
100
+ # 1) Intento de descarga con allow_patterns = lista de patrones
101
+ print(f"[Boot] Descargando {repo} con patrones: {patterns}")
102
+ snapshot_dir = snapshot_download(
103
+ repo_id=repo,
104
+ local_dir=str(models_dir),
105
+ allow_patterns=patterns,
106
+ )
107
+
108
+ # 2) Buscar candidatos en el snapshot por prioridad
109
+ def glob_once(pat: str) -> List[str]:
110
+ return sorted(glob.glob(str(Path(snapshot_dir) / pat)))
111
+
112
+ all_candidates: List[str] = []
113
+ for pat in patterns:
114
+ cs = glob_once(pat)
115
+ if cs:
116
+ all_candidates.extend(cs)
117
+
118
+ # Filtro por 'instruct' y '3b' primero
119
+ def score(path: str) -> tuple:
120
+ p = Path(path).name.lower()
121
+ # prioridad por quant y por coincidencia "instruct" / "3b"
122
+ quant_order = ["q4_k_m", "q4_0", "q3_k_m", "q5_k_m", "q3_0"]
123
+ q_idx = next((i for i, q in enumerate(quant_order) if q in p), 99)
124
+ instruct_bonus = 0 if "instruct" in p else 50
125
+ size_bonus = 0 # opcional: podrías usar tamaño
126
+ return (instruct_bonus, q_idx, size_bonus, p)
127
+
128
+ all_candidates = sorted(set(all_candidates), key=score)
129
+ if not all_candidates:
130
+ # Intenta listar qué hay para debug
131
+ existing = sorted(glob.glob(str(Path(snapshot_dir) / "*.gguf")))
132
+ raise FileNotFoundError(
133
+ "No se encontró ningún GGUF en el repo con los patrones probados.\n"
134
+ f"Repo: {repo}\n"
135
+ f"Snapshot: {snapshot_dir}\n"
136
+ f"Intentados: {patterns}\n"
137
+ f"Encontrados (*.gguf): {[Path(x).name for x in existing]}"
138
+ )
139
+
140
+ chosen = all_candidates[0]
141
+ print(f"[Boot] Usando GGUF: {chosen}")
142
+ return chosen
143
+
144
+
145
+ print(f"[Boot] Preparando modelo en {MODELS_DIR} ...")
146
+ MODEL_PATH = resolve_model_path(MODEL_REPO, MODELS_DIR, PRIMARY_PATTERN)
147
 
148
 
149
  # ------------------ Carga LLaMA.cpp ------------------
 
172
  except Exception as e:
173
  return f"[No se pudo cargar {url}: {e}]"
174
 
 
175
  def clamp_tokens(requested: Optional[int]) -> int:
176
  if requested is None or requested <= 0:
177
  return DEF_MAX_TOKENS
178
  return max(1, min(requested, MAX_TOKENS_CAP))
179
 
 
180
  def run_llm(
181
  messages,
182
  temperature: Optional[float] = None,
 
193
  try:
194
  return out["choices"][0]["message"]["content"].strip()
195
  except Exception:
 
196
  return str(out)[:1000]
197
 
198
 
 
225
  temperature: Optional[float] = None
226
  top_p: Optional[float] = None
227
 
 
228
  class PredictURLIn(BaseModel):
229
  prompt: str
230
  url: Optional[str] = None
 
251
  },
252
  }
253
 
 
254
  @app.get("/")
255
  def root():
256
  return {
 
258
  "endpoints": ["/healthz", "/run_predict", "/run_predict_with_url", "/docs"],
259
  }
260
 
 
261
  @app.post("/run_predict")
262
  def run_predict(body: PredictIn):
263
  messages = [
 
272
  )
273
  return {"reply": reply}
274
 
 
275
  @app.post("/run_predict_with_url")
276
  def run_predict_with_url(body: PredictURLIn):
277
  web_ctx = fetch_url_text(body.url) if body.url else ""
 
288
  )
289
  return {"reply": reply}
290
 
 
291
  if __name__ == "__main__":
292
  import uvicorn
293
  uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))