|
|
from fastapi import FastAPI, Request |
|
|
from pydantic import BaseModel |
|
|
from typing import List, Optional, Literal |
|
|
from gradio_client import Client |
|
|
import uvicorn |
|
|
import time |
|
|
import uuid |
|
|
|
|
|
|
|
|
gr_client = Client("Nymbo/Serverless-TextGen-Hub") |
|
|
|
|
|
|
|
|
def ask(user_prompt, system_prompt, model): |
|
|
result = gr_client.predict( |
|
|
history=[[user_prompt, None]], |
|
|
system_msg=system_prompt, |
|
|
max_tokens=512, |
|
|
temperature=0.7, |
|
|
top_p=0.95, |
|
|
freq_penalty=0, |
|
|
seed=-1, |
|
|
custom_model=model, |
|
|
search_term="", |
|
|
selected_model=model, |
|
|
api_name="/bot" |
|
|
) |
|
|
return result |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
class Message(BaseModel): |
|
|
role: Literal["user", "assistant", "system"] |
|
|
content: str |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
model: str |
|
|
messages: List[Message] |
|
|
temperature: Optional[float] = 0.7 |
|
|
top_p: Optional[float] = 0.95 |
|
|
max_tokens: Optional[int] = 512 |
|
|
|
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
|
async def chat_completion(request: ChatRequest): |
|
|
|
|
|
user_msg = next((m.content for m in reversed(request.messages) if m.role == "user"), None) |
|
|
system_msg = next((m.content for m in request.messages if m.role == "system"), "You are a helpful AI assistant.") |
|
|
|
|
|
if not user_msg: |
|
|
return {"error": "User message not found."} |
|
|
|
|
|
|
|
|
assistant_reply = ask(user_msg, system_msg, request.model) |
|
|
|
|
|
|
|
|
response = { |
|
|
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}", |
|
|
"object": "chat.completion", |
|
|
"created": int(time.time()), |
|
|
"model": request.model, |
|
|
"choices": [ |
|
|
{ |
|
|
"index": 0, |
|
|
"message": { |
|
|
"role": "assistant", |
|
|
"content": assistant_reply |
|
|
}, |
|
|
"finish_reason": "stop" |
|
|
} |
|
|
], |
|
|
"usage": { |
|
|
"prompt_tokens": 0, |
|
|
"completion_tokens": 0, |
|
|
"total_tokens": 0 |
|
|
} |
|
|
} |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run("local_openai_server:app", host="0.0.0.0", port=7860, reload=True) |
|
|
|