File size: 6,007 Bytes
a7b318c
4b99165
 
 
5ae1757
 
 
 
 
 
012c41f
5ae1757
4b99165
5ae1757
 
a867634
5ae1757
 
ca47c49
c8d6e8a
5ae1757
 
ca47c49
c8d6e8a
5ae1757
ca47c49
 
c8d6e8a
5ae1757
462067d
 
4b99165
e0011a4
4b99165
 
 
a867634
4b99165
c8d6e8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0011a4
4b99165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ae1757
012c41f
 
5ae1757
 
 
4b99165
5ae1757
 
 
 
e0011a4
5ae1757
 
4b99165
 
c8d6e8a
4b99165
c8d6e8a
4b99165
 
 
 
 
 
 
 
 
 
 
a867634
4b99165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import uuid
import threading
import logging
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from llama_cpp import Llama
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import hf_hub_download
from contextlib import asynccontextmanager

# --- Setup ---
logging.basicConfig(level=logging.INFO)

# --- MODEL MAP (Using the smarter Phi-3) ---
MODEL_MAP = {
    "light": {
        "repo_id": "microsoft/Phi-3-mini-4k-instruct-gguf",
        "filename": "Phi-3-mini-4k-instruct-q4.gguf" # 2.13 GB
    },
    "medium": {
        "repo_id": "TheBloke/DeepSeek-LLM-7B-Chat-GGUF",
        "filename": "deepseek-llm-7b-chat.Q4_K_M.gguf" # 4.08 GB
    },
    "heavy": {
        "repo_id": "TheBloke/DeepSeek-LLM-7B-Chat-GGUF",
        "filename": "deepseek-llm-7b-chat.Q5_K_M.gguf" # 4.78 GB
    }
}

# --- GLOBAL CACHE & LOCKS ---
llm_cache = {} 
model_lock = threading.Lock() # Ensures only one model loads at a time
llm_lock = threading.Lock() # Ensures only one generation job runs at a time
JOBS = {} # Our in-memory "database" for background jobs

# --- Helper: Load Model ---
def get_llm_instance(choice: str) -> Llama:
    with model_lock:
        if choice not in MODEL_MAP:
            logging.error(f"Invalid model choice: {choice}")
            return None
        if choice in llm_cache:
            logging.info(f"Using cached model: {choice}")
            return llm_cache[choice]

        model_info = MODEL_MAP[choice]
        repo_id = model_info["repo_id"]
        filename = model_info["filename"]
        
        try:
            logging.info(f"Downloading model: {filename} from {repo_id}")
            model_path = hf_hub_download(repo_id=repo_id, filename=filename)
            logging.info(f"Model downloaded to: {model_path}")
            
            logging.info("Loading model into memory...")
            llm = Llama(
                model_path=model_path,
                n_ctx=4096,       
                n_threads=2,      
                n_gpu_layers=0,   
                verbose=True
            )
            
            llm_cache.clear() 
            llm_cache[choice] = llm
            logging.info(f"Model {choice} loaded successfully.")
            return llm
            
        except Exception as e:
            logging.critical(f"CRITICAL ERROR: Failed to download/load model {filename}. Error: {e}", exc_info=True)
            return None

# --- Helper: The Background AI Task ---
def run_generation_in_background(job_id: str, model_choice: str, prompt: str):
    """
    This function runs in a separate thread.
    It performs the long-running AI generation.
    """
    global JOBS
    try:
        logging.info(f"Job {job_id}: Waiting to acquire LLM lock...")
        with llm_lock:
            logging.info(f"Job {job_id}: Lock acquired. Loading model.")
            llm = get_llm_instance(model_choice)
            if llm is None:
                raise Exception("Model could not be loaded.")
            
            JOBS[job_id]["status"] = "processing"
            logging.info(f"Job {job_id}: Processing prompt...")
            
            output = llm(
                prompt,
                max_tokens=512,
                stop=["<|user|>", "<|endoftext|>", "user:"],
                echo=False
            )
            
            generated_text = output["choices"][0]["text"].strip()
            
            JOBS[job_id]["status"] = "complete"
            JOBS[job_id]["result"] = generated_text
            logging.info(f"Job {job_id}: Complete.")
            
    except Exception as e:
        logging.error(f"Job {job_id}: Failed. Error: {e}")
        JOBS[job_id]["status"] = "error"
        JOBS[job_id]["result"] = str(e)
    finally:
        logging.info(f"Job {job_id}: LLM lock released.")


# --- FastAPI App & Lifespan ---
@asynccontextmanager
async def lifespan(app: FastAPI):
    logging.info("Server starting up... Pre-loading 'light' model.")
    get_llm_instance("light")
    logging.info("Server is ready and 'light' model is loaded.")
    yield
    logging.info("Server shutting down...")
    llm_cache.clear()

app = FastAPI(lifespan=lifespan)

# --- !!! THIS IS THE CORS FIX !!! ---
# We are explicitly adding your GitHub Pages URL
origins = [
    "https://fugthchat.github.io", # <-- YOUR LIVE SITE
    "http://localhost",           # For local testing
    "http://127.0.0.1:5500"       # For local testing
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins, 
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
# --- END OF CORS FIX ---

# --- API Data Models ---
class SubmitPrompt(BaseModel):
    prompt: str
    model_choice: str

# --- API Endpoints ---
@app.get("/")
def get_status():
    """This is the 'wake up' and status check endpoint."""
    loaded_model = list(llm_cache.keys())[0] if llm_cache else "None"
    return {
        "status": "AI server is online",
        "model_loaded": loaded_model,
        "models": list(MODEL_MAP.keys()) 
    }

@app.post("/submit_job")
async def submit_job(prompt: SubmitPrompt):
    """
    Instantly accepts a job and starts it in the background.
    """
    job_id = str(uuid.uuid4())
    JOBS[job_id] = {"status": "pending", "result": None}
    
    thread = threading.Thread(
        target=run_generation_in_background,
        args=(job_id, prompt.model_choice, prompt.prompt)
    )
    thread.start()
    
    logging.info(f"Job {job_id} submitted.")
    return {"job_id": job_id}

@app.get("/get_job_status/{job_id}")
async def get_job_status(job_id: str):
    """
    Allows the frontend to check on a job.
    """
    job = JOBS.get(job_id)
    
    if job is None:
        return JSONResponse(status_code=404, content={"error": "Job not found."})
    
    if job["status"] in ["complete", "error"]:
        result = job
        del JOBS[job_id] # Clean up
        return result
        
    return {"status": job["status"]}