Leonardo0711's picture
Update app.py
bed0ded verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os, glob, textwrap
from pathlib import Path
from typing import Optional, List
from threading import Lock # ← para serializar la inferencia
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import requests
from bs4 import BeautifulSoup
from huggingface_hub import snapshot_download
from llama_cpp import Llama
# ------------------ Helpers ------------------
def env_int(name: str, default: int) -> int:
try:
return int((os.getenv(name, "") or "").strip() or default)
except Exception:
return default
def env_float(name: str, default: float) -> float:
try:
return float((os.getenv(name, "") or "").strip() or default)
except Exception:
return default
def env_list(name: str) -> List[str]:
raw = (os.getenv(name, "") or "").strip()
return [x.strip() for x in raw.split(",") if x.strip()]
# ------------------ Config ------------------
# ⇩⇩ modelo 3B (más liviano para Spaces CPU gratis)
MODEL_REPO = os.getenv("MODEL_REPO", "Qwen/Qwen2.5-3B-Instruct-GGUF")
# Si defines MODEL_PATTERN lo respetamos; si no, probamos varios patrones típicos.
PRIMARY_PATTERN = (os.getenv("MODEL_PATTERN", "") or "").strip()
# Carpeta de modelos en /data (escribible en Docker Spaces)
MODELS_DIR = Path(os.getenv("MODELS_DIR", "/data/models"))
MODELS_DIR.mkdir(parents=True, exist_ok=True)
# Rendimiento
CPU_COUNT = os.cpu_count() or 2
N_THREADS = env_int("N_THREADS", max(1, CPU_COUNT - 1))
N_BATCH = env_int("N_BATCH", 64)
N_CTX = env_int("N_CTX", 1536) # 1536–2048 ok en CPU basic
N_GPU_LAYERS = env_int("N_GPU_LAYERS", 0)
# Decodificación / longitudes
DEF_TEMPERATURE = env_float("LLM_TEMPERATURE", 0.4) # más bajo → menos alucinación
DEF_TOP_P = env_float("LLM_TOP_P", 0.9)
DEF_MAX_TOKENS = env_int("LLM_MAX_TOKENS", 160)
MAX_TOKENS_CAP = env_int("LLM_MAX_TOKENS_CAP", 320)
SYSTEM_DEFAULT = textwrap.dedent("""\
Eres Astrohunters-Guide, un asistente en español.
- Respondes con precisión y sin inventar datos.
- Sabes explicar resultados de exoplanetas (período, duración, profundidad, SNR, radio).
- Si te paso una URL, lees su contenido y la usas como contexto.
""")
# CORS
allow_all = (os.getenv("ALLOW_ALL_ORIGINS", "") or "").strip().lower() in ("1","true","yes")
CORS_ORIGINS = env_list("CORS_ORIGINS")
if not CORS_ORIGINS and not allow_all:
# Nota: CORSMiddleware no soporta comodines tipo *.dominio;
# si necesitas eso, usa ALLOW_ALL_ORIGINS=1 durante pruebas.
CORS_ORIGINS = [
"https://pruebas.nataliacoronel.com",
]
# ------------------ Resolución robusta del archivo GGUF ------------------
def resolve_model_path(repo: str, models_dir: Path, primary_pattern: str) -> str:
"""
Descarga sólo los archivos necesarios probando varios patrones comunes de Qwen 3B.
Devuelve la ruta al GGUF elegido o lanza FileNotFoundError.
"""
patterns: List[str] = []
if primary_pattern:
patterns.append(primary_pattern)
# 3B suele venir sin sufijo -00001-of-00001
patterns += [
"qwen2.5-3b-instruct-q4_k_m-*.gguf",
"qwen2.5-3b-instruct-q4_k_m.gguf",
"qwen2.5-3b-instruct-q4_0-*.gguf",
"qwen2.5-3b-instruct-q4_0.gguf",
"qwen2.5-3b-instruct-q3_k_m-*.gguf",
"qwen2.5-3b-instruct-q3_k_m.gguf",
]
print(f"[Boot] Descargando {repo} con patrones: {patterns}")
snapshot_dir = snapshot_download(
repo_id=repo,
local_dir=str(models_dir),
allow_patterns=patterns,
)
def glob_once(pat: str) -> List[str]:
return sorted(glob.glob(str(Path(snapshot_dir) / pat)))
all_candidates: List[str] = []
for pat in patterns:
all_candidates += glob_once(pat)
def score(path: str) -> tuple:
p = Path(path).name.lower()
# prioridad por quant y coincidencia instruct
quant_order = ["q4_k_m", "q4_0", "q3_k_m", "q5_k_m", "q3_0"]
q_idx = next((i for i, q in enumerate(quant_order) if q in p), 99)
instruct_bonus = 0 if "instruct" in p else 50
return (instruct_bonus, q_idx, p)
all_candidates = sorted(set(all_candidates), key=score)
if not all_candidates:
existing = sorted(glob.glob(str(Path(snapshot_dir) / "*.gguf")))
raise FileNotFoundError(
"No se encontró ningún GGUF en el repo con los patrones probados.\n"
f"Repo: {repo}\nSnapshot: {snapshot_dir}\n"
f"Intentados: {patterns}\nEncontrados (*.gguf): {[Path(x).name for x in existing]}"
)
chosen = all_candidates[0]
print(f"[Boot] Usando GGUF: {chosen}")
return chosen
print(f"[Boot] Preparando modelo en {MODELS_DIR} ...")
MODEL_PATH = resolve_model_path(MODEL_REPO, MODELS_DIR, PRIMARY_PATTERN)
# ------------------ Carga LLaMA.cpp ------------------
llm = Llama(
model_path=MODEL_PATH,
n_ctx=N_CTX,
n_threads=N_THREADS,
n_batch=N_BATCH,
n_gpu_layers=N_GPU_LAYERS,
verbose=False,
)
# Bloqueo global para evitar concurrencia en CPU Basic
LLM_LOCK = Lock()
# ------------------ Utilidades ------------------
def fetch_url_text(url: str, max_chars: int = 6000) -> str:
if not url:
return ""
try:
r = requests.get(url, timeout=15)
r.raise_for_status()
soup = BeautifulSoup(r.text, "html.parser")
for t in soup(["script", "style", "noscript"]):
t.decompose()
txt = " ".join(soup.get_text(separator=" ").split())
return txt[:max_chars]
except Exception as e:
return f"[No se pudo cargar {url}: {e}]"
def clamp_tokens(requested: Optional[int]) -> int:
if requested is None or requested <= 0:
return DEF_MAX_TOKENS
return max(1, min(requested, MAX_TOKENS_CAP))
def run_llm(
messages,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> str:
with LLM_LOCK: # ← serializa la llamada al modelo
out = llm.create_chat_completion(
messages=messages,
temperature=DEF_TEMPERATURE if temperature is None else float(temperature),
top_p=DEF_TOP_P if top_p is None else float(top_p),
max_tokens=clamp_tokens(max_tokens),
stream=False,
)
try:
return out["choices"][0]["message"]["content"].strip()
except Exception:
return str(out)[:1000]
# ------------------ FastAPI ------------------
app = FastAPI(title="Astrohunters LLM API", docs_url="/docs", redoc_url=None)
if allow_all:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
else:
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --------- Esquemas de entrada ---------
class PredictIn(BaseModel):
prompt: str
system: Optional[str] = None
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
class PredictURLIn(BaseModel):
prompt: str
url: Optional[str] = None
system: Optional[str] = None
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
# --------- Endpoints ---------
@app.get("/healthz")
def healthz():
return {
"ok": True,
"model": os.path.basename(MODEL_PATH),
"threads": N_THREADS,
"n_batch": N_BATCH,
"n_ctx": N_CTX,
"defaults": {
"temperature": DEF_TEMPERATURE,
"top_p": DEF_TOP_P,
"max_tokens": DEF_MAX_TOKENS,
"max_tokens_cap": MAX_TOKENS_CAP,
},
}
@app.get("/")
def root():
return {
"name": "Astrohunters LLM API",
"endpoints": ["/healthz", "/run_predict", "/run_predict_with_url", "/docs"],
}
@app.post("/run_predict")
def run_predict(body: PredictIn):
messages = [
{"role": "system", "content": body.system or SYSTEM_DEFAULT},
{"role": "user", "content": body.prompt},
]
reply = run_llm(
messages,
temperature=body.temperature,
top_p=body.top_p,
max_tokens=body.max_tokens,
)
return {"reply": reply}
@app.post("/run_predict_with_url")
def run_predict_with_url(body: PredictURLIn):
web_ctx = fetch_url_text(body.url) if body.url else ""
user_msg = body.prompt if not web_ctx else f"{body.prompt}\n\n[CONTEXTO_WEB]\n{web_ctx}"
messages = [
{"role": "system", "content": body.system or SYSTEM_DEFAULT},
{"role": "user", "content": user_msg},
]
reply = run_llm(
messages,
temperature=body.temperature,
top_p=body.top_p,
max_tokens=body.max_tokens,
)
return {"reply": reply}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))