|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
from typing import List, Tuple, Optional |
|
|
import os |
|
|
import logging |
|
|
from threading import Lock |
|
|
import psutil |
|
|
from huggingface_hub import hf_hub_download |
|
|
from llama_cpp import Llama |
|
|
from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType |
|
|
from llama_cpp_agent.providers import LlamaCppPythonProvider |
|
|
from llama_cpp_agent.chat_history import BasicChatHistory |
|
|
from llama_cpp_agent.chat_history.messages import Roles |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI( |
|
|
title="bartowski/Dolphin3.0-Llama3.2-1B-GGUF", |
|
|
description="bartowski/Dolphin3.0-Llama3.2-1B-GGUF.", |
|
|
version="1.0", |
|
|
docs_url="/docs", |
|
|
redoc_url=None, |
|
|
) |
|
|
|
|
|
MODEL_DIR = "./models" |
|
|
os.makedirs(MODEL_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
MODEL_REPO_ID = "bartowski/Dolphin3.0-Llama3.2-1B-GGUF" |
|
|
MODEL_FILENAME = "Dolphin3.0-Llama3.2-1B-Q4_K_M.gguf" |
|
|
|
|
|
|
|
|
llm = None |
|
|
llm_lock = Lock() |
|
|
|
|
|
def load_model(): |
|
|
global llm |
|
|
model_path = os.path.join(MODEL_DIR, MODEL_FILENAME) |
|
|
if not os.path.exists(model_path): |
|
|
raise FileNotFoundError(f"Model file not found at {model_path}") |
|
|
|
|
|
if llm is None: |
|
|
logger.info("Loading model into memory...") |
|
|
num_threads = psutil.cpu_count(logical=False) or os.cpu_count() |
|
|
logger.info(f"Using {num_threads} threads for inference.") |
|
|
try: |
|
|
llm = Llama( |
|
|
model_path=model_path, |
|
|
n_gpu_layers=0, |
|
|
n_batch=64, |
|
|
n_ctx=2048, |
|
|
n_threads=num_threads, |
|
|
n_threads_batch=num_threads, |
|
|
verbose=False |
|
|
) |
|
|
logger.info("Model loaded successfully.") |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading model: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
message: str |
|
|
history: Optional[List[Tuple[str, str]]] = [] |
|
|
system_prompt: Optional[str] = "be good" |
|
|
max_tokens: Optional[int] = 512 |
|
|
temperature: Optional[float] = 1 |
|
|
top_p: Optional[float] = 0.9 |
|
|
top_k: Optional[int] = 60 |
|
|
repeat_penalty: Optional[float] = 1.1 |
|
|
|
|
|
class ChatResponse(BaseModel): |
|
|
response: str |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
logger.info("Initializing model...") |
|
|
model_path = os.path.join(MODEL_DIR, MODEL_FILENAME) |
|
|
if not os.path.exists(model_path): |
|
|
logger.info(f"Model not found locally. Downloading {MODEL_FILENAME} from Hugging Face...") |
|
|
try: |
|
|
hf_hub_download( |
|
|
repo_id=MODEL_REPO_ID, |
|
|
filename=MODEL_FILENAME, |
|
|
local_dir=MODEL_DIR |
|
|
) |
|
|
logger.info("Model downloaded successfully.") |
|
|
except Exception as e: |
|
|
logger.error(f"Error downloading model: {e}") |
|
|
raise |
|
|
load_model() |
|
|
logger.info("Server ready.") |
|
|
|
|
|
@app.post("/chat", response_model=ChatResponse) |
|
|
def chat(request: ChatRequest): |
|
|
try: |
|
|
global llm |
|
|
if llm is None: |
|
|
raise HTTPException(status_code=503, detail="Model not initialized yet.") |
|
|
|
|
|
provider = LlamaCppPythonProvider(llm) |
|
|
agent = LlamaCppAgent( |
|
|
provider, |
|
|
system_prompt=request.system_prompt, |
|
|
predefined_messages_formatter_type=MessagesFormatterType.CHATML, |
|
|
debug_output=False |
|
|
) |
|
|
|
|
|
settings = provider.get_provider_default_settings() |
|
|
settings.temperature = request.temperature |
|
|
settings.top_k = request.top_k |
|
|
settings.top_p = request.top_p |
|
|
settings.max_tokens = request.max_tokens |
|
|
settings.repeat_penalty = request.repeat_penalty |
|
|
|
|
|
messages = BasicChatHistory() |
|
|
for user_msg, assistant_msg in request.history: |
|
|
messages.add_message({"role": Roles.user, "content": user_msg}) |
|
|
messages.add_message({"role": Roles.assistant, "content": assistant_msg}) |
|
|
|
|
|
logger.info("Generating response...") |
|
|
with llm_lock: |
|
|
response = agent.get_chat_response( |
|
|
request.message, |
|
|
llm_sampling_settings=settings, |
|
|
chat_history=messages, |
|
|
print_output=False, |
|
|
) |
|
|
logger.info(f"Response generated: {response[:100]}...") |
|
|
|
|
|
return {"response": response} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during chat: {e}", exc_info=True) |
|
|
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") |
|
|
|
|
|
@app.get("/health") |
|
|
def health_check(): |
|
|
return {"status": "healthy"} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") |