Spaces:
Sleeping
Sleeping
| """ | |
| Session prompt handling with agentic loop support. | |
| """ | |
| from typing import Optional, List, Dict, Any, AsyncIterator, Literal | |
| from pydantic import BaseModel | |
| import asyncio | |
| import json | |
| from .session import Session | |
| from .message import Message, MessagePart, AssistantMessage | |
| from .processor import SessionProcessor | |
| from ..provider import get_provider, list_providers | |
| from ..provider.provider import Message as ProviderMessage, StreamChunk, ToolCall | |
| from ..tool import get_tool, get_tools_schema, ToolContext, get_registry | |
| from ..core.config import settings | |
| from ..core.bus import Bus, PART_UPDATED, PartPayload, STEP_STARTED, STEP_FINISHED, StepPayload, TOOL_STATE_CHANGED, ToolStatePayload | |
| from ..agent import get as get_agent, default_agent, get_system_prompt, is_tool_allowed, AgentInfo, get_prompt_for_provider | |
| class PromptInput(BaseModel): | |
| content: str | |
| provider_id: Optional[str] = None | |
| model_id: Optional[str] = None | |
| system: Optional[str] = None | |
| temperature: Optional[float] = None | |
| max_tokens: Optional[int] = None | |
| tools_enabled: bool = True | |
| # Agentic loop options | |
| auto_continue: Optional[bool] = None # None = use agent default | |
| max_steps: Optional[int] = None # None = use agent default | |
| class LoopState(BaseModel): | |
| step: int = 0 | |
| max_steps: int = 50 | |
| auto_continue: bool = True | |
| stop_reason: Optional[str] = None | |
| paused: bool = False | |
| pause_reason: Optional[str] = None | |
| import re | |
| FAKE_TOOL_CALL_PATTERN = re.compile( | |
| r'\[Called\s+tool:\s*(\w+)\s*\(\s*(\{[^}]*\}|\{[^)]*\}|[^)]*)\s*\)\]', | |
| re.IGNORECASE | |
| ) | |
| class SessionPrompt: | |
| _active_sessions: Dict[str, asyncio.Task] = {} | |
| _loop_states: Dict[str, LoopState] = {} | |
| async def prompt( | |
| cls, | |
| session_id: str, | |
| input: PromptInput, | |
| user_id: Optional[str] = None | |
| ) -> AsyncIterator[StreamChunk]: | |
| session = await Session.get(session_id, user_id) | |
| # Get agent configuration | |
| agent_id = session.agent_id or "build" | |
| agent = get_agent(agent_id) or default_agent() | |
| # Determine loop settings | |
| auto_continue = input.auto_continue if input.auto_continue is not None else agent.auto_continue | |
| max_steps = input.max_steps if input.max_steps is not None else agent.max_steps | |
| if auto_continue: | |
| async for chunk in cls._agentic_loop(session_id, input, agent, max_steps, user_id): | |
| yield chunk | |
| else: | |
| async for chunk in cls._single_turn(session_id, input, agent, user_id=user_id): | |
| yield chunk | |
| async def _agentic_loop( | |
| cls, | |
| session_id: str, | |
| input: PromptInput, | |
| agent: AgentInfo, | |
| max_steps: int, | |
| user_id: Optional[str] = None | |
| ) -> AsyncIterator[StreamChunk]: | |
| state = LoopState(step=0, max_steps=max_steps, auto_continue=True) | |
| cls._loop_states[session_id] = state | |
| # SessionProcessor 가져오기 | |
| processor = SessionProcessor.get_or_create(session_id, max_steps=max_steps) | |
| try: | |
| while processor.should_continue() and not state.paused: | |
| state.step += 1 | |
| # 스텝 시작 | |
| step_info = processor.start_step() | |
| await Bus.publish(STEP_STARTED, StepPayload( | |
| session_id=session_id, | |
| step=state.step, | |
| max_steps=max_steps | |
| )) | |
| print(f"[AGENTIC LOOP] Starting step {state.step}, stop_reason={state.stop_reason}", flush=True) | |
| turn_input = input if state.step == 1 else PromptInput( | |
| content="", | |
| provider_id=input.provider_id, | |
| model_id=input.model_id, | |
| temperature=input.temperature, | |
| max_tokens=input.max_tokens, | |
| tools_enabled=input.tools_enabled, | |
| auto_continue=False, | |
| ) | |
| if state.step > 1: | |
| yield StreamChunk(type="step", text=f"Step {state.step}") | |
| # Track tool calls in this turn | |
| has_tool_calls_this_turn = False | |
| async for chunk in cls._single_turn( | |
| session_id, | |
| turn_input, | |
| agent, | |
| is_continuation=(state.step > 1), | |
| user_id=user_id | |
| ): | |
| yield chunk | |
| if chunk.type == "tool_call" and chunk.tool_call: | |
| has_tool_calls_this_turn = True | |
| print(f"[AGENTIC LOOP] tool_call: {chunk.tool_call.name}", flush=True) | |
| if chunk.tool_call.name == "question" and agent.pause_on_question: | |
| state.paused = True | |
| state.pause_reason = "question" | |
| # question tool이 완료되면 (답변 받음) pause 해제 | |
| elif chunk.type == "tool_result": | |
| if state.paused and state.pause_reason == "question": | |
| state.paused = False | |
| state.pause_reason = None | |
| elif chunk.type == "done": | |
| state.stop_reason = chunk.stop_reason | |
| print(f"[AGENTIC LOOP] done: stop_reason={chunk.stop_reason}", flush=True) | |
| # 스텝 완료 | |
| step_status = "completed" | |
| if processor.is_doom_loop(): | |
| step_status = "doom_loop" | |
| print(f"[AGENTIC LOOP] Doom loop detected! Stopping execution.", flush=True) | |
| yield StreamChunk(type="text", text=f"\n[경고: 동일 도구 반복 호출 감지, 루프를 중단합니다]\n") | |
| processor.finish_step(status=step_status) | |
| await Bus.publish(STEP_FINISHED, StepPayload( | |
| session_id=session_id, | |
| step=state.step, | |
| max_steps=max_steps | |
| )) | |
| print(f"[AGENTIC LOOP] End of step {state.step}: stop_reason={state.stop_reason}, has_tool_calls={has_tool_calls_this_turn}", flush=True) | |
| # Doom loop 감지 시 중단 | |
| if processor.is_doom_loop(): | |
| break | |
| # If this turn had no new tool calls (just text response), we're done | |
| if state.stop_reason != "tool_calls": | |
| print(f"[AGENTIC LOOP] Breaking: stop_reason != tool_calls", flush=True) | |
| break | |
| # Loop 종료 후 상태 메시지만 출력 (summary LLM 호출 없음!) | |
| if state.paused: | |
| yield StreamChunk(type="text", text=f"\n[Paused: {state.pause_reason}]\n") | |
| elif state.step >= state.max_steps: | |
| yield StreamChunk(type="text", text=f"\n[Max steps ({state.max_steps}) reached]\n") | |
| # else: 자연스럽게 종료 (추가 출력 없음) | |
| finally: | |
| if session_id in cls._loop_states: | |
| del cls._loop_states[session_id] | |
| # SessionProcessor 정리 | |
| SessionProcessor.remove(session_id) | |
| def _infer_provider_from_model(cls, model_id: str) -> str: | |
| """model_id에서 provider_id를 추론""" | |
| # LiteLLM prefix 기반 모델은 litellm provider 사용 | |
| litellm_prefixes = ["gemini/", "groq/", "deepseek/", "openrouter/", "zai/"] | |
| for prefix in litellm_prefixes: | |
| if model_id.startswith(prefix): | |
| return "litellm" | |
| # Claude 모델 | |
| if model_id.startswith("claude-"): | |
| return "litellm" | |
| # GPT/O1 모델 | |
| if model_id.startswith("gpt-") or model_id.startswith("o1"): | |
| return "litellm" | |
| # 기본값 | |
| return settings.default_provider | |
| async def _single_turn( | |
| cls, | |
| session_id: str, | |
| input: PromptInput, | |
| agent: AgentInfo, | |
| is_continuation: bool = False, | |
| user_id: Optional[str] = None | |
| ) -> AsyncIterator[StreamChunk]: | |
| session = await Session.get(session_id, user_id) | |
| model_id = input.model_id or session.model_id or settings.default_model | |
| # provider_id가 명시되지 않으면 model_id에서 추론 | |
| if input.provider_id: | |
| provider_id = input.provider_id | |
| elif session.provider_id: | |
| provider_id = session.provider_id | |
| else: | |
| provider_id = cls._infer_provider_from_model(model_id) | |
| print(f"[Prompt DEBUG] input.provider_id={input.provider_id}, session.provider_id={session.provider_id}", flush=True) | |
| print(f"[Prompt DEBUG] Final provider_id={provider_id}, model_id={model_id}", flush=True) | |
| provider = get_provider(provider_id) | |
| print(f"[Prompt DEBUG] Got provider: {provider}", flush=True) | |
| if not provider: | |
| yield StreamChunk(type="error", error=f"Provider not found: {provider_id}") | |
| return | |
| # Only create user message if there's content (not a continuation) | |
| if input.content and not is_continuation: | |
| user_msg = await Message.create_user(session_id, input.content, user_id) | |
| assistant_msg = await Message.create_assistant(session_id, provider_id, model_id, user_id) | |
| # Build message history | |
| history = await Message.list(session_id, user_id=user_id) | |
| messages = cls._build_messages(history[:-1], include_tool_results=True) | |
| # Build system prompt with provider-specific optimization | |
| system_prompt = cls._build_system_prompt(agent, provider_id, input.system) | |
| # Get tools schema | |
| tools_schema = get_tools_schema() if input.tools_enabled else None | |
| current_text_part: Optional[MessagePart] = None | |
| accumulated_text = "" | |
| # reasoning 저장을 위한 변수 | |
| current_reasoning_part: Optional[MessagePart] = None | |
| accumulated_reasoning = "" | |
| try: | |
| async for chunk in provider.stream( | |
| model_id=model_id, | |
| messages=messages, | |
| tools=tools_schema, | |
| system=system_prompt, | |
| temperature=input.temperature or agent.temperature, | |
| max_tokens=input.max_tokens or agent.max_tokens, | |
| ): | |
| if chunk.type == "text": | |
| accumulated_text += chunk.text or "" | |
| if current_text_part is None: | |
| current_text_part = await Message.add_part( | |
| assistant_msg.id, | |
| session_id, | |
| MessagePart( | |
| id="", | |
| session_id=session_id, | |
| message_id=assistant_msg.id, | |
| type="text", | |
| content=accumulated_text | |
| ), | |
| user_id | |
| ) | |
| else: | |
| await Message.update_part( | |
| session_id, | |
| assistant_msg.id, | |
| current_text_part.id, | |
| {"content": accumulated_text}, | |
| user_id | |
| ) | |
| yield chunk | |
| elif chunk.type == "tool_call": | |
| tc = chunk.tool_call | |
| if tc: | |
| # Check permission | |
| permission = is_tool_allowed(agent, tc.name) | |
| if permission == "deny": | |
| yield StreamChunk( | |
| type="tool_result", | |
| text=f"Error: Tool '{tc.name}' is not allowed for this agent" | |
| ) | |
| continue | |
| tool_part = await Message.add_part( | |
| assistant_msg.id, | |
| session_id, | |
| MessagePart( | |
| id="", | |
| session_id=session_id, | |
| message_id=assistant_msg.id, | |
| type="tool_call", | |
| tool_call_id=tc.id, | |
| tool_name=tc.name, | |
| tool_args=tc.arguments, | |
| tool_status="running" # 실행 중 상태 | |
| ), | |
| user_id | |
| ) | |
| # IMPORTANT: Yield tool_call FIRST so frontend can show UI | |
| # This is critical for interactive tools like 'question' | |
| yield chunk | |
| # 도구 실행 시작 이벤트 발행 | |
| await Bus.publish(TOOL_STATE_CHANGED, ToolStatePayload( | |
| session_id=session_id, | |
| message_id=assistant_msg.id, | |
| part_id=tool_part.id, | |
| tool_name=tc.name, | |
| status="running" | |
| )) | |
| # Execute tool (may block for user input, e.g., question tool) | |
| tool_result, tool_status = await cls._execute_tool( | |
| session_id, | |
| assistant_msg.id, | |
| tc.id, | |
| tc.name, | |
| tc.arguments, | |
| user_id | |
| ) | |
| # tool_call 파트의 status를 completed/error로 업데이트 | |
| await Message.update_part( | |
| session_id, | |
| assistant_msg.id, | |
| tool_part.id, | |
| {"tool_status": tool_status}, | |
| user_id | |
| ) | |
| # 도구 완료 이벤트 발행 | |
| await Bus.publish(TOOL_STATE_CHANGED, ToolStatePayload( | |
| session_id=session_id, | |
| message_id=assistant_msg.id, | |
| part_id=tool_part.id, | |
| tool_name=tc.name, | |
| status=tool_status | |
| )) | |
| yield StreamChunk( | |
| type="tool_result", | |
| text=tool_result | |
| ) | |
| else: | |
| yield chunk | |
| elif chunk.type == "reasoning": | |
| # reasoning 저장 (기존에는 yield만 했음) | |
| accumulated_reasoning += chunk.text or "" | |
| if current_reasoning_part is None: | |
| current_reasoning_part = await Message.add_part( | |
| assistant_msg.id, | |
| session_id, | |
| MessagePart( | |
| id="", | |
| session_id=session_id, | |
| message_id=assistant_msg.id, | |
| type="reasoning", | |
| content=accumulated_reasoning | |
| ), | |
| user_id | |
| ) | |
| else: | |
| await Message.update_part( | |
| session_id, | |
| assistant_msg.id, | |
| current_reasoning_part.id, | |
| {"content": accumulated_reasoning}, | |
| user_id | |
| ) | |
| yield chunk | |
| elif chunk.type == "done": | |
| if chunk.usage: | |
| await Message.set_usage(session_id, assistant_msg.id, chunk.usage, user_id) | |
| yield chunk | |
| elif chunk.type == "error": | |
| await Message.set_error(session_id, assistant_msg.id, chunk.error or "Unknown error", user_id) | |
| yield chunk | |
| await Session.touch(session_id) | |
| except Exception as e: | |
| error_msg = str(e) | |
| await Message.set_error(session_id, assistant_msg.id, error_msg, user_id) | |
| yield StreamChunk(type="error", error=error_msg) | |
| def _detect_fake_tool_call(cls, text: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| Detect if the model wrote a fake tool call as text instead of using actual tool calling. | |
| Returns parsed tool call info if detected, None otherwise. | |
| Patterns detected: | |
| - [Called tool: toolname({...})] | |
| - [Called tool: toolname({'key': 'value'})] | |
| """ | |
| if not text: | |
| return None | |
| match = FAKE_TOOL_CALL_PATTERN.search(text) | |
| if match: | |
| tool_name = match.group(1) | |
| args_str = match.group(2).strip() | |
| # Try to parse arguments | |
| args = {} | |
| if args_str: | |
| try: | |
| # Handle both JSON and Python dict formats | |
| args_str = args_str.replace("'", '"') # Convert Python dict to JSON | |
| args = json.loads(args_str) | |
| except json.JSONDecodeError: | |
| # Try to extract key-value pairs manually | |
| # Pattern: 'key': 'value' or "key": "value" | |
| kv_pattern = re.compile(r'["\']?(\w+)["\']?\s*:\s*["\']([^"\']+)["\']') | |
| for kv_match in kv_pattern.finditer(args_str): | |
| args[kv_match.group(1)] = kv_match.group(2) | |
| return { | |
| "name": tool_name, | |
| "arguments": args | |
| } | |
| return None | |
| def _build_system_prompt( | |
| cls, | |
| agent: AgentInfo, | |
| provider_id: str, | |
| custom_system: Optional[str] = None | |
| ) -> Optional[str]: | |
| """Build the complete system prompt. | |
| Args: | |
| agent: The agent configuration | |
| provider_id: The provider identifier for selecting optimized prompt | |
| custom_system: Optional custom system prompt to append | |
| Returns: | |
| The complete system prompt, or None if empty | |
| """ | |
| parts = [] | |
| # Add provider-specific system prompt (optimized for Claude/Gemini/etc.) | |
| provider_prompt = get_prompt_for_provider(provider_id) | |
| if provider_prompt: | |
| parts.append(provider_prompt) | |
| # Add agent-specific prompt (if defined and different from provider prompt) | |
| agent_prompt = get_system_prompt(agent) | |
| if agent_prompt and agent_prompt != provider_prompt: | |
| parts.append(agent_prompt) | |
| # Add custom system prompt | |
| if custom_system: | |
| parts.append(custom_system) | |
| return "\n\n".join(parts) if parts else None | |
| def _build_messages( | |
| cls, | |
| history: List, | |
| include_tool_results: bool = True | |
| ) -> List[ProviderMessage]: | |
| """Build message list for LLM including tool calls and results. | |
| Proper tool calling flow: | |
| 1. User message | |
| 2. Assistant message (may include tool calls) | |
| 3. Tool results (as user message with tool context) | |
| 4. Assistant continues | |
| """ | |
| messages = [] | |
| for msg in history: | |
| if msg.role == "user": | |
| # Skip empty user messages (continuations) | |
| if msg.content: | |
| messages.append(ProviderMessage(role="user", content=msg.content)) | |
| elif msg.role == "assistant": | |
| # Collect all parts | |
| text_parts = [] | |
| tool_calls = [] | |
| tool_results = [] | |
| for part in getattr(msg, "parts", []): | |
| if part.type == "text" and part.content: | |
| text_parts.append(part.content) | |
| elif part.type == "tool_call" and include_tool_results: | |
| tool_calls.append({ | |
| "id": part.tool_call_id, | |
| "name": part.tool_name, | |
| "arguments": part.tool_args or {} | |
| }) | |
| elif part.type == "tool_result" and include_tool_results: | |
| tool_results.append({ | |
| "tool_call_id": part.tool_call_id, | |
| "output": part.tool_output or "" | |
| }) | |
| # Build assistant content - only text, NO tool call summaries | |
| # IMPORTANT: Do NOT include "[Called tool: ...]" patterns as this causes | |
| # models like Gemini to mimic the pattern instead of using actual tool calls | |
| assistant_content_parts = [] | |
| if text_parts: | |
| assistant_content_parts.append("".join(text_parts)) | |
| if assistant_content_parts: | |
| messages.append(ProviderMessage( | |
| role="assistant", | |
| content="\n".join(assistant_content_parts) | |
| )) | |
| # Add tool results as user message (simulating tool response) | |
| if tool_results: | |
| result_content = [] | |
| for result in tool_results: | |
| result_content.append(f"Tool result:\n{result['output']}") | |
| messages.append(ProviderMessage( | |
| role="user", | |
| content="\n\n".join(result_content) | |
| )) | |
| return messages | |
| async def _execute_tool( | |
| cls, | |
| session_id: str, | |
| message_id: str, | |
| tool_call_id: str, | |
| tool_name: str, | |
| tool_args: Dict[str, Any], | |
| user_id: Optional[str] = None | |
| ) -> tuple[str, str]: | |
| """Execute a tool and store the result. Returns (output, status).""" | |
| # SessionProcessor를 통한 doom loop 감지 | |
| # tool_args도 전달하여 같은 도구 + 같은 인자일 때만 doom loop으로 판단 | |
| processor = SessionProcessor.get_or_create(session_id) | |
| is_doom_loop = processor.record_tool_call(tool_name, tool_args) | |
| if is_doom_loop: | |
| error_output = f"Error: Doom loop detected - tool '{tool_name}' called repeatedly" | |
| await Message.add_part( | |
| message_id, | |
| session_id, | |
| MessagePart( | |
| id="", | |
| session_id=session_id, | |
| message_id=message_id, | |
| type="tool_result", | |
| tool_call_id=tool_call_id, | |
| tool_output=error_output | |
| ), | |
| user_id | |
| ) | |
| return error_output, "error" | |
| # Registry에서 도구 가져오기 | |
| registry = get_registry() | |
| tool = registry.get(tool_name) | |
| if not tool: | |
| error_output = f"Error: Tool '{tool_name}' not found" | |
| await Message.add_part( | |
| message_id, | |
| session_id, | |
| MessagePart( | |
| id="", | |
| session_id=session_id, | |
| message_id=message_id, | |
| type="tool_result", | |
| tool_call_id=tool_call_id, | |
| tool_output=error_output | |
| ), | |
| user_id | |
| ) | |
| return error_output, "error" | |
| ctx = ToolContext( | |
| session_id=session_id, | |
| message_id=message_id, | |
| tool_call_id=tool_call_id, | |
| ) | |
| try: | |
| result = await tool.execute(tool_args, ctx) | |
| # 출력 길이 제한 적용 | |
| truncated_output = tool.truncate_output(result.output) | |
| output = f"[{result.title}]\n{truncated_output}" | |
| status = "completed" | |
| except Exception as e: | |
| output = f"Error executing tool: {str(e)}" | |
| status = "error" | |
| await Message.add_part( | |
| message_id, | |
| session_id, | |
| MessagePart( | |
| id="", | |
| session_id=session_id, | |
| message_id=message_id, | |
| type="tool_result", | |
| tool_call_id=tool_call_id, | |
| tool_output=output | |
| ), | |
| user_id | |
| ) | |
| return output, status | |
| def cancel(cls, session_id: str) -> bool: | |
| """Cancel an active session.""" | |
| cancelled = False | |
| if session_id in cls._active_sessions: | |
| cls._active_sessions[session_id].cancel() | |
| del cls._active_sessions[session_id] | |
| cancelled = True | |
| if session_id in cls._loop_states: | |
| cls._loop_states[session_id].paused = True | |
| cls._loop_states[session_id].pause_reason = "cancelled" | |
| del cls._loop_states[session_id] | |
| cancelled = True | |
| return cancelled | |
| def get_loop_state(cls, session_id: str) -> Optional[LoopState]: | |
| """Get the current loop state for a session.""" | |
| return cls._loop_states.get(session_id) | |
| async def resume(cls, session_id: str) -> AsyncIterator[StreamChunk]: | |
| state = cls._loop_states.get(session_id) | |
| if not state or not state.paused: | |
| yield StreamChunk(type="error", error="No paused loop to resume") | |
| return | |
| state.paused = False | |
| state.pause_reason = None | |
| session = await Session.get(session_id) | |
| agent_id = session.agent_id or "build" | |
| agent = get_agent(agent_id) or default_agent() | |
| continue_input = PromptInput(content="") | |
| while state.stop_reason == "tool_calls" and not state.paused and state.step < state.max_steps: | |
| state.step += 1 | |
| yield StreamChunk(type="text", text=f"\n[Resuming... step {state.step}/{state.max_steps}]\n") | |
| async for chunk in cls._single_turn(session_id, continue_input, agent, is_continuation=True): | |
| yield chunk | |
| if chunk.type == "tool_call" and chunk.tool_call: | |
| if chunk.tool_call.name == "question" and agent.pause_on_question: | |
| state.paused = True | |
| state.pause_reason = "question" | |
| elif chunk.type == "done": | |
| state.stop_reason = chunk.stop_reason | |