import os import uuid import threading import logging from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from pydantic import BaseModel from llama_cpp import Llama from fastapi.middleware.cors import CORSMiddleware from huggingface_hub import hf_hub_download from contextlib import asynccontextmanager # --- Setup --- logging.basicConfig(level=logging.INFO) # --- MODEL MAP (Using the smarter Phi-3) --- MODEL_MAP = { "light": { "repo_id": "microsoft/Phi-3-mini-4k-instruct-gguf", "filename": "Phi-3-mini-4k-instruct-q4.gguf" # 2.13 GB }, "medium": { "repo_id": "TheBloke/DeepSeek-LLM-7B-Chat-GGUF", "filename": "deepseek-llm-7b-chat.Q4_K_M.gguf" # 4.08 GB }, "heavy": { "repo_id": "TheBloke/DeepSeek-LLM-7B-Chat-GGUF", "filename": "deepseek-llm-7b-chat.Q5_K_M.gguf" # 4.78 GB } } # --- GLOBAL CACHE & LOCKS --- llm_cache = {} model_lock = threading.Lock() # Ensures only one model loads at a time llm_lock = threading.Lock() # Ensures only one generation job runs at a time JOBS = {} # Our in-memory "database" for background jobs # --- Helper: Load Model --- def get_llm_instance(choice: str) -> Llama: with model_lock: if choice not in MODEL_MAP: logging.error(f"Invalid model choice: {choice}") return None if choice in llm_cache: logging.info(f"Using cached model: {choice}") return llm_cache[choice] model_info = MODEL_MAP[choice] repo_id = model_info["repo_id"] filename = model_info["filename"] try: logging.info(f"Downloading model: {filename} from {repo_id}") model_path = hf_hub_download(repo_id=repo_id, filename=filename) logging.info(f"Model downloaded to: {model_path}") logging.info("Loading model into memory...") llm = Llama( model_path=model_path, n_ctx=4096, n_threads=2, n_gpu_layers=0, verbose=True ) llm_cache.clear() llm_cache[choice] = llm logging.info(f"Model {choice} loaded successfully.") return llm except Exception as e: logging.critical(f"CRITICAL ERROR: Failed to download/load model {filename}. Error: {e}", exc_info=True) return None # --- Helper: The Background AI Task --- def run_generation_in_background(job_id: str, model_choice: str, prompt: str): """ This function runs in a separate thread. It performs the long-running AI generation. """ global JOBS try: logging.info(f"Job {job_id}: Waiting to acquire LLM lock...") with llm_lock: logging.info(f"Job {job_id}: Lock acquired. Loading model.") llm = get_llm_instance(model_choice) if llm is None: raise Exception("Model could not be loaded.") JOBS[job_id]["status"] = "processing" logging.info(f"Job {job_id}: Processing prompt...") output = llm( prompt, max_tokens=512, stop=["<|user|>", "<|endoftext|>", "user:"], echo=False ) generated_text = output["choices"][0]["text"].strip() JOBS[job_id]["status"] = "complete" JOBS[job_id]["result"] = generated_text logging.info(f"Job {job_id}: Complete.") except Exception as e: logging.error(f"Job {job_id}: Failed. Error: {e}") JOBS[job_id]["status"] = "error" JOBS[job_id]["result"] = str(e) finally: logging.info(f"Job {job_id}: LLM lock released.") # --- FastAPI App & Lifespan --- @asynccontextmanager async def lifespan(app: FastAPI): logging.info("Server starting up... Pre-loading 'light' model.") get_llm_instance("light") logging.info("Server is ready and 'light' model is loaded.") yield logging.info("Server shutting down...") llm_cache.clear() app = FastAPI(lifespan=lifespan) # --- !!! THIS IS THE CORS FIX !!! --- # We are explicitly adding your GitHub Pages URL origins = [ "https://fugthchat.github.io", # <-- YOUR LIVE SITE "http://localhost", # For local testing "http://127.0.0.1:5500" # For local testing ] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- END OF CORS FIX --- # --- API Data Models --- class SubmitPrompt(BaseModel): prompt: str model_choice: str # --- API Endpoints --- @app.get("/") def get_status(): """This is the 'wake up' and status check endpoint.""" loaded_model = list(llm_cache.keys())[0] if llm_cache else "None" return { "status": "AI server is online", "model_loaded": loaded_model, "models": list(MODEL_MAP.keys()) } @app.post("/submit_job") async def submit_job(prompt: SubmitPrompt): """ Instantly accepts a job and starts it in the background. """ job_id = str(uuid.uuid4()) JOBS[job_id] = {"status": "pending", "result": None} thread = threading.Thread( target=run_generation_in_background, args=(job_id, prompt.model_choice, prompt.prompt) ) thread.start() logging.info(f"Job {job_id} submitted.") return {"job_id": job_id} @app.get("/get_job_status/{job_id}") async def get_job_status(job_id: str): """ Allows the frontend to check on a job. """ job = JOBS.get(job_id) if job is None: return JSONResponse(status_code=404, content={"error": "Job not found."}) if job["status"] in ["complete", "error"]: result = job del JOBS[job_id] # Clean up return result return {"status": job["status"]}