AUXteam's picture
Upload folder using huggingface_hub
1397957 verified
raw
history blame
6.6 kB
from typing import Optional, List
from fastapi import APIRouter, HTTPException, Query, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import json
from ..session import Session, SessionInfo, SessionCreate, Message, SessionPrompt
from ..session.prompt import PromptInput
from ..core.storage import NotFoundError
from ..core.auth import AuthUser, optional_auth, require_auth
from ..core.quota import check_quota, increment_usage
from ..core.supabase import is_enabled as supabase_enabled
from ..provider import get_provider
router = APIRouter(prefix="/session", tags=["Session"])
class MessageRequest(BaseModel):
content: str
provider_id: Optional[str] = None
model_id: Optional[str] = None
system: Optional[str] = None
temperature: Optional[float] = None
max_tokens: Optional[int] = None
tools_enabled: bool = True
auto_continue: Optional[bool] = None
max_steps: Optional[int] = None
class SessionUpdate(BaseModel):
title: Optional[str] = None
class GenerateTitleRequest(BaseModel):
message: str
model_id: Optional[str] = None
@router.get("/", response_model=List[SessionInfo])
async def list_sessions(
limit: Optional[int] = Query(None, description="Maximum number of sessions to return"),
user: Optional[AuthUser] = Depends(optional_auth)
):
user_id = user.id if user else None
return await Session.list(limit, user_id)
@router.post("/", response_model=SessionInfo)
async def create_session(
data: Optional[SessionCreate] = None,
user: Optional[AuthUser] = Depends(optional_auth)
):
user_id = user.id if user else None
return await Session.create(data, user_id)
@router.get("/{session_id}", response_model=SessionInfo)
async def get_session(
session_id: str,
user: Optional[AuthUser] = Depends(optional_auth)
):
try:
user_id = user.id if user else None
return await Session.get(session_id, user_id)
except NotFoundError:
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
@router.patch("/{session_id}", response_model=SessionInfo)
async def update_session(
session_id: str,
updates: SessionUpdate,
user: Optional[AuthUser] = Depends(optional_auth)
):
try:
user_id = user.id if user else None
update_dict = {k: v for k, v in updates.model_dump().items() if v is not None}
return await Session.update(session_id, update_dict, user_id)
except NotFoundError:
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
@router.delete("/{session_id}")
async def delete_session(
session_id: str,
user: Optional[AuthUser] = Depends(optional_auth)
):
try:
user_id = user.id if user else None
await Session.delete(session_id, user_id)
return {"success": True}
except NotFoundError:
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
@router.get("/{session_id}/message")
async def list_messages(
session_id: str,
limit: Optional[int] = Query(None, description="Maximum number of messages to return"),
user: Optional[AuthUser] = Depends(optional_auth)
):
try:
user_id = user.id if user else None
await Session.get(session_id, user_id)
return await Message.list(session_id, limit, user_id)
except NotFoundError:
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
@router.post("/{session_id}/message")
async def send_message(
session_id: str,
request: MessageRequest,
user: AuthUser = Depends(check_quota) if supabase_enabled() else Depends(optional_auth)
):
user_id = user.id if user else None
try:
await Session.get(session_id, user_id)
except NotFoundError:
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
prompt_input = PromptInput(
content=request.content,
provider_id=request.provider_id,
model_id=request.model_id,
system=request.system,
temperature=request.temperature,
max_tokens=request.max_tokens,
tools_enabled=request.tools_enabled,
auto_continue=request.auto_continue,
max_steps=request.max_steps,
)
async def generate():
total_input = 0
total_output = 0
async for chunk in SessionPrompt.prompt(session_id, prompt_input, user_id):
if chunk.usage:
total_input += chunk.usage.get("input_tokens", 0)
total_output += chunk.usage.get("output_tokens", 0)
yield f"data: {json.dumps(chunk.model_dump())}\n\n"
if user_id and supabase_enabled():
await increment_usage(user_id, total_input, total_output)
yield "data: [DONE]\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
}
)
@router.post("/{session_id}/abort")
async def abort_session(session_id: str):
cancelled = SessionPrompt.cancel(session_id)
return {"cancelled": cancelled}
@router.post("/{session_id}/generate-title")
async def generate_title(
session_id: str,
request: GenerateTitleRequest,
user: Optional[AuthUser] = Depends(optional_auth)
):
"""์ฒซ ๋ฉ”์‹œ์ง€ ๊ธฐ๋ฐ˜์œผ๋กœ ์„ธ์…˜ ์ œ๋ชฉ ์ƒ์„ฑ"""
user_id = user.id if user else None
# ์„ธ์…˜ ์กด์žฌ ํ™•์ธ
try:
await Session.get(session_id, user_id)
except NotFoundError:
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
# LiteLLM Provider๋กœ ์ œ๋ชฉ ์ƒ์„ฑ
model_id = request.model_id or "gemini/gemini-2.0-flash"
provider = get_provider("litellm")
if not provider:
raise HTTPException(status_code=503, detail="LiteLLM provider not available")
prompt = f"""๋‹ค์Œ ์‚ฌ์šฉ์ž ๋ฉ”์‹œ์ง€๋ฅผ ๋ณด๊ณ  ์งง์€ ์ œ๋ชฉ์„ ์ƒ์„ฑํ•ด์ฃผ์„ธ์š”.
์ œ๋ชฉ์€ 10์ž ์ด๋‚ด, ๋”ฐ์˜ดํ‘œ ์—†์ด ์ œ๋ชฉ๋งŒ ์ถœ๋ ฅ.
์‚ฌ์šฉ์ž ๋ฉ”์‹œ์ง€: "{request.message[:200]}"
์ œ๋ชฉ:"""
try:
result = await provider.complete(model_id, prompt, max_tokens=50)
title = result.strip()[:30]
# ์„ธ์…˜ ์ œ๋ชฉ ์—…๋ฐ์ดํŠธ
await Session.update(session_id, {"title": title}, user_id)
return {"title": title}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to generate title: {str(e)}")