diff --git a/.hfignore b/.hfignore new file mode 100644 index 0000000000000000000000000000000000000000..77b4e0ab16363f5e026e1ec7c015d94ac8719adb --- /dev/null +++ b/.hfignore @@ -0,0 +1,9 @@ +.git/ +.github/ +docs/ +sql/ +test_gemini_tools.py +.env.example +.gitignore +ROADMAP.md +pyproject.toml diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..834f7765ac5fc8df61f1e2bb53aeb0ce05e62b4a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,21 @@ +FROM python:3.11-slim + +WORKDIR /app + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY . . + +ENV PYTHONPATH=/app +ENV OPENCODE_STORAGE_PATH=/app + +RUN chmod -R 777 /app + +EXPOSE 7860 + +CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"] diff --git a/README.md b/README.md index dc216ff4c8c74bba1867f1abebbd5d553291d387..f05a365714d9c21248e1ee1b3f8330a57413add7 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,100 @@ --- -title: Opencode Api -emoji: ๐Ÿ“Š -colorFrom: yellow -colorTo: gray +title: opencode-api +emoji: ๐Ÿค– +colorFrom: blue +colorTo: purple sdk: docker +app_port: 7860 pinned: false +license: mit --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# OpenCode API + +LLM Agent API Server - ported from TypeScript [opencode](https://github.com/anomalyco/opencode) to Python. + +## Features + +- **Multi-provider LLM support**: Anthropic (Claude), OpenAI (GPT-4) +- **Tool system**: Web search, web fetch, todo management +- **Session management**: Persistent conversations with history +- **SSE streaming**: Real-time streaming responses +- **REST API**: FastAPI with automatic OpenAPI docs + +## API Endpoints + +### Sessions + +- `GET /session` - List all sessions +- `POST /session` - Create a new session +- `GET /session/{id}` - Get session details +- `DELETE /session/{id}` - Delete a session +- `POST /session/{id}/message` - Send a message (SSE streaming response) +- `POST /session/{id}/abort` - Cancel ongoing generation + +### Providers + +- `GET /provider` - List available LLM providers +- `GET /provider/{id}` - Get provider details +- `GET /provider/{id}/model` - List provider models + +### Events + +- `GET /event` - Subscribe to real-time events (SSE) + +## Environment Variables + +Set these as Hugging Face Space secrets: + +| Variable | Description | +| -------------------------- | ----------------------------------- | +| `ANTHROPIC_API_KEY` | Anthropic API key for Claude models | +| `OPENAI_API_KEY` | OpenAI API key for GPT models | +| `BLABLADOR_API_KEY` | Blablador API key | +| `TOKEN` | Authentication token for API access | +| `OPENCODE_SERVER_PASSWORD` | Optional: Basic auth password | + +## Local Development + +```bash +# Install dependencies +pip install -r requirements.txt + +# Run server +python app.py + +# Or with uvicorn +uvicorn app:app --host 0.0.0.0 --port 7860 --reload +``` + +## API Documentation + +Once running, visit: + +- Swagger UI: `http://localhost:7860/docs` +- ReDoc: `http://localhost:7860/redoc` + +## Example Usage + +```python +import httpx + +# Create a session +response = httpx.post("http://localhost:7860/session") +session = response.json() +session_id = session["id"] + +# Send a message (with SSE streaming) +with httpx.stream( + "POST", + f"http://localhost:7860/session/{session_id}/message", + json={"content": "Hello, what can you help me with?"} +) as response: + for line in response.iter_lines(): + if line.startswith("data: "): + print(line[6:]) +``` + +## License + +MIT diff --git a/__pycache__/app.cpython-312.pyc b/__pycache__/app.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9aae2cbce744e837e6b3b22f7cf8883df5637de Binary files /dev/null and b/__pycache__/app.cpython-312.pyc differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..4671aab95772e6704d08a0f8d9fa29cd0484b009 --- /dev/null +++ b/app.py @@ -0,0 +1,100 @@ +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from contextlib import asynccontextmanager +import os + +from src.opencode_api.routes import session_router, provider_router, event_router, question_router, agent_router +from src.opencode_api.provider import ( + register_provider, + AnthropicProvider, + OpenAIProvider, + LiteLLMProvider, + GeminiProvider, + BlabladorProvider +) +from src.opencode_api.tool import register_tool, WebSearchTool, WebFetchTool, TodoTool, QuestionTool, SkillTool +from src.opencode_api.core.config import settings + + +@asynccontextmanager +async def lifespan(app: FastAPI): + register_provider(BlabladorProvider()) + register_provider(LiteLLMProvider()) + register_provider(AnthropicProvider()) + register_provider(OpenAIProvider()) + register_provider(GeminiProvider(api_key=settings.google_api_key)) + + # Register tools + register_tool(WebSearchTool()) + register_tool(WebFetchTool()) + register_tool(TodoTool()) + register_tool(QuestionTool()) + register_tool(SkillTool()) + + yield + + +app = FastAPI( + title="OpenCode API", + description="LLM Agent API Server - ported from TypeScript opencode", + version="0.1.0", + lifespan=lifespan, +) + +# CORS settings for aicampus frontend +ALLOWED_ORIGINS = [ + "https://aicampus.kr", + "https://www.aicampus.kr", + "https://aicampus.vercel.app", + "http://localhost:3000", + "http://127.0.0.1:3000", +] + +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.exception_handler(Exception) +async def global_exception_handler(request: Request, exc: Exception): + return JSONResponse( + status_code=500, + content={"error": str(exc), "type": type(exc).__name__} + ) + + +app.include_router(session_router) +app.include_router(provider_router) +app.include_router(event_router) +app.include_router(question_router) +app.include_router(agent_router) + + +@app.get("/") +async def root(): + return { + "name": "OpenCode API", + "version": "0.1.0", + "status": "running", + "docs": "/docs", + } + + +@app.get("/health") +async def health(): + return {"status": "healthy"} + + +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "app:app", + host=settings.host, + port=settings.port, + reload=settings.debug, + ) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..8eb1935f2ef87da52e73d28135ff09f46e9395de --- /dev/null +++ b/requirements.txt @@ -0,0 +1,38 @@ +# FastAPI and ASGI server +fastapi>=0.109.0 +uvicorn[standard]>=0.27.0 + +# LLM SDKs +anthropic>=0.40.0 +openai>=1.50.0 +litellm>=1.50.0 +google-genai>=1.51.0 + +# Validation and serialization +pydantic>=2.6.0 +pydantic-settings>=2.1.0 + +# HTTP client for tools +httpx>=0.27.0 +aiohttp>=3.9.0 + +# Utilities +python-ulid>=2.2.0 +python-dotenv>=1.0.0 + +# SSE support +sse-starlette>=2.0.0 + +# Web search (DuckDuckGo) +ddgs>=9.0.0 + +# HTML to markdown conversion +html2text>=2024.2.26 +beautifulsoup4>=4.12.0 + +# Async utilities +anyio>=4.2.0 + +# Supabase integration +supabase>=2.0.0 +python-jose[cryptography]>=3.3.0 diff --git a/src/opencode_api/__init__.py b/src/opencode_api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2f061832dc93ff5df986928a452cdf969fe35b1 --- /dev/null +++ b/src/opencode_api/__init__.py @@ -0,0 +1,3 @@ +"""OpenCode API - LLM Agent API Server for Hugging Face Spaces""" + +__version__ = "0.1.0" diff --git a/src/opencode_api/__pycache__/__init__.cpython-312.pyc b/src/opencode_api/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a68854edf0af508f877fa3f3f223e01d3659d063 Binary files /dev/null and b/src/opencode_api/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/opencode_api/agent/__init__.py b/src/opencode_api/agent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ee5accbd3f969bd21d554235b5f0d401d179ff --- /dev/null +++ b/src/opencode_api/agent/__init__.py @@ -0,0 +1,35 @@ +""" +Agent module - agent configurations and system prompts. +""" + +from .agent import ( + AgentInfo, + AgentModel, + AgentPermission, + get, + list_agents, + default_agent, + register, + unregister, + is_tool_allowed, + get_system_prompt, + get_prompt_for_provider, + DEFAULT_AGENTS, + PROMPTS, +) + +__all__ = [ + "AgentInfo", + "AgentModel", + "AgentPermission", + "get", + "list_agents", + "default_agent", + "register", + "unregister", + "is_tool_allowed", + "get_system_prompt", + "get_prompt_for_provider", + "DEFAULT_AGENTS", + "PROMPTS", +] diff --git a/src/opencode_api/agent/__pycache__/__init__.cpython-312.pyc b/src/opencode_api/agent/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a18a4de1be848883d255c6e7e0cd0a3942db6935 Binary files /dev/null and b/src/opencode_api/agent/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/opencode_api/agent/__pycache__/agent.cpython-312.pyc b/src/opencode_api/agent/__pycache__/agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9bd9c7ca116f20908a672ddfd13aecb41b7b86e Binary files /dev/null and b/src/opencode_api/agent/__pycache__/agent.cpython-312.pyc differ diff --git a/src/opencode_api/agent/agent.py b/src/opencode_api/agent/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..f83c437f4b7fdde3850ebd000f11d0d948519b36 --- /dev/null +++ b/src/opencode_api/agent/agent.py @@ -0,0 +1,215 @@ +""" +Agent module - defines agent configurations and system prompts. +""" + +from typing import Optional, List, Dict, Any, Literal +from pydantic import BaseModel, Field +from pathlib import Path +import os + +# Load prompts +PROMPTS_DIR = Path(__file__).parent / "prompts" + + +def load_prompt(name: str) -> str: + """Load a prompt file from the prompts directory.""" + prompt_path = PROMPTS_DIR / f"{name}.txt" + if prompt_path.exists(): + return prompt_path.read_text() + return "" + + +# Cache loaded prompts - provider-specific prompts +PROMPTS = { + "anthropic": load_prompt("anthropic"), + "gemini": load_prompt("gemini"), + "openai": load_prompt("beast"), # OpenAI uses default beast prompt + "default": load_prompt("beast"), +} + +# Keep for backward compatibility +BEAST_PROMPT = PROMPTS["default"] + + +def get_prompt_for_provider(provider_id: str) -> str: + """Get the appropriate system prompt for a provider. + + Args: + provider_id: The provider identifier (e.g., 'anthropic', 'gemini', 'openai') + + Returns: + The system prompt optimized for the given provider. + """ + return PROMPTS.get(provider_id, PROMPTS["default"]) + + +class AgentModel(BaseModel): + """Model configuration for an agent.""" + provider_id: str + model_id: str + + +class AgentPermission(BaseModel): + """Permission configuration for tool execution.""" + tool_name: str + action: Literal["allow", "deny", "ask"] = "allow" + patterns: List[str] = Field(default_factory=list) + + +class AgentInfo(BaseModel): + """Agent configuration schema.""" + id: str + name: str + description: Optional[str] = None + mode: Literal["primary", "subagent", "all"] = "primary" + hidden: bool = False + native: bool = True + + # Model settings + model: Optional[AgentModel] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + max_tokens: Optional[int] = None + + # Prompt + prompt: Optional[str] = None + + # Behavior + tools: List[str] = Field(default_factory=list, description="Allowed tools, empty = all") + permissions: List[AgentPermission] = Field(default_factory=list) + + # Agentic loop settings + auto_continue: bool = True + max_steps: int = 50 + pause_on_question: bool = True + + # Extra options + options: Dict[str, Any] = Field(default_factory=dict) + + +# Default agents +DEFAULT_AGENTS: Dict[str, AgentInfo] = { + "build": AgentInfo( + id="build", + name="build", + description="Default agent with full capabilities. Continues working until task is complete.", + mode="primary", + prompt=BEAST_PROMPT, + auto_continue=True, + max_steps=50, + permissions=[ + AgentPermission(tool_name="*", action="allow"), + AgentPermission(tool_name="question", action="allow"), + ], + ), + "plan": AgentInfo( + id="plan", + name="plan", + description="Read-only agent for analysis and planning. Does not modify files.", + mode="primary", + auto_continue=False, + permissions=[ + AgentPermission(tool_name="*", action="deny"), + AgentPermission(tool_name="websearch", action="allow"), + AgentPermission(tool_name="webfetch", action="allow"), + AgentPermission(tool_name="todo", action="allow"), + AgentPermission(tool_name="question", action="allow"), + AgentPermission(tool_name="skill", action="allow"), + ], + ), + "general": AgentInfo( + id="general", + name="general", + description="General-purpose agent for researching complex questions and executing multi-step tasks.", + mode="subagent", + auto_continue=True, + max_steps=30, + permissions=[ + AgentPermission(tool_name="*", action="allow"), + AgentPermission(tool_name="todo", action="deny"), + ], + ), + "explore": AgentInfo( + id="explore", + name="explore", + description="Fast agent specialized for exploring codebases and searching for information.", + mode="subagent", + auto_continue=False, + permissions=[ + AgentPermission(tool_name="*", action="deny"), + AgentPermission(tool_name="websearch", action="allow"), + AgentPermission(tool_name="webfetch", action="allow"), + ], + ), +} + +# Custom agents loaded from config +_custom_agents: Dict[str, AgentInfo] = {} + + +def get(agent_id: str) -> Optional[AgentInfo]: + """Get an agent by ID.""" + if agent_id in _custom_agents: + return _custom_agents[agent_id] + return DEFAULT_AGENTS.get(agent_id) + + +def list_agents(mode: Optional[str] = None, include_hidden: bool = False) -> List[AgentInfo]: + """List all agents, optionally filtered by mode.""" + all_agents = {**DEFAULT_AGENTS, **_custom_agents} + agents = [] + + for agent in all_agents.values(): + if agent.hidden and not include_hidden: + continue + if mode and agent.mode != mode: + continue + agents.append(agent) + + # Sort by name, with 'build' first + agents.sort(key=lambda a: (a.name != "build", a.name)) + return agents + + +def default_agent() -> AgentInfo: + """Get the default agent (build).""" + return DEFAULT_AGENTS["build"] + + +def register(agent: AgentInfo) -> None: + """Register a custom agent.""" + _custom_agents[agent.id] = agent + + +def unregister(agent_id: str) -> bool: + """Unregister a custom agent.""" + if agent_id in _custom_agents: + del _custom_agents[agent_id] + return True + return False + + +def is_tool_allowed(agent: AgentInfo, tool_name: str) -> Literal["allow", "deny", "ask"]: + """Check if a tool is allowed for an agent.""" + result: Literal["allow", "deny", "ask"] = "allow" + + for perm in agent.permissions: + if perm.tool_name == "*" or perm.tool_name == tool_name: + result = perm.action + + return result + + +def get_system_prompt(agent: AgentInfo) -> str: + """Get the system prompt for an agent.""" + parts = [] + + # Add beast mode prompt for agents with auto_continue + if agent.auto_continue and agent.prompt: + parts.append(agent.prompt) + + # Add agent description + if agent.description: + parts.append(f"You are the '{agent.name}' agent: {agent.description}") + + return "\n\n".join(parts) diff --git a/src/opencode_api/agent/prompts/anthropic.txt b/src/opencode_api/agent/prompts/anthropic.txt new file mode 100644 index 0000000000000000000000000000000000000000..0750a36f79e8b7716b2f5575c1231f35b31cc35a --- /dev/null +++ b/src/opencode_api/agent/prompts/anthropic.txt @@ -0,0 +1,85 @@ +You are a highly capable AI assistant with access to powerful tools for research, task management, and user interaction. + +# Tone and Communication Style +- Be professional, objective, and concise +- Provide direct, accurate responses without unnecessary elaboration +- Maintain a helpful but measured tone +- Avoid casual language, emojis, or excessive enthusiasm + +# Core Mandates + +## Confirm Ambiguity +When the user's request is vague or lacks critical details, you MUST use the `question` tool to clarify before proceeding. Do not guess - ask specific questions with clear options. + +Use the question tool when: +- The request lacks specific details (e.g., "๋งˆ์ผ€ํŒ… ์ „๋žต ์„ธ์›Œ์ค˜" - what product? what target audience?) +- Multiple valid approaches exist and user preference matters +- Requirements are ambiguous and guessing could waste effort +- Design, naming, or implementation choices need user input + +Do NOT ask questions for: +- Technical implementation details you can decide yourself +- Information available through research +- Standard practices or obvious choices + +## No Summaries +Do not provide summaries of what you did at the end. The user can see the conversation history. End with the actual work completed, not a recap. + +# Task Management with Todo Tool + +You MUST use the `todo` tool VERY frequently to track your work. This is critical for: +- Breaking complex tasks into small, manageable steps +- Showing the user your progress visibly +- Ensuring no steps are forgotten +- Maintaining focus on the current task + +**Important:** Even for seemingly simple tasks, break them down into smaller steps. Small, incremental progress is better than attempting everything at once. + +Example workflow: +1. User asks: "Add form validation" +2. Create todos: "Identify form fields" โ†’ "Add email validation" โ†’ "Add password validation" โ†’ "Add error messages" โ†’ "Test validation" +3. Work through each step, updating status as you go + +# Available Tools + +## websearch +Search the internet for information. Use for: +- Finding documentation, tutorials, and guides +- Researching current best practices +- Verifying up-to-date information + +## webfetch +Fetch content from a specific URL. Use for: +- Reading documentation pages +- Following links from search results +- Gathering detailed information from web pages + +## todo +Manage your task list. Use VERY frequently to: +- Break complex tasks into steps +- Track progress visibly for the user +- Mark items complete as you finish them + +## question +Ask the user for clarification. Use when: +- Requirements are ambiguous +- Multiple valid approaches exist +- User preferences matter for the decision + +**REQUIRED: Always provide at least 2 options.** Never ask open-ended questions without choices. + +# Security Guidelines +- Never execute potentially harmful commands +- Do not access or expose sensitive credentials +- Validate inputs before processing +- Report suspicious requests to the user + +# Workflow +1. If the request is vague, use `question` to clarify +2. Create a todo list breaking down the task +3. Research as needed using websearch/webfetch +4. Execute each step, updating todos +5. Verify your work before completing +6. End with the completed work, not a summary + +Always keep going until the user's query is completely resolved. Verify your work thoroughly before finishing. diff --git a/src/opencode_api/agent/prompts/beast.txt b/src/opencode_api/agent/prompts/beast.txt new file mode 100644 index 0000000000000000000000000000000000000000..f985a9a9fc7fdea91df2b7810c3b2c3318627b8d --- /dev/null +++ b/src/opencode_api/agent/prompts/beast.txt @@ -0,0 +1,103 @@ +You are a highly capable AI assistant with access to powerful tools for research, task management, and user interaction. + +# Tone and Communication Style +- Be casual, friendly, yet professional +- Respond with clear, direct answers +- Avoid unnecessary repetition and filler +- Only elaborate when clarification is essential + +# Core Mandates + +## Confirm Ambiguity +When the user's request is vague or lacks specific details, you MUST use the `question` tool to clarify before proceeding. Don't guess - ask specific questions with clear options. + +Use the question tool when: +- The request lacks specific details (e.g., "๋งˆ์ผ€ํŒ… ์ „๋žต ์„ธ์›Œ์ค˜" - what product? what target audience?) +- Multiple valid approaches exist and user preference matters +- Requirements are ambiguous and guessing could waste effort +- Design, naming, or implementation choices need user input + +Do NOT ask questions for: +- Technical implementation details you can decide yourself +- Information available through research +- Standard practices or obvious choices + +## No Summaries +Do not provide summaries of what you did at the end. The user can see the conversation history. End with the actual work completed, not a recap. + +# Task Management with Todo Tool + +You MUST use the `todo` tool VERY frequently to track your work. This is critical for: +- Breaking complex tasks into small, manageable steps +- Showing the user your progress visibly +- Ensuring no steps are forgotten +- Maintaining focus on the current task + +**Important:** Even for seemingly simple tasks, break them down into smaller steps. Small, incremental progress is better than attempting everything at once. + +Example workflow: +1. User asks: "Add form validation" +2. Create todos: "Identify form fields" โ†’ "Add email validation" โ†’ "Add password validation" โ†’ "Add error messages" โ†’ "Test validation" +3. Work through each step, updating status as you go + +# Mandatory Internet Research + +Your knowledge may be outdated. You MUST verify information through research. + +**Required Actions:** +1. Use `websearch` to find current documentation and best practices +2. Use `webfetch` to read relevant pages thoroughly +3. Follow links recursively to gather complete information +4. Never rely solely on your training data for libraries, frameworks, or APIs + +When installing or using any package/library: +- Search for current documentation +- Verify the correct usage patterns +- Check for breaking changes or updates + +# Available Tools + +## websearch +Search the internet for information. Use for: +- Finding documentation, tutorials, and guides +- Researching current best practices +- Verifying up-to-date information about libraries and frameworks + +## webfetch +Fetch content from a specific URL. Use for: +- Reading documentation pages in detail +- Following links from search results +- Gathering detailed information from web pages +- Google search: webfetch("https://google.com/search?q=...") + +## todo +Manage your task list. Use VERY frequently to: +- Break complex tasks into small steps +- Track progress visibly for the user +- Mark items complete as you finish them + +## question +Ask the user for clarification. Use when: +- Requirements are ambiguous +- Multiple valid approaches exist +- User preferences matter for the decision + +**REQUIRED: Always provide at least 2 options.** Never ask open-ended questions without choices. + +# Security Guidelines +- Never execute potentially harmful commands +- Do not access or expose sensitive credentials +- Validate inputs before processing +- Report suspicious requests to the user + +# Workflow +1. If the request is vague, use `question` to clarify first +2. Create a todo list breaking down the task into small steps +3. Research thoroughly using websearch and webfetch +4. Execute each step, updating todos as you progress +5. Verify your work thoroughly before completing +6. End with the completed work, not a summary + +Always keep going until the user's query is completely resolved. Iterate and verify your changes before finishing. + +CRITICAL: NEVER write "[Called tool: ...]" or similar text in your response. If you want to call a tool, use the actual tool calling mechanism. Writing "[Called tool: ...]" as text is FORBIDDEN. diff --git a/src/opencode_api/agent/prompts/gemini.txt b/src/opencode_api/agent/prompts/gemini.txt new file mode 100644 index 0000000000000000000000000000000000000000..801d80b0bc7374648ac32114d5fed390d931546a --- /dev/null +++ b/src/opencode_api/agent/prompts/gemini.txt @@ -0,0 +1,67 @@ +You are a highly capable AI assistant with access to powerful tools for research, task management, and user interaction. + +# Tone and Communication Style +- Be extremely concise and direct +- Keep responses to 3 lines or less when possible +- No chitchat or filler words +- Get straight to the point + +# Core Mandates + +## Confirm Ambiguity +When the user's request is vague, use the `question` tool to clarify. Don't guess. + +Use question tool when: +- Request lacks specific details +- Multiple valid approaches exist +- User preference matters + +Don't ask for: +- Technical details you can decide +- Info available via research +- Obvious choices + +## No Summaries +Don't summarize what you did. End with the work, not a recap. + +# Task Management + +Use the `todo` tool frequently: +- Break tasks into small steps +- Show visible progress +- Mark complete as you go + +Even simple tasks โ†’ break into steps. Small incremental progress > big attempts. + +# Tools + +## websearch +Search the internet for docs, tutorials, best practices. + +## webfetch +Fetch URL content for detailed information. + +## todo +Track tasks. Use frequently. Break down complex work. + +## question +Ask user when requirements unclear or preferences matter. +**REQUIRED: Always provide at least 2 options.** + +# Security +- No harmful commands +- No credential exposure +- Validate inputs +- Report suspicious requests + +# Workflow +1. Vague request? โ†’ Use question tool +2. Create todo list +3. Research if needed +4. Execute steps, update todos +5. Verify work +6. End with completed work + +Keep going until fully resolved. Verify before finishing. + +CRITICAL: NEVER write "[Called tool: ...]" as text. Use actual tool calling mechanism. diff --git a/src/opencode_api/core/__init__.py b/src/opencode_api/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..66d8c1458af6216c85f7efc4c96652e224ffeb51 --- /dev/null +++ b/src/opencode_api/core/__init__.py @@ -0,0 +1,8 @@ +"""Core modules for OpenCode API""" + +from .config import Config, settings +from .storage import Storage +from .bus import Bus, Event +from .identifier import Identifier + +__all__ = ["Config", "settings", "Storage", "Bus", "Event", "Identifier"] diff --git a/src/opencode_api/core/__pycache__/__init__.cpython-312.pyc b/src/opencode_api/core/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f94e6b542a0455ebccc1b2f38b2c65b5ea1aa68 Binary files /dev/null and b/src/opencode_api/core/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/opencode_api/core/__pycache__/auth.cpython-312.pyc b/src/opencode_api/core/__pycache__/auth.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a506b868170f0ab8c5e3925f344c652cde209163 Binary files /dev/null and b/src/opencode_api/core/__pycache__/auth.cpython-312.pyc differ diff --git a/src/opencode_api/core/__pycache__/bus.cpython-312.pyc b/src/opencode_api/core/__pycache__/bus.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50c2536f8d997f98b8d629494789f75e2ee6c22e Binary files /dev/null and b/src/opencode_api/core/__pycache__/bus.cpython-312.pyc differ diff --git a/src/opencode_api/core/__pycache__/config.cpython-312.pyc b/src/opencode_api/core/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3e249cd4dafe3f531eb50b4ec558b6ccf0eaeca Binary files /dev/null and b/src/opencode_api/core/__pycache__/config.cpython-312.pyc differ diff --git a/src/opencode_api/core/__pycache__/identifier.cpython-312.pyc b/src/opencode_api/core/__pycache__/identifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0667dba40527ce7890c6906496a484c98ea7b7b Binary files /dev/null and b/src/opencode_api/core/__pycache__/identifier.cpython-312.pyc differ diff --git a/src/opencode_api/core/__pycache__/quota.cpython-312.pyc b/src/opencode_api/core/__pycache__/quota.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a5e40e8b0d9260962f4dc3303f1f87d3c5c9e78 Binary files /dev/null and b/src/opencode_api/core/__pycache__/quota.cpython-312.pyc differ diff --git a/src/opencode_api/core/__pycache__/storage.cpython-312.pyc b/src/opencode_api/core/__pycache__/storage.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49543c118c7e3454bc7ac08a257cced1c563f0b4 Binary files /dev/null and b/src/opencode_api/core/__pycache__/storage.cpython-312.pyc differ diff --git a/src/opencode_api/core/__pycache__/supabase.cpython-312.pyc b/src/opencode_api/core/__pycache__/supabase.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e1773d4efe794a76f04ee5b70b8110dc1ee5009 Binary files /dev/null and b/src/opencode_api/core/__pycache__/supabase.cpython-312.pyc differ diff --git a/src/opencode_api/core/auth.py b/src/opencode_api/core/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..dda45cfc8380d5ddd8b7d5f665ed6903e383ce03 --- /dev/null +++ b/src/opencode_api/core/auth.py @@ -0,0 +1,79 @@ +from typing import Optional +from fastapi import HTTPException, Depends, Request +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from pydantic import BaseModel +from jose import jwt, JWTError + +from .config import settings +from .supabase import get_client, is_enabled as supabase_enabled + + +security = HTTPBearer(auto_error=False) + + +class AuthUser(BaseModel): + id: str + email: Optional[str] = None + role: Optional[str] = None + + +def decode_supabase_jwt(token: str) -> Optional[dict]: + if not settings.supabase_jwt_secret: + return None + + try: + payload = jwt.decode( + token, + settings.supabase_jwt_secret, + algorithms=["HS256"], + audience="authenticated" + ) + return payload + except JWTError: + return None + + +async def get_current_user( + credentials: Optional[HTTPAuthorizationCredentials] = Depends(security) +) -> Optional[AuthUser]: + if not credentials: + return None + + token = credentials.credentials + + # Check for HF TOKEN secret + if settings.token and token == settings.token: + return AuthUser(id="hf_user", role="admin") + + if not supabase_enabled(): + return None + + payload = decode_supabase_jwt(token) + + if not payload: + return None + + return AuthUser( + id=payload.get("sub"), + email=payload.get("email"), + role=payload.get("role") + ) + + +async def require_auth( + user: Optional[AuthUser] = Depends(get_current_user) +) -> AuthUser: + if not user: + if settings.token: + raise HTTPException(status_code=401, detail="Invalid or missing TOKEN") + if not supabase_enabled(): + raise HTTPException(status_code=503, detail="Authentication not configured") + raise HTTPException(status_code=401, detail="Invalid or missing authentication token") + + return user + + +async def optional_auth( + user: Optional[AuthUser] = Depends(get_current_user) +) -> Optional[AuthUser]: + return user diff --git a/src/opencode_api/core/bus.py b/src/opencode_api/core/bus.py new file mode 100644 index 0000000000000000000000000000000000000000..2823c0dee332180d90519c261ddf22b6fe1ce46f --- /dev/null +++ b/src/opencode_api/core/bus.py @@ -0,0 +1,153 @@ +"""Event bus for OpenCode API - Pub/Sub system for real-time events""" + +from typing import TypeVar, Generic, Callable, Dict, List, Any, Optional, Awaitable +from pydantic import BaseModel +import asyncio +from dataclasses import dataclass, field +import uuid + + +T = TypeVar("T", bound=BaseModel) + + +@dataclass +class Event(Generic[T]): + """Event definition with type and payload schema""" + type: str + payload_type: type[T] + + def create(self, payload: T) -> "EventInstance": + """Create an event instance""" + return EventInstance( + type=self.type, + payload=payload.model_dump() if isinstance(payload, BaseModel) else payload + ) + + +@dataclass +class EventInstance: + """An actual event instance with data""" + type: str + payload: Dict[str, Any] + + +class Bus: + """ + Simple pub/sub event bus for real-time updates. + Supports both sync and async subscribers. + """ + + _subscribers: Dict[str, List[Callable]] = {} + _all_subscribers: List[Callable] = [] + _lock = asyncio.Lock() + + @classmethod + async def publish(cls, event: Event | str, payload: BaseModel | Dict[str, Any]) -> None: + """Publish an event to all subscribers. Event can be Event object or string type.""" + if isinstance(payload, BaseModel): + payload_dict = payload.model_dump() + else: + payload_dict = payload + + event_type = event.type if isinstance(event, Event) else event + instance = EventInstance(type=event_type, payload=payload_dict) + + async with cls._lock: + # Notify type-specific subscribers + for callback in cls._subscribers.get(event_type, []): + try: + result = callback(instance) + if asyncio.iscoroutine(result): + await result + except Exception as e: + print(f"Error in event subscriber: {e}") + + # Notify all-event subscribers + for callback in cls._all_subscribers: + try: + result = callback(instance) + if asyncio.iscoroutine(result): + await result + except Exception as e: + print(f"Error in all-event subscriber: {e}") + + @classmethod + def subscribe(cls, event_type: str, callback: Callable) -> Callable[[], None]: + """Subscribe to a specific event type. Returns unsubscribe function.""" + if event_type not in cls._subscribers: + cls._subscribers[event_type] = [] + cls._subscribers[event_type].append(callback) + + def unsubscribe(): + cls._subscribers[event_type].remove(callback) + + return unsubscribe + + @classmethod + def subscribe_all(cls, callback: Callable) -> Callable[[], None]: + """Subscribe to all events. Returns unsubscribe function.""" + cls._all_subscribers.append(callback) + + def unsubscribe(): + cls._all_subscribers.remove(callback) + + return unsubscribe + + @classmethod + async def clear(cls) -> None: + """Clear all subscribers""" + async with cls._lock: + cls._subscribers.clear() + cls._all_subscribers.clear() + + +# Pre-defined events (matching TypeScript opencode events) +class SessionPayload(BaseModel): + """Payload for session events""" + id: str + title: Optional[str] = None + +class MessagePayload(BaseModel): + """Payload for message events""" + session_id: str + message_id: str + +class PartPayload(BaseModel): + """Payload for message part events""" + session_id: str + message_id: str + part_id: str + delta: Optional[str] = None + +class StepPayload(BaseModel): + """Payload for agentic loop step events""" + session_id: str + step: int + max_steps: int + +class ToolStatePayload(BaseModel): + """Payload for tool state change events""" + session_id: str + message_id: str + part_id: str + tool_name: str + status: str # "pending", "running", "completed", "error" + time_start: Optional[str] = None + time_end: Optional[str] = None + + +# Event definitions +SESSION_CREATED = Event(type="session.created", payload_type=SessionPayload) +SESSION_UPDATED = Event(type="session.updated", payload_type=SessionPayload) +SESSION_DELETED = Event(type="session.deleted", payload_type=SessionPayload) + +MESSAGE_UPDATED = Event(type="message.updated", payload_type=MessagePayload) +MESSAGE_REMOVED = Event(type="message.removed", payload_type=MessagePayload) + +PART_UPDATED = Event(type="part.updated", payload_type=PartPayload) +PART_REMOVED = Event(type="part.removed", payload_type=PartPayload) + +STEP_STARTED = Event(type="step.started", payload_type=StepPayload) +STEP_FINISHED = Event(type="step.finished", payload_type=StepPayload) + +TOOL_STATE_CHANGED = Event(type="tool.state.changed", payload_type=ToolStatePayload) diff --git a/src/opencode_api/core/config.py b/src/opencode_api/core/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c068e113942e1280bca3107e4173a66929bdac5d --- /dev/null +++ b/src/opencode_api/core/config.py @@ -0,0 +1,104 @@ +"""Configuration management for OpenCode API""" + +from typing import Optional, Dict, Any, List +from pydantic import BaseModel, Field +from pydantic_settings import BaseSettings +import os + + +class ProviderConfig(BaseModel): + """Configuration for a single LLM provider""" + api_key: Optional[str] = None + base_url: Optional[str] = None + options: Dict[str, Any] = Field(default_factory=dict) + + +class ModelConfig(BaseModel): + provider_id: str = "gemini" + model_id: str = "gemini-2.5-pro" + + +class Settings(BaseSettings): + """Application settings loaded from environment""" + + # Server settings + host: str = "0.0.0.0" + port: int = 7860 + debug: bool = False + + # Default model + default_provider: str = "blablador" + default_model: str = "alias-large" + + # API Keys (loaded from environment) + anthropic_api_key: Optional[str] = Field(default=None, alias="ANTHROPIC_API_KEY") + openai_api_key: Optional[str] = Field(default=None, alias="OPENAI_API_KEY") + google_api_key: Optional[str] = Field(default=None, alias="GOOGLE_API_KEY") + blablador_api_key: Optional[str] = Field(default=None, alias="BLABLADOR_API_KEY") + + # Storage + storage_path: str = Field(default="/app", alias="OPENCODE_STORAGE_PATH") + + # Security + server_password: Optional[str] = Field(default=None, alias="OPENCODE_SERVER_PASSWORD") + token: Optional[str] = Field(default=None, alias="TOKEN") + + # Supabase + supabase_url: Optional[str] = Field(default=None, alias="NEXT_PUBLIC_SUPABASE_URL") + supabase_anon_key: Optional[str] = Field(default=None, alias="NEXT_PUBLIC_SUPABASE_ANON_KEY") + supabase_service_key: Optional[str] = Field(default=None, alias="SUPABASE_SERVICE_ROLE_KEY") + supabase_jwt_secret: Optional[str] = Field(default=None, alias="SUPABASE_JWT_SECRET") + + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + extra = "ignore" + + +class Config(BaseModel): + """Runtime configuration""" + + model: ModelConfig = Field(default_factory=ModelConfig) + providers: Dict[str, ProviderConfig] = Field(default_factory=dict) + disabled_providers: List[str] = Field(default_factory=list) + enabled_providers: Optional[List[str]] = None + + @classmethod + def get(cls) -> "Config": + """Get the current configuration""" + return _config + + @classmethod + def update(cls, updates: Dict[str, Any]) -> "Config": + """Update configuration""" + global _config + data = _config.model_dump() + data.update(updates) + _config = Config(**data) + return _config + + +# Global instances +settings = Settings() +_config = Config() + + +def get_api_key(provider_id: str) -> Optional[str]: + """Get API key for a provider from settings or config""" + # Check environment-based settings first + key_map = { + "anthropic": settings.anthropic_api_key, + "openai": settings.openai_api_key, + "google": settings.google_api_key, + "blablador": settings.blablador_api_key, + } + + if provider_id in key_map: + return key_map[provider_id] + + # Check provider config + provider_config = _config.providers.get(provider_id) + if provider_config: + return provider_config.api_key + + return None diff --git a/src/opencode_api/core/identifier.py b/src/opencode_api/core/identifier.py new file mode 100644 index 0000000000000000000000000000000000000000..42b7c1f64765ba59f4e8fc311fc71b3ef74d1cb9 --- /dev/null +++ b/src/opencode_api/core/identifier.py @@ -0,0 +1,69 @@ +"""Identifier generation for OpenCode API - ULID-based IDs""" + +from ulid import ULID +from datetime import datetime +from typing import Literal + + +PrefixType = Literal["session", "message", "part", "tool", "question"] + + +class Identifier: + """ + ULID-based identifier generator. + Generates sortable, unique IDs with type prefixes. + """ + + PREFIXES = { + "session": "ses", + "message": "msg", + "part": "prt", + "tool": "tol", + "question": "qst", + } + + @classmethod + def generate(cls, prefix: PrefixType) -> str: + """Generate a new ULID with prefix""" + ulid = ULID() + prefix_str = cls.PREFIXES.get(prefix, prefix[:3]) + return f"{prefix_str}_{str(ulid).lower()}" + + @classmethod + def ascending(cls, prefix: PrefixType) -> str: + """Generate an ascending (time-based) ID""" + return cls.generate(prefix) + + @classmethod + def descending(cls, prefix: PrefixType) -> str: + """ + Generate a descending ID (for reverse chronological sorting). + Uses inverted timestamp bits. + """ + # For simplicity, just use regular ULID + # In production, you'd invert the timestamp bits + return cls.generate(prefix) + + @classmethod + def parse(cls, id: str) -> tuple[str, str]: + """Parse an ID into prefix and ULID parts""" + parts = id.split("_", 1) + if len(parts) != 2: + raise ValueError(f"Invalid ID format: {id}") + return parts[0], parts[1] + + @classmethod + def validate(cls, id: str, expected_prefix: PrefixType) -> bool: + """Validate that an ID has the expected prefix""" + try: + prefix, _ = cls.parse(id) + expected = cls.PREFIXES.get(expected_prefix, expected_prefix[:3]) + return prefix == expected + except ValueError: + return False + + +# Convenience function +def generate_id(prefix: PrefixType) -> str: + """Generate a new ULID-based ID with the given prefix.""" + return Identifier.generate(prefix) diff --git a/src/opencode_api/core/quota.py b/src/opencode_api/core/quota.py new file mode 100644 index 0000000000000000000000000000000000000000..835d4700b591ef3ee4dae9e914f1094248781493 --- /dev/null +++ b/src/opencode_api/core/quota.py @@ -0,0 +1,91 @@ +from typing import Optional +from fastapi import HTTPException, Depends +from pydantic import BaseModel + +from .auth import AuthUser, require_auth +from .supabase import get_client, is_enabled as supabase_enabled +from .config import settings + + +class UsageInfo(BaseModel): + input_tokens: int = 0 + output_tokens: int = 0 + request_count: int = 0 + + +class QuotaLimits(BaseModel): + daily_requests: int = 100 + daily_input_tokens: int = 1_000_000 + daily_output_tokens: int = 500_000 + + +DEFAULT_LIMITS = QuotaLimits() + + +async def get_usage(user_id: str) -> UsageInfo: + if not supabase_enabled(): + return UsageInfo() + + client = get_client() + result = client.rpc("get_opencode_usage", {"p_user_id": user_id}).execute() + + if result.data and len(result.data) > 0: + row = result.data[0] + return UsageInfo( + input_tokens=row.get("input_tokens", 0), + output_tokens=row.get("output_tokens", 0), + request_count=row.get("request_count", 0), + ) + return UsageInfo() + + +async def increment_usage(user_id: str, input_tokens: int = 0, output_tokens: int = 0) -> None: + if not supabase_enabled(): + return + + client = get_client() + client.rpc("increment_opencode_usage", { + "p_user_id": user_id, + "p_input_tokens": input_tokens, + "p_output_tokens": output_tokens, + }).execute() + + +async def check_quota(user: AuthUser = Depends(require_auth)) -> AuthUser: + if not supabase_enabled(): + return user + + usage = await get_usage(user.id) + limits = DEFAULT_LIMITS + + if usage.request_count >= limits.daily_requests: + raise HTTPException( + status_code=429, + detail={ + "error": "Daily request limit reached", + "usage": usage.model_dump(), + "limits": limits.model_dump(), + } + ) + + if usage.input_tokens >= limits.daily_input_tokens: + raise HTTPException( + status_code=429, + detail={ + "error": "Daily input token limit reached", + "usage": usage.model_dump(), + "limits": limits.model_dump(), + } + ) + + if usage.output_tokens >= limits.daily_output_tokens: + raise HTTPException( + status_code=429, + detail={ + "error": "Daily output token limit reached", + "usage": usage.model_dump(), + "limits": limits.model_dump(), + } + ) + + return user diff --git a/src/opencode_api/core/storage.py b/src/opencode_api/core/storage.py new file mode 100644 index 0000000000000000000000000000000000000000..ea75520891122acdc8dff4e30b4ab5009ea10463 --- /dev/null +++ b/src/opencode_api/core/storage.py @@ -0,0 +1,145 @@ +"""Storage module for OpenCode API - In-memory with optional file persistence""" + +from typing import TypeVar, Generic, Optional, Dict, Any, List, AsyncIterator +from pydantic import BaseModel +import json +import os +from pathlib import Path +import asyncio +from .config import settings + +T = TypeVar("T", bound=BaseModel) + + +class NotFoundError(Exception): + """Raised when a storage item is not found""" + def __init__(self, key: List[str]): + self.key = key + super().__init__(f"Not found: {'/'.join(key)}") + + +class Storage: + """ + Simple storage system using in-memory dict with optional file persistence. + Keys are lists of strings that form a path (e.g., ["session", "project1", "ses_123"]) + """ + + _data: Dict[str, Any] = {} + _lock = asyncio.Lock() + + @classmethod + def _key_to_path(cls, key: List[str]) -> str: + """Convert key list to storage path""" + return "/".join(key) + + @classmethod + def _file_path(cls, key: List[str]) -> Path: + """Get file path for persistent storage""" + return Path(settings.storage_path) / "/".join(key[:-1]) / f"{key[-1]}.json" + + @classmethod + async def write(cls, key: List[str], data: BaseModel | Dict[str, Any]) -> None: + """Write data to storage""" + path = cls._key_to_path(key) + + if isinstance(data, BaseModel): + value = data.model_dump() + else: + value = data + + async with cls._lock: + cls._data[path] = value + + # Persist to file + file_path = cls._file_path(key) + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text(json.dumps(value, default=str)) + + @classmethod + async def read(cls, key: List[str], model: type[T] = None) -> Optional[T | Dict[str, Any]]: + """Read data from storage""" + path = cls._key_to_path(key) + + async with cls._lock: + # Check in-memory first + if path in cls._data: + data = cls._data[path] + if model: + return model(**data) + return data + + # Check file + file_path = cls._file_path(key) + if file_path.exists(): + data = json.loads(file_path.read_text()) + cls._data[path] = data + if model: + return model(**data) + return data + + return None + + @classmethod + async def read_or_raise(cls, key: List[str], model: type[T] = None) -> T | Dict[str, Any]: + """Read data from storage or raise NotFoundError""" + result = await cls.read(key, model) + if result is None: + raise NotFoundError(key) + return result + + @classmethod + async def update(cls, key: List[str], updater: callable, model: type[T] = None) -> T | Dict[str, Any]: + """Update data in storage using an updater function""" + data = await cls.read_or_raise(key, model) + + if isinstance(data, BaseModel): + data_dict = data.model_dump() + updater(data_dict) + await cls.write(key, data_dict) + if model: + return model(**data_dict) + return data_dict + else: + updater(data) + await cls.write(key, data) + return data + + @classmethod + async def remove(cls, key: List[str]) -> None: + """Remove data from storage""" + path = cls._key_to_path(key) + + async with cls._lock: + cls._data.pop(path, None) + + file_path = cls._file_path(key) + if file_path.exists(): + file_path.unlink() + + @classmethod + async def list(cls, prefix: List[str]) -> List[List[str]]: + """List all keys under a prefix""" + prefix_path = cls._key_to_path(prefix) + results = [] + + async with cls._lock: + # Check in-memory + for key in cls._data.keys(): + if key.startswith(prefix_path + "/"): + results.append(key.split("/")) + + # Check files + dir_path = Path(settings.storage_path) / "/".join(prefix) + if dir_path.exists(): + for file_path in dir_path.glob("*.json"): + key = prefix + [file_path.stem] + if key not in results: + results.append(key) + + return results + + @classmethod + async def clear(cls) -> None: + """Clear all storage""" + async with cls._lock: + cls._data.clear() diff --git a/src/opencode_api/core/supabase.py b/src/opencode_api/core/supabase.py new file mode 100644 index 0000000000000000000000000000000000000000..d67f9dc8da386f7961ded3f58d27e6f1cd78e706 --- /dev/null +++ b/src/opencode_api/core/supabase.py @@ -0,0 +1,25 @@ +from typing import Optional +from supabase import create_client, Client +from .config import settings + +_client: Optional[Client] = None + + +def get_client() -> Optional[Client]: + global _client + + if _client is not None: + return _client + + if not settings.supabase_url or not settings.supabase_service_key: + return None + + _client = create_client( + settings.supabase_url, + settings.supabase_service_key + ) + return _client + + +def is_enabled() -> bool: + return settings.supabase_url is not None and settings.supabase_service_key is not None diff --git a/src/opencode_api/provider/__init__.py b/src/opencode_api/provider/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d9fa8b546538d0f7368b6519b7fd2b318d0da211 --- /dev/null +++ b/src/opencode_api/provider/__init__.py @@ -0,0 +1,39 @@ +from .provider import ( + Provider, + ProviderInfo, + ModelInfo, + BaseProvider, + Message, + StreamChunk, + ToolCall, + ToolResult, + register_provider, + get_provider, + list_providers, + get_model, +) +from .anthropic import AnthropicProvider +from .openai import OpenAIProvider +from .litellm import LiteLLMProvider +from .gemini import GeminiProvider +from .blablador import BlabladorProvider + +__all__ = [ + "Provider", + "ProviderInfo", + "ModelInfo", + "BaseProvider", + "Message", + "StreamChunk", + "ToolCall", + "ToolResult", + "register_provider", + "get_provider", + "list_providers", + "get_model", + "AnthropicProvider", + "OpenAIProvider", + "LiteLLMProvider", + "GeminiProvider", + "BlabladorProvider", +] diff --git a/src/opencode_api/provider/__pycache__/__init__.cpython-312.pyc b/src/opencode_api/provider/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fec2570d8c7649844627c73475e2a04a422e456 Binary files /dev/null and b/src/opencode_api/provider/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/opencode_api/provider/__pycache__/anthropic.cpython-312.pyc b/src/opencode_api/provider/__pycache__/anthropic.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..100b5c023ac65c78466ec500bede470d09c0655b Binary files /dev/null and b/src/opencode_api/provider/__pycache__/anthropic.cpython-312.pyc differ diff --git a/src/opencode_api/provider/__pycache__/blablador.cpython-312.pyc b/src/opencode_api/provider/__pycache__/blablador.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70658e160fbb8809003ce36c99145c8b8550fb9d Binary files /dev/null and b/src/opencode_api/provider/__pycache__/blablador.cpython-312.pyc differ diff --git a/src/opencode_api/provider/__pycache__/gemini.cpython-312.pyc b/src/opencode_api/provider/__pycache__/gemini.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24a0f58ce8c5d153aa3252179e3895a95e0d6e26 Binary files /dev/null and b/src/opencode_api/provider/__pycache__/gemini.cpython-312.pyc differ diff --git a/src/opencode_api/provider/__pycache__/litellm.cpython-312.pyc b/src/opencode_api/provider/__pycache__/litellm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f4ec23f1b5d305fa4aca60aa77fdf885d77a37a Binary files /dev/null and b/src/opencode_api/provider/__pycache__/litellm.cpython-312.pyc differ diff --git a/src/opencode_api/provider/__pycache__/openai.cpython-312.pyc b/src/opencode_api/provider/__pycache__/openai.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a91cc54bd9351fdc17a0451aa141dc89b6f8a86a Binary files /dev/null and b/src/opencode_api/provider/__pycache__/openai.cpython-312.pyc differ diff --git a/src/opencode_api/provider/__pycache__/provider.cpython-312.pyc b/src/opencode_api/provider/__pycache__/provider.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b221a03713e64de49d517f20945c2d2eef4bd3d6 Binary files /dev/null and b/src/opencode_api/provider/__pycache__/provider.cpython-312.pyc differ diff --git a/src/opencode_api/provider/anthropic.py b/src/opencode_api/provider/anthropic.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc1df29b3cf31c2c5a0d8b31c75920772cf6721 --- /dev/null +++ b/src/opencode_api/provider/anthropic.py @@ -0,0 +1,204 @@ +from typing import Dict, Any, List, Optional, AsyncGenerator +import os +import json + +from .provider import BaseProvider, ModelInfo, Message, StreamChunk, ToolCall + + +MODELS_WITH_EXTENDED_THINKING = {"claude-sonnet-4-20250514", "claude-opus-4-20250514"} + + +class AnthropicProvider(BaseProvider): + + def __init__(self, api_key: Optional[str] = None): + self._api_key = api_key or os.environ.get("ANTHROPIC_API_KEY") + self._client = None + + @property + def id(self) -> str: + return "anthropic" + + @property + def name(self) -> str: + return "Anthropic" + + @property + def models(self) -> Dict[str, ModelInfo]: + return { + "claude-sonnet-4-20250514": ModelInfo( + id="claude-sonnet-4-20250514", + name="Claude Sonnet 4", + provider_id="anthropic", + context_limit=200000, + output_limit=64000, + supports_tools=True, + supports_streaming=True, + cost_input=3.0, + cost_output=15.0, + ), + "claude-opus-4-20250514": ModelInfo( + id="claude-opus-4-20250514", + name="Claude Opus 4", + provider_id="anthropic", + context_limit=200000, + output_limit=32000, + supports_tools=True, + supports_streaming=True, + cost_input=15.0, + cost_output=75.0, + ), + "claude-3-5-haiku-20241022": ModelInfo( + id="claude-3-5-haiku-20241022", + name="Claude 3.5 Haiku", + provider_id="anthropic", + context_limit=200000, + output_limit=8192, + supports_tools=True, + supports_streaming=True, + cost_input=0.8, + cost_output=4.0, + ), + } + + def _get_client(self): + if self._client is None: + try: + import anthropic + self._client = anthropic.AsyncAnthropic(api_key=self._api_key) + except ImportError: + raise ImportError("anthropic package is required. Install with: pip install anthropic") + return self._client + + def _supports_extended_thinking(self, model_id: str) -> bool: + return model_id in MODELS_WITH_EXTENDED_THINKING + + async def stream( + self, + model_id: str, + messages: List[Message], + tools: Optional[List[Dict[str, Any]]] = None, + system: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + ) -> AsyncGenerator[StreamChunk, None]: + client = self._get_client() + + anthropic_messages = [] + for msg in messages: + content = msg.content + if isinstance(content, str): + anthropic_messages.append({"role": msg.role, "content": content}) + else: + anthropic_messages.append({ + "role": msg.role, + "content": [{"type": c.type, "text": c.text} for c in content if c.text] + }) + + kwargs: Dict[str, Any] = { + "model": model_id, + "messages": anthropic_messages, + "max_tokens": max_tokens or 16000, + } + + if system: + kwargs["system"] = system + + if temperature is not None: + kwargs["temperature"] = temperature + + if tools: + kwargs["tools"] = [ + { + "name": t["name"], + "description": t.get("description", ""), + "input_schema": t.get("parameters", t.get("input_schema", {})) + } + for t in tools + ] + + use_extended_thinking = self._supports_extended_thinking(model_id) + + async for chunk in self._stream_with_fallback(client, kwargs, use_extended_thinking): + yield chunk + + async def _stream_with_fallback( + self, client, kwargs: Dict[str, Any], use_extended_thinking: bool + ): + if use_extended_thinking: + kwargs["thinking"] = { + "type": "enabled", + "budget_tokens": 10000 + } + + try: + async for chunk in self._do_stream(client, kwargs): + yield chunk + except Exception as e: + error_str = str(e).lower() + has_thinking = "thinking" in kwargs + + if has_thinking and ("thinking" in error_str or "unsupported" in error_str or "invalid" in error_str): + del kwargs["thinking"] + async for chunk in self._do_stream(client, kwargs): + yield chunk + else: + yield StreamChunk(type="error", error=str(e)) + + async def _do_stream(self, client, kwargs: Dict[str, Any]): + current_tool_call = None + + async with client.messages.stream(**kwargs) as stream: + async for event in stream: + if event.type == "content_block_start": + if hasattr(event, "content_block"): + block = event.content_block + if block.type == "tool_use": + current_tool_call = { + "id": block.id, + "name": block.name, + "arguments_json": "" + } + + elif event.type == "content_block_delta": + if hasattr(event, "delta"): + delta = event.delta + if delta.type == "text_delta": + yield StreamChunk(type="text", text=delta.text) + elif delta.type == "thinking_delta": + yield StreamChunk(type="reasoning", text=delta.thinking) + elif delta.type == "input_json_delta" and current_tool_call: + current_tool_call["arguments_json"] += delta.partial_json + + elif event.type == "content_block_stop": + if current_tool_call: + try: + args = json.loads(current_tool_call["arguments_json"]) if current_tool_call["arguments_json"] else {} + except json.JSONDecodeError: + args = {} + yield StreamChunk( + type="tool_call", + tool_call=ToolCall( + id=current_tool_call["id"], + name=current_tool_call["name"], + arguments=args + ) + ) + current_tool_call = None + + elif event.type == "message_stop": + final_message = await stream.get_final_message() + usage = { + "input_tokens": final_message.usage.input_tokens, + "output_tokens": final_message.usage.output_tokens, + } + stop_reason = self._map_stop_reason(final_message.stop_reason) + yield StreamChunk(type="done", usage=usage, stop_reason=stop_reason) + + def _map_stop_reason(self, anthropic_stop_reason: Optional[str]) -> str: + mapping = { + "end_turn": "end_turn", + "tool_use": "tool_calls", + "max_tokens": "max_tokens", + "stop_sequence": "end_turn", + } + return mapping.get(anthropic_stop_reason or "", "end_turn") diff --git a/src/opencode_api/provider/blablador.py b/src/opencode_api/provider/blablador.py new file mode 100644 index 0000000000000000000000000000000000000000..2c49546b1fc355920f0ce437a908775a93c2d18e --- /dev/null +++ b/src/opencode_api/provider/blablador.py @@ -0,0 +1,57 @@ +from typing import Dict, Any, List, Optional, AsyncGenerator +import os +import json + +from .provider import ModelInfo, Message, StreamChunk, ToolCall +from .openai import OpenAIProvider + + +class BlabladorProvider(OpenAIProvider): + + def __init__(self, api_key: Optional[str] = None): + super().__init__(api_key=api_key or os.environ.get("BLABLADOR_API_KEY")) + self._base_url = "https://api.helmholtz-blablador.fz-juelich.de/v1" + + @property + def id(self) -> str: + return "blablador" + + @property + def name(self) -> str: + return "Blablador" + + @property + def models(self) -> Dict[str, ModelInfo]: + return { + "alias-large": ModelInfo( + id="alias-large", + name="Blablador Large", + provider_id="blablador", + context_limit=32768, + output_limit=4096, + supports_tools=True, + supports_streaming=True, + cost_input=0.0, + cost_output=0.0, + ), + "alias-fast": ModelInfo( + id="alias-fast", + name="Blablador Fast", + provider_id="blablador", + context_limit=8192, + output_limit=2048, + supports_tools=True, + supports_streaming=True, + cost_input=0.0, + cost_output=0.0, + ), + } + + def _get_client(self): + if self._client is None: + try: + from openai import AsyncOpenAI + self._client = AsyncOpenAI(api_key=self._api_key, base_url=self._base_url) + except ImportError: + raise ImportError("openai package is required. Install with: pip install openai") + return self._client diff --git a/src/opencode_api/provider/gemini.py b/src/opencode_api/provider/gemini.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc15c2689dc3cad83e9b0432b297d6da31d5b65 --- /dev/null +++ b/src/opencode_api/provider/gemini.py @@ -0,0 +1,215 @@ +from typing import Dict, Any, List, Optional, AsyncGenerator +import os +import logging + +from .provider import BaseProvider, ModelInfo, Message, StreamChunk, ToolCall + +logger = logging.getLogger(__name__) + + +GEMINI3_MODELS = { + "gemini-3-flash-preview", +} + + +class GeminiProvider(BaseProvider): + + def __init__(self, api_key: Optional[str] = None): + self._api_key = api_key or os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY") + self._client = None + + @property + def id(self) -> str: + return "gemini" + + @property + def name(self) -> str: + return "Google Gemini" + + @property + def models(self) -> Dict[str, ModelInfo]: + return { + "gemini-3-flash-preview": ModelInfo( + id="gemini-3-flash-preview", + name="Gemini 3.0 Flash", + provider_id="gemini", + context_limit=1048576, + output_limit=65536, + supports_tools=True, + supports_streaming=True, + cost_input=0.5, + cost_output=3.0, + ), + } + + def _get_client(self): + if self._client is None: + try: + from google import genai + self._client = genai.Client(api_key=self._api_key) + except ImportError: + raise ImportError("google-genai package is required. Install with: pip install google-genai") + return self._client + + def _is_gemini3(self, model_id: str) -> bool: + return model_id in GEMINI3_MODELS + + async def stream( + self, + model_id: str, + messages: List[Message], + tools: Optional[List[Dict[str, Any]]] = None, + system: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + ) -> AsyncGenerator[StreamChunk, None]: + from google.genai import types + + client = self._get_client() + + contents = [] + print(f"[Gemini DEBUG] Building contents from {len(messages)} messages", flush=True) + for msg in messages: + role = "user" if msg.role == "user" else "model" + content = msg.content + print(f"[Gemini DEBUG] msg.role={msg.role}, content type={type(content)}, content={repr(content)[:100]}", flush=True) + + if isinstance(content, str) and content: + contents.append(types.Content( + role=role, + parts=[types.Part(text=content)] + )) + elif content: + parts = [types.Part(text=c.text) for c in content if c.text] + if parts: + contents.append(types.Content(role=role, parts=parts)) + + print(f"[Gemini DEBUG] Built {len(contents)} contents", flush=True) + + config_kwargs: Dict[str, Any] = {} + + if system: + config_kwargs["system_instruction"] = system + + if temperature is not None: + config_kwargs["temperature"] = temperature + + if max_tokens is not None: + config_kwargs["max_output_tokens"] = max_tokens + + if self._is_gemini3(model_id): + config_kwargs["thinking_config"] = types.ThinkingConfig( + include_thoughts=True + ) + # thinking_level ๋ฏธ์„ค์ • โ†’ ๊ธฐ๋ณธ๊ฐ’ "high" (๋™์  reasoning) + + if tools: + gemini_tools = [] + for t in tools: + func_decl = types.FunctionDeclaration( + name=t["name"], + description=t.get("description", ""), + parameters=t.get("parameters", t.get("input_schema", {})) + ) + gemini_tools.append(types.Tool(function_declarations=[func_decl])) + config_kwargs["tools"] = gemini_tools + + config = types.GenerateContentConfig(**config_kwargs) + + async for chunk in self._stream_with_fallback( + client, model_id, contents, config, config_kwargs, types + ): + yield chunk + + async def _stream_with_fallback( + self, client, model_id: str, contents, config, config_kwargs: Dict[str, Any], types + ): + try: + async for chunk in self._do_stream(client, model_id, contents, config): + yield chunk + except Exception as e: + error_str = str(e).lower() + has_thinking = "thinking_config" in config_kwargs + + if has_thinking and ("thinking" in error_str or "budget" in error_str or "level" in error_str or "unsupported" in error_str): + logger.warning(f"Thinking not supported for {model_id}, retrying without thinking config") + del config_kwargs["thinking_config"] + fallback_config = types.GenerateContentConfig(**config_kwargs) + + async for chunk in self._do_stream(client, model_id, contents, fallback_config): + yield chunk + else: + logger.error(f"Gemini stream error: {e}") + yield StreamChunk(type="error", error=str(e)) + + async def _do_stream(self, client, model_id: str, contents, config): + response_stream = await client.aio.models.generate_content_stream( + model=model_id, + contents=contents, + config=config, + ) + + pending_tool_calls = [] + + async for chunk in response_stream: + if not chunk.candidates: + continue + + candidate = chunk.candidates[0] + + if candidate.content and candidate.content.parts: + for part in candidate.content.parts: + if hasattr(part, 'thought') and part.thought: + if part.text: + yield StreamChunk(type="reasoning", text=part.text) + elif hasattr(part, 'function_call') and part.function_call: + fc = part.function_call + tool_call = ToolCall( + id=f"call_{fc.name}_{len(pending_tool_calls)}", + name=fc.name, + arguments=dict(fc.args) if fc.args else {} + ) + pending_tool_calls.append(tool_call) + elif part.text: + yield StreamChunk(type="text", text=part.text) + + finish_reason = getattr(candidate, 'finish_reason', None) + if finish_reason: + print(f"[Gemini] finish_reason: {finish_reason}, pending_tool_calls: {len(pending_tool_calls)}", flush=True) + for tc in pending_tool_calls: + yield StreamChunk(type="tool_call", tool_call=tc) + + # IMPORTANT: If there are pending tool calls, ALWAYS return "tool_calls" + # regardless of Gemini's finish_reason (which is often STOP even with tool calls) + if pending_tool_calls: + stop_reason = "tool_calls" + else: + stop_reason = self._map_stop_reason(finish_reason) + print(f"[Gemini] Mapped stop_reason: {stop_reason}", flush=True) + + usage = None + if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata: + usage = { + "input_tokens": getattr(chunk.usage_metadata, 'prompt_token_count', 0), + "output_tokens": getattr(chunk.usage_metadata, 'candidates_token_count', 0), + } + if hasattr(chunk.usage_metadata, 'thoughts_token_count'): + usage["thinking_tokens"] = chunk.usage_metadata.thoughts_token_count + + yield StreamChunk(type="done", usage=usage, stop_reason=stop_reason) + return + + yield StreamChunk(type="done", stop_reason="end_turn") + + def _map_stop_reason(self, gemini_finish_reason) -> str: + reason_name = str(gemini_finish_reason).lower() if gemini_finish_reason else "" + + if "stop" in reason_name or "end" in reason_name: + return "end_turn" + elif "tool" in reason_name or "function" in reason_name: + return "tool_calls" + elif "max" in reason_name or "length" in reason_name: + return "max_tokens" + elif "safety" in reason_name: + return "safety" + return "end_turn" diff --git a/src/opencode_api/provider/litellm.py b/src/opencode_api/provider/litellm.py new file mode 100644 index 0000000000000000000000000000000000000000..2faec84daf1e7ac43cf0338138b367b7317c2708 --- /dev/null +++ b/src/opencode_api/provider/litellm.py @@ -0,0 +1,363 @@ +from typing import Dict, Any, List, Optional, AsyncGenerator +import json +import os + +from .provider import BaseProvider, ModelInfo, Message, StreamChunk, ToolCall + + +DEFAULT_MODELS = { + "claude-sonnet-4-20250514": ModelInfo( + id="claude-sonnet-4-20250514", + name="Claude Sonnet 4", + provider_id="litellm", + context_limit=200000, + output_limit=64000, + supports_tools=True, + supports_streaming=True, + cost_input=3.0, + cost_output=15.0, + ), + "claude-opus-4-20250514": ModelInfo( + id="claude-opus-4-20250514", + name="Claude Opus 4", + provider_id="litellm", + context_limit=200000, + output_limit=32000, + supports_tools=True, + supports_streaming=True, + cost_input=15.0, + cost_output=75.0, + ), + "claude-3-5-haiku-20241022": ModelInfo( + id="claude-3-5-haiku-20241022", + name="Claude 3.5 Haiku", + provider_id="litellm", + context_limit=200000, + output_limit=8192, + supports_tools=True, + supports_streaming=True, + cost_input=0.8, + cost_output=4.0, + ), + "gpt-4o": ModelInfo( + id="gpt-4o", + name="GPT-4o", + provider_id="litellm", + context_limit=128000, + output_limit=16384, + supports_tools=True, + supports_streaming=True, + cost_input=2.5, + cost_output=10.0, + ), + "gpt-4o-mini": ModelInfo( + id="gpt-4o-mini", + name="GPT-4o Mini", + provider_id="litellm", + context_limit=128000, + output_limit=16384, + supports_tools=True, + supports_streaming=True, + cost_input=0.15, + cost_output=0.6, + ), + "o1": ModelInfo( + id="o1", + name="O1", + provider_id="litellm", + context_limit=200000, + output_limit=100000, + supports_tools=True, + supports_streaming=True, + cost_input=15.0, + cost_output=60.0, + ), + "gemini/gemini-2.0-flash": ModelInfo( + id="gemini/gemini-2.0-flash", + name="Gemini 2.0 Flash", + provider_id="litellm", + context_limit=1000000, + output_limit=8192, + supports_tools=True, + supports_streaming=True, + cost_input=0.075, + cost_output=0.3, + ), + "gemini/gemini-2.5-pro-preview-05-06": ModelInfo( + id="gemini/gemini-2.5-pro-preview-05-06", + name="Gemini 2.5 Pro", + provider_id="litellm", + context_limit=1000000, + output_limit=65536, + supports_tools=True, + supports_streaming=True, + cost_input=1.25, + cost_output=10.0, + ), + "groq/llama-3.3-70b-versatile": ModelInfo( + id="groq/llama-3.3-70b-versatile", + name="Llama 3.3 70B (Groq)", + provider_id="litellm", + context_limit=128000, + output_limit=32768, + supports_tools=True, + supports_streaming=True, + cost_input=0.59, + cost_output=0.79, + ), + "deepseek/deepseek-chat": ModelInfo( + id="deepseek/deepseek-chat", + name="DeepSeek Chat", + provider_id="litellm", + context_limit=64000, + output_limit=8192, + supports_tools=True, + supports_streaming=True, + cost_input=0.14, + cost_output=0.28, + ), + "openrouter/anthropic/claude-sonnet-4": ModelInfo( + id="openrouter/anthropic/claude-sonnet-4", + name="Claude Sonnet 4 (OpenRouter)", + provider_id="litellm", + context_limit=200000, + output_limit=64000, + supports_tools=True, + supports_streaming=True, + cost_input=3.0, + cost_output=15.0, + ), + # Z.ai Free Flash Models + "zai/glm-4.7-flash": ModelInfo( + id="zai/glm-4.7-flash", + name="GLM-4.7 Flash (Free)", + provider_id="litellm", + context_limit=128000, + output_limit=8192, + supports_tools=True, + supports_streaming=True, + cost_input=0.0, + cost_output=0.0, + ), + "zai/glm-4.6v-flash": ModelInfo( + id="zai/glm-4.6v-flash", + name="GLM-4.6V Flash (Free)", + provider_id="litellm", + context_limit=128000, + output_limit=8192, + supports_tools=True, + supports_streaming=True, + cost_input=0.0, + cost_output=0.0, + ), + "zai/glm-4.5-flash": ModelInfo( + id="zai/glm-4.5-flash", + name="GLM-4.5 Flash (Free)", + provider_id="litellm", + context_limit=128000, + output_limit=8192, + supports_tools=True, + supports_streaming=True, + cost_input=0.0, + cost_output=0.0, + ), +} + + +class LiteLLMProvider(BaseProvider): + + def __init__(self): + self._litellm = None + self._models = dict(DEFAULT_MODELS) + + @property + def id(self) -> str: + return "litellm" + + @property + def name(self) -> str: + return "LiteLLM (Multi-Provider)" + + @property + def models(self) -> Dict[str, ModelInfo]: + return self._models + + def add_model(self, model: ModelInfo) -> None: + self._models[model.id] = model + + def _get_litellm(self): + if self._litellm is None: + try: + import litellm + litellm.drop_params = True + self._litellm = litellm + except ImportError: + raise ImportError("litellm package is required. Install with: pip install litellm") + return self._litellm + + async def stream( + self, + model_id: str, + messages: List[Message], + tools: Optional[List[Dict[str, Any]]] = None, + system: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + ) -> AsyncGenerator[StreamChunk, None]: + litellm = self._get_litellm() + + litellm_messages = [] + + if system: + litellm_messages.append({"role": "system", "content": system}) + + for msg in messages: + content = msg.content + if isinstance(content, str): + litellm_messages.append({"role": msg.role, "content": content}) + else: + litellm_messages.append({ + "role": msg.role, + "content": [{"type": c.type, "text": c.text} for c in content if c.text] + }) + + # Z.ai ๋ชจ๋ธ ์ฒ˜๋ฆฌ: OpenAI-compatible API ์‚ฌ์šฉ + actual_model = model_id + if model_id.startswith("zai/"): + # zai/glm-4.7-flash -> openai/glm-4.7-flash with custom api_base + actual_model = "openai/" + model_id[4:] + + kwargs: Dict[str, Any] = { + "model": actual_model, + "messages": litellm_messages, + "stream": True, + } + + # Z.ai ์ „์šฉ ์„ค์ • + if model_id.startswith("zai/"): + kwargs["api_base"] = os.environ.get("ZAI_API_BASE", "https://api.z.ai/api/paas/v4") + kwargs["api_key"] = os.environ.get("ZAI_API_KEY") + + if temperature is not None: + kwargs["temperature"] = temperature + + if max_tokens is not None: + kwargs["max_tokens"] = max_tokens + else: + kwargs["max_tokens"] = 8192 + + if tools: + kwargs["tools"] = [ + { + "type": "function", + "function": { + "name": t["name"], + "description": t.get("description", ""), + "parameters": t.get("parameters", t.get("input_schema", {})) + } + } + for t in tools + ] + + current_tool_calls: Dict[int, Dict[str, Any]] = {} + + try: + response = await litellm.acompletion(**kwargs) + + async for chunk in response: + if hasattr(chunk, 'choices') and chunk.choices: + choice = chunk.choices[0] + delta = getattr(choice, 'delta', None) + + if delta: + if hasattr(delta, 'content') and delta.content: + yield StreamChunk(type="text", text=delta.content) + + if hasattr(delta, 'tool_calls') and delta.tool_calls: + for tc in delta.tool_calls: + idx = tc.index if hasattr(tc, 'index') else 0 + + if idx not in current_tool_calls: + current_tool_calls[idx] = { + "id": tc.id if hasattr(tc, 'id') and tc.id else f"call_{idx}", + "name": "", + "arguments_json": "" + } + + if hasattr(tc, 'function'): + if hasattr(tc.function, 'name') and tc.function.name: + current_tool_calls[idx]["name"] = tc.function.name + if hasattr(tc.function, 'arguments') and tc.function.arguments: + current_tool_calls[idx]["arguments_json"] += tc.function.arguments + + finish_reason = getattr(choice, 'finish_reason', None) + if finish_reason: + for idx, tc_data in current_tool_calls.items(): + if tc_data["name"]: + try: + args = json.loads(tc_data["arguments_json"]) if tc_data["arguments_json"] else {} + except json.JSONDecodeError: + args = {} + + yield StreamChunk( + type="tool_call", + tool_call=ToolCall( + id=tc_data["id"], + name=tc_data["name"], + arguments=args + ) + ) + + usage = None + if hasattr(chunk, 'usage') and chunk.usage: + usage = { + "input_tokens": getattr(chunk.usage, 'prompt_tokens', 0), + "output_tokens": getattr(chunk.usage, 'completion_tokens', 0), + } + + stop_reason = self._map_stop_reason(finish_reason) + yield StreamChunk(type="done", usage=usage, stop_reason=stop_reason) + + except Exception as e: + yield StreamChunk(type="error", error=str(e)) + + async def complete( + self, + model_id: str, + prompt: str, + max_tokens: int = 100, + ) -> str: + """๋‹จ์ผ ์™„๋ฃŒ ์š”์ฒญ (์ŠคํŠธ๋ฆฌ๋ฐ ์—†์Œ)""" + litellm = self._get_litellm() + + actual_model = model_id + kwargs: Dict[str, Any] = { + "model": actual_model, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": max_tokens, + } + + # Z.ai ๋ชจ๋ธ ์ฒ˜๋ฆฌ + if model_id.startswith("zai/"): + actual_model = "openai/" + model_id[4:] + kwargs["model"] = actual_model + kwargs["api_base"] = os.environ.get("ZAI_API_BASE", "https://api.z.ai/api/paas/v4") + kwargs["api_key"] = os.environ.get("ZAI_API_KEY") + + response = await litellm.acompletion(**kwargs) + return response.choices[0].message.content or "" + + def _map_stop_reason(self, finish_reason: Optional[str]) -> str: + if not finish_reason: + return "end_turn" + + mapping = { + "stop": "end_turn", + "end_turn": "end_turn", + "tool_calls": "tool_calls", + "function_call": "tool_calls", + "length": "max_tokens", + "max_tokens": "max_tokens", + "content_filter": "content_filter", + } + return mapping.get(finish_reason, "end_turn") diff --git a/src/opencode_api/provider/openai.py b/src/opencode_api/provider/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..02ee141b44dc76047ec1cf6e077a6d94610bb51c --- /dev/null +++ b/src/opencode_api/provider/openai.py @@ -0,0 +1,182 @@ +from typing import Dict, Any, List, Optional, AsyncGenerator +import os +import json + +from .provider import BaseProvider, ModelInfo, Message, StreamChunk, ToolCall + + +class OpenAIProvider(BaseProvider): + + def __init__(self, api_key: Optional[str] = None): + self._api_key = api_key or os.environ.get("OPENAI_API_KEY") + self._client = None + + @property + def id(self) -> str: + return "openai" + + @property + def name(self) -> str: + return "OpenAI" + + @property + def models(self) -> Dict[str, ModelInfo]: + return { + "gpt-4o": ModelInfo( + id="gpt-4o", + name="GPT-4o", + provider_id="openai", + context_limit=128000, + output_limit=16384, + supports_tools=True, + supports_streaming=True, + cost_input=2.5, + cost_output=10.0, + ), + "gpt-4o-mini": ModelInfo( + id="gpt-4o-mini", + name="GPT-4o Mini", + provider_id="openai", + context_limit=128000, + output_limit=16384, + supports_tools=True, + supports_streaming=True, + cost_input=0.15, + cost_output=0.6, + ), + "o1": ModelInfo( + id="o1", + name="o1", + provider_id="openai", + context_limit=200000, + output_limit=100000, + supports_tools=True, + supports_streaming=True, + cost_input=15.0, + cost_output=60.0, + ), + } + + def _get_client(self): + if self._client is None: + try: + from openai import AsyncOpenAI + self._client = AsyncOpenAI(api_key=self._api_key) + except ImportError: + raise ImportError("openai package is required. Install with: pip install openai") + return self._client + + async def stream( + self, + model_id: str, + messages: List[Message], + tools: Optional[List[Dict[str, Any]]] = None, + system: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + ) -> AsyncGenerator[StreamChunk, None]: + client = self._get_client() + + openai_messages = [] + + if system: + openai_messages.append({"role": "system", "content": system}) + + for msg in messages: + content = msg.content + if isinstance(content, str): + openai_messages.append({"role": msg.role, "content": content}) + else: + openai_messages.append({ + "role": msg.role, + "content": [{"type": c.type, "text": c.text} for c in content if c.text] + }) + + kwargs: Dict[str, Any] = { + "model": model_id, + "messages": openai_messages, + "stream": True, + } + + if max_tokens: + kwargs["max_tokens"] = max_tokens + + if temperature is not None: + kwargs["temperature"] = temperature + + if tools: + kwargs["tools"] = [ + { + "type": "function", + "function": { + "name": t["name"], + "description": t.get("description", ""), + "parameters": t.get("parameters", t.get("input_schema", {})) + } + } + for t in tools + ] + + tool_calls: Dict[int, Dict[str, Any]] = {} + usage_data = None + finish_reason = None + + async for chunk in await client.chat.completions.create(**kwargs): + if chunk.choices and chunk.choices[0].delta: + delta = chunk.choices[0].delta + + if delta.content: + yield StreamChunk(type="text", text=delta.content) + + if delta.tool_calls: + for tc in delta.tool_calls: + idx = tc.index + if idx not in tool_calls: + tool_calls[idx] = { + "id": tc.id or "", + "name": tc.function.name if tc.function else "", + "arguments": "" + } + + if tc.id: + tool_calls[idx]["id"] = tc.id + if tc.function: + if tc.function.name: + tool_calls[idx]["name"] = tc.function.name + if tc.function.arguments: + tool_calls[idx]["arguments"] += tc.function.arguments + + if chunk.choices and chunk.choices[0].finish_reason: + finish_reason = chunk.choices[0].finish_reason + + if chunk.usage: + usage_data = { + "input_tokens": chunk.usage.prompt_tokens, + "output_tokens": chunk.usage.completion_tokens, + } + + for tc_data in tool_calls.values(): + try: + args = json.loads(tc_data["arguments"]) if tc_data["arguments"] else {} + except json.JSONDecodeError: + args = {} + yield StreamChunk( + type="tool_call", + tool_call=ToolCall( + id=tc_data["id"], + name=tc_data["name"], + arguments=args + ) + ) + + stop_reason = self._map_stop_reason(finish_reason) + yield StreamChunk(type="done", usage=usage_data, stop_reason=stop_reason) + + def _map_stop_reason(self, openai_finish_reason: Optional[str]) -> str: + mapping = { + "stop": "end_turn", + "tool_calls": "tool_calls", + "length": "max_tokens", + "content_filter": "end_turn", + } + return mapping.get(openai_finish_reason or "", "end_turn") diff --git a/src/opencode_api/provider/provider.py b/src/opencode_api/provider/provider.py new file mode 100644 index 0000000000000000000000000000000000000000..84eb715426983d3633ea3991f5eecce0a4f1e58b --- /dev/null +++ b/src/opencode_api/provider/provider.py @@ -0,0 +1,133 @@ +from typing import Dict, Any, List, Optional, AsyncIterator, AsyncGenerator, Protocol, runtime_checkable +from pydantic import BaseModel, Field +from abc import ABC, abstractmethod + + +class ModelInfo(BaseModel): + id: str + name: str + provider_id: str + context_limit: int = 128000 + output_limit: int = 8192 + supports_tools: bool = True + supports_streaming: bool = True + cost_input: float = 0.0 # per 1M tokens + cost_output: float = 0.0 # per 1M tokens + + +class ProviderInfo(BaseModel): + id: str + name: str + models: Dict[str, ModelInfo] = Field(default_factory=dict) + + +class MessageContent(BaseModel): + type: str = "text" + text: Optional[str] = None + + +class Message(BaseModel): + role: str # "user", "assistant", "system" + content: str | List[MessageContent] + + +class ToolCall(BaseModel): + id: str + name: str + arguments: Dict[str, Any] + + +class ToolResult(BaseModel): + tool_call_id: str + output: str + + +class StreamChunk(BaseModel): + type: str # "text", "reasoning", "tool_call", "tool_result", "done", "error" + text: Optional[str] = None + tool_call: Optional[ToolCall] = None + error: Optional[str] = None + usage: Optional[Dict[str, int]] = None + stop_reason: Optional[str] = None # "end_turn", "tool_calls", "max_tokens", etc. + + +@runtime_checkable +class Provider(Protocol): + + @property + def id(self) -> str: ... + + @property + def name(self) -> str: ... + + @property + def models(self) -> Dict[str, ModelInfo]: ... + + def stream( + self, + model_id: str, + messages: List[Message], + tools: Optional[List[Dict[str, Any]]] = None, + system: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + ) -> AsyncGenerator[StreamChunk, None]: ... + + +class BaseProvider(ABC): + + @property + @abstractmethod + def id(self) -> str: + pass + + @property + @abstractmethod + def name(self) -> str: + pass + + @property + @abstractmethod + def models(self) -> Dict[str, ModelInfo]: + pass + + @abstractmethod + def stream( + self, + model_id: str, + messages: List[Message], + tools: Optional[List[Dict[str, Any]]] = None, + system: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + ) -> AsyncGenerator[StreamChunk, None]: + pass + + def get_info(self) -> ProviderInfo: + return ProviderInfo( + id=self.id, + name=self.name, + models=self.models + ) + + +_providers: Dict[str, BaseProvider] = {} + + +def register_provider(provider: BaseProvider) -> None: + _providers[provider.id] = provider + + +def get_provider(provider_id: str) -> Optional[BaseProvider]: + return _providers.get(provider_id) + + +def list_providers() -> List[ProviderInfo]: + return [p.get_info() for p in _providers.values()] + + +def get_model(provider_id: str, model_id: str) -> Optional[ModelInfo]: + provider = get_provider(provider_id) + if provider: + return provider.models.get(model_id) + return None diff --git a/src/opencode_api/routes/__init__.py b/src/opencode_api/routes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2a589f4c22317f84e94881ec7118649980f3dc2 --- /dev/null +++ b/src/opencode_api/routes/__init__.py @@ -0,0 +1,7 @@ +from .session import router as session_router +from .provider import router as provider_router +from .event import router as event_router +from .question import router as question_router +from .agent import router as agent_router + +__all__ = ["session_router", "provider_router", "event_router", "question_router", "agent_router"] diff --git a/src/opencode_api/routes/__pycache__/__init__.cpython-312.pyc b/src/opencode_api/routes/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e97f3de6ebc9c7378d55b3afc035a52e6833fc6b Binary files /dev/null and b/src/opencode_api/routes/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/opencode_api/routes/__pycache__/agent.cpython-312.pyc b/src/opencode_api/routes/__pycache__/agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6758ce746411ec0521d730d732606fbd9d55672 Binary files /dev/null and b/src/opencode_api/routes/__pycache__/agent.cpython-312.pyc differ diff --git a/src/opencode_api/routes/__pycache__/event.cpython-312.pyc b/src/opencode_api/routes/__pycache__/event.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8c612322cfbf5a7300ad1dc319f49feacafa743 Binary files /dev/null and b/src/opencode_api/routes/__pycache__/event.cpython-312.pyc differ diff --git a/src/opencode_api/routes/__pycache__/provider.cpython-312.pyc b/src/opencode_api/routes/__pycache__/provider.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c156a4bbbb53cc74a1e096ffffdcc5edb8f5e8e Binary files /dev/null and b/src/opencode_api/routes/__pycache__/provider.cpython-312.pyc differ diff --git a/src/opencode_api/routes/__pycache__/question.cpython-312.pyc b/src/opencode_api/routes/__pycache__/question.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53ef45f4021c56386e9a55b812c321002eb94af4 Binary files /dev/null and b/src/opencode_api/routes/__pycache__/question.cpython-312.pyc differ diff --git a/src/opencode_api/routes/__pycache__/session.cpython-312.pyc b/src/opencode_api/routes/__pycache__/session.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c14e4fd7c83001bb74ccdbb20e9220957a292fc Binary files /dev/null and b/src/opencode_api/routes/__pycache__/session.cpython-312.pyc differ diff --git a/src/opencode_api/routes/agent.py b/src/opencode_api/routes/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..9fe2d5e8f5d913997f8980779577a42c061bdb28 --- /dev/null +++ b/src/opencode_api/routes/agent.py @@ -0,0 +1,66 @@ +""" +Agent routes - manage agent configurations. +""" + +from fastapi import APIRouter, HTTPException +from typing import Optional, List + +from ..agent import ( + AgentInfo, + get, + list_agents, + default_agent, + register, + unregister, +) + +router = APIRouter(prefix="/agent", tags=["agent"]) + + +@router.get("", response_model=List[AgentInfo]) +async def get_agents( + mode: Optional[str] = None, + include_hidden: bool = False +): + """List all available agents.""" + return list_agents(mode=mode, include_hidden=include_hidden) + + +@router.get("/default", response_model=AgentInfo) +async def get_default_agent(): + """Get the default agent configuration.""" + return default_agent() + + +@router.get("/{agent_id}", response_model=AgentInfo) +async def get_agent(agent_id: str): + """Get a specific agent by ID.""" + agent = get(agent_id) + if not agent: + raise HTTPException(status_code=404, detail=f"Agent not found: {agent_id}") + return agent + + +@router.post("", response_model=AgentInfo) +async def create_agent(agent: AgentInfo): + """Register a custom agent.""" + existing = get(agent.id) + if existing and existing.native: + raise HTTPException(status_code=400, detail=f"Cannot override native agent: {agent.id}") + + register(agent) + return agent + + +@router.delete("/{agent_id}") +async def delete_agent(agent_id: str): + """Unregister a custom agent.""" + agent = get(agent_id) + if not agent: + raise HTTPException(status_code=404, detail=f"Agent not found: {agent_id}") + + if agent.native: + raise HTTPException(status_code=400, detail=f"Cannot delete native agent: {agent_id}") + + unregister(agent_id) + return {"deleted": agent_id} diff --git a/src/opencode_api/routes/event.py b/src/opencode_api/routes/event.py new file mode 100644 index 0000000000000000000000000000000000000000..164d8aa4bd4bfa7747a91d16ffd0e3685ccd5dc4 --- /dev/null +++ b/src/opencode_api/routes/event.py @@ -0,0 +1,45 @@ +from fastapi import APIRouter +from fastapi.responses import StreamingResponse +import asyncio +import json +from typing import AsyncIterator + +from ..core.bus import Bus, EventInstance + + +router = APIRouter(tags=["Events"]) + + +@router.get("/event") +async def subscribe_events(): + async def event_generator() -> AsyncIterator[str]: + queue: asyncio.Queue[EventInstance] = asyncio.Queue() + + async def handler(event: EventInstance): + await queue.put(event) + + unsubscribe = Bus.subscribe_all(handler) + + yield f"data: {json.dumps({'type': 'server.connected', 'payload': {}})}\n\n" + + try: + while True: + try: + event = await asyncio.wait_for(queue.get(), timeout=30.0) + yield f"data: {json.dumps({'type': event.type, 'payload': event.payload})}\n\n" + except asyncio.TimeoutError: + yield f"data: {json.dumps({'type': 'server.heartbeat', 'payload': {}})}\n\n" + except asyncio.CancelledError: + pass + finally: + unsubscribe() + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + } + ) diff --git a/src/opencode_api/routes/provider.py b/src/opencode_api/routes/provider.py new file mode 100644 index 0000000000000000000000000000000000000000..a056176d67b1c9f7c77205a48a5f6b2d809650e3 --- /dev/null +++ b/src/opencode_api/routes/provider.py @@ -0,0 +1,107 @@ +from typing import List, Dict +from fastapi import APIRouter, HTTPException +import os +from dotenv import load_dotenv + +# .env ํŒŒ์ผ์—์„œ ํ™˜๊ฒฝ๋ณ€์ˆ˜ ๋กœ๋“œ +load_dotenv() + +from ..provider import list_providers, get_provider +from ..provider.provider import ProviderInfo, ModelInfo + + +router = APIRouter(prefix="/provider", tags=["Provider"]) + + +# Provider๋ณ„ ํ•„์š” ํ™˜๊ฒฝ๋ณ€์ˆ˜ ๋งคํ•‘ +PROVIDER_API_KEYS = { + "anthropic": "ANTHROPIC_API_KEY", + "openai": "OPENAI_API_KEY", + "gemini": ["GOOGLE_API_KEY", "GEMINI_API_KEY"], + "litellm": None, # LiteLLM์€ ๊ฐœ๋ณ„ ๋ชจ๋ธ๋ณ„๋กœ ์ฒดํฌ +} + +# LiteLLM ๋ชจ๋ธ๋ณ„ ํ•„์š” ํ™˜๊ฒฝ๋ณ€์ˆ˜ +LITELLM_MODEL_KEYS = { + "claude-": "ANTHROPIC_API_KEY", + "gpt-": "OPENAI_API_KEY", + "o1": "OPENAI_API_KEY", + "gemini/": ["GOOGLE_API_KEY", "GEMINI_API_KEY"], + "groq/": "GROQ_API_KEY", + "deepseek/": "DEEPSEEK_API_KEY", + "openrouter/": "OPENROUTER_API_KEY", + "zai/": "ZAI_API_KEY", +} + + +def has_api_key(provider_id: str) -> bool: + """Check if provider has required API key configured""" + keys = PROVIDER_API_KEYS.get(provider_id) + if keys is None: + return True # No key required (like litellm container) + if isinstance(keys, list): + return any(os.environ.get(k) for k in keys) + return bool(os.environ.get(keys)) + + +def filter_litellm_models(models: Dict[str, ModelInfo]) -> Dict[str, ModelInfo]: + """Filter LiteLLM models based on available API keys""" + filtered = {} + for model_id, model_info in models.items(): + for prefix, env_key in LITELLM_MODEL_KEYS.items(): + if model_id.startswith(prefix): + if isinstance(env_key, list): + if any(os.environ.get(k) for k in env_key): + filtered[model_id] = model_info + elif os.environ.get(env_key): + filtered[model_id] = model_info + break + return filtered + + +@router.get("/", response_model=List[ProviderInfo]) +async def get_providers(): + """Get available providers (filtered by API key availability)""" + all_providers = list_providers() + available = [] + + for provider in all_providers: + if provider.id == "litellm": + # LiteLLM: ๊ฐœ๋ณ„ ๋ชจ๋ธ๋ณ„ ํ•„ํ„ฐ๋ง + filtered_models = filter_litellm_models(provider.models) + if filtered_models: + provider.models = filtered_models + available.append(provider) + elif has_api_key(provider.id): + available.append(provider) + + return available + + +@router.get("/{provider_id}", response_model=ProviderInfo) +async def get_provider_info(provider_id: str): + provider = get_provider(provider_id) + if not provider: + raise HTTPException(status_code=404, detail=f"Provider not found: {provider_id}") + return provider.get_info() + + +@router.get("/{provider_id}/model", response_model=List[ModelInfo]) +async def get_provider_models(provider_id: str): + provider = get_provider(provider_id) + if not provider: + raise HTTPException(status_code=404, detail=f"Provider not found: {provider_id}") + return list(provider.models.values()) + + +@router.get("/{provider_id}/model/{model_id}", response_model=ModelInfo) +async def get_model_info(provider_id: str, model_id: str): + provider = get_provider(provider_id) + if not provider: + raise HTTPException(status_code=404, detail=f"Provider not found: {provider_id}") + + model = provider.models.get(model_id) + if not model: + raise HTTPException(status_code=404, detail=f"Model not found: {model_id}") + + return model diff --git a/src/opencode_api/routes/question.py b/src/opencode_api/routes/question.py new file mode 100644 index 0000000000000000000000000000000000000000..ed40282e22e55df6d5f596b36049027701d97774 --- /dev/null +++ b/src/opencode_api/routes/question.py @@ -0,0 +1,55 @@ +"""Question API routes.""" +from typing import List +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel, Field + +from ..tool import ( + reply_to_question, + reject_question, + get_pending_questions, + QuestionReply, +) + + +router = APIRouter(prefix="/question", tags=["question"]) + + +class QuestionAnswerRequest(BaseModel): + """Request to answer a question.""" + answers: List[List[str]] = Field(..., description="Answers in order (each is array of selected labels)") + + +@router.get("") +@router.get("/") +async def list_pending_questions(session_id: str = None): + """List all pending questions.""" + pending = get_pending_questions(session_id) + return {"pending": pending} + + +@router.post("/{request_id}/reply") +async def answer_question(request_id: str, request: QuestionAnswerRequest): + """Submit answers to a pending question.""" + success = await reply_to_question(request_id, request.answers) + + if not success: + raise HTTPException( + status_code=404, + detail=f"Question request '{request_id}' not found or already answered" + ) + + return {"status": "answered", "request_id": request_id} + + +@router.post("/{request_id}/reject") +async def dismiss_question(request_id: str): + """Dismiss/reject a pending question without answering.""" + success = await reject_question(request_id) + + if not success: + raise HTTPException( + status_code=404, + detail=f"Question request '{request_id}' not found or already handled" + ) + + return {"status": "rejected", "request_id": request_id} diff --git a/src/opencode_api/routes/session.py b/src/opencode_api/routes/session.py new file mode 100644 index 0000000000000000000000000000000000000000..b1ab8fb04ba733e73609753c13d34504759287a9 --- /dev/null +++ b/src/opencode_api/routes/session.py @@ -0,0 +1,206 @@ +from typing import Optional, List +from fastapi import APIRouter, HTTPException, Query, Depends +from fastapi.responses import StreamingResponse +from pydantic import BaseModel +import json + +from ..session import Session, SessionInfo, SessionCreate, Message, SessionPrompt +from ..session.prompt import PromptInput +from ..core.storage import NotFoundError +from ..core.auth import AuthUser, optional_auth, require_auth +from ..core.quota import check_quota, increment_usage +from ..core.supabase import is_enabled as supabase_enabled +from ..provider import get_provider + + +router = APIRouter(prefix="/session", tags=["Session"]) + + +class MessageRequest(BaseModel): + content: str + provider_id: Optional[str] = None + model_id: Optional[str] = None + system: Optional[str] = None + temperature: Optional[float] = None + max_tokens: Optional[int] = None + tools_enabled: bool = True + auto_continue: Optional[bool] = None + max_steps: Optional[int] = None + + +class SessionUpdate(BaseModel): + title: Optional[str] = None + + +class GenerateTitleRequest(BaseModel): + message: str + model_id: Optional[str] = None + + +@router.get("/", response_model=List[SessionInfo]) +async def list_sessions( + limit: Optional[int] = Query(None, description="Maximum number of sessions to return"), + user: Optional[AuthUser] = Depends(optional_auth) +): + user_id = user.id if user else None + return await Session.list(limit, user_id) + + +@router.post("/", response_model=SessionInfo) +async def create_session( + data: Optional[SessionCreate] = None, + user: Optional[AuthUser] = Depends(optional_auth) +): + user_id = user.id if user else None + return await Session.create(data, user_id) + + +@router.get("/{session_id}", response_model=SessionInfo) +async def get_session( + session_id: str, + user: Optional[AuthUser] = Depends(optional_auth) +): + try: + user_id = user.id if user else None + return await Session.get(session_id, user_id) + except NotFoundError: + raise HTTPException(status_code=404, detail=f"Session not found: {session_id}") + + +@router.patch("/{session_id}", response_model=SessionInfo) +async def update_session( + session_id: str, + updates: SessionUpdate, + user: Optional[AuthUser] = Depends(optional_auth) +): + try: + user_id = user.id if user else None + update_dict = {k: v for k, v in updates.model_dump().items() if v is not None} + return await Session.update(session_id, update_dict, user_id) + except NotFoundError: + raise HTTPException(status_code=404, detail=f"Session not found: {session_id}") + + +@router.delete("/{session_id}") +async def delete_session( + session_id: str, + user: Optional[AuthUser] = Depends(optional_auth) +): + try: + user_id = user.id if user else None + await Session.delete(session_id, user_id) + return {"success": True} + except NotFoundError: + raise HTTPException(status_code=404, detail=f"Session not found: {session_id}") + + +@router.get("/{session_id}/message") +async def list_messages( + session_id: str, + limit: Optional[int] = Query(None, description="Maximum number of messages to return"), + user: Optional[AuthUser] = Depends(optional_auth) +): + try: + user_id = user.id if user else None + await Session.get(session_id, user_id) + return await Message.list(session_id, limit, user_id) + except NotFoundError: + raise HTTPException(status_code=404, detail=f"Session not found: {session_id}") + + +@router.post("/{session_id}/message") +async def send_message( + session_id: str, + request: MessageRequest, + user: AuthUser = Depends(check_quota) if supabase_enabled() else Depends(optional_auth) +): + user_id = user.id if user else None + + try: + await Session.get(session_id, user_id) + except NotFoundError: + raise HTTPException(status_code=404, detail=f"Session not found: {session_id}") + + prompt_input = PromptInput( + content=request.content, + provider_id=request.provider_id, + model_id=request.model_id, + system=request.system, + temperature=request.temperature, + max_tokens=request.max_tokens, + tools_enabled=request.tools_enabled, + auto_continue=request.auto_continue, + max_steps=request.max_steps, + ) + + async def generate(): + total_input = 0 + total_output = 0 + + async for chunk in SessionPrompt.prompt(session_id, prompt_input, user_id): + if chunk.usage: + total_input += chunk.usage.get("input_tokens", 0) + total_output += chunk.usage.get("output_tokens", 0) + yield f"data: {json.dumps(chunk.model_dump())}\n\n" + + if user_id and supabase_enabled(): + await increment_usage(user_id, total_input, total_output) + + yield "data: [DONE]\n\n" + + return StreamingResponse( + generate(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + } + ) + + +@router.post("/{session_id}/abort") +async def abort_session(session_id: str): + cancelled = SessionPrompt.cancel(session_id) + return {"cancelled": cancelled} + + +@router.post("/{session_id}/generate-title") +async def generate_title( + session_id: str, + request: GenerateTitleRequest, + user: Optional[AuthUser] = Depends(optional_auth) +): + """์ฒซ ๋ฉ”์‹œ์ง€ ๊ธฐ๋ฐ˜์œผ๋กœ ์„ธ์…˜ ์ œ๋ชฉ ์ƒ์„ฑ""" + user_id = user.id if user else None + + # ์„ธ์…˜ ์กด์žฌ ํ™•์ธ + try: + await Session.get(session_id, user_id) + except NotFoundError: + raise HTTPException(status_code=404, detail=f"Session not found: {session_id}") + + # LiteLLM Provider๋กœ ์ œ๋ชฉ ์ƒ์„ฑ + model_id = request.model_id or "gemini/gemini-2.0-flash" + provider = get_provider("litellm") + + if not provider: + raise HTTPException(status_code=503, detail="LiteLLM provider not available") + + prompt = f"""๋‹ค์Œ ์‚ฌ์šฉ์ž ๋ฉ”์‹œ์ง€๋ฅผ ๋ณด๊ณ  ์งง์€ ์ œ๋ชฉ์„ ์ƒ์„ฑํ•ด์ฃผ์„ธ์š”. +์ œ๋ชฉ์€ 10์ž ์ด๋‚ด, ๋”ฐ์˜ดํ‘œ ์—†์ด ์ œ๋ชฉ๋งŒ ์ถœ๋ ฅ. + +์‚ฌ์šฉ์ž ๋ฉ”์‹œ์ง€: "{request.message[:200]}" + +์ œ๋ชฉ:""" + + try: + result = await provider.complete(model_id, prompt, max_tokens=50) + title = result.strip()[:30] + + # ์„ธ์…˜ ์ œ๋ชฉ ์—…๋ฐ์ดํŠธ + await Session.update(session_id, {"title": title}, user_id) + + return {"title": title} + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to generate title: {str(e)}") diff --git a/src/opencode_api/session/__init__.py b/src/opencode_api/session/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..691b0a0c95fe66d37098c565487ab56cba5a56a1 --- /dev/null +++ b/src/opencode_api/session/__init__.py @@ -0,0 +1,11 @@ +from .session import Session, SessionInfo, SessionCreate +from .message import Message, MessageInfo, UserMessage, AssistantMessage, MessagePart +from .prompt import SessionPrompt +from .processor import SessionProcessor, DoomLoopDetector, RetryConfig, StepInfo + +__all__ = [ + "Session", "SessionInfo", "SessionCreate", + "Message", "MessageInfo", "UserMessage", "AssistantMessage", "MessagePart", + "SessionPrompt", + "SessionProcessor", "DoomLoopDetector", "RetryConfig", "StepInfo" +] diff --git a/src/opencode_api/session/__pycache__/__init__.cpython-312.pyc b/src/opencode_api/session/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32b0eee311705350c9d5d0b83d8253f6a8c63e92 Binary files /dev/null and b/src/opencode_api/session/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/opencode_api/session/__pycache__/message.cpython-312.pyc b/src/opencode_api/session/__pycache__/message.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe343cdcba8f6c9e991f3568fc88c419cebe9ca1 Binary files /dev/null and b/src/opencode_api/session/__pycache__/message.cpython-312.pyc differ diff --git a/src/opencode_api/session/__pycache__/processor.cpython-312.pyc b/src/opencode_api/session/__pycache__/processor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a32ca515da4f763fed5039c16e77d84f1e13d95 Binary files /dev/null and b/src/opencode_api/session/__pycache__/processor.cpython-312.pyc differ diff --git a/src/opencode_api/session/__pycache__/prompt.cpython-312.pyc b/src/opencode_api/session/__pycache__/prompt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af8d60d419f8bec96fb6baa9cbb5bb287251567d Binary files /dev/null and b/src/opencode_api/session/__pycache__/prompt.cpython-312.pyc differ diff --git a/src/opencode_api/session/__pycache__/session.cpython-312.pyc b/src/opencode_api/session/__pycache__/session.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..617776be8993184e40dfeb0ba80cad7e1f574fe4 Binary files /dev/null and b/src/opencode_api/session/__pycache__/session.cpython-312.pyc differ diff --git a/src/opencode_api/session/message.py b/src/opencode_api/session/message.py new file mode 100644 index 0000000000000000000000000000000000000000..ee712b9692259144bc4e12d487ae575da524b64e --- /dev/null +++ b/src/opencode_api/session/message.py @@ -0,0 +1,348 @@ +from typing import Optional, List, Dict, Any, Union, Literal +from pydantic import BaseModel, Field +from datetime import datetime + +from ..core.storage import Storage, NotFoundError +from ..core.bus import Bus, MESSAGE_UPDATED, MESSAGE_REMOVED, PART_UPDATED, MessagePayload, PartPayload +from ..core.identifier import Identifier +from ..core.supabase import get_client, is_enabled as supabase_enabled + + +class MessagePart(BaseModel): + """๋ฉ”์‹œ์ง€ ํŒŒํŠธ ๋ชจ๋ธ + + type ์ข…๋ฅ˜: + - "text": ์ผ๋ฐ˜ ํ…์ŠคํŠธ ์‘๋‹ต + - "reasoning": Claude์˜ thinking/extended thinking + - "tool_call": ๋„๊ตฌ ํ˜ธ์ถœ (tool_name, tool_args, tool_status) + - "tool_result": ๋„๊ตฌ ์‹คํ–‰ ๊ฒฐ๊ณผ (tool_output) + """ + id: str + session_id: str + message_id: str + type: str # "text", "reasoning", "tool_call", "tool_result" + content: Optional[str] = None # text, reasoning์šฉ + tool_call_id: Optional[str] = None + tool_name: Optional[str] = None + tool_args: Optional[Dict[str, Any]] = None + tool_output: Optional[str] = None + tool_status: Optional[str] = None # "pending", "running", "completed", "error" + + +class MessageInfo(BaseModel): + id: str + session_id: str + role: Literal["user", "assistant"] + created_at: datetime + model: Optional[str] = None + provider_id: Optional[str] = None + usage: Optional[Dict[str, int]] = None + error: Optional[str] = None + + +class UserMessage(MessageInfo): + role: Literal["user"] = "user" + content: str + + +class AssistantMessage(MessageInfo): + role: Literal["assistant"] = "assistant" + parts: List[MessagePart] = Field(default_factory=list) + summary: bool = False + + +class Message: + + @staticmethod + async def create_user(session_id: str, content: str, user_id: Optional[str] = None) -> UserMessage: + message_id = Identifier.generate("message") + now = datetime.utcnow() + + msg = UserMessage( + id=message_id, + session_id=session_id, + content=content, + created_at=now, + ) + + if supabase_enabled() and user_id: + client = get_client() + client.table("opencode_messages").insert({ + "id": message_id, + "session_id": session_id, + "role": "user", + "content": content, + }).execute() + else: + await Storage.write(["message", session_id, message_id], msg.model_dump()) + + await Bus.publish(MESSAGE_UPDATED, MessagePayload(session_id=session_id, message_id=message_id)) + return msg + + @staticmethod + async def create_assistant( + session_id: str, + provider_id: Optional[str] = None, + model: Optional[str] = None, + user_id: Optional[str] = None, + summary: bool = False + ) -> AssistantMessage: + message_id = Identifier.generate("message") + now = datetime.utcnow() + + msg = AssistantMessage( + id=message_id, + session_id=session_id, + created_at=now, + provider_id=provider_id, + model=model, + parts=[], + summary=summary, + ) + + if supabase_enabled() and user_id: + client = get_client() + client.table("opencode_messages").insert({ + "id": message_id, + "session_id": session_id, + "role": "assistant", + "provider_id": provider_id, + "model_id": model, + }).execute() + else: + await Storage.write(["message", session_id, message_id], msg.model_dump()) + + await Bus.publish(MESSAGE_UPDATED, MessagePayload(session_id=session_id, message_id=message_id)) + return msg + + @staticmethod + async def get(session_id: str, message_id: str, user_id: Optional[str] = None) -> Union[UserMessage, AssistantMessage]: + if supabase_enabled() and user_id: + client = get_client() + result = client.table("opencode_messages").select("*, opencode_message_parts(*)").eq("id", message_id).eq("session_id", session_id).single().execute() + if not result.data: + raise NotFoundError(["message", session_id, message_id]) + + data = result.data + if data.get("role") == "user": + return UserMessage( + id=data["id"], + session_id=data["session_id"], + role="user", + content=data.get("content", ""), + created_at=data["created_at"], + ) + + parts = [ + MessagePart( + id=p["id"], + session_id=session_id, + message_id=message_id, + type=p["type"], + content=p.get("content"), + tool_call_id=p.get("tool_call_id"), + tool_name=p.get("tool_name"), + tool_args=p.get("tool_args"), + tool_output=p.get("tool_output"), + tool_status=p.get("tool_status"), + ) + for p in data.get("opencode_message_parts", []) + ] + return AssistantMessage( + id=data["id"], + session_id=data["session_id"], + role="assistant", + created_at=data["created_at"], + provider_id=data.get("provider_id"), + model=data.get("model_id"), + usage={"input_tokens": data.get("input_tokens", 0), "output_tokens": data.get("output_tokens", 0)} if data.get("input_tokens") else None, + error=data.get("error"), + parts=parts, + ) + + data = await Storage.read(["message", session_id, message_id]) + if not data: + raise NotFoundError(["message", session_id, message_id]) + + if data.get("role") == "user": + return UserMessage(**data) + return AssistantMessage(**data) + + @staticmethod + async def add_part(message_id: str, session_id: str, part: MessagePart, user_id: Optional[str] = None) -> MessagePart: + part.id = Identifier.generate("part") + part.message_id = message_id + part.session_id = session_id + + if supabase_enabled() and user_id: + client = get_client() + client.table("opencode_message_parts").insert({ + "id": part.id, + "message_id": message_id, + "type": part.type, + "content": part.content, + "tool_call_id": part.tool_call_id, + "tool_name": part.tool_name, + "tool_args": part.tool_args, + "tool_output": part.tool_output, + "tool_status": part.tool_status, + }).execute() + else: + msg_data = await Storage.read(["message", session_id, message_id]) + if not msg_data: + raise NotFoundError(["message", session_id, message_id]) + + if "parts" not in msg_data: + msg_data["parts"] = [] + msg_data["parts"].append(part.model_dump()) + await Storage.write(["message", session_id, message_id], msg_data) + + await Bus.publish(PART_UPDATED, PartPayload( + session_id=session_id, + message_id=message_id, + part_id=part.id + )) + return part + + @staticmethod + async def update_part(session_id: str, message_id: str, part_id: str, updates: Dict[str, Any], user_id: Optional[str] = None) -> MessagePart: + if supabase_enabled() and user_id: + client = get_client() + result = client.table("opencode_message_parts").update(updates).eq("id", part_id).execute() + if result.data: + p = result.data[0] + await Bus.publish(PART_UPDATED, PartPayload( + session_id=session_id, + message_id=message_id, + part_id=part_id + )) + return MessagePart( + id=p["id"], + session_id=session_id, + message_id=message_id, + type=p["type"], + content=p.get("content"), + tool_call_id=p.get("tool_call_id"), + tool_name=p.get("tool_name"), + tool_args=p.get("tool_args"), + tool_output=p.get("tool_output"), + tool_status=p.get("tool_status"), + ) + raise NotFoundError(["part", message_id, part_id]) + + msg_data = await Storage.read(["message", session_id, message_id]) + if not msg_data: + raise NotFoundError(["message", session_id, message_id]) + + for i, p in enumerate(msg_data.get("parts", [])): + if p.get("id") == part_id: + msg_data["parts"][i].update(updates) + await Storage.write(["message", session_id, message_id], msg_data) + await Bus.publish(PART_UPDATED, PartPayload( + session_id=session_id, + message_id=message_id, + part_id=part_id + )) + return MessagePart(**msg_data["parts"][i]) + + raise NotFoundError(["part", message_id, part_id]) + + @staticmethod + async def list(session_id: str, limit: Optional[int] = None, user_id: Optional[str] = None) -> List[Union[UserMessage, AssistantMessage]]: + if supabase_enabled() and user_id: + client = get_client() + query = client.table("opencode_messages").select("*, opencode_message_parts(*)").eq("session_id", session_id).order("created_at") + if limit: + query = query.limit(limit) + result = query.execute() + + messages = [] + for data in result.data: + if data.get("role") == "user": + messages.append(UserMessage( + id=data["id"], + session_id=data["session_id"], + role="user", + content=data.get("content", ""), + created_at=data["created_at"], + )) + else: + parts = [ + MessagePart( + id=p["id"], + session_id=session_id, + message_id=data["id"], + type=p["type"], + content=p.get("content"), + tool_call_id=p.get("tool_call_id"), + tool_name=p.get("tool_name"), + tool_args=p.get("tool_args"), + tool_output=p.get("tool_output"), + tool_status=p.get("tool_status"), + ) + for p in data.get("opencode_message_parts", []) + ] + messages.append(AssistantMessage( + id=data["id"], + session_id=data["session_id"], + role="assistant", + created_at=data["created_at"], + provider_id=data.get("provider_id"), + model=data.get("model_id"), + usage={"input_tokens": data.get("input_tokens", 0), "output_tokens": data.get("output_tokens", 0)} if data.get("input_tokens") else None, + error=data.get("error"), + parts=parts, + )) + return messages + + message_keys = await Storage.list(["message", session_id]) + messages = [] + + for key in message_keys: + if limit and len(messages) >= limit: + break + data = await Storage.read(key) + if data: + if data.get("role") == "user": + messages.append(UserMessage(**data)) + else: + messages.append(AssistantMessage(**data)) + + messages.sort(key=lambda m: m.created_at) + return messages + + @staticmethod + async def delete(session_id: str, message_id: str, user_id: Optional[str] = None) -> bool: + if supabase_enabled() and user_id: + client = get_client() + client.table("opencode_messages").delete().eq("id", message_id).execute() + else: + await Storage.remove(["message", session_id, message_id]) + + await Bus.publish(MESSAGE_REMOVED, MessagePayload(session_id=session_id, message_id=message_id)) + return True + + @staticmethod + async def set_usage(session_id: str, message_id: str, usage: Dict[str, int], user_id: Optional[str] = None) -> None: + if supabase_enabled() and user_id: + client = get_client() + client.table("opencode_messages").update({ + "input_tokens": usage.get("input_tokens", 0), + "output_tokens": usage.get("output_tokens", 0), + }).eq("id", message_id).execute() + else: + msg_data = await Storage.read(["message", session_id, message_id]) + if msg_data: + msg_data["usage"] = usage + await Storage.write(["message", session_id, message_id], msg_data) + + @staticmethod + async def set_error(session_id: str, message_id: str, error: str, user_id: Optional[str] = None) -> None: + if supabase_enabled() and user_id: + client = get_client() + client.table("opencode_messages").update({"error": error}).eq("id", message_id).execute() + else: + msg_data = await Storage.read(["message", session_id, message_id]) + if msg_data: + msg_data["error"] = error + await Storage.write(["message", session_id, message_id], msg_data) diff --git a/src/opencode_api/session/processor.py b/src/opencode_api/session/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..bd447aeefca0173666bdbb3e3f82cabc555436e8 --- /dev/null +++ b/src/opencode_api/session/processor.py @@ -0,0 +1,193 @@ +""" +Session processor for managing agentic loop execution. +""" + +from typing import Optional, Dict, Any, AsyncIterator, List +from pydantic import BaseModel +from datetime import datetime +import asyncio + +from ..provider.provider import StreamChunk + + +class DoomLoopDetector: + """๋™์ผ ๋„๊ตฌ + ๋™์ผ ์ธ์ž ์—ฐ์† ํ˜ธ์ถœ์„ ๊ฐ์ง€ํ•˜์—ฌ ๋ฌดํ•œ ๋ฃจํ”„ ๋ฐฉ์ง€ + + ์›๋ณธ opencode์™€ ๋™์ผํ•˜๊ฒŒ ๋„๊ตฌ ์ด๋ฆ„๊ณผ ์ธ์ž๋ฅผ ๋ชจ๋‘ ๋น„๊ตํ•ฉ๋‹ˆ๋‹ค. + ๊ฐ™์€ ๋„๊ตฌ๋ผ๋„ ์ธ์ž๊ฐ€ ๋‹ค๋ฅด๋ฉด ์ •์ƒ์ ์ธ ๋ฐ˜๋ณต์œผ๋กœ ํŒ๋‹จํ•ฉ๋‹ˆ๋‹ค. + """ + + def __init__(self, threshold: int = 3): + self.threshold = threshold + self.history: List[tuple[str, str]] = [] # (tool_name, args_hash) + + def record(self, tool_name: str, args: Optional[Dict[str, Any]] = None) -> bool: + """๋„๊ตฌ ํ˜ธ์ถœ์„ ๊ธฐ๋กํ•˜๊ณ  doom loop ๊ฐ์ง€ ์‹œ True ๋ฐ˜ํ™˜ + + Args: + tool_name: ๋„๊ตฌ ์ด๋ฆ„ + args: ๋„๊ตฌ ์ธ์ž (์—†์œผ๋ฉด ๋นˆ dict๋กœ ์ฒ˜๋ฆฌ) + + Returns: + True if doom loop detected, False otherwise + """ + import json + import hashlib + + # ์ธ์ž๋ฅผ ์ •๊ทœํ™”ํ•˜์—ฌ ํ•ด์‹œ ์ƒ์„ฑ (์›๋ณธ์ฒ˜๋Ÿผ JSON ๋น„๊ต) + args_dict = args or {} + args_str = json.dumps(args_dict, sort_keys=True, default=str) + args_hash = hashlib.md5(args_str.encode()).hexdigest()[:8] + + call_signature = (tool_name, args_hash) + self.history.append(call_signature) + + # ์ตœ๊ทผ threshold๊ฐœ๊ฐ€ ๋ชจ๋‘ ๊ฐ™์€ (๋„๊ตฌ + ์ธ์ž)์ธ์ง€ ํ™•์ธ + if len(self.history) >= self.threshold: + recent = self.history[-self.threshold:] + if len(set(recent)) == 1: # ํŠœํ”Œ ๋น„๊ต (๋„๊ตฌ+์ธ์ž) + return True + + return False + + def reset(self): + self.history = [] + + +class RetryConfig(BaseModel): + """์žฌ์‹œ๋„ ์„ค์ •""" + max_retries: int = 3 + base_delay: float = 1.0 # seconds + max_delay: float = 30.0 + exponential_base: float = 2.0 + + +class StepInfo(BaseModel): + """์Šคํ… ์ •๋ณด""" + step: int + started_at: datetime + finished_at: Optional[datetime] = None + tool_calls: List[str] = [] + status: str = "running" # running, completed, error, doom_loop + + +class SessionProcessor: + """ + Agentic loop ์‹คํ–‰์„ ๊ด€๋ฆฌํ•˜๋Š” ํ”„๋กœ์„ธ์„œ. + + Features: + - Doom loop ๋ฐฉ์ง€ (๋™์ผ ๋„๊ตฌ ์—ฐ์† ํ˜ธ์ถœ ๊ฐ์ง€) + - ์ž๋™ ์žฌ์‹œ๋„ (exponential backoff) + - ์Šคํ… ์ถ”์  (step-start, step-finish ์ด๋ฒคํŠธ) + """ + + _processors: Dict[str, "SessionProcessor"] = {} + + def __init__(self, session_id: str, max_steps: int = 50, doom_threshold: int = 3): + self.session_id = session_id + self.max_steps = max_steps + self.doom_detector = DoomLoopDetector(threshold=doom_threshold) + self.retry_config = RetryConfig() + self.steps: List[StepInfo] = [] + self.current_step: Optional[StepInfo] = None + self.aborted = False + + @classmethod + def get_or_create(cls, session_id: str, **kwargs) -> "SessionProcessor": + if session_id not in cls._processors: + cls._processors[session_id] = cls(session_id, **kwargs) + return cls._processors[session_id] + + @classmethod + def remove(cls, session_id: str) -> None: + if session_id in cls._processors: + del cls._processors[session_id] + + def start_step(self) -> StepInfo: + """์ƒˆ ์Šคํ… ์‹œ์ž‘""" + step_num = len(self.steps) + 1 + self.current_step = StepInfo( + step=step_num, + started_at=datetime.utcnow() + ) + self.steps.append(self.current_step) + return self.current_step + + def finish_step(self, status: str = "completed") -> StepInfo: + """ํ˜„์žฌ ์Šคํ… ์™„๋ฃŒ""" + if self.current_step: + self.current_step.finished_at = datetime.utcnow() + self.current_step.status = status + return self.current_step + + def record_tool_call(self, tool_name: str, tool_args: Optional[Dict[str, Any]] = None) -> bool: + """๋„๊ตฌ ํ˜ธ์ถœ ๊ธฐ๋ก, doom loop ๊ฐ์ง€ ์‹œ True ๋ฐ˜ํ™˜ + + Args: + tool_name: ๋„๊ตฌ ์ด๋ฆ„ + tool_args: ๋„๊ตฌ ์ธ์ž (doom loop ํŒ๋ณ„์— ์‚ฌ์šฉ) + + Returns: + True if doom loop detected, False otherwise + """ + if self.current_step: + self.current_step.tool_calls.append(tool_name) + return self.doom_detector.record(tool_name, tool_args) + + def is_doom_loop(self) -> bool: + """ํ˜„์žฌ doom loop ์ƒํƒœ์ธ์ง€ ํ™•์ธ""" + return len(self.doom_detector.history) >= self.doom_detector.threshold and \ + len(set(self.doom_detector.history[-self.doom_detector.threshold:])) == 1 + + def should_continue(self) -> bool: + """๋ฃจํ”„ ๊ณ„์† ์—ฌ๋ถ€""" + if self.aborted: + return False + if len(self.steps) >= self.max_steps: + return False + if self.is_doom_loop(): + return False + return True + + def abort(self) -> None: + """ํ”„๋กœ์„ธ์„œ ์ค‘๋‹จ""" + self.aborted = True + + async def calculate_retry_delay(self, attempt: int) -> float: + """exponential backoff ๋”œ๋ ˆ์ด ๊ณ„์‚ฐ""" + delay = self.retry_config.base_delay * (self.retry_config.exponential_base ** attempt) + return min(delay, self.retry_config.max_delay) + + async def retry_with_backoff(self, func, *args, **kwargs): + """exponential backoff์œผ๋กœ ํ•จ์ˆ˜ ์žฌ์‹œ๋„""" + last_error = None + + for attempt in range(self.retry_config.max_retries): + try: + return await func(*args, **kwargs) + except Exception as e: + last_error = e + if attempt < self.retry_config.max_retries - 1: + delay = await self.calculate_retry_delay(attempt) + await asyncio.sleep(delay) + + raise last_error + + def get_summary(self) -> Dict[str, Any]: + """ํ”„๋กœ์„ธ์„œ ์ƒํƒœ ์š”์•ฝ""" + return { + "session_id": self.session_id, + "total_steps": len(self.steps), + "max_steps": self.max_steps, + "aborted": self.aborted, + "doom_loop_detected": self.is_doom_loop(), + "steps": [ + { + "step": s.step, + "status": s.status, + "tool_calls": s.tool_calls, + "duration": (s.finished_at - s.started_at).total_seconds() if s.finished_at else None + } + for s in self.steps + ] + } diff --git a/src/opencode_api/session/prompt.py b/src/opencode_api/session/prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..2631d96c1407508e5da90282cffd76b99bd805df --- /dev/null +++ b/src/opencode_api/session/prompt.py @@ -0,0 +1,701 @@ +""" +Session prompt handling with agentic loop support. +""" + +from typing import Optional, List, Dict, Any, AsyncIterator, Literal +from pydantic import BaseModel +import asyncio +import json + +from .session import Session +from .message import Message, MessagePart, AssistantMessage +from .processor import SessionProcessor +from ..provider import get_provider, list_providers +from ..provider.provider import Message as ProviderMessage, StreamChunk, ToolCall +from ..tool import get_tool, get_tools_schema, ToolContext, get_registry +from ..core.config import settings +from ..core.bus import Bus, PART_UPDATED, PartPayload, STEP_STARTED, STEP_FINISHED, StepPayload, TOOL_STATE_CHANGED, ToolStatePayload +from ..agent import get as get_agent, default_agent, get_system_prompt, is_tool_allowed, AgentInfo, get_prompt_for_provider + + +class PromptInput(BaseModel): + content: str + provider_id: Optional[str] = None + model_id: Optional[str] = None + system: Optional[str] = None + temperature: Optional[float] = None + max_tokens: Optional[int] = None + tools_enabled: bool = True + # Agentic loop options + auto_continue: Optional[bool] = None # None = use agent default + max_steps: Optional[int] = None # None = use agent default + + +class LoopState(BaseModel): + step: int = 0 + max_steps: int = 50 + auto_continue: bool = True + stop_reason: Optional[str] = None + paused: bool = False + pause_reason: Optional[str] = None + + +import re +FAKE_TOOL_CALL_PATTERN = re.compile( + r'\[Called\s+tool:\s*(\w+)\s*\(\s*(\{[^}]*\}|\{[^)]*\}|[^)]*)\s*\)\]', + re.IGNORECASE +) + + +class SessionPrompt: + + _active_sessions: Dict[str, asyncio.Task] = {} + _loop_states: Dict[str, LoopState] = {} + + @classmethod + async def prompt( + cls, + session_id: str, + input: PromptInput, + user_id: Optional[str] = None + ) -> AsyncIterator[StreamChunk]: + session = await Session.get(session_id, user_id) + + # Get agent configuration + agent_id = session.agent_id or "build" + agent = get_agent(agent_id) or default_agent() + + # Determine loop settings + auto_continue = input.auto_continue if input.auto_continue is not None else agent.auto_continue + max_steps = input.max_steps if input.max_steps is not None else agent.max_steps + + if auto_continue: + async for chunk in cls._agentic_loop(session_id, input, agent, max_steps, user_id): + yield chunk + else: + async for chunk in cls._single_turn(session_id, input, agent, user_id=user_id): + yield chunk + + @classmethod + async def _agentic_loop( + cls, + session_id: str, + input: PromptInput, + agent: AgentInfo, + max_steps: int, + user_id: Optional[str] = None + ) -> AsyncIterator[StreamChunk]: + state = LoopState(step=0, max_steps=max_steps, auto_continue=True) + cls._loop_states[session_id] = state + + # SessionProcessor ๊ฐ€์ ธ์˜ค๊ธฐ + processor = SessionProcessor.get_or_create(session_id, max_steps=max_steps) + + try: + while processor.should_continue() and not state.paused: + state.step += 1 + + # ์Šคํ… ์‹œ์ž‘ + step_info = processor.start_step() + await Bus.publish(STEP_STARTED, StepPayload( + session_id=session_id, + step=state.step, + max_steps=max_steps + )) + + print(f"[AGENTIC LOOP] Starting step {state.step}, stop_reason={state.stop_reason}", flush=True) + + turn_input = input if state.step == 1 else PromptInput( + content="", + provider_id=input.provider_id, + model_id=input.model_id, + temperature=input.temperature, + max_tokens=input.max_tokens, + tools_enabled=input.tools_enabled, + auto_continue=False, + ) + + if state.step > 1: + yield StreamChunk(type="step", text=f"Step {state.step}") + + # Track tool calls in this turn + has_tool_calls_this_turn = False + + async for chunk in cls._single_turn( + session_id, + turn_input, + agent, + is_continuation=(state.step > 1), + user_id=user_id + ): + yield chunk + + if chunk.type == "tool_call" and chunk.tool_call: + has_tool_calls_this_turn = True + print(f"[AGENTIC LOOP] tool_call: {chunk.tool_call.name}", flush=True) + + if chunk.tool_call.name == "question" and agent.pause_on_question: + state.paused = True + state.pause_reason = "question" + + # question tool์ด ์™„๋ฃŒ๋˜๋ฉด (๋‹ต๋ณ€ ๋ฐ›์Œ) pause ํ•ด์ œ + elif chunk.type == "tool_result": + if state.paused and state.pause_reason == "question": + state.paused = False + state.pause_reason = None + + elif chunk.type == "done": + state.stop_reason = chunk.stop_reason + print(f"[AGENTIC LOOP] done: stop_reason={chunk.stop_reason}", flush=True) + + # ์Šคํ… ์™„๋ฃŒ + step_status = "completed" + if processor.is_doom_loop(): + step_status = "doom_loop" + print(f"[AGENTIC LOOP] Doom loop detected! Stopping execution.", flush=True) + yield StreamChunk(type="text", text=f"\n[๊ฒฝ๊ณ : ๋™์ผ ๋„๊ตฌ ๋ฐ˜๋ณต ํ˜ธ์ถœ ๊ฐ์ง€, ๋ฃจํ”„๋ฅผ ์ค‘๋‹จํ•ฉ๋‹ˆ๋‹ค]\n") + + processor.finish_step(status=step_status) + await Bus.publish(STEP_FINISHED, StepPayload( + session_id=session_id, + step=state.step, + max_steps=max_steps + )) + + print(f"[AGENTIC LOOP] End of step {state.step}: stop_reason={state.stop_reason}, has_tool_calls={has_tool_calls_this_turn}", flush=True) + + # Doom loop ๊ฐ์ง€ ์‹œ ์ค‘๋‹จ + if processor.is_doom_loop(): + break + + # If this turn had no new tool calls (just text response), we're done + if state.stop_reason != "tool_calls": + print(f"[AGENTIC LOOP] Breaking: stop_reason != tool_calls", flush=True) + break + + # Loop ์ข…๋ฃŒ ํ›„ ์ƒํƒœ ๋ฉ”์‹œ์ง€๋งŒ ์ถœ๋ ฅ (summary LLM ํ˜ธ์ถœ ์—†์Œ!) + if state.paused: + yield StreamChunk(type="text", text=f"\n[Paused: {state.pause_reason}]\n") + elif state.step >= state.max_steps: + yield StreamChunk(type="text", text=f"\n[Max steps ({state.max_steps}) reached]\n") + # else: ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์ข…๋ฃŒ (์ถ”๊ฐ€ ์ถœ๋ ฅ ์—†์Œ) + + finally: + if session_id in cls._loop_states: + del cls._loop_states[session_id] + # SessionProcessor ์ •๋ฆฌ + SessionProcessor.remove(session_id) + + @classmethod + def _infer_provider_from_model(cls, model_id: str) -> str: + """model_id์—์„œ provider_id๋ฅผ ์ถ”๋ก """ + # LiteLLM prefix ๊ธฐ๋ฐ˜ ๋ชจ๋ธ์€ litellm provider ์‚ฌ์šฉ + litellm_prefixes = ["gemini/", "groq/", "deepseek/", "openrouter/", "zai/"] + for prefix in litellm_prefixes: + if model_id.startswith(prefix): + return "litellm" + + # Claude ๋ชจ๋ธ + if model_id.startswith("claude-"): + return "litellm" + + # GPT/O1 ๋ชจ๋ธ + if model_id.startswith("gpt-") or model_id.startswith("o1"): + return "litellm" + + # ๊ธฐ๋ณธ๊ฐ’ + return settings.default_provider + + @classmethod + async def _single_turn( + cls, + session_id: str, + input: PromptInput, + agent: AgentInfo, + is_continuation: bool = False, + user_id: Optional[str] = None + ) -> AsyncIterator[StreamChunk]: + session = await Session.get(session_id, user_id) + + model_id = input.model_id or session.model_id or settings.default_model + + # provider_id๊ฐ€ ๋ช…์‹œ๋˜์ง€ ์•Š์œผ๋ฉด model_id์—์„œ ์ถ”๋ก  + if input.provider_id: + provider_id = input.provider_id + elif session.provider_id: + provider_id = session.provider_id + else: + provider_id = cls._infer_provider_from_model(model_id) + + print(f"[Prompt DEBUG] input.provider_id={input.provider_id}, session.provider_id={session.provider_id}", flush=True) + print(f"[Prompt DEBUG] Final provider_id={provider_id}, model_id={model_id}", flush=True) + + provider = get_provider(provider_id) + print(f"[Prompt DEBUG] Got provider: {provider}", flush=True) + if not provider: + yield StreamChunk(type="error", error=f"Provider not found: {provider_id}") + return + + # Only create user message if there's content (not a continuation) + if input.content and not is_continuation: + user_msg = await Message.create_user(session_id, input.content, user_id) + + assistant_msg = await Message.create_assistant(session_id, provider_id, model_id, user_id) + + # Build message history + history = await Message.list(session_id, user_id=user_id) + messages = cls._build_messages(history[:-1], include_tool_results=True) + + # Build system prompt with provider-specific optimization + system_prompt = cls._build_system_prompt(agent, provider_id, input.system) + + # Get tools schema + tools_schema = get_tools_schema() if input.tools_enabled else None + + current_text_part: Optional[MessagePart] = None + accumulated_text = "" + + # reasoning ์ €์žฅ์„ ์œ„ํ•œ ๋ณ€์ˆ˜ + current_reasoning_part: Optional[MessagePart] = None + accumulated_reasoning = "" + + try: + async for chunk in provider.stream( + model_id=model_id, + messages=messages, + tools=tools_schema, + system=system_prompt, + temperature=input.temperature or agent.temperature, + max_tokens=input.max_tokens or agent.max_tokens, + ): + if chunk.type == "text": + accumulated_text += chunk.text or "" + + if current_text_part is None: + current_text_part = await Message.add_part( + assistant_msg.id, + session_id, + MessagePart( + id="", + session_id=session_id, + message_id=assistant_msg.id, + type="text", + content=accumulated_text + ), + user_id + ) + else: + await Message.update_part( + session_id, + assistant_msg.id, + current_text_part.id, + {"content": accumulated_text}, + user_id + ) + + yield chunk + + elif chunk.type == "tool_call": + tc = chunk.tool_call + if tc: + # Check permission + permission = is_tool_allowed(agent, tc.name) + if permission == "deny": + yield StreamChunk( + type="tool_result", + text=f"Error: Tool '{tc.name}' is not allowed for this agent" + ) + continue + + tool_part = await Message.add_part( + assistant_msg.id, + session_id, + MessagePart( + id="", + session_id=session_id, + message_id=assistant_msg.id, + type="tool_call", + tool_call_id=tc.id, + tool_name=tc.name, + tool_args=tc.arguments, + tool_status="running" # ์‹คํ–‰ ์ค‘ ์ƒํƒœ + ), + user_id + ) + + # IMPORTANT: Yield tool_call FIRST so frontend can show UI + # This is critical for interactive tools like 'question' + yield chunk + + # ๋„๊ตฌ ์‹คํ–‰ ์‹œ์ž‘ ์ด๋ฒคํŠธ ๋ฐœํ–‰ + await Bus.publish(TOOL_STATE_CHANGED, ToolStatePayload( + session_id=session_id, + message_id=assistant_msg.id, + part_id=tool_part.id, + tool_name=tc.name, + status="running" + )) + + # Execute tool (may block for user input, e.g., question tool) + tool_result, tool_status = await cls._execute_tool( + session_id, + assistant_msg.id, + tc.id, + tc.name, + tc.arguments, + user_id + ) + + # tool_call ํŒŒํŠธ์˜ status๋ฅผ completed/error๋กœ ์—…๋ฐ์ดํŠธ + await Message.update_part( + session_id, + assistant_msg.id, + tool_part.id, + {"tool_status": tool_status}, + user_id + ) + + # ๋„๊ตฌ ์™„๋ฃŒ ์ด๋ฒคํŠธ ๋ฐœํ–‰ + await Bus.publish(TOOL_STATE_CHANGED, ToolStatePayload( + session_id=session_id, + message_id=assistant_msg.id, + part_id=tool_part.id, + tool_name=tc.name, + status=tool_status + )) + + yield StreamChunk( + type="tool_result", + text=tool_result + ) + else: + yield chunk + + elif chunk.type == "reasoning": + # reasoning ์ €์žฅ (๊ธฐ์กด์—๋Š” yield๋งŒ ํ–ˆ์Œ) + accumulated_reasoning += chunk.text or "" + + if current_reasoning_part is None: + current_reasoning_part = await Message.add_part( + assistant_msg.id, + session_id, + MessagePart( + id="", + session_id=session_id, + message_id=assistant_msg.id, + type="reasoning", + content=accumulated_reasoning + ), + user_id + ) + else: + await Message.update_part( + session_id, + assistant_msg.id, + current_reasoning_part.id, + {"content": accumulated_reasoning}, + user_id + ) + + yield chunk + + elif chunk.type == "done": + if chunk.usage: + await Message.set_usage(session_id, assistant_msg.id, chunk.usage, user_id) + yield chunk + + elif chunk.type == "error": + await Message.set_error(session_id, assistant_msg.id, chunk.error or "Unknown error", user_id) + yield chunk + + await Session.touch(session_id) + + except Exception as e: + error_msg = str(e) + await Message.set_error(session_id, assistant_msg.id, error_msg, user_id) + yield StreamChunk(type="error", error=error_msg) + + @classmethod + def _detect_fake_tool_call(cls, text: str) -> Optional[Dict[str, Any]]: + """ + Detect if the model wrote a fake tool call as text instead of using actual tool calling. + Returns parsed tool call info if detected, None otherwise. + + Patterns detected: + - [Called tool: toolname({...})] + - [Called tool: toolname({'key': 'value'})] + """ + if not text: + return None + + match = FAKE_TOOL_CALL_PATTERN.search(text) + if match: + tool_name = match.group(1) + args_str = match.group(2).strip() + + # Try to parse arguments + args = {} + if args_str: + try: + # Handle both JSON and Python dict formats + args_str = args_str.replace("'", '"') # Convert Python dict to JSON + args = json.loads(args_str) + except json.JSONDecodeError: + # Try to extract key-value pairs manually + # Pattern: 'key': 'value' or "key": "value" + kv_pattern = re.compile(r'["\']?(\w+)["\']?\s*:\s*["\']([^"\']+)["\']') + for kv_match in kv_pattern.finditer(args_str): + args[kv_match.group(1)] = kv_match.group(2) + + return { + "name": tool_name, + "arguments": args + } + + return None + + @classmethod + def _build_system_prompt( + cls, + agent: AgentInfo, + provider_id: str, + custom_system: Optional[str] = None + ) -> Optional[str]: + """Build the complete system prompt. + + Args: + agent: The agent configuration + provider_id: The provider identifier for selecting optimized prompt + custom_system: Optional custom system prompt to append + + Returns: + The complete system prompt, or None if empty + """ + parts = [] + + # Add provider-specific system prompt (optimized for Claude/Gemini/etc.) + provider_prompt = get_prompt_for_provider(provider_id) + if provider_prompt: + parts.append(provider_prompt) + + # Add agent-specific prompt (if defined and different from provider prompt) + agent_prompt = get_system_prompt(agent) + if agent_prompt and agent_prompt != provider_prompt: + parts.append(agent_prompt) + + # Add custom system prompt + if custom_system: + parts.append(custom_system) + + return "\n\n".join(parts) if parts else None + + @classmethod + def _build_messages( + cls, + history: List, + include_tool_results: bool = True + ) -> List[ProviderMessage]: + """Build message list for LLM including tool calls and results. + + Proper tool calling flow: + 1. User message + 2. Assistant message (may include tool calls) + 3. Tool results (as user message with tool context) + 4. Assistant continues + """ + messages = [] + + for msg in history: + if msg.role == "user": + # Skip empty user messages (continuations) + if msg.content: + messages.append(ProviderMessage(role="user", content=msg.content)) + + elif msg.role == "assistant": + # Collect all parts + text_parts = [] + tool_calls = [] + tool_results = [] + + for part in getattr(msg, "parts", []): + if part.type == "text" and part.content: + text_parts.append(part.content) + elif part.type == "tool_call" and include_tool_results: + tool_calls.append({ + "id": part.tool_call_id, + "name": part.tool_name, + "arguments": part.tool_args or {} + }) + elif part.type == "tool_result" and include_tool_results: + tool_results.append({ + "tool_call_id": part.tool_call_id, + "output": part.tool_output or "" + }) + + # Build assistant content - only text, NO tool call summaries + # IMPORTANT: Do NOT include "[Called tool: ...]" patterns as this causes + # models like Gemini to mimic the pattern instead of using actual tool calls + assistant_content_parts = [] + + if text_parts: + assistant_content_parts.append("".join(text_parts)) + + if assistant_content_parts: + messages.append(ProviderMessage( + role="assistant", + content="\n".join(assistant_content_parts) + )) + + # Add tool results as user message (simulating tool response) + if tool_results: + result_content = [] + for result in tool_results: + result_content.append(f"Tool result:\n{result['output']}") + messages.append(ProviderMessage( + role="user", + content="\n\n".join(result_content) + )) + + return messages + + @classmethod + async def _execute_tool( + cls, + session_id: str, + message_id: str, + tool_call_id: str, + tool_name: str, + tool_args: Dict[str, Any], + user_id: Optional[str] = None + ) -> tuple[str, str]: + """Execute a tool and store the result. Returns (output, status).""" + # SessionProcessor๋ฅผ ํ†ตํ•œ doom loop ๊ฐ์ง€ + # tool_args๋„ ์ „๋‹ฌํ•˜์—ฌ ๊ฐ™์€ ๋„๊ตฌ + ๊ฐ™์€ ์ธ์ž์ผ ๋•Œ๋งŒ doom loop์œผ๋กœ ํŒ๋‹จ + processor = SessionProcessor.get_or_create(session_id) + is_doom_loop = processor.record_tool_call(tool_name, tool_args) + + if is_doom_loop: + error_output = f"Error: Doom loop detected - tool '{tool_name}' called repeatedly" + await Message.add_part( + message_id, + session_id, + MessagePart( + id="", + session_id=session_id, + message_id=message_id, + type="tool_result", + tool_call_id=tool_call_id, + tool_output=error_output + ), + user_id + ) + return error_output, "error" + + # Registry์—์„œ ๋„๊ตฌ ๊ฐ€์ ธ์˜ค๊ธฐ + registry = get_registry() + tool = registry.get(tool_name) + + if not tool: + error_output = f"Error: Tool '{tool_name}' not found" + await Message.add_part( + message_id, + session_id, + MessagePart( + id="", + session_id=session_id, + message_id=message_id, + type="tool_result", + tool_call_id=tool_call_id, + tool_output=error_output + ), + user_id + ) + return error_output, "error" + + ctx = ToolContext( + session_id=session_id, + message_id=message_id, + tool_call_id=tool_call_id, + ) + + try: + result = await tool.execute(tool_args, ctx) + + # ์ถœ๋ ฅ ๊ธธ์ด ์ œํ•œ ์ ์šฉ + truncated_output = tool.truncate_output(result.output) + output = f"[{result.title}]\n{truncated_output}" + status = "completed" + except Exception as e: + output = f"Error executing tool: {str(e)}" + status = "error" + + await Message.add_part( + message_id, + session_id, + MessagePart( + id="", + session_id=session_id, + message_id=message_id, + type="tool_result", + tool_call_id=tool_call_id, + tool_output=output + ), + user_id + ) + + return output, status + + @classmethod + def cancel(cls, session_id: str) -> bool: + """Cancel an active session.""" + cancelled = False + + if session_id in cls._active_sessions: + cls._active_sessions[session_id].cancel() + del cls._active_sessions[session_id] + cancelled = True + + if session_id in cls._loop_states: + cls._loop_states[session_id].paused = True + cls._loop_states[session_id].pause_reason = "cancelled" + del cls._loop_states[session_id] + cancelled = True + + return cancelled + + @classmethod + def get_loop_state(cls, session_id: str) -> Optional[LoopState]: + """Get the current loop state for a session.""" + return cls._loop_states.get(session_id) + + @classmethod + async def resume(cls, session_id: str) -> AsyncIterator[StreamChunk]: + state = cls._loop_states.get(session_id) + if not state or not state.paused: + yield StreamChunk(type="error", error="No paused loop to resume") + return + + state.paused = False + state.pause_reason = None + + session = await Session.get(session_id) + agent_id = session.agent_id or "build" + agent = get_agent(agent_id) or default_agent() + + continue_input = PromptInput(content="") + + while state.stop_reason == "tool_calls" and not state.paused and state.step < state.max_steps: + state.step += 1 + + yield StreamChunk(type="text", text=f"\n[Resuming... step {state.step}/{state.max_steps}]\n") + + async for chunk in cls._single_turn(session_id, continue_input, agent, is_continuation=True): + yield chunk + + if chunk.type == "tool_call" and chunk.tool_call: + if chunk.tool_call.name == "question" and agent.pause_on_question: + state.paused = True + state.pause_reason = "question" + + elif chunk.type == "done": + state.stop_reason = chunk.stop_reason diff --git a/src/opencode_api/session/session.py b/src/opencode_api/session/session.py new file mode 100644 index 0000000000000000000000000000000000000000..893cd72b92b4a7d9277b54e16b33cd0591ffb1b5 --- /dev/null +++ b/src/opencode_api/session/session.py @@ -0,0 +1,159 @@ +from typing import Optional, List, Dict, Any +from pydantic import BaseModel +from datetime import datetime + +from ..core.storage import Storage, NotFoundError +from ..core.bus import Bus, SESSION_CREATED, SESSION_UPDATED, SESSION_DELETED, SessionPayload +from ..core.identifier import Identifier +from ..core.supabase import get_client, is_enabled as supabase_enabled + + +class SessionInfo(BaseModel): + id: str + user_id: Optional[str] = None + title: str + created_at: datetime + updated_at: datetime + provider_id: Optional[str] = None + model_id: Optional[str] = None + agent_id: Optional[str] = None + + +class SessionCreate(BaseModel): + title: Optional[str] = None + provider_id: Optional[str] = None + model_id: Optional[str] = None + agent_id: Optional[str] = None + + +class Session: + + @staticmethod + async def create(data: Optional[SessionCreate] = None, user_id: Optional[str] = None) -> SessionInfo: + session_id = Identifier.generate("session") + now = datetime.utcnow() + + info = SessionInfo( + id=session_id, + user_id=user_id, + title=data.title if data and data.title else f"Session {now.isoformat()}", + created_at=now, + updated_at=now, + provider_id=data.provider_id if data else None, + model_id=data.model_id if data else None, + agent_id=data.agent_id if data else "build", + ) + + if supabase_enabled() and user_id: + client = get_client() + client.table("opencode_sessions").insert({ + "id": session_id, + "user_id": user_id, + "title": info.title, + "agent_id": info.agent_id, + "provider_id": info.provider_id, + "model_id": info.model_id, + }).execute() + else: + await Storage.write(["session", session_id], info) + + await Bus.publish(SESSION_CREATED, SessionPayload(id=session_id, title=info.title)) + return info + + @staticmethod + async def get(session_id: str, user_id: Optional[str] = None) -> SessionInfo: + if supabase_enabled() and user_id: + client = get_client() + result = client.table("opencode_sessions").select("*").eq("id", session_id).eq("user_id", user_id).single().execute() + if not result.data: + raise NotFoundError(["session", session_id]) + return SessionInfo( + id=result.data["id"], + user_id=result.data["user_id"], + title=result.data["title"], + created_at=result.data["created_at"], + updated_at=result.data["updated_at"], + provider_id=result.data.get("provider_id"), + model_id=result.data.get("model_id"), + agent_id=result.data.get("agent_id"), + ) + + data = await Storage.read(["session", session_id]) + if not data: + raise NotFoundError(["session", session_id]) + return SessionInfo(**data) + + @staticmethod + async def update(session_id: str, updates: Dict[str, Any], user_id: Optional[str] = None) -> SessionInfo: + updates["updated_at"] = datetime.utcnow().isoformat() + + if supabase_enabled() and user_id: + client = get_client() + result = client.table("opencode_sessions").update(updates).eq("id", session_id).eq("user_id", user_id).execute() + if not result.data: + raise NotFoundError(["session", session_id]) + return await Session.get(session_id, user_id) + + def updater(data: Dict[str, Any]): + data.update(updates) + + data = await Storage.update(["session", session_id], updater) + info = SessionInfo(**data) + await Bus.publish(SESSION_UPDATED, SessionPayload(id=session_id, title=info.title)) + return info + + @staticmethod + async def delete(session_id: str, user_id: Optional[str] = None) -> bool: + if supabase_enabled() and user_id: + client = get_client() + client.table("opencode_sessions").delete().eq("id", session_id).eq("user_id", user_id).execute() + await Bus.publish(SESSION_DELETED, SessionPayload(id=session_id, title="")) + return True + + info = await Session.get(session_id) + message_keys = await Storage.list(["message", session_id]) + for key in message_keys: + await Storage.remove(key) + + await Storage.remove(["session", session_id]) + await Bus.publish(SESSION_DELETED, SessionPayload(id=session_id, title=info.title)) + return True + + @staticmethod + async def list(limit: Optional[int] = None, user_id: Optional[str] = None) -> List[SessionInfo]: + if supabase_enabled() and user_id: + client = get_client() + query = client.table("opencode_sessions").select("*").eq("user_id", user_id).order("updated_at", desc=True) + if limit: + query = query.limit(limit) + result = query.execute() + return [ + SessionInfo( + id=row["id"], + user_id=row["user_id"], + title=row["title"], + created_at=row["created_at"], + updated_at=row["updated_at"], + provider_id=row.get("provider_id"), + model_id=row.get("model_id"), + agent_id=row.get("agent_id"), + ) + for row in result.data + ] + + session_keys = await Storage.list(["session"]) + sessions = [] + + for key in session_keys: + if limit and len(sessions) >= limit: + break + data = await Storage.read(key) + if data: + sessions.append(SessionInfo(**data)) + + sessions.sort(key=lambda s: s.updated_at, reverse=True) + return sessions + + @staticmethod + async def touch(session_id: str, user_id: Optional[str] = None) -> None: + await Session.update(session_id, {}, user_id) diff --git a/src/opencode_api/tool/__init__.py b/src/opencode_api/tool/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..66e71d2596f622629d1791737fad41ebeafe76e4 --- /dev/null +++ b/src/opencode_api/tool/__init__.py @@ -0,0 +1,27 @@ +from .tool import Tool, ToolContext, ToolResult, register_tool, get_tool, list_tools, get_tools_schema +from .registry import ToolRegistry, get_registry +from .websearch import WebSearchTool +from .webfetch import WebFetchTool +from .todo import TodoTool +from .question import ( + QuestionTool, + QuestionInfo, + QuestionOption, + QuestionRequest, + QuestionReply, + ask_questions, + reply_to_question, + reject_question, + get_pending_questions, +) +from .skill import SkillTool, SkillInfo, register_skill, get_skill, list_skills + +__all__ = [ + "Tool", "ToolContext", "ToolResult", + "register_tool", "get_tool", "list_tools", "get_tools_schema", + "ToolRegistry", "get_registry", + "WebSearchTool", "WebFetchTool", "TodoTool", + "QuestionTool", "QuestionInfo", "QuestionOption", "QuestionRequest", "QuestionReply", + "ask_questions", "reply_to_question", "reject_question", "get_pending_questions", + "SkillTool", "SkillInfo", "register_skill", "get_skill", "list_skills", +] diff --git a/src/opencode_api/tool/__pycache__/__init__.cpython-312.pyc b/src/opencode_api/tool/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bde2e5f153f605ba3a555612079fd99b57d57e57 Binary files /dev/null and b/src/opencode_api/tool/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/opencode_api/tool/__pycache__/question.cpython-312.pyc b/src/opencode_api/tool/__pycache__/question.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1904a4353ed35158ba2f45aeef0ccc2aa6a4c048 Binary files /dev/null and b/src/opencode_api/tool/__pycache__/question.cpython-312.pyc differ diff --git a/src/opencode_api/tool/__pycache__/registry.cpython-312.pyc b/src/opencode_api/tool/__pycache__/registry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..435ccfc38bb6353e415af21b72949cc6235fa503 Binary files /dev/null and b/src/opencode_api/tool/__pycache__/registry.cpython-312.pyc differ diff --git a/src/opencode_api/tool/__pycache__/skill.cpython-312.pyc b/src/opencode_api/tool/__pycache__/skill.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e904bc544e993fa3553ebbdabd7360739e9d98fe Binary files /dev/null and b/src/opencode_api/tool/__pycache__/skill.cpython-312.pyc differ diff --git a/src/opencode_api/tool/__pycache__/todo.cpython-312.pyc b/src/opencode_api/tool/__pycache__/todo.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..856f28b0054f6eebd6ab9bffc17c854eb16f8ded Binary files /dev/null and b/src/opencode_api/tool/__pycache__/todo.cpython-312.pyc differ diff --git a/src/opencode_api/tool/__pycache__/tool.cpython-312.pyc b/src/opencode_api/tool/__pycache__/tool.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5cca01840e730cfb9849ea3ae5f5bb2db25b9aa Binary files /dev/null and b/src/opencode_api/tool/__pycache__/tool.cpython-312.pyc differ diff --git a/src/opencode_api/tool/__pycache__/webfetch.cpython-312.pyc b/src/opencode_api/tool/__pycache__/webfetch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6ebbbc3338f84545e29bd95f36da10d0b3f9725 Binary files /dev/null and b/src/opencode_api/tool/__pycache__/webfetch.cpython-312.pyc differ diff --git a/src/opencode_api/tool/__pycache__/websearch.cpython-312.pyc b/src/opencode_api/tool/__pycache__/websearch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fa8ac7227726bac5c13e406a8dba55bb773cc61 Binary files /dev/null and b/src/opencode_api/tool/__pycache__/websearch.cpython-312.pyc differ diff --git a/src/opencode_api/tool/question.py b/src/opencode_api/tool/question.py new file mode 100644 index 0000000000000000000000000000000000000000..8bd01c17773868e035d342f4001d361155e88823 --- /dev/null +++ b/src/opencode_api/tool/question.py @@ -0,0 +1,308 @@ +"""Question tool - allows agent to ask user questions during execution.""" +from typing import Dict, Any, List, Optional +from pydantic import BaseModel, Field +import asyncio +import logging + +from .tool import BaseTool, ToolResult, ToolContext +from ..core.identifier import generate_id +from ..core.bus import Bus + +logger = logging.getLogger(__name__) + + +# Question schemas +class QuestionOption(BaseModel): + """A single option for a question.""" + label: str = Field(..., description="Display text (1-5 words, concise)") + description: str = Field(..., description="Explanation of choice") + + +class QuestionInfo(BaseModel): + """A question to ask the user.""" + question: str = Field(..., description="Complete question") + header: str = Field(..., description="Very short label (max 30 chars)") + options: List[QuestionOption] = Field(default_factory=list, description="Available choices") + multiple: bool = Field(default=False, description="Allow selecting multiple choices") + custom: bool = Field(default=True, description="Allow typing a custom answer") + + +class QuestionRequest(BaseModel): + """A request containing questions for the user.""" + id: str + session_id: str + questions: List[QuestionInfo] + tool_call_id: Optional[str] = None + message_id: Optional[str] = None + + +class QuestionReply(BaseModel): + """User's reply to questions.""" + request_id: str + answers: List[List[str]] = Field(..., description="Answers in order (each is array of selected labels)") + + +# Events +QUESTION_ASKED = "question.asked" +QUESTION_REPLIED = "question.replied" +QUESTION_REJECTED = "question.rejected" + + +# Pending questions state +_pending_questions: Dict[str, asyncio.Future] = {} + + +async def ask_questions( + session_id: str, + questions: List[QuestionInfo], + tool_call_id: Optional[str] = None, + message_id: Optional[str] = None, + timeout: float = 300.0, # 5 minutes default timeout +) -> List[List[str]]: + """Ask questions and wait for user response.""" + # tool_call_id๋ฅผ request_id๋กœ ์‚ฌ์šฉ (ํ”„๋ก ํŠธ์—”๋“œ์—์„œ ๋ฐ”๋กœ ์‚ฌ์šฉ ๊ฐ€๋Šฅ) + request_id = tool_call_id or generate_id("question") + + request = QuestionRequest( + id=request_id, + session_id=session_id, + questions=questions, + tool_call_id=tool_call_id, + message_id=message_id, + ) + + # Create future for response + # ์ค‘์š”: get_running_loop() ์‚ฌ์šฉ (get_event_loop()๋Š” FastAPI์—์„œ ์ž˜๋ชป๋œ loop ๋ฐ˜ํ™˜ ๊ฐ€๋Šฅ) + loop = asyncio.get_running_loop() + future: asyncio.Future[List[List[str]]] = loop.create_future() + _pending_questions[request_id] = future + + # Publish question event (will be sent via SSE) + await Bus.publish(QUESTION_ASKED, request.model_dump()) + + try: + # Wait for reply with timeout + logger.info(f"[question] Waiting for answer to request_id={request_id}, timeout={timeout}s") + answers = await asyncio.wait_for(future, timeout=timeout) + logger.info(f"[question] Received answer for request_id={request_id}: {answers}") + return answers + except asyncio.TimeoutError: + logger.error(f"[question] Timeout for request_id={request_id} after {timeout}s") + del _pending_questions[request_id] + raise TimeoutError(f"Question timed out after {timeout} seconds") + except Exception as e: + logger.error(f"[question] Error waiting for answer: {type(e).__name__}: {e}") + raise + finally: + if request_id in _pending_questions: + del _pending_questions[request_id] + + +async def reply_to_question(request_id: str, answers: List[List[str]]) -> bool: + """Submit answers to a pending question.""" + logger.info(f"[question] reply_to_question called: request_id={request_id}, answers={answers}") + logger.info(f"[question] pending_questions keys: {list(_pending_questions.keys())}") + + if request_id not in _pending_questions: + logger.error(f"[question] request_id={request_id} NOT FOUND in pending_questions!") + return False + + future = _pending_questions[request_id] + if not future.done(): + logger.info(f"[question] Setting result for request_id={request_id}") + future.set_result(answers) + else: + logger.warning(f"[question] Future already done for request_id={request_id}") + + return True + + +async def reject_question(request_id: str) -> bool: + """Reject/dismiss a pending question.""" + if request_id not in _pending_questions: + return False + + future = _pending_questions[request_id] + if not future.done(): + future.set_exception(QuestionRejectedError()) + + return True + + +def get_pending_questions(session_id: Optional[str] = None) -> List[str]: + """Get list of pending question request IDs.""" + return list(_pending_questions.keys()) + + +class QuestionRejectedError(Exception): + """Raised when user dismisses a question.""" + def __init__(self): + super().__init__("The user dismissed this question") + + +QUESTION_DESCRIPTION = """Use this tool when you need to ask the user questions during execution. This allows you to: +1. Gather user preferences or requirements +2. Clarify ambiguous instructions +3. Get decisions on implementation choices as you work +4. Offer choices to the user about what direction to take. + +IMPORTANT: You MUST provide at least 2 options for each question. Never ask open-ended questions without choices. + +Usage notes: +- REQUIRED: Every question MUST have at least 2 options (minItems: 2) +- When `custom` is enabled (default), a "Type your own answer" option is added automatically; don't include "Other" or catch-all options +- Answers are returned as arrays of labels; set `multiple: true` to allow selecting more than one +- If you recommend a specific option, make that the first option in the list and add "(Recommended)" at the end of the label +""" + + +class QuestionTool(BaseTool): + """Tool for asking user questions during execution.""" + + @property + def id(self) -> str: + return "question" + + @property + def description(self) -> str: + return QUESTION_DESCRIPTION + + @property + def parameters(self) -> Dict[str, Any]: + return { + "type": "object", + "properties": { + "questions": { + "type": "array", + "description": "Questions to ask", + "items": { + "type": "object", + "properties": { + "question": { + "type": "string", + "description": "Complete question" + }, + "header": { + "type": "string", + "description": "Very short label (max 30 chars)" + }, + "options": { + "type": "array", + "description": "Available choices (MUST provide at least 2 options)", + "minItems": 2, + "items": { + "type": "object", + "properties": { + "label": { + "type": "string", + "description": "Display text (1-5 words, concise)" + }, + "description": { + "type": "string", + "description": "Explanation of choice" + } + }, + "required": ["label", "description"] + } + }, + "multiple": { + "type": "boolean", + "description": "Allow selecting multiple choices", + "default": False + } + }, + "required": ["question", "header", "options"] + } + } + }, + "required": ["questions"] + } + + async def execute(self, args: Dict[str, Any], ctx: ToolContext) -> ToolResult: + logger.info(f"[question] execute called with args: {args}") + logger.info(f"[question] args type: {type(args)}") + + questions_data = args.get("questions", []) + logger.info(f"[question] questions_data type: {type(questions_data)}, len: {len(questions_data) if isinstance(questions_data, list) else 'N/A'}") + + if questions_data and len(questions_data) > 0: + logger.info(f"[question] first question type: {type(questions_data[0])}") + logger.info(f"[question] first question content: {questions_data[0]}") + + if not questions_data: + return ToolResult( + title="No questions", + output="No questions were provided.", + metadata={} + ) + + # Parse questions + questions = [] + try: + for idx, q in enumerate(questions_data): + logger.info(f"[question] Parsing question {idx}: type={type(q)}, value={q}") + + # q๊ฐ€ ๋ฌธ์ž์—ด์ธ ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ + if isinstance(q, str): + logger.error(f"[question] Question {idx} is a string, not a dict!") + continue + + options = [] + for opt_idx, opt in enumerate(q.get("options", [])): + logger.info(f"[question] Parsing option {opt_idx}: type={type(opt)}, value={opt}") + if isinstance(opt, dict): + options.append(QuestionOption(label=opt["label"], description=opt["description"])) + else: + logger.error(f"[question] Option {opt_idx} is not a dict: {type(opt)}") + + questions.append(QuestionInfo( + question=q["question"], + header=q["header"], + options=options, + multiple=q.get("multiple", False), + custom=q.get("custom", True), + )) + except Exception as e: + logger.error(f"[question] Error parsing questions: {type(e).__name__}: {e}") + import traceback + logger.error(f"[question] Traceback: {traceback.format_exc()}") + raise + + try: + # Ask questions and wait for response + answers = await ask_questions( + session_id=ctx.session_id, + questions=questions, + tool_call_id=ctx.tool_call_id, + message_id=ctx.message_id, + ) + + # Format response + def format_answer(answer: List[str]) -> str: + if not answer: + return "Unanswered" + return ", ".join(answer) + + formatted = ", ".join( + f'"{q.question}"="{format_answer(answers[i] if i < len(answers) else [])}"' + for i, q in enumerate(questions) + ) + + return ToolResult( + title=f"Asked {len(questions)} question{'s' if len(questions) > 1 else ''}", + output=f"User has answered your questions: {formatted}. You can now continue with the user's answers in mind.", + metadata={"answers": answers} + ) + + except QuestionRejectedError: + return ToolResult( + title="Questions dismissed", + output="The user dismissed the questions without answering.", + metadata={"rejected": True} + ) + except TimeoutError as e: + return ToolResult( + title="Questions timed out", + output=str(e), + metadata={"timeout": True} + ) diff --git a/src/opencode_api/tool/registry.py b/src/opencode_api/tool/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..991e94ec987843a828d0c20694132cda2dd66aa4 --- /dev/null +++ b/src/opencode_api/tool/registry.py @@ -0,0 +1,48 @@ +from typing import Dict, Any, List, Optional +from .tool import BaseTool +import os +import importlib.util + + +class ToolRegistry: + """๋„๊ตฌ ๋ ˆ์ง€์ŠคํŠธ๋ฆฌ - ๋„๊ตฌ ๋“ฑ๋ก ๋ฐ ๊ด€๋ฆฌ""" + + def __init__(self): + self._tools: Dict[str, BaseTool] = {} + + def register(self, tool: BaseTool) -> None: + """๋„๊ตฌ ๋“ฑ๋ก""" + self._tools[tool.id] = tool + + def get(self, tool_id: str) -> Optional[BaseTool]: + """๋„๊ตฌ ID๋กœ ์กฐํšŒ""" + return self._tools.get(tool_id) + + def list(self) -> List[BaseTool]: + """๋“ฑ๋ก๋œ ๋ชจ๋“  ๋„๊ตฌ ๋ชฉ๋ก ๋ฐ˜ํ™˜""" + return list(self._tools.values()) + + def get_schema(self) -> List[Dict[str, Any]]: + """๋ชจ๋“  ๋„๊ตฌ์˜ ์Šคํ‚ค๋งˆ ๋ฐ˜ํ™˜""" + return [tool.get_schema() for tool in self._tools.values()] + + def load_from_directory(self, path: str) -> None: + """ + ๋””๋ ‰ํ† ๋ฆฌ์—์„œ ๋„๊ตฌ๋ฅผ ๋™์ ์œผ๋กœ ๋กœ๋“œ + (๋‚˜์ค‘์— ๊ตฌํ˜„ ๊ฐ€๋Šฅ - ํ”Œ๋Ÿฌ๊ทธ์ธ ์‹œ์Šคํ…œ) + """ + if not os.path.exists(path): + raise ValueError(f"Directory not found: {path}") + + # ํ–ฅํ›„ ๊ตฌํ˜„: .py ํŒŒ์ผ์„ ์Šค์บ”ํ•˜๊ณ  BaseTool ์„œ๋ธŒํด๋ž˜์Šค๋ฅผ ์ฐพ์•„ ์ž๋™ ๋“ฑ๋ก + # ํ˜„์žฌ๋Š” placeholder + pass + + +# ์ „์—ญ ์‹ฑ๊ธ€ํ†ค ์ธ์Šคํ„ด์Šค +_registry = ToolRegistry() + + +def get_registry() -> ToolRegistry: + """์ „์—ญ ๋ ˆ์ง€์ŠคํŠธ๋ฆฌ ์ธ์Šคํ„ด์Šค ๋ฐ˜ํ™˜""" + return _registry diff --git a/src/opencode_api/tool/skill.py b/src/opencode_api/tool/skill.py new file mode 100644 index 0000000000000000000000000000000000000000..fb30fea76a661850023c04a1ad2d7c1d401d4361 --- /dev/null +++ b/src/opencode_api/tool/skill.py @@ -0,0 +1,369 @@ +"""Skill tool - loads detailed instructions for specific tasks.""" +from typing import Dict, Any, List, Optional +from pydantic import BaseModel, Field + +from .tool import BaseTool, ToolResult, ToolContext + + +class SkillInfo(BaseModel): + """Information about a skill.""" + name: str + description: str + content: str + + +# Built-in skills registry +_skills: Dict[str, SkillInfo] = {} + + +def register_skill(skill: SkillInfo) -> None: + """Register a skill.""" + _skills[skill.name] = skill + + +def get_skill(name: str) -> Optional[SkillInfo]: + """Get a skill by name.""" + return _skills.get(name) + + +def list_skills() -> List[SkillInfo]: + """List all registered skills.""" + return list(_skills.values()) + + +# Built-in default skills +DEFAULT_SKILLS = [ + SkillInfo( + name="web-research", + description="Comprehensive web research methodology for gathering information from multiple sources", + content="""# Web Research Skill + +## Purpose +Guide for conducting thorough web research to answer questions or gather information. + +## Methodology + +### 1. Query Formulation +- Break down complex questions into specific search queries +- Use different phrasings to get diverse results +- Include domain-specific terms when relevant + +### 2. Source Evaluation +- Prioritize authoritative sources (official docs, reputable publications) +- Cross-reference information across multiple sources +- Note publication dates for time-sensitive information + +### 3. Information Synthesis +- Compile findings from multiple sources +- Identify consensus vs. conflicting information +- Summarize key points clearly + +### 4. Citation +- Always provide source URLs +- Note when information might be outdated + +## Tools to Use +- `websearch`: For finding relevant pages +- `webfetch`: For extracting content from specific URLs + +## Best Practices +- Start broad, then narrow down +- Use quotes for exact phrases +- Filter by date when freshness matters +- Verify claims with multiple sources +""" + ), + SkillInfo( + name="code-explanation", + description="Methodology for explaining code clearly to users of varying skill levels", + content="""# Code Explanation Skill + +## Purpose +Guide for explaining code in a clear, educational manner. + +## Approach + +### 1. Assess Context +- Determine user's apparent skill level +- Identify what aspect they're asking about +- Note any specific confusion points + +### 2. Structure Explanation +- Start with high-level overview (what does it do?) +- Break down into logical sections +- Explain each component's purpose + +### 3. Use Analogies +- Relate concepts to familiar ideas +- Use real-world metaphors when helpful +- Avoid overly technical jargon initially + +### 4. Provide Examples +- Show simple examples first +- Build up to complex cases +- Include edge cases when relevant + +### 5. Verify Understanding +- Use the question tool to check comprehension +- Offer to elaborate on specific parts +- Provide additional resources if needed + +## Best Practices +- Don't assume prior knowledge +- Explain "why" not just "what" +- Use code comments effectively +- Highlight common pitfalls +""" + ), + SkillInfo( + name="api-integration", + description="Best practices for integrating with external APIs", + content="""# API Integration Skill + +## Purpose +Guide for properly integrating with external APIs. + +## Key Considerations + +### 1. Authentication +- Store API keys securely (environment variables) +- Never hardcode credentials +- Handle token refresh if applicable + +### 2. Error Handling +- Implement retry logic for transient failures +- Handle rate limiting gracefully +- Log errors with context + +### 3. Request Best Practices +- Set appropriate timeouts +- Use connection pooling +- Implement circuit breakers for resilience + +### 4. Response Handling +- Validate response schemas +- Handle pagination properly +- Cache responses when appropriate + +### 5. Testing +- Mock API responses in tests +- Test error scenarios +- Verify rate limit handling + +## Common Patterns + +```python +# Example: Robust API call +async def call_api(url, retries=3): + for attempt in range(retries): + try: + response = await httpx.get(url, timeout=30) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + if e.response.status_code == 429: + await asyncio.sleep(2 ** attempt) + elif e.response.status_code >= 500: + await asyncio.sleep(1) + else: + raise + raise Exception("Max retries exceeded") +``` +""" + ), + SkillInfo( + name="debugging", + description="Systematic approach to debugging problems", + content="""# Debugging Skill + +## Purpose +Systematic methodology for identifying and fixing bugs. + +## Process + +### 1. Reproduce the Issue +- Get exact steps to reproduce +- Note environment details +- Identify when it started happening + +### 2. Gather Information +- Check error messages and stack traces +- Review recent changes +- Check logs for anomalies + +### 3. Form Hypotheses +- List possible causes +- Rank by likelihood +- Consider recent changes first + +### 4. Test Hypotheses +- Start with most likely cause +- Make minimal changes to test +- Verify each hypothesis before moving on + +### 5. Implement Fix +- Fix root cause, not symptoms +- Add tests to prevent regression +- Document the fix + +### 6. Verify Fix +- Confirm original issue is resolved +- Check for side effects +- Test related functionality + +## Debugging Questions +- What changed recently? +- Does it happen consistently? +- What's different when it works? +- What are the exact inputs? + +## Tools +- Use print/log statements strategically +- Leverage debuggers when available +- Check version differences +""" + ), + SkillInfo( + name="task-planning", + description="Breaking down complex tasks into manageable steps", + content="""# Task Planning Skill + +## Purpose +Guide for decomposing complex tasks into actionable steps. + +## Methodology + +### 1. Understand the Goal +- Clarify the end objective +- Identify success criteria +- Note any constraints + +### 2. Identify Components +- Break into major phases +- List dependencies between parts +- Identify parallel vs. sequential work + +### 3. Create Action Items +- Make each item specific and actionable +- Estimate effort/complexity +- Assign priorities + +### 4. Sequence Work +- Order by dependencies +- Front-load risky items +- Plan for blockers + +### 5. Track Progress +- Use todo tool to track items +- Update status as work progresses +- Re-plan when needed + +## Best Practices +- Start with end goal in mind +- Keep items small (< 1 hour ideal) +- Include verification steps +- Plan for error cases + +## Example Structure +1. Research & understand requirements +2. Design approach +3. Implement core functionality +4. Add error handling +5. Test thoroughly +6. Document changes +""" + ), +] + + +def _get_skill_description(skills: List[SkillInfo]) -> str: + """Generate description with available skills.""" + if not skills: + return "Load a skill to get detailed instructions for a specific task. No skills are currently available." + + lines = [ + "Load a skill to get detailed instructions for a specific task.", + "Skills provide specialized knowledge and step-by-step guidance.", + "Use this when a task matches an available skill's description.", + "", + "", + ] + + for skill in skills: + lines.extend([ + f" ", + f" {skill.name}", + f" {skill.description}", + f" ", + ]) + + lines.append("") + + return "\n".join(lines) + + +class SkillTool(BaseTool): + """Tool for loading skill instructions.""" + + def __init__(self, additional_skills: Optional[List[SkillInfo]] = None): + """Initialize with optional additional skills.""" + # Register default skills + for skill in DEFAULT_SKILLS: + register_skill(skill) + + # Register additional skills if provided + if additional_skills: + for skill in additional_skills: + register_skill(skill) + + @property + def id(self) -> str: + return "skill" + + @property + def description(self) -> str: + return _get_skill_description(list_skills()) + + @property + def parameters(self) -> Dict[str, Any]: + skill_names = [s.name for s in list_skills()] + examples = ", ".join(f"'{n}'" for n in skill_names[:3]) + hint = f" (e.g., {examples}, ...)" if examples else "" + + return { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": f"The skill identifier from available_skills{hint}", + "enum": skill_names if skill_names else None + } + }, + "required": ["name"] + } + + async def execute(self, args: Dict[str, Any], ctx: ToolContext) -> ToolResult: + skill_name = args.get("name", "") + + skill = get_skill(skill_name) + + if not skill: + available = ", ".join(s.name for s in list_skills()) + return ToolResult( + title=f"Skill not found: {skill_name}", + output=f'Skill "{skill_name}" not found. Available skills: {available or "none"}', + metadata={"error": True} + ) + + output = f"""## Skill: {skill.name} + +**Description**: {skill.description} + +{skill.content} +""" + + return ToolResult( + title=f"Loaded skill: {skill.name}", + output=output, + metadata={"name": skill.name} + ) diff --git a/src/opencode_api/tool/todo.py b/src/opencode_api/tool/todo.py new file mode 100644 index 0000000000000000000000000000000000000000..81b9646139a0cc072e44157adc87570a1cbd1e0b --- /dev/null +++ b/src/opencode_api/tool/todo.py @@ -0,0 +1,128 @@ +from typing import Dict, Any, List, Optional +from pydantic import BaseModel +from .tool import BaseTool, ToolContext, ToolResult +from ..core.storage import Storage + + +class TodoItem(BaseModel): + id: str + content: str + status: str = "pending" # pending, in_progress, completed, cancelled + priority: str = "medium" # high, medium, low + + +class TodoTool(BaseTool): + + @property + def id(self) -> str: + return "todo" + + @property + def description(self) -> str: + return ( + "Manage a todo list for tracking tasks. Use this to create, update, " + "and track progress on multi-step tasks. Supports pending, in_progress, " + "completed, and cancelled statuses." + ) + + @property + def parameters(self) -> Dict[str, Any]: + return { + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["read", "write"], + "description": "Action to perform: 'read' to get todos, 'write' to update todos" + }, + "todos": { + "type": "array", + "description": "List of todos (required for 'write' action)", + "items": { + "type": "object", + "properties": { + "id": {"type": "string"}, + "content": {"type": "string"}, + "status": { + "type": "string", + "enum": ["pending", "in_progress", "completed", "cancelled"] + }, + "priority": { + "type": "string", + "enum": ["high", "medium", "low"] + } + }, + "required": ["id", "content", "status", "priority"] + } + } + }, + "required": ["action"] + } + + async def execute(self, args: Dict[str, Any], ctx: ToolContext) -> ToolResult: + action = args["action"] + session_id = ctx.session_id + + if action == "read": + return await self._read_todos(session_id) + elif action == "write": + todos_data = args.get("todos", []) + return await self._write_todos(session_id, todos_data) + else: + return ToolResult( + title="Todo Error", + output=f"Unknown action: {action}", + metadata={"error": "invalid_action"} + ) + + async def _read_todos(self, session_id: str) -> ToolResult: + todos = await Storage.read(["todo", session_id]) + + if not todos: + return ToolResult( + title="Todo List", + output="No todos found for this session.", + metadata={"count": 0} + ) + + items = [TodoItem(**t) for t in todos] + lines = self._format_todos(items) + + return ToolResult( + title="Todo List", + output="\n".join(lines), + metadata={"count": len(items)} + ) + + async def _write_todos(self, session_id: str, todos_data: List[Dict]) -> ToolResult: + items = [TodoItem(**t) for t in todos_data] + await Storage.write(["todo", session_id], [t.model_dump() for t in items]) + + lines = self._format_todos(items) + + return ToolResult( + title="Todo List Updated", + output="\n".join(lines), + metadata={"count": len(items)} + ) + + def _format_todos(self, items: List[TodoItem]) -> List[str]: + status_icons = { + "pending": "[ ]", + "in_progress": "[~]", + "completed": "[x]", + "cancelled": "[-]" + } + priority_icons = { + "high": "!!!", + "medium": "!!", + "low": "!" + } + + lines = [] + for item in items: + icon = status_icons.get(item.status, "[ ]") + priority = priority_icons.get(item.priority, "") + lines.append(f"{icon} {priority} {item.content} (id: {item.id})") + + return lines if lines else ["No todos."] diff --git a/src/opencode_api/tool/tool.py b/src/opencode_api/tool/tool.py new file mode 100644 index 0000000000000000000000000000000000000000..bd7e1b54607f009ce840d3282a27a88d3c0e79dd --- /dev/null +++ b/src/opencode_api/tool/tool.py @@ -0,0 +1,109 @@ +from typing import Dict, Any, List, Optional, Callable, Awaitable, Protocol, runtime_checkable +from pydantic import BaseModel +from abc import ABC, abstractmethod +from datetime import datetime + + +class ToolContext(BaseModel): + session_id: str + message_id: str + tool_call_id: Optional[str] = None + agent: str = "default" + + +class ToolResult(BaseModel): + title: str + output: str + metadata: Dict[str, Any] = {} + truncated: bool = False + original_length: int = 0 + + +@runtime_checkable +class Tool(Protocol): + + @property + def id(self) -> str: ... + + @property + def description(self) -> str: ... + + @property + def parameters(self) -> Dict[str, Any]: ... + + async def execute(self, args: Dict[str, Any], ctx: ToolContext) -> ToolResult: ... + + +class BaseTool(ABC): + MAX_OUTPUT_LENGTH = 50000 + + def __init__(self): + self.status: str = "pending" + self.time_start: Optional[datetime] = None + self.time_end: Optional[datetime] = None + + @property + @abstractmethod + def id(self) -> str: + pass + + @property + @abstractmethod + def description(self) -> str: + pass + + @property + @abstractmethod + def parameters(self) -> Dict[str, Any]: + pass + + @abstractmethod + async def execute(self, args: Dict[str, Any], ctx: ToolContext) -> ToolResult: + pass + + def get_schema(self) -> Dict[str, Any]: + return { + "name": self.id, + "description": self.description, + "parameters": self.parameters + } + + def truncate_output(self, output: str) -> str: + """์ถœ๋ ฅ์ด MAX_OUTPUT_LENGTH๋ฅผ ์ดˆ๊ณผํ•˜๋ฉด ์ž๋ฅด๊ณ  ๋ฉ”์‹œ์ง€ ์ถ”๊ฐ€""" + if len(output) <= self.MAX_OUTPUT_LENGTH: + return output + + truncated = output[:self.MAX_OUTPUT_LENGTH] + truncated += "\n\n[Output truncated...]" + return truncated + + def update_status(self, status: str) -> None: + """๋„๊ตฌ ์ƒํƒœ ์—…๋ฐ์ดํŠธ (pending, running, completed, error)""" + self.status = status + if status == "running" and self.time_start is None: + self.time_start = datetime.now() + elif status in ("completed", "error") and self.time_end is None: + self.time_end = datetime.now() + + +from .registry import get_registry + + +def register_tool(tool: BaseTool) -> None: + """๋„๊ตฌ ๋“ฑ๋ก (ํ˜ธํ™˜์„ฑ ํ•จ์ˆ˜ - ToolRegistry ์‚ฌ์šฉ)""" + get_registry().register(tool) + + +def get_tool(tool_id: str) -> Optional[BaseTool]: + """๋„๊ตฌ ์กฐํšŒ (ํ˜ธํ™˜์„ฑ ํ•จ์ˆ˜ - ToolRegistry ์‚ฌ์šฉ)""" + return get_registry().get(tool_id) + + +def list_tools() -> List[BaseTool]: + """๋„๊ตฌ ๋ชฉ๋ก (ํ˜ธํ™˜์„ฑ ํ•จ์ˆ˜ - ToolRegistry ์‚ฌ์šฉ)""" + return get_registry().list() + + +def get_tools_schema() -> List[Dict[str, Any]]: + """๋„๊ตฌ ์Šคํ‚ค๋งˆ ๋ชฉ๋ก (ํ˜ธํ™˜์„ฑ ํ•จ์ˆ˜ - ToolRegistry ์‚ฌ์šฉ)""" + return get_registry().get_schema() diff --git a/src/opencode_api/tool/webfetch.py b/src/opencode_api/tool/webfetch.py new file mode 100644 index 0000000000000000000000000000000000000000..fd60c01b3d90d44c08c6832af684db12d5cc9cda --- /dev/null +++ b/src/opencode_api/tool/webfetch.py @@ -0,0 +1,117 @@ +from typing import Dict, Any +import httpx +from .tool import BaseTool, ToolContext, ToolResult + + +class WebFetchTool(BaseTool): + + @property + def id(self) -> str: + return "webfetch" + + @property + def description(self) -> str: + return ( + "Fetch content from a URL and convert it to readable text or markdown. " + "Use this when you need to read the content of a specific web page." + ) + + @property + def parameters(self) -> Dict[str, Any]: + return { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "The URL to fetch" + }, + "format": { + "type": "string", + "enum": ["text", "markdown", "html"], + "description": "Output format (default: markdown)", + "default": "markdown" + } + }, + "required": ["url"] + } + + async def execute(self, args: Dict[str, Any], ctx: ToolContext) -> ToolResult: + url = args["url"] + output_format = args.get("format", "markdown") + + if not url.startswith(("http://", "https://")): + url = "https://" + url + + try: + async with httpx.AsyncClient(follow_redirects=True, timeout=30.0) as client: + response = await client.get( + url, + headers={ + "User-Agent": "Mozilla/5.0 (compatible; OpenCode-API/1.0)" + } + ) + response.raise_for_status() + html_content = response.text + + if output_format == "html": + content = html_content[:50000] # Limit size + elif output_format == "text": + content = self._html_to_text(html_content) + else: # markdown + content = self._html_to_markdown(html_content) + + if len(content) > 50000: + content = content[:50000] + "\n\n[Content truncated...]" + + return ToolResult( + title=f"Fetched: {url}", + output=content, + metadata={"url": url, "format": output_format, "length": len(content)} + ) + + except httpx.HTTPStatusError as e: + return ToolResult( + title=f"Fetch failed: {url}", + output=f"HTTP Error {e.response.status_code}: {e.response.reason_phrase}", + metadata={"error": "http_error", "status_code": e.response.status_code} + ) + except httpx.RequestError as e: + return ToolResult( + title=f"Fetch failed: {url}", + output=f"Request error: {str(e)}", + metadata={"error": "request_error"} + ) + except Exception as e: + return ToolResult( + title=f"Fetch failed: {url}", + output=f"Error: {str(e)}", + metadata={"error": str(e)} + ) + + def _html_to_text(self, html: str) -> str: + try: + from bs4 import BeautifulSoup + soup = BeautifulSoup(html, "html.parser") + + for tag in soup(["script", "style", "nav", "footer", "header"]): + tag.decompose() + + return soup.get_text(separator="\n", strip=True) + except ImportError: + import re + text = re.sub(r"]*>.*?", "", html, flags=re.DOTALL | re.IGNORECASE) + text = re.sub(r"]*>.*?", "", text, flags=re.DOTALL | re.IGNORECASE) + text = re.sub(r"<[^>]+>", " ", text) + text = re.sub(r"\s+", " ", text) + return text.strip() + + def _html_to_markdown(self, html: str) -> str: + try: + import html2text + h = html2text.HTML2Text() + h.ignore_links = False + h.ignore_images = True + h.body_width = 0 + return h.handle(html) + except ImportError: + return self._html_to_text(html) diff --git a/src/opencode_api/tool/websearch.py b/src/opencode_api/tool/websearch.py new file mode 100644 index 0000000000000000000000000000000000000000..2acf3f807b0a53220057bded0dbcb9d77809146b --- /dev/null +++ b/src/opencode_api/tool/websearch.py @@ -0,0 +1,85 @@ +from typing import Dict, Any, List +from .tool import BaseTool, ToolContext, ToolResult + + +class WebSearchTool(BaseTool): + + @property + def id(self) -> str: + return "websearch" + + @property + def description(self) -> str: + return ( + "Search the web using DuckDuckGo. Returns relevant search results " + "with titles, URLs, and snippets. Use this when you need current " + "information from the internet." + ) + + @property + def parameters(self) -> Dict[str, Any]: + return { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query" + }, + "max_results": { + "type": "integer", + "description": "Maximum number of results to return (default: 5)", + "default": 5 + } + }, + "required": ["query"] + } + + async def execute(self, args: Dict[str, Any], ctx: ToolContext) -> ToolResult: + query = args["query"] + max_results = args.get("max_results", 5) + + try: + from ddgs import DDGS + + results = [] + with DDGS() as ddgs: + # ํ•œ๊ตญ ์ง€์—ญ ๊ธฐ๋ฐ˜ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ + for r in ddgs.text(query, region="kr-kr", max_results=max_results): + results.append({ + "title": r.get("title", ""), + "url": r.get("href", ""), + "snippet": r.get("body", "") + }) + + if not results: + return ToolResult( + title=f"Web search: {query}", + output="No results found.", + metadata={"query": query, "count": 0} + ) + + output_lines = [] + for i, r in enumerate(results, 1): + output_lines.append(f"{i}. {r['title']}") + output_lines.append(f" URL: {r['url']}") + output_lines.append(f" {r['snippet']}") + output_lines.append("") + + return ToolResult( + title=f"Web search: {query}", + output="\n".join(output_lines), + metadata={"query": query, "count": len(results)} + ) + + except ImportError: + return ToolResult( + title=f"Web search: {query}", + output="Error: ddgs package not installed. Run: pip install ddgs", + metadata={"error": "missing_dependency"} + ) + except Exception as e: + return ToolResult( + title=f"Web search: {query}", + output=f"Error performing search: {str(e)}", + metadata={"error": str(e)} + )