Update app.py
Browse files
app.py
CHANGED
|
@@ -7,7 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
import logging
|
| 9 |
import threading
|
| 10 |
-
from contextlib import asynccontextmanager
|
| 11 |
|
| 12 |
# Set up logging
|
| 13 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -32,8 +32,7 @@ MODEL_MAP = {
|
|
| 32 |
llm_cache = {}
|
| 33 |
model_lock = threading.Lock()
|
| 34 |
|
| 35 |
-
# ---
|
| 36 |
-
# This replaces the old @app.on_event("startup")
|
| 37 |
@asynccontextmanager
|
| 38 |
async def lifespan(app: FastAPI):
|
| 39 |
# This code runs ON STARTUP
|
|
@@ -60,7 +59,7 @@ app.add_middleware(
|
|
| 60 |
allow_headers=["*"],
|
| 61 |
)
|
| 62 |
|
| 63 |
-
# --- Helper Function to Load Model
|
| 64 |
def get_llm_instance(choice: str) -> Llama:
|
| 65 |
if choice not in MODEL_MAP:
|
| 66 |
logging.error(f"Invalid model choice: {choice}")
|
|
@@ -97,12 +96,18 @@ def get_llm_instance(choice: str) -> Llama:
|
|
| 97 |
logging.critical(f"CRITICAL ERROR: Failed to download/load model {filename}. Error: {e}", exc_info=True)
|
| 98 |
return None
|
| 99 |
|
| 100 |
-
# --- API Data Models (
|
|
|
|
|
|
|
| 101 |
class StoryPrompt(BaseModel):
|
| 102 |
prompt: str
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
feedback: str
|
| 104 |
story_memory: str
|
| 105 |
-
|
| 106 |
|
| 107 |
# --- API Endpoints ---
|
| 108 |
|
|
@@ -117,6 +122,10 @@ def get_status():
|
|
| 117 |
|
| 118 |
@app.post("/generate")
|
| 119 |
async def generate_story(prompt: StoryPrompt):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
logging.info("Request received. Waiting to acquire model lock...")
|
| 121 |
with model_lock:
|
| 122 |
logging.info("Lock acquired. Processing request.")
|
|
@@ -126,20 +135,10 @@ async def generate_story(prompt: StoryPrompt):
|
|
| 126 |
logging.error(f"Failed to get model for choice: {prompt.model_choice}")
|
| 127 |
return JSONResponse(status_code=503, content={"error": "The AI model is not available or failed to load."})
|
| 128 |
|
| 129 |
-
#
|
| 130 |
-
#
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
{prompt.story_memory}
|
| 134 |
-
|
| 135 |
-
Here is the part I just wrote or want tocontinue from:
|
| 136 |
-
{prompt.prompt}
|
| 137 |
-
|
| 138 |
-
Please use this feedback to guide the next chapter:
|
| 139 |
-
{prompt.feedback}
|
| 140 |
-
|
| 141 |
-
Generate the next part of the story.<|endoftext|>
|
| 142 |
-
<|assistant|>"""
|
| 143 |
|
| 144 |
logging.info(f"Generating with {prompt.model_choice}...")
|
| 145 |
output = llm(
|
|
|
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
import logging
|
| 9 |
import threading
|
| 10 |
+
from contextlib import asynccontextmanager
|
| 11 |
|
| 12 |
# Set up logging
|
| 13 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 32 |
llm_cache = {}
|
| 33 |
model_lock = threading.Lock()
|
| 34 |
|
| 35 |
+
# --- LIFESPAN FUNCTION ---
|
|
|
|
| 36 |
@asynccontextmanager
|
| 37 |
async def lifespan(app: FastAPI):
|
| 38 |
# This code runs ON STARTUP
|
|
|
|
| 59 |
allow_headers=["*"],
|
| 60 |
)
|
| 61 |
|
| 62 |
+
# --- Helper Function to Load Model ---
|
| 63 |
def get_llm_instance(choice: str) -> Llama:
|
| 64 |
if choice not in MODEL_MAP:
|
| 65 |
logging.error(f"Invalid model choice: {choice}")
|
|
|
|
| 96 |
logging.critical(f"CRITICAL ERROR: Failed to download/load model {filename}. Error: {e}", exc_info=True)
|
| 97 |
return None
|
| 98 |
|
| 99 |
+
# --- API Data Models (SIMPLIFIED) ---
|
| 100 |
+
# We only need the full prompt and the model choice
|
| 101 |
+
# The frontend will build the prompt.
|
| 102 |
class StoryPrompt(BaseModel):
|
| 103 |
prompt: str
|
| 104 |
+
model_choice: str
|
| 105 |
+
|
| 106 |
+
# These are no longer used by the backend, but we include them
|
| 107 |
+
# so the frontend's request doesn't fail
|
| 108 |
feedback: str
|
| 109 |
story_memory: str
|
| 110 |
+
|
| 111 |
|
| 112 |
# --- API Endpoints ---
|
| 113 |
|
|
|
|
| 122 |
|
| 123 |
@app.post("/generate")
|
| 124 |
async def generate_story(prompt: StoryPrompt):
|
| 125 |
+
"""
|
| 126 |
+
Main generation endpoint.
|
| 127 |
+
This is now much simpler.
|
| 128 |
+
"""
|
| 129 |
logging.info("Request received. Waiting to acquire model lock...")
|
| 130 |
with model_lock:
|
| 131 |
logging.info("Lock acquired. Processing request.")
|
|
|
|
| 135 |
logging.error(f"Failed to get model for choice: {prompt.model_choice}")
|
| 136 |
return JSONResponse(status_code=503, content={"error": "The AI model is not available or failed to load."})
|
| 137 |
|
| 138 |
+
# --- THIS IS THE FIX ---
|
| 139 |
+
# We trust the frontend and use the prompt exactly as it was sent.
|
| 140 |
+
# We no longer re-format it.
|
| 141 |
+
final_prompt = prompt.prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
logging.info(f"Generating with {prompt.model_choice}...")
|
| 144 |
output = llm(
|