fugthdes / app.py
fugthchat's picture
Update app.py
4b99165 verified
import os
import uuid
import threading
import logging
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from llama_cpp import Llama
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import hf_hub_download
from contextlib import asynccontextmanager
# --- Setup ---
logging.basicConfig(level=logging.INFO)
# --- MODEL MAP (Using the smarter Phi-3) ---
MODEL_MAP = {
"light": {
"repo_id": "microsoft/Phi-3-mini-4k-instruct-gguf",
"filename": "Phi-3-mini-4k-instruct-q4.gguf" # 2.13 GB
},
"medium": {
"repo_id": "TheBloke/DeepSeek-LLM-7B-Chat-GGUF",
"filename": "deepseek-llm-7b-chat.Q4_K_M.gguf" # 4.08 GB
},
"heavy": {
"repo_id": "TheBloke/DeepSeek-LLM-7B-Chat-GGUF",
"filename": "deepseek-llm-7b-chat.Q5_K_M.gguf" # 4.78 GB
}
}
# --- GLOBAL CACHE & LOCKS ---
llm_cache = {}
model_lock = threading.Lock() # Ensures only one model loads at a time
llm_lock = threading.Lock() # Ensures only one generation job runs at a time
JOBS = {} # Our in-memory "database" for background jobs
# --- Helper: Load Model ---
def get_llm_instance(choice: str) -> Llama:
with model_lock:
if choice not in MODEL_MAP:
logging.error(f"Invalid model choice: {choice}")
return None
if choice in llm_cache:
logging.info(f"Using cached model: {choice}")
return llm_cache[choice]
model_info = MODEL_MAP[choice]
repo_id = model_info["repo_id"]
filename = model_info["filename"]
try:
logging.info(f"Downloading model: {filename} from {repo_id}")
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
logging.info(f"Model downloaded to: {model_path}")
logging.info("Loading model into memory...")
llm = Llama(
model_path=model_path,
n_ctx=4096,
n_threads=2,
n_gpu_layers=0,
verbose=True
)
llm_cache.clear()
llm_cache[choice] = llm
logging.info(f"Model {choice} loaded successfully.")
return llm
except Exception as e:
logging.critical(f"CRITICAL ERROR: Failed to download/load model {filename}. Error: {e}", exc_info=True)
return None
# --- Helper: The Background AI Task ---
def run_generation_in_background(job_id: str, model_choice: str, prompt: str):
"""
This function runs in a separate thread.
It performs the long-running AI generation.
"""
global JOBS
try:
logging.info(f"Job {job_id}: Waiting to acquire LLM lock...")
with llm_lock:
logging.info(f"Job {job_id}: Lock acquired. Loading model.")
llm = get_llm_instance(model_choice)
if llm is None:
raise Exception("Model could not be loaded.")
JOBS[job_id]["status"] = "processing"
logging.info(f"Job {job_id}: Processing prompt...")
output = llm(
prompt,
max_tokens=512,
stop=["<|user|>", "<|endoftext|>", "user:"],
echo=False
)
generated_text = output["choices"][0]["text"].strip()
JOBS[job_id]["status"] = "complete"
JOBS[job_id]["result"] = generated_text
logging.info(f"Job {job_id}: Complete.")
except Exception as e:
logging.error(f"Job {job_id}: Failed. Error: {e}")
JOBS[job_id]["status"] = "error"
JOBS[job_id]["result"] = str(e)
finally:
logging.info(f"Job {job_id}: LLM lock released.")
# --- FastAPI App & Lifespan ---
@asynccontextmanager
async def lifespan(app: FastAPI):
logging.info("Server starting up... Pre-loading 'light' model.")
get_llm_instance("light")
logging.info("Server is ready and 'light' model is loaded.")
yield
logging.info("Server shutting down...")
llm_cache.clear()
app = FastAPI(lifespan=lifespan)
# --- !!! THIS IS THE CORS FIX !!! ---
# We are explicitly adding your GitHub Pages URL
origins = [
"https://fugthchat.github.io", # <-- YOUR LIVE SITE
"http://localhost", # For local testing
"http://127.0.0.1:5500" # For local testing
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- END OF CORS FIX ---
# --- API Data Models ---
class SubmitPrompt(BaseModel):
prompt: str
model_choice: str
# --- API Endpoints ---
@app.get("/")
def get_status():
"""This is the 'wake up' and status check endpoint."""
loaded_model = list(llm_cache.keys())[0] if llm_cache else "None"
return {
"status": "AI server is online",
"model_loaded": loaded_model,
"models": list(MODEL_MAP.keys())
}
@app.post("/submit_job")
async def submit_job(prompt: SubmitPrompt):
"""
Instantly accepts a job and starts it in the background.
"""
job_id = str(uuid.uuid4())
JOBS[job_id] = {"status": "pending", "result": None}
thread = threading.Thread(
target=run_generation_in_background,
args=(job_id, prompt.model_choice, prompt.prompt)
)
thread.start()
logging.info(f"Job {job_id} submitted.")
return {"job_id": job_id}
@app.get("/get_job_status/{job_id}")
async def get_job_status(job_id: str):
"""
Allows the frontend to check on a job.
"""
job = JOBS.get(job_id)
if job is None:
return JSONResponse(status_code=404, content={"error": "Job not found."})
if job["status"] in ["complete", "error"]:
result = job
del JOBS[job_id] # Clean up
return result
return {"status": job["status"]}