fugthchat commited on
Commit
c8d6e8a
·
verified ·
1 Parent(s): ca47c49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -94
app.py CHANGED
@@ -1,57 +1,134 @@
1
  import os
 
 
 
2
  from fastapi import FastAPI, Request
3
  from fastapi.responses import JSONResponse
4
  from pydantic import BaseModel
5
  from llama_cpp import Llama
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from huggingface_hub import hf_hub_download
8
- import logging
9
- import threading
10
  from contextlib import asynccontextmanager
11
 
12
- # Set up logging
13
  logging.basicConfig(level=logging.INFO)
14
 
15
- # --- NEW, SMARTER MODEL MAP ---
16
- # We are swapping to better storytelling models
17
  MODEL_MAP = {
18
  "light": {
19
  "repo_id": "microsoft/Phi-3-mini-4k-instruct-gguf",
20
- "filename": "Phi-3-mini-4k-instruct-q4.gguf" # 2.13 GB - MUCH smarter
21
  },
22
  "medium": {
23
  "repo_id": "TheBloke/DeepSeek-LLM-7B-Chat-GGUF",
24
- "filename": "deepseek-llm-7b-chat.Q4_K_M.gguf" # 4.08 GB - High Quality
25
  },
26
  "heavy": {
27
  "repo_id": "TheBloke/DeepSeek-LLM-7B-Chat-GGUF",
28
- "filename": "deepseek-llm-7b-chat.Q5_K_M.gguf" # 4.78 GB - Best Quality
29
  }
30
  }
31
 
32
- # --- GLOBAL CACHE & LOCK ---
33
  llm_cache = {}
34
- model_lock = threading.Lock()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # --- LIFESPAN FUNCTION ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  @asynccontextmanager
38
  async def lifespan(app: FastAPI):
39
- # This code runs ON STARTUP
40
- logging.info("Server starting up... Acquiring lock to pre-load 'light' model (Phi-3).")
41
- with model_lock:
42
- get_llm_instance("light")
43
- logging.info("Server is ready and 'light' model (Phi-3) is loaded.")
44
-
45
  yield
46
-
47
- # This code runs ON SHUTDOWN
48
  logging.info("Server shutting down...")
49
  llm_cache.clear()
50
 
51
- # Pass the lifespan function to FastAPI
52
  app = FastAPI(lifespan=lifespan)
53
-
54
- # --- CORS ---
55
  app.add_middleware(
56
  CORSMiddleware,
57
  allow_origins=["*"],
@@ -60,53 +137,15 @@ app.add_middleware(
60
  allow_headers=["*"],
61
  )
62
 
63
- # --- Helper Function to Load Model ---
64
- def get_llm_instance(choice: str) -> Llama:
65
- if choice not in MODEL_MAP:
66
- logging.error(f"Invalid model choice: {choice}")
67
- return None
68
-
69
- if choice in llm_cache:
70
- logging.info(f"Using cached model: {choice}")
71
- return llm_cache[choice]
72
-
73
- model_info = MODEL_MAP[choice]
74
- repo_id = model_info["repo_id"]
75
- filename = model_info["filename"]
76
-
77
- try:
78
- logging.info(f"Downloading model: {filename} from {repo_id}")
79
- model_path = hf_hub_download(repo_id=repo_id, filename=filename)
80
- logging.info(f"Model downloaded to: {model_path}")
81
-
82
- logging.info("Loading model into memory...")
83
- llm = Llama(
84
- model_path=model_path,
85
- n_ctx=4096,
86
- n_threads=2,
87
- n_gpu_layers=0,
88
- verbose=True
89
- )
90
-
91
- llm_cache.clear()
92
- llm_cache[choice] = llm
93
- logging.info(f"Model {choice} loaded successfully.")
94
- return llm
95
-
96
- except Exception as e:
97
- logging.critical(f"CRITICAL ERROR: Failed to download/load model {filename}. Error: {e}", exc_info=True)
98
- return None
99
-
100
- # --- API Data Models (SIMPLIFIED) ---
101
- class StoryPrompt(BaseModel):
102
  prompt: str
103
  model_choice: str
104
- feedback: str = ""
105
- story_memory: str = ""
106
 
107
  # --- API Endpoints ---
108
  @app.get("/")
109
  def get_status():
 
110
  loaded_model = list(llm_cache.keys())[0] if llm_cache else "None"
111
  return {
112
  "status": "AI server is online",
@@ -114,35 +153,42 @@ def get_status():
114
  "models": list(MODEL_MAP.keys())
115
  }
116
 
117
- @app.post("/generate")
118
- async def generate_story(prompt: StoryPrompt):
119
- logging.info("Request received. Waiting to acquire model lock...")
120
- with model_lock:
121
- logging.info("Lock acquired. Processing request.")
122
- try:
123
- llm = get_llm_instance(prompt.model_choice)
124
- if llm is None:
125
- logging.error(f"Failed to get model for choice: {prompt.model_choice}")
126
- return JSONResponse(status_code=503, content={"error": "The AI model is not available or failed to load."})
127
-
128
- # We trust the frontend to build the full prompt
129
- final_prompt = prompt.prompt
130
-
131
- logging.info(f"Generating with {prompt.model_choice}...")
132
- output = llm(
133
- final_prompt,
134
- max_tokens=512,
135
- stop=["<|user|>", "<|endoftext|>", "user:"], # Added stop tokens for Phi-3
136
- echo=False
137
- )
138
-
139
- generated_text = output["choices"][0]["text"].strip()
140
- logging.info("Generation complete.")
141
-
142
- return {"story_text": generated_text}
143
 
144
- except Exception as e:
145
- logging.error(f"An internal error occurred during generation: {e}", exc_info=True)
146
- return JSONResponse(status_code=500, content={"error": "An unexpected error occurred."})
147
- finally:
148
- logging.info("Releasing model lock.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import uuid
3
+ import threading
4
+ import logging
5
  from fastapi import FastAPI, Request
6
  from fastapi.responses import JSONResponse
7
  from pydantic import BaseModel
8
  from llama_cpp import Llama
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from huggingface_hub import hf_hub_download
 
 
11
  from contextlib import asynccontextmanager
12
 
13
+ # --- Setup ---
14
  logging.basicConfig(level=logging.INFO)
15
 
16
+ # --- Model Map (Using the smarter Phi-3) ---
 
17
  MODEL_MAP = {
18
  "light": {
19
  "repo_id": "microsoft/Phi-3-mini-4k-instruct-gguf",
20
+ "filename": "Phi-3-mini-4k-instruct-q4.gguf" # 2.13 GB
21
  },
22
  "medium": {
23
  "repo_id": "TheBloke/DeepSeek-LLM-7B-Chat-GGUF",
24
+ "filename": "deepseek-llm-7b-chat.Q4_K_M.gguf" # 4.08 GB
25
  },
26
  "heavy": {
27
  "repo_id": "TheBloke/DeepSeek-LLM-7B-Chat-GGUF",
28
+ "filename": "deepseek-llm-7b-chat.Q5_K_M.gguf" # 4.78 GB
29
  }
30
  }
31
 
32
+ # --- Global Caches & Locks ---
33
  llm_cache = {}
34
+ model_lock = threading.Lock() # Ensures only one model loads at a time
35
+ llm_lock = threading.Lock() # Ensures only one generation job runs at a time
36
+
37
+ # This is our new in-memory "database" for jobs
38
+ # It will hold the status and results of background tasks
39
+ JOBS = {}
40
+
41
+ # --- Helper: Load Model ---
42
+ def get_llm_instance(choice: str) -> Llama:
43
+ with model_lock:
44
+ if choice not in MODEL_MAP:
45
+ logging.error(f"Invalid model choice: {choice}")
46
+ return None
47
+
48
+ if choice in llm_cache:
49
+ logging.info(f"Using cached model: {choice}")
50
+ return llm_cache[choice]
51
+
52
+ model_info = MODEL_MAP[choice]
53
+ repo_id = model_info["repo_id"]
54
+ filename = model_info["filename"]
55
+
56
+ try:
57
+ logging.info(f"Downloading model: {filename} from {repo_id}")
58
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
59
+ logging.info(f"Model downloaded to: {model_path}")
60
+
61
+ logging.info("Loading model into memory...")
62
+ llm = Llama(
63
+ model_path=model_path,
64
+ n_ctx=4096,
65
+ n_threads=2,
66
+ n_gpu_layers=0,
67
+ verbose=True
68
+ )
69
+
70
+ llm_cache.clear()
71
+ llm_cache[choice] = llm
72
+ logging.info(f"Model {choice} loaded successfully.")
73
+ return llm
74
+
75
+ except Exception as e:
76
+ logging.critical(f"CRITICAL ERROR: Failed to download/load model {filename}. Error: {e}", exc_info=True)
77
+ return None
78
 
79
+ # --- Helper: The Background AI Task ---
80
+ def run_generation_in_background(job_id: str, model_choice: str, prompt: str):
81
+ """
82
+ This function runs in a separate thread.
83
+ It performs the long-running AI generation.
84
+ """
85
+ global JOBS
86
+ try:
87
+ # Acquire the lock. If another job is running, this will wait.
88
+ logging.info(f"Job {job_id}: Waiting to acquire LLM lock...")
89
+ with llm_lock:
90
+ logging.info(f"Job {job_id}: Lock acquired. Loading model.")
91
+ llm = get_llm_instance(model_choice)
92
+ if llm is None:
93
+ raise Exception("Model could not be loaded.")
94
+
95
+ JOBS[job_id]["status"] = "processing"
96
+ logging.info(f"Job {job_id}: Processing prompt...")
97
+
98
+ output = llm(
99
+ prompt,
100
+ max_tokens=512,
101
+ stop=["<|user|>", "<|endoftext|>", "user:"],
102
+ echo=False
103
+ )
104
+
105
+ generated_text = output["choices"][0]["text"].strip()
106
+
107
+ # Save the result and mark as complete
108
+ JOBS[job_id]["status"] = "complete"
109
+ JOBS[job_id]["result"] = generated_text
110
+ logging.info(f"Job {job_id}: Complete.")
111
+
112
+ except Exception as e:
113
+ logging.error(f"Job {job_id}: Failed. Error: {e}")
114
+ JOBS[job_id]["status"] = "error"
115
+ JOBS[job_id]["result"] = str(e)
116
+ finally:
117
+ # The lock is automatically released by the 'with' statement
118
+ logging.info(f"Job {job_id}: LLM lock released.")
119
+
120
+
121
+ # --- FastAPI App & Lifespan ---
122
  @asynccontextmanager
123
  async def lifespan(app: FastAPI):
124
+ logging.info("Server starting up... Pre-loading 'light' model.")
125
+ get_llm_instance("light")
126
+ logging.info("Server is ready and 'light' model is loaded.")
 
 
 
127
  yield
 
 
128
  logging.info("Server shutting down...")
129
  llm_cache.clear()
130
 
 
131
  app = FastAPI(lifespan=lifespan)
 
 
132
  app.add_middleware(
133
  CORSMiddleware,
134
  allow_origins=["*"],
 
137
  allow_headers=["*"],
138
  )
139
 
140
+ # --- API Data Models ---
141
+ class SubmitPrompt(BaseModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  prompt: str
143
  model_choice: str
 
 
144
 
145
  # --- API Endpoints ---
146
  @app.get("/")
147
  def get_status():
148
+ """This is the 'wake up' and status check endpoint."""
149
  loaded_model = list(llm_cache.keys())[0] if llm_cache else "None"
150
  return {
151
  "status": "AI server is online",
 
153
  "models": list(MODEL_MAP.keys())
154
  }
155
 
156
+ @app.post("/submit_job")
157
+ async def submit_job(prompt: SubmitPrompt):
158
+ """
159
+ NEW: Instantly accepts a job and starts it in the background.
160
+ """
161
+ job_id = str(uuid.uuid4())
162
+
163
+ # Store the job as "pending"
164
+ JOBS[job_id] = {"status": "pending", "result": None}
165
+
166
+ # Start the background thread
167
+ thread = threading.Thread(
168
+ target=run_generation_in_background,
169
+ args=(job_id, prompt.model_choice, prompt.prompt)
170
+ )
171
+ thread.start()
172
+
173
+ logging.info(f"Job {job_id} submitted.")
174
+ # Return the Job ID to the user immediately
175
+ return {"job_id": job_id}
 
 
 
 
 
 
176
 
177
+ @app.get("/get_job_status/{job_id}")
178
+ async def get_job_status(job_id: str):
179
+ """
180
+ NEW: Allows the frontend to check on a job.
181
+ """
182
+ job = JOBS.get(job_id)
183
+
184
+ if job is None:
185
+ return JSONResponse(status_code=404, content={"error": "Job not found."})
186
+
187
+ # If the job is done, send the result and remove it from memory
188
+ if job["status"] in ["complete", "error"]:
189
+ result = job
190
+ del JOBS[job_id] # Clean up
191
+ return result
192
+
193
+ # If not done, just send the current status
194
+ return {"status": job["status"]}