Leonardo0711 commited on
Commit
bed0ded
·
verified ·
1 Parent(s): 950bb84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -43
app.py CHANGED
@@ -4,6 +4,7 @@
4
  import os, glob, textwrap
5
  from pathlib import Path
6
  from typing import Optional, List
 
7
 
8
  from fastapi import FastAPI
9
  from fastapi.middleware.cors import CORSMiddleware
@@ -18,27 +19,27 @@ from llama_cpp import Llama
18
  # ------------------ Helpers ------------------
19
  def env_int(name: str, default: int) -> int:
20
  try:
21
- return int(os.getenv(name, "").strip() or default)
22
  except Exception:
23
  return default
24
 
25
  def env_float(name: str, default: float) -> float:
26
  try:
27
- return float(os.getenv(name, "").strip() or default)
28
  except Exception:
29
  return default
30
 
31
  def env_list(name: str) -> List[str]:
32
- raw = os.getenv(name, "").strip()
33
  return [x.strip() for x in raw.split(",") if x.strip()]
34
 
35
 
36
  # ------------------ Config ------------------
37
- # ⇩⇩ modelo 3B (mejor para Spaces CPU gratis)
38
  MODEL_REPO = os.getenv("MODEL_REPO", "Qwen/Qwen2.5-3B-Instruct-GGUF")
39
 
40
  # Si defines MODEL_PATTERN lo respetamos; si no, probamos varios patrones típicos.
41
- PRIMARY_PATTERN = os.getenv("MODEL_PATTERN", "").strip()
42
 
43
  # Carpeta de modelos en /data (escribible en Docker Spaces)
44
  MODELS_DIR = Path(os.getenv("MODELS_DIR", "/data/models"))
@@ -47,14 +48,15 @@ MODELS_DIR.mkdir(parents=True, exist_ok=True)
47
  # Rendimiento
48
  CPU_COUNT = os.cpu_count() or 2
49
  N_THREADS = env_int("N_THREADS", max(1, CPU_COUNT - 1))
50
- N_BATCH = env_int("N_BATCH", 64)
51
- N_CTX = env_int("N_CTX", 1536) # 1536–2048 ok en CPU basic
 
52
 
53
  # Decodificación / longitudes
54
- DEF_TEMPERATURE = env_float("LLM_TEMPERATURE", 0.4) # un poco más bajo → menos alucinación
55
- DEF_TOP_P = env_float("LLM_TOP_P", 0.9)
56
- DEF_MAX_TOKENS = env_int("LLM_MAX_TOKENS", 160)
57
- MAX_TOKENS_CAP = env_int("LLM_MAX_TOKENS_CAP", 320)
58
 
59
  SYSTEM_DEFAULT = textwrap.dedent("""\
60
  Eres Astrohunters-Guide, un asistente en español.
@@ -65,23 +67,23 @@ Eres Astrohunters-Guide, un asistente en español.
65
 
66
 
67
  # CORS
68
- allow_all = os.getenv("ALLOW_ALL_ORIGINS", "").strip() in ("1", "true", "yes")
69
  CORS_ORIGINS = env_list("CORS_ORIGINS")
70
- if not CORS_ORIGINS:
 
 
71
  CORS_ORIGINS = [
72
  "https://pruebas.nataliacoronel.com",
73
- "https://*.nataliacoronel.com",
74
  ]
75
 
76
 
77
  # ------------------ Resolución robusta del archivo GGUF ------------------
78
  def resolve_model_path(repo: str, models_dir: Path, primary_pattern: str) -> str:
79
  """
80
- Descarga sólo los archivos que necesitamos probando varios patrones comunes de Qwen 3B.
81
  Devuelve la ruta al GGUF elegido o lanza FileNotFoundError.
82
  """
83
- # Patrones preferidos (ordenados por calidad/viabilidad en CPU gratuita)
84
- patterns = []
85
  if primary_pattern:
86
  patterns.append(primary_pattern)
87
 
@@ -94,10 +96,7 @@ def resolve_model_path(repo: str, models_dir: Path, primary_pattern: str) -> str
94
  "qwen2.5-3b-instruct-q3_k_m-*.gguf",
95
  "qwen2.5-3b-instruct-q3_k_m.gguf",
96
  ]
97
- # Como último recurso (no deseable porque puede bajar más de un archivo):
98
- # patterns.append("*.gguf")
99
 
100
- # 1) Intento de descarga con allow_patterns = lista de patrones
101
  print(f"[Boot] Descargando {repo} con patrones: {patterns}")
102
  snapshot_dir = snapshot_download(
103
  repo_id=repo,
@@ -105,36 +104,28 @@ def resolve_model_path(repo: str, models_dir: Path, primary_pattern: str) -> str
105
  allow_patterns=patterns,
106
  )
107
 
108
- # 2) Buscar candidatos en el snapshot por prioridad
109
  def glob_once(pat: str) -> List[str]:
110
  return sorted(glob.glob(str(Path(snapshot_dir) / pat)))
111
 
112
  all_candidates: List[str] = []
113
  for pat in patterns:
114
- cs = glob_once(pat)
115
- if cs:
116
- all_candidates.extend(cs)
117
 
118
- # Filtro por 'instruct' y '3b' primero
119
  def score(path: str) -> tuple:
120
  p = Path(path).name.lower()
121
- # prioridad por quant y por coincidencia "instruct" / "3b"
122
  quant_order = ["q4_k_m", "q4_0", "q3_k_m", "q5_k_m", "q3_0"]
123
  q_idx = next((i for i, q in enumerate(quant_order) if q in p), 99)
124
  instruct_bonus = 0 if "instruct" in p else 50
125
- size_bonus = 0 # opcional: podrías usar tamaño
126
- return (instruct_bonus, q_idx, size_bonus, p)
127
 
128
  all_candidates = sorted(set(all_candidates), key=score)
129
  if not all_candidates:
130
- # Intenta listar qué hay para debug
131
  existing = sorted(glob.glob(str(Path(snapshot_dir) / "*.gguf")))
132
  raise FileNotFoundError(
133
  "No se encontró ningún GGUF en el repo con los patrones probados.\n"
134
- f"Repo: {repo}\n"
135
- f"Snapshot: {snapshot_dir}\n"
136
- f"Intentados: {patterns}\n"
137
- f"Encontrados (*.gguf): {[Path(x).name for x in existing]}"
138
  )
139
 
140
  chosen = all_candidates[0]
@@ -152,10 +143,13 @@ llm = Llama(
152
  n_ctx=N_CTX,
153
  n_threads=N_THREADS,
154
  n_batch=N_BATCH,
155
- n_gpu_layers=env_int("N_GPU_LAYERS", 0),
156
  verbose=False,
157
  )
158
 
 
 
 
159
 
160
  # ------------------ Utilidades ------------------
161
  def fetch_url_text(url: str, max_chars: int = 6000) -> str:
@@ -183,13 +177,14 @@ def run_llm(
183
  top_p: Optional[float] = None,
184
  max_tokens: Optional[int] = None,
185
  ) -> str:
186
- out = llm.create_chat_completion(
187
- messages=messages,
188
- temperature=DEF_TEMPERATURE if temperature is None else float(temperature),
189
- top_p=DEF_TOP_P if top_p is None else float(top_p),
190
- max_tokens=clamp_tokens(max_tokens),
191
- stream=False,
192
- )
 
193
  try:
194
  return out["choices"][0]["message"]["content"].strip()
195
  except Exception:
@@ -262,7 +257,7 @@ def root():
262
  def run_predict(body: PredictIn):
263
  messages = [
264
  {"role": "system", "content": body.system or SYSTEM_DEFAULT},
265
- {"role": "user", "content": body.prompt},
266
  ]
267
  reply = run_llm(
268
  messages,
@@ -278,7 +273,7 @@ def run_predict_with_url(body: PredictURLIn):
278
  user_msg = body.prompt if not web_ctx else f"{body.prompt}\n\n[CONTEXTO_WEB]\n{web_ctx}"
279
  messages = [
280
  {"role": "system", "content": body.system or SYSTEM_DEFAULT},
281
- {"role": "user", "content": user_msg},
282
  ]
283
  reply = run_llm(
284
  messages,
 
4
  import os, glob, textwrap
5
  from pathlib import Path
6
  from typing import Optional, List
7
+ from threading import Lock # ← para serializar la inferencia
8
 
9
  from fastapi import FastAPI
10
  from fastapi.middleware.cors import CORSMiddleware
 
19
  # ------------------ Helpers ------------------
20
  def env_int(name: str, default: int) -> int:
21
  try:
22
+ return int((os.getenv(name, "") or "").strip() or default)
23
  except Exception:
24
  return default
25
 
26
  def env_float(name: str, default: float) -> float:
27
  try:
28
+ return float((os.getenv(name, "") or "").strip() or default)
29
  except Exception:
30
  return default
31
 
32
  def env_list(name: str) -> List[str]:
33
+ raw = (os.getenv(name, "") or "").strip()
34
  return [x.strip() for x in raw.split(",") if x.strip()]
35
 
36
 
37
  # ------------------ Config ------------------
38
+ # ⇩⇩ modelo 3B (más liviano para Spaces CPU gratis)
39
  MODEL_REPO = os.getenv("MODEL_REPO", "Qwen/Qwen2.5-3B-Instruct-GGUF")
40
 
41
  # Si defines MODEL_PATTERN lo respetamos; si no, probamos varios patrones típicos.
42
+ PRIMARY_PATTERN = (os.getenv("MODEL_PATTERN", "") or "").strip()
43
 
44
  # Carpeta de modelos en /data (escribible en Docker Spaces)
45
  MODELS_DIR = Path(os.getenv("MODELS_DIR", "/data/models"))
 
48
  # Rendimiento
49
  CPU_COUNT = os.cpu_count() or 2
50
  N_THREADS = env_int("N_THREADS", max(1, CPU_COUNT - 1))
51
+ N_BATCH = env_int("N_BATCH", 64)
52
+ N_CTX = env_int("N_CTX", 1536) # 1536–2048 ok en CPU basic
53
+ N_GPU_LAYERS = env_int("N_GPU_LAYERS", 0)
54
 
55
  # Decodificación / longitudes
56
+ DEF_TEMPERATURE = env_float("LLM_TEMPERATURE", 0.4) # más bajo → menos alucinación
57
+ DEF_TOP_P = env_float("LLM_TOP_P", 0.9)
58
+ DEF_MAX_TOKENS = env_int("LLM_MAX_TOKENS", 160)
59
+ MAX_TOKENS_CAP = env_int("LLM_MAX_TOKENS_CAP", 320)
60
 
61
  SYSTEM_DEFAULT = textwrap.dedent("""\
62
  Eres Astrohunters-Guide, un asistente en español.
 
67
 
68
 
69
  # CORS
70
+ allow_all = (os.getenv("ALLOW_ALL_ORIGINS", "") or "").strip().lower() in ("1","true","yes")
71
  CORS_ORIGINS = env_list("CORS_ORIGINS")
72
+ if not CORS_ORIGINS and not allow_all:
73
+ # Nota: CORSMiddleware no soporta comodines tipo *.dominio;
74
+ # si necesitas eso, usa ALLOW_ALL_ORIGINS=1 durante pruebas.
75
  CORS_ORIGINS = [
76
  "https://pruebas.nataliacoronel.com",
 
77
  ]
78
 
79
 
80
  # ------------------ Resolución robusta del archivo GGUF ------------------
81
  def resolve_model_path(repo: str, models_dir: Path, primary_pattern: str) -> str:
82
  """
83
+ Descarga sólo los archivos necesarios probando varios patrones comunes de Qwen 3B.
84
  Devuelve la ruta al GGUF elegido o lanza FileNotFoundError.
85
  """
86
+ patterns: List[str] = []
 
87
  if primary_pattern:
88
  patterns.append(primary_pattern)
89
 
 
96
  "qwen2.5-3b-instruct-q3_k_m-*.gguf",
97
  "qwen2.5-3b-instruct-q3_k_m.gguf",
98
  ]
 
 
99
 
 
100
  print(f"[Boot] Descargando {repo} con patrones: {patterns}")
101
  snapshot_dir = snapshot_download(
102
  repo_id=repo,
 
104
  allow_patterns=patterns,
105
  )
106
 
 
107
  def glob_once(pat: str) -> List[str]:
108
  return sorted(glob.glob(str(Path(snapshot_dir) / pat)))
109
 
110
  all_candidates: List[str] = []
111
  for pat in patterns:
112
+ all_candidates += glob_once(pat)
 
 
113
 
 
114
  def score(path: str) -> tuple:
115
  p = Path(path).name.lower()
116
+ # prioridad por quant y coincidencia instruct
117
  quant_order = ["q4_k_m", "q4_0", "q3_k_m", "q5_k_m", "q3_0"]
118
  q_idx = next((i for i, q in enumerate(quant_order) if q in p), 99)
119
  instruct_bonus = 0 if "instruct" in p else 50
120
+ return (instruct_bonus, q_idx, p)
 
121
 
122
  all_candidates = sorted(set(all_candidates), key=score)
123
  if not all_candidates:
 
124
  existing = sorted(glob.glob(str(Path(snapshot_dir) / "*.gguf")))
125
  raise FileNotFoundError(
126
  "No se encontró ningún GGUF en el repo con los patrones probados.\n"
127
+ f"Repo: {repo}\nSnapshot: {snapshot_dir}\n"
128
+ f"Intentados: {patterns}\nEncontrados (*.gguf): {[Path(x).name for x in existing]}"
 
 
129
  )
130
 
131
  chosen = all_candidates[0]
 
143
  n_ctx=N_CTX,
144
  n_threads=N_THREADS,
145
  n_batch=N_BATCH,
146
+ n_gpu_layers=N_GPU_LAYERS,
147
  verbose=False,
148
  )
149
 
150
+ # Bloqueo global para evitar concurrencia en CPU Basic
151
+ LLM_LOCK = Lock()
152
+
153
 
154
  # ------------------ Utilidades ------------------
155
  def fetch_url_text(url: str, max_chars: int = 6000) -> str:
 
177
  top_p: Optional[float] = None,
178
  max_tokens: Optional[int] = None,
179
  ) -> str:
180
+ with LLM_LOCK: # ← serializa la llamada al modelo
181
+ out = llm.create_chat_completion(
182
+ messages=messages,
183
+ temperature=DEF_TEMPERATURE if temperature is None else float(temperature),
184
+ top_p=DEF_TOP_P if top_p is None else float(top_p),
185
+ max_tokens=clamp_tokens(max_tokens),
186
+ stream=False,
187
+ )
188
  try:
189
  return out["choices"][0]["message"]["content"].strip()
190
  except Exception:
 
257
  def run_predict(body: PredictIn):
258
  messages = [
259
  {"role": "system", "content": body.system or SYSTEM_DEFAULT},
260
+ {"role": "user", "content": body.prompt},
261
  ]
262
  reply = run_llm(
263
  messages,
 
273
  user_msg = body.prompt if not web_ctx else f"{body.prompt}\n\n[CONTEXTO_WEB]\n{web_ctx}"
274
  messages = [
275
  {"role": "system", "content": body.system or SYSTEM_DEFAULT},
276
+ {"role": "user", "content": user_msg},
277
  ]
278
  reply = run_llm(
279
  messages,