Spaces:
Sleeping
Sleeping
File size: 6,601 Bytes
1397957 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | from typing import Optional, List
from fastapi import APIRouter, HTTPException, Query, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import json
from ..session import Session, SessionInfo, SessionCreate, Message, SessionPrompt
from ..session.prompt import PromptInput
from ..core.storage import NotFoundError
from ..core.auth import AuthUser, optional_auth, require_auth
from ..core.quota import check_quota, increment_usage
from ..core.supabase import is_enabled as supabase_enabled
from ..provider import get_provider
router = APIRouter(prefix="/session", tags=["Session"])
class MessageRequest(BaseModel):
content: str
provider_id: Optional[str] = None
model_id: Optional[str] = None
system: Optional[str] = None
temperature: Optional[float] = None
max_tokens: Optional[int] = None
tools_enabled: bool = True
auto_continue: Optional[bool] = None
max_steps: Optional[int] = None
class SessionUpdate(BaseModel):
title: Optional[str] = None
class GenerateTitleRequest(BaseModel):
message: str
model_id: Optional[str] = None
@router.get("/", response_model=List[SessionInfo])
async def list_sessions(
limit: Optional[int] = Query(None, description="Maximum number of sessions to return"),
user: Optional[AuthUser] = Depends(optional_auth)
):
user_id = user.id if user else None
return await Session.list(limit, user_id)
@router.post("/", response_model=SessionInfo)
async def create_session(
data: Optional[SessionCreate] = None,
user: Optional[AuthUser] = Depends(optional_auth)
):
user_id = user.id if user else None
return await Session.create(data, user_id)
@router.get("/{session_id}", response_model=SessionInfo)
async def get_session(
session_id: str,
user: Optional[AuthUser] = Depends(optional_auth)
):
try:
user_id = user.id if user else None
return await Session.get(session_id, user_id)
except NotFoundError:
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
@router.patch("/{session_id}", response_model=SessionInfo)
async def update_session(
session_id: str,
updates: SessionUpdate,
user: Optional[AuthUser] = Depends(optional_auth)
):
try:
user_id = user.id if user else None
update_dict = {k: v for k, v in updates.model_dump().items() if v is not None}
return await Session.update(session_id, update_dict, user_id)
except NotFoundError:
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
@router.delete("/{session_id}")
async def delete_session(
session_id: str,
user: Optional[AuthUser] = Depends(optional_auth)
):
try:
user_id = user.id if user else None
await Session.delete(session_id, user_id)
return {"success": True}
except NotFoundError:
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
@router.get("/{session_id}/message")
async def list_messages(
session_id: str,
limit: Optional[int] = Query(None, description="Maximum number of messages to return"),
user: Optional[AuthUser] = Depends(optional_auth)
):
try:
user_id = user.id if user else None
await Session.get(session_id, user_id)
return await Message.list(session_id, limit, user_id)
except NotFoundError:
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
@router.post("/{session_id}/message")
async def send_message(
session_id: str,
request: MessageRequest,
user: AuthUser = Depends(check_quota) if supabase_enabled() else Depends(optional_auth)
):
user_id = user.id if user else None
try:
await Session.get(session_id, user_id)
except NotFoundError:
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
prompt_input = PromptInput(
content=request.content,
provider_id=request.provider_id,
model_id=request.model_id,
system=request.system,
temperature=request.temperature,
max_tokens=request.max_tokens,
tools_enabled=request.tools_enabled,
auto_continue=request.auto_continue,
max_steps=request.max_steps,
)
async def generate():
total_input = 0
total_output = 0
async for chunk in SessionPrompt.prompt(session_id, prompt_input, user_id):
if chunk.usage:
total_input += chunk.usage.get("input_tokens", 0)
total_output += chunk.usage.get("output_tokens", 0)
yield f"data: {json.dumps(chunk.model_dump())}\n\n"
if user_id and supabase_enabled():
await increment_usage(user_id, total_input, total_output)
yield "data: [DONE]\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
}
)
@router.post("/{session_id}/abort")
async def abort_session(session_id: str):
cancelled = SessionPrompt.cancel(session_id)
return {"cancelled": cancelled}
@router.post("/{session_id}/generate-title")
async def generate_title(
session_id: str,
request: GenerateTitleRequest,
user: Optional[AuthUser] = Depends(optional_auth)
):
"""์ฒซ ๋ฉ์์ง ๊ธฐ๋ฐ์ผ๋ก ์ธ์
์ ๋ชฉ ์์ฑ"""
user_id = user.id if user else None
# ์ธ์
์กด์ฌ ํ์ธ
try:
await Session.get(session_id, user_id)
except NotFoundError:
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
# LiteLLM Provider๋ก ์ ๋ชฉ ์์ฑ
model_id = request.model_id or "gemini/gemini-2.0-flash"
provider = get_provider("litellm")
if not provider:
raise HTTPException(status_code=503, detail="LiteLLM provider not available")
prompt = f"""๋ค์ ์ฌ์ฉ์ ๋ฉ์์ง๋ฅผ ๋ณด๊ณ ์งง์ ์ ๋ชฉ์ ์์ฑํด์ฃผ์ธ์.
์ ๋ชฉ์ 10์ ์ด๋ด, ๋ฐ์ดํ ์์ด ์ ๋ชฉ๋ง ์ถ๋ ฅ.
์ฌ์ฉ์ ๋ฉ์์ง: "{request.message[:200]}"
์ ๋ชฉ:"""
try:
result = await provider.complete(model_id, prompt, max_tokens=50)
title = result.strip()[:30]
# ์ธ์
์ ๋ชฉ ์
๋ฐ์ดํธ
await Session.update(session_id, {"title": title}, user_id)
return {"title": title}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to generate title: {str(e)}")
|