Spaces:
Sleeping
Sleeping
File size: 6,548 Bytes
1397957 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
"""
Session processor for managing agentic loop execution.
"""
from typing import Optional, Dict, Any, AsyncIterator, List
from pydantic import BaseModel
from datetime import datetime
import asyncio
from ..provider.provider import StreamChunk
class DoomLoopDetector:
"""๋์ผ ๋๊ตฌ + ๋์ผ ์ธ์ ์ฐ์ ํธ์ถ์ ๊ฐ์งํ์ฌ ๋ฌดํ ๋ฃจํ ๋ฐฉ์ง
์๋ณธ opencode์ ๋์ผํ๊ฒ ๋๊ตฌ ์ด๋ฆ๊ณผ ์ธ์๋ฅผ ๋ชจ๋ ๋น๊ตํฉ๋๋ค.
๊ฐ์ ๋๊ตฌ๋ผ๋ ์ธ์๊ฐ ๋ค๋ฅด๋ฉด ์ ์์ ์ธ ๋ฐ๋ณต์ผ๋ก ํ๋จํฉ๋๋ค.
"""
def __init__(self, threshold: int = 3):
self.threshold = threshold
self.history: List[tuple[str, str]] = [] # (tool_name, args_hash)
def record(self, tool_name: str, args: Optional[Dict[str, Any]] = None) -> bool:
"""๋๊ตฌ ํธ์ถ์ ๊ธฐ๋กํ๊ณ doom loop ๊ฐ์ง ์ True ๋ฐํ
Args:
tool_name: ๋๊ตฌ ์ด๋ฆ
args: ๋๊ตฌ ์ธ์ (์์ผ๋ฉด ๋น dict๋ก ์ฒ๋ฆฌ)
Returns:
True if doom loop detected, False otherwise
"""
import json
import hashlib
# ์ธ์๋ฅผ ์ ๊ทํํ์ฌ ํด์ ์์ฑ (์๋ณธ์ฒ๋ผ JSON ๋น๊ต)
args_dict = args or {}
args_str = json.dumps(args_dict, sort_keys=True, default=str)
args_hash = hashlib.md5(args_str.encode()).hexdigest()[:8]
call_signature = (tool_name, args_hash)
self.history.append(call_signature)
# ์ต๊ทผ threshold๊ฐ๊ฐ ๋ชจ๋ ๊ฐ์ (๋๊ตฌ + ์ธ์)์ธ์ง ํ์ธ
if len(self.history) >= self.threshold:
recent = self.history[-self.threshold:]
if len(set(recent)) == 1: # ํํ ๋น๊ต (๋๊ตฌ+์ธ์)
return True
return False
def reset(self):
self.history = []
class RetryConfig(BaseModel):
"""์ฌ์๋ ์ค์ """
max_retries: int = 3
base_delay: float = 1.0 # seconds
max_delay: float = 30.0
exponential_base: float = 2.0
class StepInfo(BaseModel):
"""์คํ
์ ๋ณด"""
step: int
started_at: datetime
finished_at: Optional[datetime] = None
tool_calls: List[str] = []
status: str = "running" # running, completed, error, doom_loop
class SessionProcessor:
"""
Agentic loop ์คํ์ ๊ด๋ฆฌํ๋ ํ๋ก์ธ์.
Features:
- Doom loop ๋ฐฉ์ง (๋์ผ ๋๊ตฌ ์ฐ์ ํธ์ถ ๊ฐ์ง)
- ์๋ ์ฌ์๋ (exponential backoff)
- ์คํ
์ถ์ (step-start, step-finish ์ด๋ฒคํธ)
"""
_processors: Dict[str, "SessionProcessor"] = {}
def __init__(self, session_id: str, max_steps: int = 50, doom_threshold: int = 3):
self.session_id = session_id
self.max_steps = max_steps
self.doom_detector = DoomLoopDetector(threshold=doom_threshold)
self.retry_config = RetryConfig()
self.steps: List[StepInfo] = []
self.current_step: Optional[StepInfo] = None
self.aborted = False
@classmethod
def get_or_create(cls, session_id: str, **kwargs) -> "SessionProcessor":
if session_id not in cls._processors:
cls._processors[session_id] = cls(session_id, **kwargs)
return cls._processors[session_id]
@classmethod
def remove(cls, session_id: str) -> None:
if session_id in cls._processors:
del cls._processors[session_id]
def start_step(self) -> StepInfo:
"""์ ์คํ
์์"""
step_num = len(self.steps) + 1
self.current_step = StepInfo(
step=step_num,
started_at=datetime.utcnow()
)
self.steps.append(self.current_step)
return self.current_step
def finish_step(self, status: str = "completed") -> StepInfo:
"""ํ์ฌ ์คํ
์๋ฃ"""
if self.current_step:
self.current_step.finished_at = datetime.utcnow()
self.current_step.status = status
return self.current_step
def record_tool_call(self, tool_name: str, tool_args: Optional[Dict[str, Any]] = None) -> bool:
"""๋๊ตฌ ํธ์ถ ๊ธฐ๋ก, doom loop ๊ฐ์ง ์ True ๋ฐํ
Args:
tool_name: ๋๊ตฌ ์ด๋ฆ
tool_args: ๋๊ตฌ ์ธ์ (doom loop ํ๋ณ์ ์ฌ์ฉ)
Returns:
True if doom loop detected, False otherwise
"""
if self.current_step:
self.current_step.tool_calls.append(tool_name)
return self.doom_detector.record(tool_name, tool_args)
def is_doom_loop(self) -> bool:
"""ํ์ฌ doom loop ์ํ์ธ์ง ํ์ธ"""
return len(self.doom_detector.history) >= self.doom_detector.threshold and \
len(set(self.doom_detector.history[-self.doom_detector.threshold:])) == 1
def should_continue(self) -> bool:
"""๋ฃจํ ๊ณ์ ์ฌ๋ถ"""
if self.aborted:
return False
if len(self.steps) >= self.max_steps:
return False
if self.is_doom_loop():
return False
return True
def abort(self) -> None:
"""ํ๋ก์ธ์ ์ค๋จ"""
self.aborted = True
async def calculate_retry_delay(self, attempt: int) -> float:
"""exponential backoff ๋๋ ์ด ๊ณ์ฐ"""
delay = self.retry_config.base_delay * (self.retry_config.exponential_base ** attempt)
return min(delay, self.retry_config.max_delay)
async def retry_with_backoff(self, func, *args, **kwargs):
"""exponential backoff์ผ๋ก ํจ์ ์ฌ์๋"""
last_error = None
for attempt in range(self.retry_config.max_retries):
try:
return await func(*args, **kwargs)
except Exception as e:
last_error = e
if attempt < self.retry_config.max_retries - 1:
delay = await self.calculate_retry_delay(attempt)
await asyncio.sleep(delay)
raise last_error
def get_summary(self) -> Dict[str, Any]:
"""ํ๋ก์ธ์ ์ํ ์์ฝ"""
return {
"session_id": self.session_id,
"total_steps": len(self.steps),
"max_steps": self.max_steps,
"aborted": self.aborted,
"doom_loop_detected": self.is_doom_loop(),
"steps": [
{
"step": s.step,
"status": s.status,
"tool_calls": s.tool_calls,
"duration": (s.finished_at - s.started_at).total_seconds() if s.finished_at else None
}
for s in self.steps
]
}
|