|
|
import os |
|
|
import uuid |
|
|
import threading |
|
|
import logging |
|
|
from fastapi import FastAPI, Request |
|
|
from fastapi.responses import JSONResponse |
|
|
from pydantic import BaseModel |
|
|
from llama_cpp import Llama |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from huggingface_hub import hf_hub_download |
|
|
from contextlib import asynccontextmanager |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
|
MODEL_MAP = { |
|
|
"light": { |
|
|
"repo_id": "microsoft/Phi-3-mini-4k-instruct-gguf", |
|
|
"filename": "Phi-3-mini-4k-instruct-q4.gguf" |
|
|
}, |
|
|
"medium": { |
|
|
"repo_id": "TheBloke/DeepSeek-LLM-7B-Chat-GGUF", |
|
|
"filename": "deepseek-llm-7b-chat.Q4_K_M.gguf" |
|
|
}, |
|
|
"heavy": { |
|
|
"repo_id": "TheBloke/DeepSeek-LLM-7B-Chat-GGUF", |
|
|
"filename": "deepseek-llm-7b-chat.Q5_K_M.gguf" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
llm_cache = {} |
|
|
model_lock = threading.Lock() |
|
|
llm_lock = threading.Lock() |
|
|
JOBS = {} |
|
|
|
|
|
|
|
|
def get_llm_instance(choice: str) -> Llama: |
|
|
with model_lock: |
|
|
if choice not in MODEL_MAP: |
|
|
logging.error(f"Invalid model choice: {choice}") |
|
|
return None |
|
|
if choice in llm_cache: |
|
|
logging.info(f"Using cached model: {choice}") |
|
|
return llm_cache[choice] |
|
|
|
|
|
model_info = MODEL_MAP[choice] |
|
|
repo_id = model_info["repo_id"] |
|
|
filename = model_info["filename"] |
|
|
|
|
|
try: |
|
|
logging.info(f"Downloading model: {filename} from {repo_id}") |
|
|
model_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
|
logging.info(f"Model downloaded to: {model_path}") |
|
|
|
|
|
logging.info("Loading model into memory...") |
|
|
llm = Llama( |
|
|
model_path=model_path, |
|
|
n_ctx=4096, |
|
|
n_threads=2, |
|
|
n_gpu_layers=0, |
|
|
verbose=True |
|
|
) |
|
|
|
|
|
llm_cache.clear() |
|
|
llm_cache[choice] = llm |
|
|
logging.info(f"Model {choice} loaded successfully.") |
|
|
return llm |
|
|
|
|
|
except Exception as e: |
|
|
logging.critical(f"CRITICAL ERROR: Failed to download/load model {filename}. Error: {e}", exc_info=True) |
|
|
return None |
|
|
|
|
|
|
|
|
def run_generation_in_background(job_id: str, model_choice: str, prompt: str): |
|
|
""" |
|
|
This function runs in a separate thread. |
|
|
It performs the long-running AI generation. |
|
|
""" |
|
|
global JOBS |
|
|
try: |
|
|
logging.info(f"Job {job_id}: Waiting to acquire LLM lock...") |
|
|
with llm_lock: |
|
|
logging.info(f"Job {job_id}: Lock acquired. Loading model.") |
|
|
llm = get_llm_instance(model_choice) |
|
|
if llm is None: |
|
|
raise Exception("Model could not be loaded.") |
|
|
|
|
|
JOBS[job_id]["status"] = "processing" |
|
|
logging.info(f"Job {job_id}: Processing prompt...") |
|
|
|
|
|
output = llm( |
|
|
prompt, |
|
|
max_tokens=512, |
|
|
stop=["<|user|>", "<|endoftext|>", "user:"], |
|
|
echo=False |
|
|
) |
|
|
|
|
|
generated_text = output["choices"][0]["text"].strip() |
|
|
|
|
|
JOBS[job_id]["status"] = "complete" |
|
|
JOBS[job_id]["result"] = generated_text |
|
|
logging.info(f"Job {job_id}: Complete.") |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Job {job_id}: Failed. Error: {e}") |
|
|
JOBS[job_id]["status"] = "error" |
|
|
JOBS[job_id]["result"] = str(e) |
|
|
finally: |
|
|
logging.info(f"Job {job_id}: LLM lock released.") |
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
logging.info("Server starting up... Pre-loading 'light' model.") |
|
|
get_llm_instance("light") |
|
|
logging.info("Server is ready and 'light' model is loaded.") |
|
|
yield |
|
|
logging.info("Server shutting down...") |
|
|
llm_cache.clear() |
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
|
|
|
|
|
|
origins = [ |
|
|
"https://fugthchat.github.io", |
|
|
"http://localhost", |
|
|
"http://127.0.0.1:5500" |
|
|
] |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=origins, |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class SubmitPrompt(BaseModel): |
|
|
prompt: str |
|
|
model_choice: str |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def get_status(): |
|
|
"""This is the 'wake up' and status check endpoint.""" |
|
|
loaded_model = list(llm_cache.keys())[0] if llm_cache else "None" |
|
|
return { |
|
|
"status": "AI server is online", |
|
|
"model_loaded": loaded_model, |
|
|
"models": list(MODEL_MAP.keys()) |
|
|
} |
|
|
|
|
|
@app.post("/submit_job") |
|
|
async def submit_job(prompt: SubmitPrompt): |
|
|
""" |
|
|
Instantly accepts a job and starts it in the background. |
|
|
""" |
|
|
job_id = str(uuid.uuid4()) |
|
|
JOBS[job_id] = {"status": "pending", "result": None} |
|
|
|
|
|
thread = threading.Thread( |
|
|
target=run_generation_in_background, |
|
|
args=(job_id, prompt.model_choice, prompt.prompt) |
|
|
) |
|
|
thread.start() |
|
|
|
|
|
logging.info(f"Job {job_id} submitted.") |
|
|
return {"job_id": job_id} |
|
|
|
|
|
@app.get("/get_job_status/{job_id}") |
|
|
async def get_job_status(job_id: str): |
|
|
""" |
|
|
Allows the frontend to check on a job. |
|
|
""" |
|
|
job = JOBS.get(job_id) |
|
|
|
|
|
if job is None: |
|
|
return JSONResponse(status_code=404, content={"error": "Job not found."}) |
|
|
|
|
|
if job["status"] in ["complete", "error"]: |
|
|
result = job |
|
|
del JOBS[job_id] |
|
|
return result |
|
|
|
|
|
return {"status": job["status"]} |
|
|
|