AUXteam's picture
Upload folder using huggingface_hub
1397957 verified
raw
history blame
2.91 kB
from typing import Dict, Any, List, Optional, Callable, Awaitable, Protocol, runtime_checkable
from pydantic import BaseModel
from abc import ABC, abstractmethod
from datetime import datetime
class ToolContext(BaseModel):
session_id: str
message_id: str
tool_call_id: Optional[str] = None
agent: str = "default"
class ToolResult(BaseModel):
title: str
output: str
metadata: Dict[str, Any] = {}
truncated: bool = False
original_length: int = 0
@runtime_checkable
class Tool(Protocol):
@property
def id(self) -> str: ...
@property
def description(self) -> str: ...
@property
def parameters(self) -> Dict[str, Any]: ...
async def execute(self, args: Dict[str, Any], ctx: ToolContext) -> ToolResult: ...
class BaseTool(ABC):
MAX_OUTPUT_LENGTH = 50000
def __init__(self):
self.status: str = "pending"
self.time_start: Optional[datetime] = None
self.time_end: Optional[datetime] = None
@property
@abstractmethod
def id(self) -> str:
pass
@property
@abstractmethod
def description(self) -> str:
pass
@property
@abstractmethod
def parameters(self) -> Dict[str, Any]:
pass
@abstractmethod
async def execute(self, args: Dict[str, Any], ctx: ToolContext) -> ToolResult:
pass
def get_schema(self) -> Dict[str, Any]:
return {
"name": self.id,
"description": self.description,
"parameters": self.parameters
}
def truncate_output(self, output: str) -> str:
"""์ถœ๋ ฅ์ด MAX_OUTPUT_LENGTH๋ฅผ ์ดˆ๊ณผํ•˜๋ฉด ์ž๋ฅด๊ณ  ๋ฉ”์‹œ์ง€ ์ถ”๊ฐ€"""
if len(output) <= self.MAX_OUTPUT_LENGTH:
return output
truncated = output[:self.MAX_OUTPUT_LENGTH]
truncated += "\n\n[Output truncated...]"
return truncated
def update_status(self, status: str) -> None:
"""๋„๊ตฌ ์ƒํƒœ ์—…๋ฐ์ดํŠธ (pending, running, completed, error)"""
self.status = status
if status == "running" and self.time_start is None:
self.time_start = datetime.now()
elif status in ("completed", "error") and self.time_end is None:
self.time_end = datetime.now()
from .registry import get_registry
def register_tool(tool: BaseTool) -> None:
"""๋„๊ตฌ ๋“ฑ๋ก (ํ˜ธํ™˜์„ฑ ํ•จ์ˆ˜ - ToolRegistry ์‚ฌ์šฉ)"""
get_registry().register(tool)
def get_tool(tool_id: str) -> Optional[BaseTool]:
"""๋„๊ตฌ ์กฐํšŒ (ํ˜ธํ™˜์„ฑ ํ•จ์ˆ˜ - ToolRegistry ์‚ฌ์šฉ)"""
return get_registry().get(tool_id)
def list_tools() -> List[BaseTool]:
"""๋„๊ตฌ ๋ชฉ๋ก (ํ˜ธํ™˜์„ฑ ํ•จ์ˆ˜ - ToolRegistry ์‚ฌ์šฉ)"""
return get_registry().list()
def get_tools_schema() -> List[Dict[str, Any]]:
"""๋„๊ตฌ ์Šคํ‚ค๋งˆ ๋ชฉ๋ก (ํ˜ธํ™˜์„ฑ ํ•จ์ˆ˜ - ToolRegistry ์‚ฌ์šฉ)"""
return get_registry().get_schema()