prox / main.py
EmoCube's picture
Update main.py
ab9cd81 verified
raw
history blame
2.7 kB
from fastapi import FastAPI, Request
from pydantic import BaseModel
from typing import List, Optional, Literal
from gradio_client import Client
import uvicorn
import time
import uuid
# ==== Инициализация Gradio Client ====
gr_client = Client("Nymbo/Serverless-TextGen-Hub")
# ==== Функция обращения к нейросети ====
def ask(user_prompt, system_prompt, model):
result = gr_client.predict(
history=[[user_prompt, None]],
system_msg=system_prompt,
max_tokens=512,
temperature=0.7,
top_p=0.95,
freq_penalty=0,
seed=-1,
custom_model=model,
search_term="",
selected_model=model,
api_name="/bot"
)
return result
# ==== FastAPI приложение ====
app = FastAPI()
# ==== Модели запросов/ответов ====
class Message(BaseModel):
role: Literal["user", "assistant", "system"]
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.95
max_tokens: Optional[int] = 512
# остальные параметры можно добавить при необходимости
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatRequest):
# Извлекаем последнее сообщение от пользователя
user_msg = next((m.content for m in reversed(request.messages) if m.role == "user"), None)
system_msg = next((m.content for m in request.messages if m.role == "system"), "You are a helpful AI assistant.")
if not user_msg:
return {"error": "User message not found."}
# Получаем ответ от модели
assistant_reply = ask(user_msg, system_msg, request.model)
# Формируем ответ в стиле OpenAI API
response = {
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": assistant_reply
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 0, # Можно вычислить при необходимости
"completion_tokens": 0,
"total_tokens": 0
}
}
return response
# ==== Запуск сервера ====
if __name__ == "__main__":
uvicorn.run("local_openai_server:app", host="0.0.0.0", port=7860, reload=True)