#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os, glob, textwrap from pathlib import Path from typing import Optional, List from threading import Lock # ← para serializar la inferencia from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import requests from bs4 import BeautifulSoup from huggingface_hub import snapshot_download from llama_cpp import Llama # ------------------ Helpers ------------------ def env_int(name: str, default: int) -> int: try: return int((os.getenv(name, "") or "").strip() or default) except Exception: return default def env_float(name: str, default: float) -> float: try: return float((os.getenv(name, "") or "").strip() or default) except Exception: return default def env_list(name: str) -> List[str]: raw = (os.getenv(name, "") or "").strip() return [x.strip() for x in raw.split(",") if x.strip()] # ------------------ Config ------------------ # ⇩⇩ modelo 3B (más liviano para Spaces CPU gratis) MODEL_REPO = os.getenv("MODEL_REPO", "Qwen/Qwen2.5-3B-Instruct-GGUF") # Si defines MODEL_PATTERN lo respetamos; si no, probamos varios patrones típicos. PRIMARY_PATTERN = (os.getenv("MODEL_PATTERN", "") or "").strip() # Carpeta de modelos en /data (escribible en Docker Spaces) MODELS_DIR = Path(os.getenv("MODELS_DIR", "/data/models")) MODELS_DIR.mkdir(parents=True, exist_ok=True) # Rendimiento CPU_COUNT = os.cpu_count() or 2 N_THREADS = env_int("N_THREADS", max(1, CPU_COUNT - 1)) N_BATCH = env_int("N_BATCH", 64) N_CTX = env_int("N_CTX", 1536) # 1536–2048 ok en CPU basic N_GPU_LAYERS = env_int("N_GPU_LAYERS", 0) # Decodificación / longitudes DEF_TEMPERATURE = env_float("LLM_TEMPERATURE", 0.4) # más bajo → menos alucinación DEF_TOP_P = env_float("LLM_TOP_P", 0.9) DEF_MAX_TOKENS = env_int("LLM_MAX_TOKENS", 160) MAX_TOKENS_CAP = env_int("LLM_MAX_TOKENS_CAP", 320) SYSTEM_DEFAULT = textwrap.dedent("""\ Eres Astrohunters-Guide, un asistente en español. - Respondes con precisión y sin inventar datos. - Sabes explicar resultados de exoplanetas (período, duración, profundidad, SNR, radio). - Si te paso una URL, lees su contenido y la usas como contexto. """) # CORS allow_all = (os.getenv("ALLOW_ALL_ORIGINS", "") or "").strip().lower() in ("1","true","yes") CORS_ORIGINS = env_list("CORS_ORIGINS") if not CORS_ORIGINS and not allow_all: # Nota: CORSMiddleware no soporta comodines tipo *.dominio; # si necesitas eso, usa ALLOW_ALL_ORIGINS=1 durante pruebas. CORS_ORIGINS = [ "https://pruebas.nataliacoronel.com", ] # ------------------ Resolución robusta del archivo GGUF ------------------ def resolve_model_path(repo: str, models_dir: Path, primary_pattern: str) -> str: """ Descarga sólo los archivos necesarios probando varios patrones comunes de Qwen 3B. Devuelve la ruta al GGUF elegido o lanza FileNotFoundError. """ patterns: List[str] = [] if primary_pattern: patterns.append(primary_pattern) # 3B suele venir sin sufijo -00001-of-00001 patterns += [ "qwen2.5-3b-instruct-q4_k_m-*.gguf", "qwen2.5-3b-instruct-q4_k_m.gguf", "qwen2.5-3b-instruct-q4_0-*.gguf", "qwen2.5-3b-instruct-q4_0.gguf", "qwen2.5-3b-instruct-q3_k_m-*.gguf", "qwen2.5-3b-instruct-q3_k_m.gguf", ] print(f"[Boot] Descargando {repo} con patrones: {patterns}") snapshot_dir = snapshot_download( repo_id=repo, local_dir=str(models_dir), allow_patterns=patterns, ) def glob_once(pat: str) -> List[str]: return sorted(glob.glob(str(Path(snapshot_dir) / pat))) all_candidates: List[str] = [] for pat in patterns: all_candidates += glob_once(pat) def score(path: str) -> tuple: p = Path(path).name.lower() # prioridad por quant y coincidencia instruct quant_order = ["q4_k_m", "q4_0", "q3_k_m", "q5_k_m", "q3_0"] q_idx = next((i for i, q in enumerate(quant_order) if q in p), 99) instruct_bonus = 0 if "instruct" in p else 50 return (instruct_bonus, q_idx, p) all_candidates = sorted(set(all_candidates), key=score) if not all_candidates: existing = sorted(glob.glob(str(Path(snapshot_dir) / "*.gguf"))) raise FileNotFoundError( "No se encontró ningún GGUF en el repo con los patrones probados.\n" f"Repo: {repo}\nSnapshot: {snapshot_dir}\n" f"Intentados: {patterns}\nEncontrados (*.gguf): {[Path(x).name for x in existing]}" ) chosen = all_candidates[0] print(f"[Boot] Usando GGUF: {chosen}") return chosen print(f"[Boot] Preparando modelo en {MODELS_DIR} ...") MODEL_PATH = resolve_model_path(MODEL_REPO, MODELS_DIR, PRIMARY_PATTERN) # ------------------ Carga LLaMA.cpp ------------------ llm = Llama( model_path=MODEL_PATH, n_ctx=N_CTX, n_threads=N_THREADS, n_batch=N_BATCH, n_gpu_layers=N_GPU_LAYERS, verbose=False, ) # Bloqueo global para evitar concurrencia en CPU Basic LLM_LOCK = Lock() # ------------------ Utilidades ------------------ def fetch_url_text(url: str, max_chars: int = 6000) -> str: if not url: return "" try: r = requests.get(url, timeout=15) r.raise_for_status() soup = BeautifulSoup(r.text, "html.parser") for t in soup(["script", "style", "noscript"]): t.decompose() txt = " ".join(soup.get_text(separator=" ").split()) return txt[:max_chars] except Exception as e: return f"[No se pudo cargar {url}: {e}]" def clamp_tokens(requested: Optional[int]) -> int: if requested is None or requested <= 0: return DEF_MAX_TOKENS return max(1, min(requested, MAX_TOKENS_CAP)) def run_llm( messages, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[int] = None, ) -> str: with LLM_LOCK: # ← serializa la llamada al modelo out = llm.create_chat_completion( messages=messages, temperature=DEF_TEMPERATURE if temperature is None else float(temperature), top_p=DEF_TOP_P if top_p is None else float(top_p), max_tokens=clamp_tokens(max_tokens), stream=False, ) try: return out["choices"][0]["message"]["content"].strip() except Exception: return str(out)[:1000] # ------------------ FastAPI ------------------ app = FastAPI(title="Astrohunters LLM API", docs_url="/docs", redoc_url=None) if allow_all: app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) else: app.add_middleware( CORSMiddleware, allow_origins=CORS_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --------- Esquemas de entrada --------- class PredictIn(BaseModel): prompt: str system: Optional[str] = None max_tokens: Optional[int] = None temperature: Optional[float] = None top_p: Optional[float] = None class PredictURLIn(BaseModel): prompt: str url: Optional[str] = None system: Optional[str] = None max_tokens: Optional[int] = None temperature: Optional[float] = None top_p: Optional[float] = None # --------- Endpoints --------- @app.get("/healthz") def healthz(): return { "ok": True, "model": os.path.basename(MODEL_PATH), "threads": N_THREADS, "n_batch": N_BATCH, "n_ctx": N_CTX, "defaults": { "temperature": DEF_TEMPERATURE, "top_p": DEF_TOP_P, "max_tokens": DEF_MAX_TOKENS, "max_tokens_cap": MAX_TOKENS_CAP, }, } @app.get("/") def root(): return { "name": "Astrohunters LLM API", "endpoints": ["/healthz", "/run_predict", "/run_predict_with_url", "/docs"], } @app.post("/run_predict") def run_predict(body: PredictIn): messages = [ {"role": "system", "content": body.system or SYSTEM_DEFAULT}, {"role": "user", "content": body.prompt}, ] reply = run_llm( messages, temperature=body.temperature, top_p=body.top_p, max_tokens=body.max_tokens, ) return {"reply": reply} @app.post("/run_predict_with_url") def run_predict_with_url(body: PredictURLIn): web_ctx = fetch_url_text(body.url) if body.url else "" user_msg = body.prompt if not web_ctx else f"{body.prompt}\n\n[CONTEXTO_WEB]\n{web_ctx}" messages = [ {"role": "system", "content": body.system or SYSTEM_DEFAULT}, {"role": "user", "content": user_msg}, ] reply = run_llm( messages, temperature=body.temperature, top_p=body.top_p, max_tokens=body.max_tokens, ) return {"reply": reply} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))