Spaces:
Running
Running
A-Mahla
commited on
Commit
·
e0d4a07
1
Parent(s):
97e46c6
ADD Backend V1 (#2)
Browse files* ADD backend logic
* FIX pre-commit
* FIX pre-commit
* FIX pre-commit
* FIX Agent loop
* ADD pytest
* ADD pytest
* CHG github workflow
- .github/workflows/pre-commit.yml +1 -1
- Makefile +8 -0
- cua2-core/pyproject.toml +3 -3
- cua2-core/pytest.ini +13 -0
- cua2-core/src/cua2_core/app.py +6 -5
- cua2-core/src/cua2_core/models/models.py +113 -47
- cua2-core/src/cua2_core/routes/routes.py +29 -20
- cua2-core/src/cua2_core/routes/websocket.py +10 -10
- cua2-core/src/cua2_core/services/agent_service.py +315 -106
- cua2-core/src/cua2_core/services/agent_utils/desktop_agent.py +231 -0
- cua2-core/src/cua2_core/services/agent_utils/function_parser.py +560 -0
- cua2-core/src/cua2_core/services/agent_utils/get_model.py +18 -0
- cua2-core/src/cua2_core/services/agent_utils/prompt.py +136 -0
- cua2-core/src/cua2_core/services/sandbox_service.py +90 -0
- cua2-core/src/cua2_core/websocket/websocket_manager.py +62 -62
- cua2-core/tests/__init__.py +1 -0
- cua2-core/tests/test_routes.py +311 -0
- cua2-front/src/components/mock/TaskButton.tsx +4 -4
- cua2-front/src/pages/Index.tsx +9 -11
- cua2-front/src/types/agent.ts +15 -0
.github/workflows/pre-commit.yml
CHANGED
|
@@ -31,4 +31,4 @@ jobs:
|
|
| 31 |
|
| 32 |
- name: Run pre-commit
|
| 33 |
run: |
|
| 34 |
-
|
|
|
|
| 31 |
|
| 32 |
- name: Run pre-commit
|
| 33 |
run: |
|
| 34 |
+
make pre-commit
|
Makefile
CHANGED
|
@@ -23,6 +23,14 @@ dev-frontend:
|
|
| 23 |
|
| 24 |
pre-commit:
|
| 25 |
uv run pre-commit run --all-files --show-diff-on-failure
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
clean:
|
| 28 |
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
|
|
|
| 23 |
|
| 24 |
pre-commit:
|
| 25 |
uv run pre-commit run --all-files --show-diff-on-failure
|
| 26 |
+
make test
|
| 27 |
+
|
| 28 |
+
# Run tests
|
| 29 |
+
test:
|
| 30 |
+
cd cua2-core && uv run pytest tests/ -v
|
| 31 |
+
|
| 32 |
+
test-coverage:
|
| 33 |
+
cd cua2-core && uv run pytest tests/ -v --cov=cua2_core --cov-report=html --cov-report=term
|
| 34 |
|
| 35 |
clean:
|
| 36 |
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
cua2-core/pyproject.toml
CHANGED
|
@@ -33,9 +33,9 @@ dependencies = [
|
|
| 33 |
"httpx>=0.27.1",
|
| 34 |
"asyncio-mqtt==0.16.1",
|
| 35 |
"aiofiles==23.2.1",
|
| 36 |
-
"smolagents[openai,litellm]==1.
|
| 37 |
-
"openai==
|
| 38 |
-
"
|
| 39 |
]
|
| 40 |
|
| 41 |
[project.optional-dependencies]
|
|
|
|
| 33 |
"httpx>=0.27.1",
|
| 34 |
"asyncio-mqtt==0.16.1",
|
| 35 |
"aiofiles==23.2.1",
|
| 36 |
+
"smolagents[openai,litellm]==1.22.0",
|
| 37 |
+
"openai==2.6.1",
|
| 38 |
+
"e2b-desktop==2.1.0",
|
| 39 |
]
|
| 40 |
|
| 41 |
[project.optional-dependencies]
|
cua2-core/pytest.ini
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
testpaths = tests
|
| 3 |
+
python_files = test_*.py
|
| 4 |
+
python_classes = Test*
|
| 5 |
+
python_functions = test_*
|
| 6 |
+
addopts =
|
| 7 |
+
-v
|
| 8 |
+
--strict-markers
|
| 9 |
+
--tb=short
|
| 10 |
+
--disable-warnings
|
| 11 |
+
markers =
|
| 12 |
+
unit: Unit tests
|
| 13 |
+
integration: Integration tests
|
cua2-core/src/cua2_core/app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from contextlib import asynccontextmanager
|
| 2 |
|
| 3 |
from cua2_core.services.agent_service import AgentService
|
|
|
|
| 4 |
from cua2_core.websocket.websocket_manager import WebSocketManager
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
from fastapi import FastAPI
|
|
@@ -16,23 +17,23 @@ async def lifespan(app: FastAPI):
|
|
| 16 |
# Startup: Initialize services
|
| 17 |
print("Initializing services...")
|
| 18 |
|
| 19 |
-
# Initialize WebSocket manager
|
| 20 |
websocket_manager = WebSocketManager()
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
|
|
|
| 24 |
|
| 25 |
# Store services in app state for access in routes
|
| 26 |
app.state.websocket_manager = websocket_manager
|
|
|
|
| 27 |
app.state.agent_service = agent_service
|
| 28 |
|
| 29 |
print("Services initialized successfully")
|
| 30 |
|
| 31 |
yield
|
| 32 |
|
| 33 |
-
# Shutdown: Clean up resources
|
| 34 |
print("Shutting down services...")
|
| 35 |
-
|
| 36 |
print("Services shut down successfully")
|
| 37 |
|
| 38 |
|
|
|
|
| 1 |
from contextlib import asynccontextmanager
|
| 2 |
|
| 3 |
from cua2_core.services.agent_service import AgentService
|
| 4 |
+
from cua2_core.services.sandbox_service import SandboxService
|
| 5 |
from cua2_core.websocket.websocket_manager import WebSocketManager
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
from fastapi import FastAPI
|
|
|
|
| 17 |
# Startup: Initialize services
|
| 18 |
print("Initializing services...")
|
| 19 |
|
|
|
|
| 20 |
websocket_manager = WebSocketManager()
|
| 21 |
|
| 22 |
+
sandbox_service = SandboxService()
|
| 23 |
+
|
| 24 |
+
agent_service = AgentService(websocket_manager, sandbox_service)
|
| 25 |
|
| 26 |
# Store services in app state for access in routes
|
| 27 |
app.state.websocket_manager = websocket_manager
|
| 28 |
+
app.state.sandbox_service = sandbox_service
|
| 29 |
app.state.agent_service = agent_service
|
| 30 |
|
| 31 |
print("Services initialized successfully")
|
| 32 |
|
| 33 |
yield
|
| 34 |
|
|
|
|
| 35 |
print("Shutting down services...")
|
| 36 |
+
await sandbox_service.cleanup_sandboxes()
|
| 37 |
print("Services shut down successfully")
|
| 38 |
|
| 39 |
|
cua2-core/src/cua2_core/models/models.py
CHANGED
|
@@ -1,70 +1,91 @@
|
|
| 1 |
import json
|
| 2 |
import os
|
|
|
|
| 3 |
from datetime import datetime
|
| 4 |
from typing import Annotated, Literal, Optional
|
| 5 |
|
| 6 |
-
from
|
|
|
|
| 7 |
from typing_extensions import TypeAlias
|
| 8 |
|
| 9 |
#################### Backend -> Frontend ########################
|
| 10 |
|
| 11 |
|
| 12 |
-
class AgentAction(
|
| 13 |
"""Agent action structure"""
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
"refresh",
|
| 24 |
-
"go_back",
|
| 25 |
-
]
|
| 26 |
-
actionArguments: dict
|
| 27 |
|
| 28 |
def to_string(self) -> str:
|
| 29 |
"""Convert action to a human-readable string"""
|
| 30 |
-
action_type = self.
|
| 31 |
-
args = self.
|
| 32 |
|
| 33 |
if action_type == "click":
|
| 34 |
-
x = args.get("x"
|
| 35 |
-
y = args.get("y"
|
| 36 |
return f"Click at coordinates ({x}, {y})"
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
elif action_type == "write":
|
| 39 |
-
text = args.get("text"
|
| 40 |
return f"Type text: '{text}'"
|
| 41 |
|
| 42 |
elif action_type == "press":
|
| 43 |
-
key = args.get("key"
|
| 44 |
return f"Press key: {key}"
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
elif action_type == "scroll":
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
| 49 |
return f"Scroll {direction} by {amount}"
|
| 50 |
|
| 51 |
elif action_type == "wait":
|
| 52 |
-
seconds = args.get("seconds"
|
| 53 |
return f"Wait for {seconds} seconds"
|
| 54 |
|
| 55 |
elif action_type == "open":
|
| 56 |
-
|
| 57 |
-
return f"Open: {
|
| 58 |
-
|
| 59 |
-
elif action_type == "launch_app":
|
| 60 |
-
app_name = args.get("app_name", "")
|
| 61 |
-
return f"Launch app: {app_name}"
|
| 62 |
|
| 63 |
-
elif action_type == "
|
| 64 |
-
|
|
|
|
| 65 |
|
| 66 |
-
|
| 67 |
-
return "Go back one page"
|
| 68 |
|
| 69 |
|
| 70 |
class AgentStep(BaseModel):
|
|
@@ -85,10 +106,10 @@ class AgentStep(BaseModel):
|
|
| 85 |
def serialize_actions(self, actions: list[AgentAction], _info):
|
| 86 |
"""Convert actions to list of strings when dumping (controlled by context)"""
|
| 87 |
|
| 88 |
-
if _info.context and _info.context.get("actions_as_json",
|
| 89 |
return [action.model_dump(mode="json") for action in actions]
|
| 90 |
|
| 91 |
-
return [action.
|
| 92 |
|
| 93 |
|
| 94 |
class AgentTraceMetadata(BaseModel):
|
|
@@ -100,6 +121,7 @@ class AgentTraceMetadata(BaseModel):
|
|
| 100 |
duration: float = 0.0 # in seconds
|
| 101 |
numberOfSteps: int = 0
|
| 102 |
maxSteps: int = 0
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
class AgentTrace(BaseModel):
|
|
@@ -204,29 +226,54 @@ class ActiveTask(BaseModel):
|
|
| 204 |
|
| 205 |
message_id: str
|
| 206 |
instruction: str
|
| 207 |
-
|
| 208 |
timestamp: datetime = datetime.now()
|
| 209 |
steps: list[AgentStep] = []
|
| 210 |
traceMetadata: AgentTraceMetadata = AgentTraceMetadata()
|
|
|
|
| 211 |
|
| 212 |
@property
|
| 213 |
def trace_path(self):
|
| 214 |
"""Trace path"""
|
| 215 |
-
return f"data/trace-{self.message_id}-{self.
|
| 216 |
|
| 217 |
@model_validator(mode="after")
|
| 218 |
def store_model(self):
|
| 219 |
"""Validate model ID"""
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
|
| 232 |
class HealthResponse(BaseModel):
|
|
@@ -249,3 +296,22 @@ class ActiveTasksResponse(BaseModel):
|
|
| 249 |
|
| 250 |
active_tasks: dict[str, ActiveTask]
|
| 251 |
total_connections: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
+
import threading
|
| 4 |
from datetime import datetime
|
| 5 |
from typing import Annotated, Literal, Optional
|
| 6 |
|
| 7 |
+
from cua2_core.services.agent_utils.function_parser import FunctionCall
|
| 8 |
+
from pydantic import BaseModel, Field, PrivateAttr, field_serializer, model_validator
|
| 9 |
from typing_extensions import TypeAlias
|
| 10 |
|
| 11 |
#################### Backend -> Frontend ########################
|
| 12 |
|
| 13 |
|
| 14 |
+
class AgentAction(FunctionCall):
|
| 15 |
"""Agent action structure"""
|
| 16 |
|
| 17 |
+
@classmethod
|
| 18 |
+
def from_function_calls(
|
| 19 |
+
cls, function_calls: list[FunctionCall]
|
| 20 |
+
) -> list["AgentAction"]:
|
| 21 |
+
list_of_actions = [cls(**action.model_dump()) for action in function_calls]
|
| 22 |
+
for action in list_of_actions:
|
| 23 |
+
action.description = action.to_string()
|
| 24 |
+
return list_of_actions
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
def to_string(self) -> str:
|
| 27 |
"""Convert action to a human-readable string"""
|
| 28 |
+
action_type = self.function_name
|
| 29 |
+
args = self.parameters
|
| 30 |
|
| 31 |
if action_type == "click":
|
| 32 |
+
x = args.get("x") or args.get("arg_0")
|
| 33 |
+
y = args.get("y") or args.get("arg_1")
|
| 34 |
return f"Click at coordinates ({x}, {y})"
|
| 35 |
|
| 36 |
+
if action_type == "right_click":
|
| 37 |
+
x = args.get("x") or args.get("arg_0")
|
| 38 |
+
y = args.get("y") or args.get("arg_1")
|
| 39 |
+
return f"Right click at coordinates ({x}, {y})"
|
| 40 |
+
|
| 41 |
+
if action_type == "double_click":
|
| 42 |
+
x = args.get("x") or args.get("arg_0")
|
| 43 |
+
y = args.get("y") or args.get("arg_1")
|
| 44 |
+
return f"Right click at coordinates ({x}, {y})"
|
| 45 |
+
|
| 46 |
+
if action_type == "move_mouse":
|
| 47 |
+
x = args.get("x") or args.get("arg_0")
|
| 48 |
+
y = args.get("y") or args.get("arg_1")
|
| 49 |
+
return f"Move mouse to coordinates ({x}, {y})"
|
| 50 |
+
|
| 51 |
elif action_type == "write":
|
| 52 |
+
text = args.get("text") or args.get("arg_0")
|
| 53 |
return f"Type text: '{text}'"
|
| 54 |
|
| 55 |
elif action_type == "press":
|
| 56 |
+
key = args.get("key") or args.get("arg_0")
|
| 57 |
return f"Press key: {key}"
|
| 58 |
|
| 59 |
+
elif action_type == "go_back":
|
| 60 |
+
return "Go back one page"
|
| 61 |
+
|
| 62 |
+
elif action_type == "drag":
|
| 63 |
+
x1 = args.get("x1") or args.get("arg_0")
|
| 64 |
+
y1 = args.get("y1") or args.get("arg_1")
|
| 65 |
+
x2 = args.get("x2") or args.get("arg_2")
|
| 66 |
+
y2 = args.get("y2") or args.get("arg_3")
|
| 67 |
+
return f"Drag from ({x1}, {y1}) to ({x2}, {y2})"
|
| 68 |
+
|
| 69 |
elif action_type == "scroll":
|
| 70 |
+
x = args.get("x") or args.get("arg_0")
|
| 71 |
+
y = args.get("y") or args.get("arg_1")
|
| 72 |
+
direction = args.get("direction") or args.get("arg_2")
|
| 73 |
+
amount = args.get("amount") or args.get("arg_3") or 2
|
| 74 |
return f"Scroll {direction} by {amount}"
|
| 75 |
|
| 76 |
elif action_type == "wait":
|
| 77 |
+
seconds = args.get("seconds") or args.get("arg_0")
|
| 78 |
return f"Wait for {seconds} seconds"
|
| 79 |
|
| 80 |
elif action_type == "open":
|
| 81 |
+
url = args.get("url") or args.get("arg_0")
|
| 82 |
+
return f"Open: {url}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
+
elif action_type == "final_answer":
|
| 85 |
+
answer = args.get("answer") or args.get("arg_0")
|
| 86 |
+
return f"Final answer: {answer}"
|
| 87 |
|
| 88 |
+
return "Unknown action"
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
class AgentStep(BaseModel):
|
|
|
|
| 106 |
def serialize_actions(self, actions: list[AgentAction], _info):
|
| 107 |
"""Convert actions to list of strings when dumping (controlled by context)"""
|
| 108 |
|
| 109 |
+
if _info.context and _info.context.get("actions_as_json", True):
|
| 110 |
return [action.model_dump(mode="json") for action in actions]
|
| 111 |
|
| 112 |
+
return [action.description for action in actions]
|
| 113 |
|
| 114 |
|
| 115 |
class AgentTraceMetadata(BaseModel):
|
|
|
|
| 121 |
duration: float = 0.0 # in seconds
|
| 122 |
numberOfSteps: int = 0
|
| 123 |
maxSteps: int = 0
|
| 124 |
+
completed: bool = False
|
| 125 |
|
| 126 |
|
| 127 |
class AgentTrace(BaseModel):
|
|
|
|
| 226 |
|
| 227 |
message_id: str
|
| 228 |
instruction: str
|
| 229 |
+
model_id: str
|
| 230 |
timestamp: datetime = datetime.now()
|
| 231 |
steps: list[AgentStep] = []
|
| 232 |
traceMetadata: AgentTraceMetadata = AgentTraceMetadata()
|
| 233 |
+
_file_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
| 234 |
|
| 235 |
@property
|
| 236 |
def trace_path(self):
|
| 237 |
"""Trace path"""
|
| 238 |
+
return f"data/trace-{self.message_id}-{self.model_id.replace('/', '-')}"
|
| 239 |
|
| 240 |
@model_validator(mode="after")
|
| 241 |
def store_model(self):
|
| 242 |
"""Validate model ID"""
|
| 243 |
+
with self._file_lock:
|
| 244 |
+
os.makedirs(self.trace_path, exist_ok=True)
|
| 245 |
+
with open(f"{self.trace_path}/tasks.json", "w") as f:
|
| 246 |
+
json.dump(
|
| 247 |
+
self.model_dump(
|
| 248 |
+
mode="json",
|
| 249 |
+
exclude={"_file_locks"},
|
| 250 |
+
context={"actions_as_json": True},
|
| 251 |
+
),
|
| 252 |
+
f,
|
| 253 |
+
indent=2,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
def update_step(self, step: AgentStep):
|
| 257 |
+
"""Update step"""
|
| 258 |
+
with self._file_lock:
|
| 259 |
+
if int(step.stepId) <= len(self.steps):
|
| 260 |
+
self.steps[int(step.stepId) - 1] = step
|
| 261 |
+
else:
|
| 262 |
+
self.steps.append(step)
|
| 263 |
+
self.traceMetadata.numberOfSteps = len(self.steps)
|
| 264 |
+
with open(f"{self.trace_path}/tasks.json", "w") as f:
|
| 265 |
+
json.dump(
|
| 266 |
+
self.model_dump(
|
| 267 |
+
mode="json",
|
| 268 |
+
exclude={"_file_locks"},
|
| 269 |
+
context={"actions_as_json": True},
|
| 270 |
+
),
|
| 271 |
+
f,
|
| 272 |
+
indent=2,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
#################### API Routes Models ########################
|
| 277 |
|
| 278 |
|
| 279 |
class HealthResponse(BaseModel):
|
|
|
|
| 296 |
|
| 297 |
active_tasks: dict[str, ActiveTask]
|
| 298 |
total_connections: int
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class UpdateStepRequest(BaseModel):
|
| 302 |
+
"""Request model for updating a step"""
|
| 303 |
+
|
| 304 |
+
step_evaluation: Literal["like", "dislike", "neutral"]
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class UpdateStepResponse(BaseModel):
|
| 308 |
+
"""Response model for step update"""
|
| 309 |
+
|
| 310 |
+
success: bool
|
| 311 |
+
message: str
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class AvailableModelsResponse(BaseModel):
|
| 315 |
+
"""Response for available models"""
|
| 316 |
+
|
| 317 |
+
models: list[str]
|
cua2-core/src/cua2_core/routes/routes.py
CHANGED
|
@@ -2,11 +2,13 @@ from datetime import datetime
|
|
| 2 |
|
| 3 |
# Get services from app state
|
| 4 |
from cua2_core.models.models import (
|
| 5 |
-
|
| 6 |
HealthResponse,
|
| 7 |
-
|
|
|
|
| 8 |
)
|
| 9 |
from cua2_core.services.agent_service import AgentService
|
|
|
|
| 10 |
from cua2_core.websocket.websocket_manager import WebSocketManager
|
| 11 |
from fastapi import APIRouter, Depends, HTTPException, Request
|
| 12 |
|
|
@@ -36,24 +38,31 @@ async def health_check(
|
|
| 36 |
)
|
| 37 |
|
| 38 |
|
| 39 |
-
@router.get("/
|
| 40 |
-
async def
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
):
|
| 44 |
-
"""Get currently active tasks"""
|
| 45 |
-
return ActiveTasksResponse(
|
| 46 |
-
active_tasks=agent_service.get_active_tasks(),
|
| 47 |
-
total_connections=websocket_manager.get_connection_count(),
|
| 48 |
-
)
|
| 49 |
|
| 50 |
|
| 51 |
-
@router.
|
| 52 |
-
async def
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
| 54 |
):
|
| 55 |
-
"""
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
# Get services from app state
|
| 4 |
from cua2_core.models.models import (
|
| 5 |
+
AvailableModelsResponse,
|
| 6 |
HealthResponse,
|
| 7 |
+
UpdateStepRequest,
|
| 8 |
+
UpdateStepResponse,
|
| 9 |
)
|
| 10 |
from cua2_core.services.agent_service import AgentService
|
| 11 |
+
from cua2_core.services.agent_utils.get_model import AVAILABLE_MODELS
|
| 12 |
from cua2_core.websocket.websocket_manager import WebSocketManager
|
| 13 |
from fastapi import APIRouter, Depends, HTTPException, Request
|
| 14 |
|
|
|
|
| 38 |
)
|
| 39 |
|
| 40 |
|
| 41 |
+
@router.get("/models", response_model=AvailableModelsResponse)
|
| 42 |
+
async def get_available_models():
|
| 43 |
+
"""Get list of all available model IDs"""
|
| 44 |
+
return AvailableModelsResponse(models=AVAILABLE_MODELS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
+
@router.patch("/traces/{trace_id}/steps/{step_id}", response_model=UpdateStepResponse)
|
| 48 |
+
async def update_trace_step(
|
| 49 |
+
trace_id: str,
|
| 50 |
+
step_id: str,
|
| 51 |
+
request: UpdateStepRequest,
|
| 52 |
+
agent_service: AgentService = Depends(get_agent_service),
|
| 53 |
):
|
| 54 |
+
"""Update a specific step in a trace (e.g., update step evaluation)"""
|
| 55 |
+
try:
|
| 56 |
+
agent_service.update_trace_step(
|
| 57 |
+
trace_id=trace_id,
|
| 58 |
+
step_id=step_id,
|
| 59 |
+
step_evaluation=request.step_evaluation,
|
| 60 |
+
)
|
| 61 |
+
return UpdateStepResponse(
|
| 62 |
+
success=True,
|
| 63 |
+
message="Step updated successfully",
|
| 64 |
+
)
|
| 65 |
+
except ValueError as e:
|
| 66 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 67 |
+
except FileNotFoundError as e:
|
| 68 |
+
raise HTTPException(status_code=404, detail=str(e))
|
cua2-core/src/cua2_core/routes/websocket.py
CHANGED
|
@@ -3,6 +3,8 @@ import json
|
|
| 3 |
# Get services from app state
|
| 4 |
from cua2_core.app import app
|
| 5 |
from cua2_core.models.models import AgentTrace, HeartbeatEvent
|
|
|
|
|
|
|
| 6 |
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
| 7 |
|
| 8 |
# Create router
|
|
@@ -13,15 +15,15 @@ router = APIRouter()
|
|
| 13 |
async def websocket_endpoint(websocket: WebSocket):
|
| 14 |
"""WebSocket endpoint for real-time communication"""
|
| 15 |
|
| 16 |
-
websocket_manager = app.state.websocket_manager
|
| 17 |
-
agent_service = app.state.agent_service
|
| 18 |
|
| 19 |
await websocket_manager.connect(websocket)
|
| 20 |
|
| 21 |
try:
|
| 22 |
# Send welcome heartbeat
|
| 23 |
welcome_message = HeartbeatEvent(type="heartbeat")
|
| 24 |
-
await websocket_manager.
|
| 25 |
|
| 26 |
# Keep the connection alive and wait for messages
|
| 27 |
while True:
|
|
@@ -50,7 +52,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 50 |
trace = AgentTrace(**trace_data)
|
| 51 |
|
| 52 |
# Process the user task with the trace
|
| 53 |
-
trace_id = await agent_service.process_user_task(
|
|
|
|
|
|
|
| 54 |
print(f"Started processing trace: {trace_id}")
|
| 55 |
else:
|
| 56 |
print("No trace data in message")
|
|
@@ -62,9 +66,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 62 |
error_response = AgentErrorEvent(
|
| 63 |
type="agent_error", error="Invalid JSON format"
|
| 64 |
)
|
| 65 |
-
await websocket_manager.
|
| 66 |
-
error_response, websocket
|
| 67 |
-
)
|
| 68 |
|
| 69 |
except Exception as e:
|
| 70 |
print(f"Error processing message: {e}")
|
|
@@ -76,9 +78,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 76 |
error_response = AgentErrorEvent(
|
| 77 |
type="agent_error", error=f"Error processing message: {str(e)}"
|
| 78 |
)
|
| 79 |
-
await websocket_manager.
|
| 80 |
-
error_response, websocket
|
| 81 |
-
)
|
| 82 |
|
| 83 |
except Exception as e:
|
| 84 |
print(f"Error receiving WebSocket message: {e}")
|
|
|
|
| 3 |
# Get services from app state
|
| 4 |
from cua2_core.app import app
|
| 5 |
from cua2_core.models.models import AgentTrace, HeartbeatEvent
|
| 6 |
+
from cua2_core.services.agent_service import AgentService
|
| 7 |
+
from cua2_core.websocket.websocket_manager import WebSocketManager
|
| 8 |
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
| 9 |
|
| 10 |
# Create router
|
|
|
|
| 15 |
async def websocket_endpoint(websocket: WebSocket):
|
| 16 |
"""WebSocket endpoint for real-time communication"""
|
| 17 |
|
| 18 |
+
websocket_manager: WebSocketManager = app.state.websocket_manager
|
| 19 |
+
agent_service: AgentService = app.state.agent_service
|
| 20 |
|
| 21 |
await websocket_manager.connect(websocket)
|
| 22 |
|
| 23 |
try:
|
| 24 |
# Send welcome heartbeat
|
| 25 |
welcome_message = HeartbeatEvent(type="heartbeat")
|
| 26 |
+
await websocket_manager.send_message(welcome_message, websocket)
|
| 27 |
|
| 28 |
# Keep the connection alive and wait for messages
|
| 29 |
while True:
|
|
|
|
| 52 |
trace = AgentTrace(**trace_data)
|
| 53 |
|
| 54 |
# Process the user task with the trace
|
| 55 |
+
trace_id = await agent_service.process_user_task(
|
| 56 |
+
trace, websocket
|
| 57 |
+
)
|
| 58 |
print(f"Started processing trace: {trace_id}")
|
| 59 |
else:
|
| 60 |
print("No trace data in message")
|
|
|
|
| 66 |
error_response = AgentErrorEvent(
|
| 67 |
type="agent_error", error="Invalid JSON format"
|
| 68 |
)
|
| 69 |
+
await websocket_manager.send_message(error_response, websocket)
|
|
|
|
|
|
|
| 70 |
|
| 71 |
except Exception as e:
|
| 72 |
print(f"Error processing message: {e}")
|
|
|
|
| 78 |
error_response = AgentErrorEvent(
|
| 79 |
type="agent_error", error=f"Error processing message: {str(e)}"
|
| 80 |
)
|
| 81 |
+
await websocket_manager.send_message(error_response, websocket)
|
|
|
|
|
|
|
| 82 |
|
| 83 |
except Exception as e:
|
| 84 |
print(f"Error receiving WebSocket message: {e}")
|
cua2-core/src/cua2_core/services/agent_service.py
CHANGED
|
@@ -1,39 +1,45 @@
|
|
| 1 |
import asyncio
|
| 2 |
import base64
|
| 3 |
import json
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from cua2_core.models.models import (
|
| 8 |
ActiveTask,
|
| 9 |
AgentAction,
|
| 10 |
-
AgentCompleteEvent,
|
| 11 |
-
AgentErrorEvent,
|
| 12 |
-
AgentProgressEvent,
|
| 13 |
-
AgentStartEvent,
|
| 14 |
AgentStep,
|
| 15 |
AgentTrace,
|
| 16 |
AgentTraceMetadata,
|
| 17 |
-
VncUrlSetEvent,
|
| 18 |
-
VncUrlUnsetEvent,
|
| 19 |
)
|
| 20 |
-
from cua2_core.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
class AgentService:
|
| 24 |
"""Service for handling agent tasks and processing"""
|
| 25 |
|
| 26 |
-
def __init__(
|
|
|
|
|
|
|
| 27 |
self.active_tasks: dict[str, ActiveTask] = {}
|
| 28 |
self.websocket_manager: WebSocketManager = websocket_manager
|
| 29 |
-
self.
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
self.simulation_images_path = (
|
| 33 |
-
Path(__file__).parent / "simulation_metadata" / "images"
|
| 34 |
-
)
|
| 35 |
|
| 36 |
-
async def process_user_task(self, trace: AgentTrace) -> str:
|
| 37 |
"""Process a user task and return the trace ID"""
|
| 38 |
|
| 39 |
trace_id = trace.id
|
|
@@ -44,123 +50,326 @@ class AgentService:
|
|
| 44 |
self.active_tasks[trace_id] = ActiveTask(
|
| 45 |
message_id=trace_id,
|
| 46 |
instruction=trace.instruction,
|
| 47 |
-
|
| 48 |
timestamp=trace.timestamp,
|
| 49 |
steps=trace.steps,
|
| 50 |
traceMetadata=trace.traceMetadata,
|
| 51 |
)
|
| 52 |
|
| 53 |
-
#
|
| 54 |
-
|
|
|
|
|
|
|
| 55 |
|
| 56 |
return trace_id
|
| 57 |
|
| 58 |
-
async def
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
try:
|
| 63 |
-
#
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
|
| 68 |
-
start_event = AgentStartEvent(type="agent_start", agentTrace=trace)
|
| 69 |
-
await self.websocket_manager.broadcast(start_event)
|
| 70 |
|
| 71 |
-
#
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
|
| 76 |
-
|
|
|
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
| 82 |
|
| 83 |
-
|
| 84 |
-
image_path = (
|
| 85 |
-
self.simulation_images_path / step_data["image"].split("/")[-1]
|
| 86 |
-
)
|
| 87 |
-
with open(image_path, "rb") as img_file:
|
| 88 |
-
image_bytes = img_file.read()
|
| 89 |
-
image_base64 = f"data:image/png;base64,{base64.b64encode(image_bytes).decode('utf-8')}"
|
| 90 |
-
|
| 91 |
-
# Convert actions to AgentAction objects
|
| 92 |
-
actions = [
|
| 93 |
-
AgentAction(
|
| 94 |
-
actionType=action["actionType"],
|
| 95 |
-
actionArguments=action["actionArguments"],
|
| 96 |
-
)
|
| 97 |
-
for action in step_data["actions"]
|
| 98 |
-
]
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
actions=actions,
|
| 107 |
-
error="",
|
| 108 |
-
duration=step_data["duration"],
|
| 109 |
-
inputTokensUsed=step_data["inputTokensUsed"],
|
| 110 |
-
outputTokensUsed=step_data["outputTokensUsed"],
|
| 111 |
-
step_evaluation=step_data["step_evaluation"],
|
| 112 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
| 126 |
|
| 127 |
-
|
| 128 |
-
|
|
|
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
await self.websocket_manager.broadcast(vnc_unset_event)
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
# Send completion event
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
)
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
-
#
|
| 144 |
-
|
| 145 |
-
if trace_id in self.active_tasks:
|
| 146 |
-
del self.active_tasks[trace_id]
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
| 153 |
)
|
| 154 |
-
|
| 155 |
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import base64
|
| 3 |
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
from typing import Callable, Literal
|
| 9 |
|
| 10 |
from cua2_core.models.models import (
|
| 11 |
ActiveTask,
|
| 12 |
AgentAction,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
AgentStep,
|
| 14 |
AgentTrace,
|
| 15 |
AgentTraceMetadata,
|
|
|
|
|
|
|
| 16 |
)
|
| 17 |
+
from cua2_core.services.agent_utils.desktop_agent import E2BVisionAgent
|
| 18 |
+
from cua2_core.services.agent_utils.function_parser import parse_function_call
|
| 19 |
+
from cua2_core.services.agent_utils.get_model import get_model
|
| 20 |
+
from cua2_core.services.sandbox_service import SandboxService
|
| 21 |
+
from cua2_core.websocket.websocket_manager import WebSocketException, WebSocketManager
|
| 22 |
+
from e2b_desktop import Sandbox
|
| 23 |
+
from fastapi import WebSocket
|
| 24 |
+
from PIL import Image
|
| 25 |
+
from smolagents import ActionStep, AgentImage, AgentMaxStepsError, TaskStep
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
|
| 29 |
|
| 30 |
class AgentService:
|
| 31 |
"""Service for handling agent tasks and processing"""
|
| 32 |
|
| 33 |
+
def __init__(
|
| 34 |
+
self, websocket_manager: WebSocketManager, sandbox_service: SandboxService
|
| 35 |
+
):
|
| 36 |
self.active_tasks: dict[str, ActiveTask] = {}
|
| 37 |
self.websocket_manager: WebSocketManager = websocket_manager
|
| 38 |
+
self.task_websockets: dict[str, WebSocket] = {}
|
| 39 |
+
self.sandbox_service: SandboxService = sandbox_service
|
| 40 |
+
self.last_screenshot: dict[str, AgentImage] = {}
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
async def process_user_task(self, trace: AgentTrace, websocket: WebSocket) -> str:
|
| 43 |
"""Process a user task and return the trace ID"""
|
| 44 |
|
| 45 |
trace_id = trace.id
|
|
|
|
| 50 |
self.active_tasks[trace_id] = ActiveTask(
|
| 51 |
message_id=trace_id,
|
| 52 |
instruction=trace.instruction,
|
| 53 |
+
model_id=trace.modelId,
|
| 54 |
timestamp=trace.timestamp,
|
| 55 |
steps=trace.steps,
|
| 56 |
traceMetadata=trace.traceMetadata,
|
| 57 |
)
|
| 58 |
|
| 59 |
+
# Store the websocket for this task
|
| 60 |
+
self.task_websockets[trace_id] = websocket
|
| 61 |
+
|
| 62 |
+
asyncio.create_task(self._agent_processing(trace_id))
|
| 63 |
|
| 64 |
return trace_id
|
| 65 |
|
| 66 |
+
async def _agent_runner(
|
| 67 |
+
self,
|
| 68 |
+
message_id: str,
|
| 69 |
+
step_callback: Callable[[ActionStep, E2BVisionAgent], None],
|
| 70 |
+
):
|
| 71 |
+
"""Run the task with the appropriate agent"""
|
| 72 |
+
|
| 73 |
+
sandbox: Sandbox | None = None
|
| 74 |
+
agent = None
|
| 75 |
+
novnc_active = False
|
| 76 |
+
websocket_exception = False
|
| 77 |
|
| 78 |
try:
|
| 79 |
+
# Get the websocket for this task
|
| 80 |
+
websocket = self.task_websockets.get(message_id)
|
| 81 |
+
|
| 82 |
+
await self.websocket_manager.send_agent_start(
|
| 83 |
+
active_task=self.active_tasks[message_id], websocket=websocket
|
| 84 |
+
)
|
| 85 |
|
| 86 |
+
model = get_model(self.active_tasks[message_id].model_id)
|
|
|
|
|
|
|
| 87 |
|
| 88 |
+
# Acquire a sandbox from the pool
|
| 89 |
+
sandbox = await self.sandbox_service.acquire_sandbox(message_id)
|
| 90 |
+
if sandbox is None:
|
| 91 |
+
raise Exception("No sandbox available: pool limit reached")
|
| 92 |
|
| 93 |
+
data_dir = self.active_tasks[message_id].trace_path
|
| 94 |
+
user_content = self.active_tasks[message_id].instruction
|
| 95 |
|
| 96 |
+
agent = E2BVisionAgent(
|
| 97 |
+
model=model,
|
| 98 |
+
data_dir=data_dir,
|
| 99 |
+
desktop=sandbox,
|
| 100 |
+
step_callbacks=[step_callback],
|
| 101 |
+
)
|
| 102 |
|
| 103 |
+
self.active_tasks[message_id].traceMetadata.maxSteps = agent.max_steps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
+
await self.websocket_manager.send_vnc_url_set(
|
| 106 |
+
vnc_url=sandbox.stream.get_url(
|
| 107 |
+
auto_connect=True,
|
| 108 |
+
view_only=True,
|
| 109 |
+
resize="scale",
|
| 110 |
+
auth_key=sandbox.stream.get_auth_key(),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
)
|
| 112 |
+
or "",
|
| 113 |
+
websocket=websocket,
|
| 114 |
+
)
|
| 115 |
+
novnc_active = True
|
| 116 |
|
| 117 |
+
step_filename = f"{message_id}-1"
|
| 118 |
+
screenshot_bytes = agent.desktop.screenshot()
|
| 119 |
+
image = Image.open(BytesIO(screenshot_bytes))
|
| 120 |
+
screenshot_path = os.path.join(agent.data_dir, f"{step_filename}.png")
|
| 121 |
+
image.save(screenshot_path)
|
| 122 |
|
| 123 |
+
self.last_screenshot[message_id] = image
|
| 124 |
+
|
| 125 |
+
await asyncio.to_thread(
|
| 126 |
+
agent.run,
|
| 127 |
+
user_content,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
self.active_tasks[message_id].traceMetadata.completed = True
|
| 131 |
|
| 132 |
+
except WebSocketException:
|
| 133 |
+
websocket_exception = True
|
| 134 |
+
pass
|
| 135 |
|
| 136 |
+
except (Exception, KeyboardInterrupt):
|
| 137 |
+
import traceback
|
|
|
|
| 138 |
|
| 139 |
+
logger.error(
|
| 140 |
+
f"Error processing task: {traceback.format_exc()}", exc_info=True
|
| 141 |
+
)
|
| 142 |
+
await self.websocket_manager.send_agent_error(
|
| 143 |
+
error="Error processing task", websocket=websocket
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
finally:
|
| 147 |
# Send completion event
|
| 148 |
+
if not websocket_exception:
|
| 149 |
+
await self.websocket_manager.send_agent_complete(
|
| 150 |
+
metadata=self.active_tasks[message_id].traceMetadata,
|
| 151 |
+
websocket=websocket,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if novnc_active:
|
| 155 |
+
await self.websocket_manager.send_vnc_url_unset(websocket=websocket)
|
| 156 |
+
|
| 157 |
+
novnc_active = False
|
| 158 |
+
|
| 159 |
+
# Clean up
|
| 160 |
+
if message_id in self.active_tasks:
|
| 161 |
+
self.active_tasks[message_id].store_model()
|
| 162 |
+
del self.active_tasks[message_id]
|
| 163 |
+
|
| 164 |
+
# Clean up websocket reference
|
| 165 |
+
if message_id in self.task_websockets:
|
| 166 |
+
del self.task_websockets[message_id]
|
| 167 |
+
|
| 168 |
+
if message_id in self.last_screenshot:
|
| 169 |
+
del self.last_screenshot[message_id]
|
| 170 |
+
|
| 171 |
+
# Release sandbox back to the pool
|
| 172 |
+
if sandbox:
|
| 173 |
+
await self.sandbox_service.release_sandbox(sandbox)
|
| 174 |
+
|
| 175 |
+
async def _agent_processing(
|
| 176 |
+
self,
|
| 177 |
+
message_id: str,
|
| 178 |
+
):
|
| 179 |
+
"""Process the user task with the appropriate agent"""
|
| 180 |
+
|
| 181 |
+
# Set up log file for this task
|
| 182 |
+
active_task = self.active_tasks[message_id]
|
| 183 |
+
|
| 184 |
+
# Ensure the directory exists
|
| 185 |
+
os.makedirs(active_task.trace_path, exist_ok=True)
|
| 186 |
+
|
| 187 |
+
# Capture the event loop reference in the async context
|
| 188 |
+
# This will be used in the callback to safely schedule coroutines from the worker thread
|
| 189 |
+
loop = asyncio.get_running_loop()
|
| 190 |
+
|
| 191 |
+
def step_callback(memory_step: ActionStep, agent: E2BVisionAgent):
|
| 192 |
+
assert memory_step.step_number is not None
|
| 193 |
+
|
| 194 |
+
time.sleep(3)
|
| 195 |
+
|
| 196 |
+
if message_id in self.last_screenshot:
|
| 197 |
+
memory_step.observations_images = [
|
| 198 |
+
self.last_screenshot[message_id].copy()
|
| 199 |
+
]
|
| 200 |
+
else:
|
| 201 |
+
image = self.last_screenshot[message_id]
|
| 202 |
+
# agent.last_marked_screenshot = AgentImage(screenshot_path)
|
| 203 |
+
|
| 204 |
+
for previous_memory_step in (
|
| 205 |
+
agent.memory.steps
|
| 206 |
+
): # Remove previous screenshots from logs for lean processing
|
| 207 |
+
if (
|
| 208 |
+
isinstance(previous_memory_step, ActionStep)
|
| 209 |
+
and previous_memory_step.step_number is not None
|
| 210 |
+
and previous_memory_step.step_number
|
| 211 |
+
<= memory_step.step_number - 1
|
| 212 |
+
):
|
| 213 |
+
previous_memory_step.observations_images = None
|
| 214 |
+
elif isinstance(previous_memory_step, TaskStep):
|
| 215 |
+
previous_memory_step.task_images = None
|
| 216 |
+
|
| 217 |
+
memory_step.observations_images = [image.copy()]
|
| 218 |
+
|
| 219 |
+
model_output = (
|
| 220 |
+
memory_step.model_output_message.content
|
| 221 |
+
if memory_step.model_output_message
|
| 222 |
+
else None
|
| 223 |
+
)
|
| 224 |
+
if model_output is None and isinstance(
|
| 225 |
+
memory_step.error, AgentMaxStepsError
|
| 226 |
+
):
|
| 227 |
+
model_output = memory_step.action_output
|
| 228 |
+
|
| 229 |
+
thought = (
|
| 230 |
+
model_output.split("```")[0].replace("\nAction:\n", "")
|
| 231 |
+
if model_output
|
| 232 |
+
and (
|
| 233 |
+
memory_step.error is None
|
| 234 |
+
or isinstance(memory_step.error, AgentMaxStepsError)
|
| 235 |
+
)
|
| 236 |
+
else None
|
| 237 |
+
)
|
| 238 |
+
action_sequence = (
|
| 239 |
+
model_output.split("```")[1]
|
| 240 |
+
if model_output and memory_step.error is None
|
| 241 |
+
else None
|
| 242 |
)
|
| 243 |
+
if memory_step.observations_images:
|
| 244 |
+
image = memory_step.observations_images[0]
|
| 245 |
+
buffered = BytesIO()
|
| 246 |
+
image.save(buffered, format="PNG")
|
| 247 |
+
image_base64 = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode('utf-8')}"
|
| 248 |
+
del buffered
|
| 249 |
+
del image
|
| 250 |
+
else:
|
| 251 |
+
image_base64 = None
|
| 252 |
|
| 253 |
+
step = AgentStep(
|
| 254 |
+
traceId=message_id,
|
| 255 |
+
stepId=str(memory_step.step_number),
|
| 256 |
+
image=image_base64,
|
| 257 |
+
thought=thought,
|
| 258 |
+
actions=AgentAction.from_function_calls(
|
| 259 |
+
parse_function_call(action_sequence)
|
| 260 |
+
)
|
| 261 |
+
if action_sequence
|
| 262 |
+
else None,
|
| 263 |
+
error=memory_step.error.message if memory_step.error else None,
|
| 264 |
+
duration=memory_step.timing.duration,
|
| 265 |
+
inputTokensUsed=memory_step.token_usage.input_tokens,
|
| 266 |
+
outputTokensUsed=memory_step.token_usage.output_tokens,
|
| 267 |
+
step_evaluation="neutral",
|
| 268 |
+
)
|
| 269 |
+
self.active_tasks[
|
| 270 |
+
message_id
|
| 271 |
+
].traceMetadata.inputTokensUsed += memory_step.token_usage.input_tokens
|
| 272 |
+
self.active_tasks[
|
| 273 |
+
message_id
|
| 274 |
+
].traceMetadata.outputTokensUsed += memory_step.token_usage.output_tokens
|
| 275 |
+
self.active_tasks[message_id].traceMetadata.numberOfSteps += 1
|
| 276 |
+
self.active_tasks[
|
| 277 |
+
message_id
|
| 278 |
+
].traceMetadata.duration += memory_step.timing.duration
|
| 279 |
|
| 280 |
+
# Add step to active task
|
| 281 |
+
self.active_tasks[message_id].update_step(step)
|
|
|
|
|
|
|
| 282 |
|
| 283 |
+
websocket = self.task_websockets.get(message_id)
|
| 284 |
+
future = asyncio.run_coroutine_threadsafe(
|
| 285 |
+
self.websocket_manager.send_agent_progress(
|
| 286 |
+
step=step,
|
| 287 |
+
metadata=self.active_tasks[message_id].traceMetadata,
|
| 288 |
+
websocket=websocket,
|
| 289 |
+
),
|
| 290 |
+
loop,
|
| 291 |
)
|
| 292 |
+
future.result()
|
| 293 |
|
| 294 |
+
step_filename = f"{message_id}-{memory_step.step_number}"
|
| 295 |
+
screenshot_bytes = agent.desktop.screenshot()
|
| 296 |
+
image = Image.open(BytesIO(screenshot_bytes))
|
| 297 |
+
screenshot_path = os.path.join(agent.data_dir, f"{step_filename}.png")
|
| 298 |
+
image.save(screenshot_path)
|
| 299 |
+
del self.last_screenshot[message_id]
|
| 300 |
+
self.last_screenshot[message_id] = image
|
| 301 |
+
|
| 302 |
+
await self._agent_runner(message_id, step_callback)
|
| 303 |
+
|
| 304 |
+
def update_trace_step(
|
| 305 |
+
self,
|
| 306 |
+
trace_id: str,
|
| 307 |
+
step_id: str,
|
| 308 |
+
step_evaluation: Literal["like", "dislike", "neutral"],
|
| 309 |
+
):
|
| 310 |
+
"""
|
| 311 |
+
Update a specific step in a trace (e.g., update step evaluation)
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
trace_id: The trace ID
|
| 315 |
+
step_id: The step ID (1-indexed)
|
| 316 |
+
step_evaluation: The evaluation value to set
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
The updated AgentStep
|
| 320 |
+
|
| 321 |
+
Raises:
|
| 322 |
+
ValueError: If step_id is invalid or step not found
|
| 323 |
+
FileNotFoundError: If trace not found
|
| 324 |
+
"""
|
| 325 |
+
# Try to find in active tasks first
|
| 326 |
+
active_task = self.active_tasks.get(trace_id)
|
| 327 |
+
|
| 328 |
+
if active_task:
|
| 329 |
+
# Task is still active
|
| 330 |
+
try:
|
| 331 |
+
step_index = int(step_id) - 1
|
| 332 |
+
if 0 <= step_index < len(active_task.steps):
|
| 333 |
+
active_task.steps[step_index].step_evaluation = step_evaluation
|
| 334 |
+
active_task.update_step(active_task.steps[step_index])
|
| 335 |
+
else:
|
| 336 |
+
raise ValueError(f"Step {step_id} not found in trace")
|
| 337 |
+
except (ValueError, TypeError) as e:
|
| 338 |
+
raise ValueError(f"Invalid step_id format: {e}")
|
| 339 |
+
else:
|
| 340 |
+
# Task is not active, try to load from file
|
| 341 |
+
data_dir = "data"
|
| 342 |
+
trace_dirs = [
|
| 343 |
+
d for d in os.listdir(data_dir) if d.startswith(f"trace-{trace_id}")
|
| 344 |
+
]
|
| 345 |
+
|
| 346 |
+
if not trace_dirs:
|
| 347 |
+
raise FileNotFoundError("Trace not found")
|
| 348 |
+
|
| 349 |
+
trace_path = os.path.join(data_dir, trace_dirs[0])
|
| 350 |
+
tasks_file = os.path.join(trace_path, "tasks.json")
|
| 351 |
+
|
| 352 |
+
if not os.path.exists(tasks_file):
|
| 353 |
+
raise FileNotFoundError("Trace data not found")
|
| 354 |
+
|
| 355 |
+
try:
|
| 356 |
+
# Load the trace data
|
| 357 |
+
with open(tasks_file, "r") as f:
|
| 358 |
+
task_data = json.load(f)
|
| 359 |
+
|
| 360 |
+
# Find and update the step
|
| 361 |
+
step_index = int(step_id) - 1
|
| 362 |
+
if 0 <= step_index < len(task_data["steps"]):
|
| 363 |
+
task_data["steps"][step_index]["step_evaluation"] = step_evaluation
|
| 364 |
|
| 365 |
+
# Save the updated data
|
| 366 |
+
with open(tasks_file, "w") as f:
|
| 367 |
+
json.dump(task_data, f, indent=2)
|
| 368 |
|
| 369 |
+
# Convert to AgentStep for response
|
| 370 |
+
updated_step = AgentStep(**task_data["steps"][step_index])
|
| 371 |
+
return updated_step
|
| 372 |
+
else:
|
| 373 |
+
raise ValueError(f"Step {step_id} not found in trace")
|
| 374 |
+
except (ValueError, KeyError, TypeError) as e:
|
| 375 |
+
raise ValueError(f"Error processing step update: {e}")
|
cua2-core/src/cua2_core/services/agent_utils/desktop_agent.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import unicodedata
|
| 4 |
+
|
| 5 |
+
from cua2_core.services.agent_utils.prompt import E2B_SYSTEM_PROMPT_TEMPLATE
|
| 6 |
+
|
| 7 |
+
# E2B imports
|
| 8 |
+
from e2b_desktop import Sandbox
|
| 9 |
+
|
| 10 |
+
# SmolaAgents imports
|
| 11 |
+
from smolagents import CodeAgent, Model, tool
|
| 12 |
+
from smolagents.monitoring import LogLevel
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class E2BVisionAgent(CodeAgent):
|
| 16 |
+
"""Agent for e2b desktop automation with Qwen2.5VL vision capabilities"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
model: Model,
|
| 21 |
+
data_dir: str,
|
| 22 |
+
desktop: Sandbox,
|
| 23 |
+
max_steps: int = 200,
|
| 24 |
+
verbosity_level: LogLevel = 2,
|
| 25 |
+
planning_interval: int | None = None,
|
| 26 |
+
use_v1_prompt: bool = False,
|
| 27 |
+
**kwargs,
|
| 28 |
+
):
|
| 29 |
+
self.desktop = desktop
|
| 30 |
+
self.data_dir = data_dir
|
| 31 |
+
self.planning_interval = planning_interval
|
| 32 |
+
# Initialize Desktop
|
| 33 |
+
self.width, self.height = self.desktop.get_screen_size()
|
| 34 |
+
print(f"Screen size: {self.width}x{self.height}")
|
| 35 |
+
|
| 36 |
+
# Set up temp directory
|
| 37 |
+
os.makedirs(self.data_dir, exist_ok=True)
|
| 38 |
+
print(f"Screenshots and steps will be saved to: {self.data_dir}")
|
| 39 |
+
|
| 40 |
+
self.use_v1_prompt = use_v1_prompt
|
| 41 |
+
# Initialize base agent
|
| 42 |
+
super().__init__(
|
| 43 |
+
tools=[],
|
| 44 |
+
model=model,
|
| 45 |
+
max_steps=max_steps,
|
| 46 |
+
verbosity_level=verbosity_level,
|
| 47 |
+
planning_interval=self.planning_interval,
|
| 48 |
+
stream_outputs=True,
|
| 49 |
+
**kwargs,
|
| 50 |
+
)
|
| 51 |
+
self.prompt_templates["system_prompt"] = E2B_SYSTEM_PROMPT_TEMPLATE.replace(
|
| 52 |
+
"<<resolution_x>>", str(self.width)
|
| 53 |
+
).replace("<<resolution_y>>", str(self.height))
|
| 54 |
+
|
| 55 |
+
# Add screen info to state
|
| 56 |
+
self.state["screen_width"] = self.width
|
| 57 |
+
self.state["screen_height"] = self.height
|
| 58 |
+
|
| 59 |
+
# Add default tools
|
| 60 |
+
self.logger.log("Setting up agent tools...")
|
| 61 |
+
self._setup_desktop_tools()
|
| 62 |
+
|
| 63 |
+
def _setup_desktop_tools(self):
|
| 64 |
+
"""Register all desktop tools"""
|
| 65 |
+
|
| 66 |
+
@tool
|
| 67 |
+
def click(x: int, y: int) -> str:
|
| 68 |
+
"""
|
| 69 |
+
Performs a left-click at the specified coordinates
|
| 70 |
+
Args:
|
| 71 |
+
x: The x coordinate (horizontal position)
|
| 72 |
+
y: The y coordinate (vertical position)
|
| 73 |
+
"""
|
| 74 |
+
self.desktop.move_mouse(x, y)
|
| 75 |
+
self.desktop.left_click()
|
| 76 |
+
self.click_coordinates = [x, y]
|
| 77 |
+
self.logger.log(f"Clicked at coordinates ({x}, {y})")
|
| 78 |
+
return f"Clicked at coordinates ({x}, {y})"
|
| 79 |
+
|
| 80 |
+
@tool
|
| 81 |
+
def right_click(x: int, y: int) -> str:
|
| 82 |
+
"""
|
| 83 |
+
Performs a right-click at the specified coordinates
|
| 84 |
+
Args:
|
| 85 |
+
x: The x coordinate (horizontal position)
|
| 86 |
+
y: The y coordinate (vertical position)
|
| 87 |
+
"""
|
| 88 |
+
self.desktop.move_mouse(x, y)
|
| 89 |
+
self.desktop.right_click()
|
| 90 |
+
self.click_coordinates = [x, y]
|
| 91 |
+
self.logger.log(f"Right-clicked at coordinates ({x}, {y})")
|
| 92 |
+
return f"Right-clicked at coordinates ({x}, {y})"
|
| 93 |
+
|
| 94 |
+
@tool
|
| 95 |
+
def double_click(x: int, y: int) -> str:
|
| 96 |
+
"""
|
| 97 |
+
Performs a double-click at the specified coordinates
|
| 98 |
+
Args:
|
| 99 |
+
x: The x coordinate (horizontal position)
|
| 100 |
+
y: The y coordinate (vertical position)
|
| 101 |
+
"""
|
| 102 |
+
self.desktop.move_mouse(x, y)
|
| 103 |
+
self.desktop.double_click()
|
| 104 |
+
self.click_coordinates = [x, y]
|
| 105 |
+
self.logger.log(f"Double-clicked at coordinates ({x}, {y})")
|
| 106 |
+
return f"Double-clicked at coordinates ({x}, {y})"
|
| 107 |
+
|
| 108 |
+
@tool
|
| 109 |
+
def move_mouse(x: int, y: int) -> str:
|
| 110 |
+
"""
|
| 111 |
+
Moves the mouse cursor to the specified coordinates
|
| 112 |
+
Args:
|
| 113 |
+
x: The x coordinate (horizontal position)
|
| 114 |
+
y: The y coordinate (vertical position)
|
| 115 |
+
"""
|
| 116 |
+
self.desktop.move_mouse(x, y)
|
| 117 |
+
self.logger.log(f"Moved mouse to coordinates ({x}, {y})")
|
| 118 |
+
return f"Moved mouse to coordinates ({x}, {y})"
|
| 119 |
+
|
| 120 |
+
def normalize_text(text):
|
| 121 |
+
return "".join(
|
| 122 |
+
c
|
| 123 |
+
for c in unicodedata.normalize("NFD", text)
|
| 124 |
+
if not unicodedata.combining(c)
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
@tool
|
| 128 |
+
def write(text: str) -> str:
|
| 129 |
+
"""
|
| 130 |
+
Types the specified text at the current cursor position.
|
| 131 |
+
Args:
|
| 132 |
+
text: The text to type
|
| 133 |
+
"""
|
| 134 |
+
clean_text = normalize_text(text)
|
| 135 |
+
self.desktop.write(clean_text, delay_in_ms=75)
|
| 136 |
+
self.logger.log(f"Typed text: '{clean_text}'")
|
| 137 |
+
return f"Typed text: '{clean_text}'"
|
| 138 |
+
|
| 139 |
+
@tool
|
| 140 |
+
def press(key: str) -> str:
|
| 141 |
+
"""
|
| 142 |
+
Presses a keyboard key
|
| 143 |
+
Args:
|
| 144 |
+
key: The key to press (e.g. "enter", "space", "backspace", etc.).
|
| 145 |
+
"""
|
| 146 |
+
self.desktop.press(key)
|
| 147 |
+
self.logger.log(f"Pressed key: {key}")
|
| 148 |
+
return f"Pressed key: {key}"
|
| 149 |
+
|
| 150 |
+
@tool
|
| 151 |
+
def go_back() -> str:
|
| 152 |
+
"""
|
| 153 |
+
Goes back to the previous page in the browser. If using this tool doesn't work, just click the button directly.
|
| 154 |
+
Args:
|
| 155 |
+
"""
|
| 156 |
+
self.desktop.press(["alt", "left"])
|
| 157 |
+
self.logger.log("Went back one page")
|
| 158 |
+
return "Went back one page"
|
| 159 |
+
|
| 160 |
+
@tool
|
| 161 |
+
def drag(x1: int, y1: int, x2: int, y2: int) -> str:
|
| 162 |
+
"""
|
| 163 |
+
Clicks [x1, y1], drags mouse to [x2, y2], then release click.
|
| 164 |
+
Args:
|
| 165 |
+
x1: origin x coordinate
|
| 166 |
+
y1: origin y coordinate
|
| 167 |
+
x2: end x coordinate
|
| 168 |
+
y2: end y coordinate
|
| 169 |
+
"""
|
| 170 |
+
self.desktop.drag([x1, y1], [x2, y2])
|
| 171 |
+
message = f"Dragged and dropped from [{x1}, {y1}] to [{x2}, {y2}]"
|
| 172 |
+
self.logger.log(message)
|
| 173 |
+
return message
|
| 174 |
+
|
| 175 |
+
@tool
|
| 176 |
+
def scroll(x: int, y: int, direction: str = "down", amount: int = 2) -> str:
|
| 177 |
+
"""
|
| 178 |
+
Moves the mouse to selected coordinates, then uses the scroll button: this could scroll the page or zoom, depending on the app. DO NOT use scroll to move through linux desktop menus.
|
| 179 |
+
Args:
|
| 180 |
+
x: The x coordinate (horizontal position) of the element to scroll/zoom
|
| 181 |
+
y: The y coordinate (vertical position) of the element to scroll/zoom
|
| 182 |
+
direction: The direction to scroll ("up" or "down"), defaults to "down". For zoom, "up" zooms in, "down" zooms out.
|
| 183 |
+
amount: The amount to scroll. A good amount is 1 or 2.
|
| 184 |
+
"""
|
| 185 |
+
self.desktop.move_mouse(x, y)
|
| 186 |
+
self.desktop.scroll(direction=direction, amount=amount)
|
| 187 |
+
message = f"Scrolled {direction} by {amount}"
|
| 188 |
+
self.logger.log(message)
|
| 189 |
+
return message
|
| 190 |
+
|
| 191 |
+
@tool
|
| 192 |
+
def wait(seconds: float) -> str:
|
| 193 |
+
"""
|
| 194 |
+
Waits for the specified number of seconds. Very useful in case the prior order is still executing (for example starting very heavy applications like browsers or office apps)
|
| 195 |
+
Args:
|
| 196 |
+
seconds: Number of seconds to wait, generally 3 is enough.
|
| 197 |
+
"""
|
| 198 |
+
time.sleep(seconds)
|
| 199 |
+
self.logger.log(f"Waited for {seconds} seconds")
|
| 200 |
+
return f"Waited for {seconds} seconds"
|
| 201 |
+
|
| 202 |
+
@tool
|
| 203 |
+
def open(url: str) -> str:
|
| 204 |
+
"""
|
| 205 |
+
Directly opens a browser with the specified url: use this at start of web searches rather than trying to click the browser.
|
| 206 |
+
Args:
|
| 207 |
+
url: The URL to open
|
| 208 |
+
"""
|
| 209 |
+
# Make sure URL has http/https prefix
|
| 210 |
+
if not url.startswith(("http://", "https://")):
|
| 211 |
+
url = "https://" + url
|
| 212 |
+
|
| 213 |
+
self.desktop.open(url)
|
| 214 |
+
# Give it time to load
|
| 215 |
+
time.sleep(2)
|
| 216 |
+
self.logger.log(f"Opening URL: {url}")
|
| 217 |
+
return f"Opened URL: {url}"
|
| 218 |
+
|
| 219 |
+
# Register the tools
|
| 220 |
+
self.tools["click"] = click
|
| 221 |
+
self.tools["right_click"] = right_click
|
| 222 |
+
self.tools["double_click"] = double_click
|
| 223 |
+
self.tools["move_mouse"] = move_mouse
|
| 224 |
+
self.tools["write"] = write
|
| 225 |
+
self.tools["press"] = press
|
| 226 |
+
self.tools["scroll"] = scroll
|
| 227 |
+
self.tools["wait"] = wait
|
| 228 |
+
self.tools["open"] = open
|
| 229 |
+
self.tools["go_back"] = go_back
|
| 230 |
+
self.tools["drag"] = drag
|
| 231 |
+
self.tools["scroll"] = scroll
|
cua2-core/src/cua2_core/services/agent_utils/function_parser.py
ADDED
|
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Function parser for extracting function names, parameter names, and values from string function calls.
|
| 4 |
+
Supports both mobile and pyautogui function patterns.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
from collections import OrderedDict
|
| 9 |
+
from typing import Any, Dict, List, Tuple
|
| 10 |
+
|
| 11 |
+
from pydantic import BaseModel
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class FunctionCall(BaseModel):
|
| 15 |
+
"""Represents a parsed function call with its parameters."""
|
| 16 |
+
|
| 17 |
+
function_name: str
|
| 18 |
+
parameters: Dict[str, Any]
|
| 19 |
+
original_string: str
|
| 20 |
+
description: str = ""
|
| 21 |
+
|
| 22 |
+
def to_string(self) -> str:
|
| 23 |
+
"""
|
| 24 |
+
Reconstruct the function call string from the parsed data.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
String representation of the function call
|
| 28 |
+
|
| 29 |
+
Examples:
|
| 30 |
+
>>> call = FunctionCall("mobile.wait", {"seconds": 3}, "mobile.wait(seconds=3)")
|
| 31 |
+
>>> call.to_string()
|
| 32 |
+
"mobile.wait(seconds=3)"
|
| 33 |
+
|
| 34 |
+
>>> call = FunctionCall("function", {"arg_0": 1, "arg_1": 2, "x": 0.5}, "function(1, 2, x=0.5)")
|
| 35 |
+
>>> call.to_string()
|
| 36 |
+
"function(1, 2, x=0.5)"
|
| 37 |
+
"""
|
| 38 |
+
if not self.parameters:
|
| 39 |
+
return f"{self.function_name}()"
|
| 40 |
+
|
| 41 |
+
# Separate positional and named arguments
|
| 42 |
+
positional_args = []
|
| 43 |
+
named_args = []
|
| 44 |
+
|
| 45 |
+
for name, value in self.parameters.items():
|
| 46 |
+
if name.startswith("arg_"):
|
| 47 |
+
# Positional argument
|
| 48 |
+
positional_args.append((int(name.split("_")[1]), value))
|
| 49 |
+
else:
|
| 50 |
+
# kwargs
|
| 51 |
+
named_args.append((name, value))
|
| 52 |
+
|
| 53 |
+
# Sort positional arguments by index
|
| 54 |
+
positional_args.sort(key=lambda x: x[0])
|
| 55 |
+
|
| 56 |
+
# Build parameter string
|
| 57 |
+
param_parts = []
|
| 58 |
+
|
| 59 |
+
# Add positional arguments
|
| 60 |
+
for _, value in positional_args:
|
| 61 |
+
param_parts.append(self._value_to_string(value))
|
| 62 |
+
|
| 63 |
+
# Add named arguments
|
| 64 |
+
for name, value in named_args:
|
| 65 |
+
param_parts.append(f"{name}={self._value_to_string(value)}")
|
| 66 |
+
|
| 67 |
+
return f"{self.function_name}({', '.join(param_parts)})"
|
| 68 |
+
|
| 69 |
+
def _value_to_string(self, value: Any) -> str:
|
| 70 |
+
"""
|
| 71 |
+
Convert a value to its string representation for function calls.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
value: The value to convert
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
String representation of the value
|
| 78 |
+
"""
|
| 79 |
+
if isinstance(value, str):
|
| 80 |
+
# Quote strings
|
| 81 |
+
return f"'{value}'"
|
| 82 |
+
elif isinstance(value, (list, tuple)):
|
| 83 |
+
# Convert lists/tuples to string representation
|
| 84 |
+
items = [self._value_to_string(item) for item in value]
|
| 85 |
+
return f"[{', '.join(items)}]"
|
| 86 |
+
elif isinstance(value, dict):
|
| 87 |
+
# Convert dictionaries to string representation
|
| 88 |
+
items = [f"'{k}': {self._value_to_string(v)}" for k, v in value.items()]
|
| 89 |
+
return f"{{{', '.join(items)}}}"
|
| 90 |
+
elif isinstance(value, bool):
|
| 91 |
+
# Convert booleans to lowercase
|
| 92 |
+
return str(value).lower()
|
| 93 |
+
elif value is None:
|
| 94 |
+
return "None"
|
| 95 |
+
else:
|
| 96 |
+
# Numbers and other types
|
| 97 |
+
return str(value)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def parse_function_call(
|
| 101 |
+
function_string: str, pattern_to_match: list[str] = []
|
| 102 |
+
) -> List[FunctionCall]:
|
| 103 |
+
"""
|
| 104 |
+
Parse a function call string and extract all function calls found.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
function_string: String representation of function calls
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
List of FunctionCall objects with parsed information
|
| 111 |
+
|
| 112 |
+
Examples:
|
| 113 |
+
>>> parse_function_call("mobile.wait(seconds=3)")
|
| 114 |
+
[FunctionCall(function_name='wait', parameters={'seconds': 3}, ...)]
|
| 115 |
+
|
| 116 |
+
>>> parse_function_call("mobile. wait(seconds=3)")
|
| 117 |
+
[FunctionCall(function_name='wait', parameters={'seconds': 3}, ...)]
|
| 118 |
+
|
| 119 |
+
>>> parse_function_call("mobile.wait(seconds=3) mobile.home()")
|
| 120 |
+
[FunctionCall(function_name='wait', parameters={'seconds': 3}, ...), FunctionCall(function_name='home', parameters={}, ...)]
|
| 121 |
+
"""
|
| 122 |
+
# Remove any leading/trailing whitespace
|
| 123 |
+
function_string = function_string.strip()
|
| 124 |
+
|
| 125 |
+
# Pattern to match function calls with parameters
|
| 126 |
+
# Matches: function_name(param1=value1, param2=value2, ...)
|
| 127 |
+
# Can have any characters before the function call, extracts just the function name
|
| 128 |
+
pattern = r".*?([a-zA-Z_][a-zA-Z0-9_.]*)\(([^)]*)\)"
|
| 129 |
+
|
| 130 |
+
matches = re.findall(pattern, function_string)
|
| 131 |
+
if not matches:
|
| 132 |
+
# No valid function calls found in: {function_string}
|
| 133 |
+
return []
|
| 134 |
+
|
| 135 |
+
results = []
|
| 136 |
+
for match in matches:
|
| 137 |
+
function_name = match[0]
|
| 138 |
+
params_string = match[1]
|
| 139 |
+
|
| 140 |
+
if pattern_to_match and all(
|
| 141 |
+
pattern not in function_name for pattern in pattern_to_match
|
| 142 |
+
):
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
# Parse parameters
|
| 146 |
+
parameters = parse_parameters(params_string)
|
| 147 |
+
|
| 148 |
+
# Create the original string for this specific function call
|
| 149 |
+
original_string = f"{function_name}({params_string})"
|
| 150 |
+
|
| 151 |
+
results.append(
|
| 152 |
+
FunctionCall(
|
| 153 |
+
function_name=function_name,
|
| 154 |
+
parameters=parameters,
|
| 155 |
+
original_string=original_string,
|
| 156 |
+
)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
return results
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def parse_parameters(params_string: str) -> Dict[str, Any]:
|
| 163 |
+
"""
|
| 164 |
+
Parse parameter string and extract parameter names and values.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
params_string: String containing parameters (e.g., "x=0.5, y=0.6, text='hello'")
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
Dictionary mapping parameter names to their values
|
| 171 |
+
|
| 172 |
+
Examples:
|
| 173 |
+
>>> parse_parameters("x=0.5, y=0.6")
|
| 174 |
+
{'x': 0.5, 'y': 0.6}
|
| 175 |
+
|
| 176 |
+
>>> parse_parameters("app_name='drupe'")
|
| 177 |
+
{'app_name': 'drupe'}
|
| 178 |
+
|
| 179 |
+
>>> parse_parameters("'text'")
|
| 180 |
+
{'arg_0': 'text'}
|
| 181 |
+
|
| 182 |
+
>>> parse_parameters("1, 3, 4")
|
| 183 |
+
{'arg_0': 1, 'arg_1': 3, 'arg_2': 4}
|
| 184 |
+
|
| 185 |
+
>>> parse_parameters("arg1, arg2, x=0.5")
|
| 186 |
+
{'arg_0': 'arg1', 'arg_1': 'arg2', 'x': 0.5}
|
| 187 |
+
"""
|
| 188 |
+
if not params_string.strip():
|
| 189 |
+
return {}
|
| 190 |
+
|
| 191 |
+
parameters = OrderedDict()
|
| 192 |
+
|
| 193 |
+
# Split by commas, but be careful with commas inside quotes or brackets
|
| 194 |
+
param_parts = split_parameters(params_string)
|
| 195 |
+
|
| 196 |
+
positional_index = 0
|
| 197 |
+
|
| 198 |
+
for part in param_parts:
|
| 199 |
+
part = part.strip()
|
| 200 |
+
if not part:
|
| 201 |
+
continue
|
| 202 |
+
|
| 203 |
+
# Parse individual parameter
|
| 204 |
+
name, value = parse_single_parameter(part)
|
| 205 |
+
|
| 206 |
+
# For positional arguments, use index-based naming
|
| 207 |
+
if name.startswith("arg_"):
|
| 208 |
+
name = f"arg_{positional_index}"
|
| 209 |
+
positional_index += 1
|
| 210 |
+
|
| 211 |
+
parameters[name] = value
|
| 212 |
+
|
| 213 |
+
return parameters
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def split_parameters(params_string: str) -> List[str]:
|
| 217 |
+
"""
|
| 218 |
+
Split parameter string by commas, respecting quotes and brackets.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
params_string: String containing parameters
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
List of individual parameter strings
|
| 225 |
+
"""
|
| 226 |
+
parts = []
|
| 227 |
+
current_part = ""
|
| 228 |
+
paren_count = 0
|
| 229 |
+
bracket_count = 0
|
| 230 |
+
brace_count = 0
|
| 231 |
+
in_quotes = False
|
| 232 |
+
quote_char = None
|
| 233 |
+
|
| 234 |
+
for char in params_string:
|
| 235 |
+
if char in ['"', "'"] and (not in_quotes or char == quote_char):
|
| 236 |
+
if not in_quotes:
|
| 237 |
+
in_quotes = True
|
| 238 |
+
quote_char = char
|
| 239 |
+
else:
|
| 240 |
+
in_quotes = False
|
| 241 |
+
quote_char = None
|
| 242 |
+
elif not in_quotes:
|
| 243 |
+
if char == "(":
|
| 244 |
+
paren_count += 1
|
| 245 |
+
elif char == ")":
|
| 246 |
+
paren_count -= 1
|
| 247 |
+
elif char == "[":
|
| 248 |
+
bracket_count += 1
|
| 249 |
+
elif char == "]":
|
| 250 |
+
bracket_count -= 1
|
| 251 |
+
elif char == "{":
|
| 252 |
+
brace_count += 1
|
| 253 |
+
elif char == "}":
|
| 254 |
+
brace_count -= 1
|
| 255 |
+
elif (
|
| 256 |
+
char == ","
|
| 257 |
+
and paren_count == 0
|
| 258 |
+
and bracket_count == 0
|
| 259 |
+
and brace_count == 0
|
| 260 |
+
):
|
| 261 |
+
parts.append(current_part.strip())
|
| 262 |
+
current_part = ""
|
| 263 |
+
continue
|
| 264 |
+
|
| 265 |
+
current_part += char
|
| 266 |
+
|
| 267 |
+
if current_part.strip():
|
| 268 |
+
parts.append(current_part.strip())
|
| 269 |
+
|
| 270 |
+
return parts
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def parse_single_parameter(param_string: str) -> Tuple[str, Any]:
|
| 274 |
+
"""
|
| 275 |
+
Parse a single parameter string into name and value.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
param_string: String like "x=0.5" or "app_name='drupe'" or just "value"
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
Tuple of (parameter_name, parameter_value)
|
| 282 |
+
|
| 283 |
+
Examples:
|
| 284 |
+
>>> parse_single_parameter("x=0.5")
|
| 285 |
+
('x', 0.5)
|
| 286 |
+
|
| 287 |
+
>>> parse_single_parameter("app_name='drupe'")
|
| 288 |
+
('app_name', 'drupe')
|
| 289 |
+
|
| 290 |
+
>>> parse_single_parameter("'text'")
|
| 291 |
+
('arg_0', 'text')
|
| 292 |
+
|
| 293 |
+
>>> parse_single_parameter("123")
|
| 294 |
+
('arg_0', 123)
|
| 295 |
+
|
| 296 |
+
>>> parse_single_parameter("3")
|
| 297 |
+
('arg_0', 3)
|
| 298 |
+
"""
|
| 299 |
+
# Pattern to match parameter name and value
|
| 300 |
+
pattern = r"^([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*(.+)$"
|
| 301 |
+
|
| 302 |
+
match = re.match(pattern, param_string)
|
| 303 |
+
if match:
|
| 304 |
+
# Named parameter
|
| 305 |
+
param_name = match.group(1)
|
| 306 |
+
param_value_str = match.group(2).strip()
|
| 307 |
+
param_value = parse_value(param_value_str)
|
| 308 |
+
return param_name, param_value
|
| 309 |
+
else:
|
| 310 |
+
# Positional parameter - treat as unnamed argument
|
| 311 |
+
param_value = parse_value(param_string)
|
| 312 |
+
return "arg_0", param_value
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def parse_value(value_string: str) -> Any:
|
| 316 |
+
"""
|
| 317 |
+
Parse a value string into appropriate Python type.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
value_string: String representation of a value
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
Parsed value (int, float, str, list, etc.)
|
| 324 |
+
|
| 325 |
+
Examples:
|
| 326 |
+
>>> parse_value("3")
|
| 327 |
+
3
|
| 328 |
+
|
| 329 |
+
>>> parse_value("3.14")
|
| 330 |
+
3.14
|
| 331 |
+
|
| 332 |
+
>>> parse_value("'hello'")
|
| 333 |
+
'hello'
|
| 334 |
+
|
| 335 |
+
>>> parse_value("[0.581, 0.898]")
|
| 336 |
+
[0.581, 0.898]
|
| 337 |
+
"""
|
| 338 |
+
value_string = value_string.strip()
|
| 339 |
+
|
| 340 |
+
# String values (quoted)
|
| 341 |
+
if (value_string.startswith("'") and value_string.endswith("'")) or (
|
| 342 |
+
value_string.startswith('"') and value_string.endswith('"')
|
| 343 |
+
):
|
| 344 |
+
return value_string[1:-1]
|
| 345 |
+
|
| 346 |
+
# List values
|
| 347 |
+
if value_string.startswith("[") and value_string.endswith("]"):
|
| 348 |
+
return parse_list(value_string)
|
| 349 |
+
|
| 350 |
+
# Dictionary values
|
| 351 |
+
if value_string.startswith("{") and value_string.endswith("}"):
|
| 352 |
+
return parse_dict(value_string)
|
| 353 |
+
|
| 354 |
+
# Boolean values
|
| 355 |
+
if value_string.lower() in ["true", "false"]:
|
| 356 |
+
return value_string.lower() == "true"
|
| 357 |
+
|
| 358 |
+
# None value
|
| 359 |
+
if value_string.lower() == "none":
|
| 360 |
+
return None
|
| 361 |
+
|
| 362 |
+
# Numeric values
|
| 363 |
+
try:
|
| 364 |
+
# Try integer first
|
| 365 |
+
if "." not in value_string:
|
| 366 |
+
return int(value_string)
|
| 367 |
+
else:
|
| 368 |
+
return float(value_string)
|
| 369 |
+
except ValueError:
|
| 370 |
+
# If it's not a number, return as string (remove quotes if present)
|
| 371 |
+
if value_string.startswith("'") and value_string.endswith("'"):
|
| 372 |
+
return value_string[1:-1]
|
| 373 |
+
elif value_string.startswith('"') and value_string.endswith('"'):
|
| 374 |
+
return value_string[1:-1]
|
| 375 |
+
else:
|
| 376 |
+
return value_string
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def parse_list(list_string: str) -> List[Any]:
|
| 380 |
+
"""
|
| 381 |
+
Parse a list string into a Python list.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
list_string: String like "[0.581, 0.898]"
|
| 385 |
+
|
| 386 |
+
Returns:
|
| 387 |
+
List of parsed values
|
| 388 |
+
|
| 389 |
+
Examples:
|
| 390 |
+
>>> parse_list("[0.581, 0.898]")
|
| 391 |
+
[0.581, 0.898]
|
| 392 |
+
"""
|
| 393 |
+
# Remove outer brackets
|
| 394 |
+
content = list_string[1:-1].strip()
|
| 395 |
+
if not content:
|
| 396 |
+
return []
|
| 397 |
+
|
| 398 |
+
# Split by commas, respecting nested structures
|
| 399 |
+
parts = split_parameters(content)
|
| 400 |
+
|
| 401 |
+
return [parse_value(part.strip()) for part in parts]
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def parse_dict(dict_string: str) -> Dict[str, Any]:
|
| 405 |
+
"""
|
| 406 |
+
Parse a dictionary string into a Python dict.
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
dict_string: String like "{'key': 'value'}"
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
Dictionary of parsed key-value pairs
|
| 413 |
+
"""
|
| 414 |
+
# Remove outer braces
|
| 415 |
+
content = dict_string[1:-1].strip()
|
| 416 |
+
if not content:
|
| 417 |
+
return {}
|
| 418 |
+
|
| 419 |
+
# Split by commas, respecting nested structures
|
| 420 |
+
parts = split_parameters(content)
|
| 421 |
+
|
| 422 |
+
result = {}
|
| 423 |
+
for part in parts:
|
| 424 |
+
part = part.strip()
|
| 425 |
+
if ":" in part:
|
| 426 |
+
key_str, value_str = part.split(":", 1)
|
| 427 |
+
key = parse_value(key_str.strip())
|
| 428 |
+
value = parse_value(value_str.strip())
|
| 429 |
+
result[key] = value
|
| 430 |
+
|
| 431 |
+
return result
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def parse_multiple_functions(function_strings: List[str]) -> List[FunctionCall]:
|
| 435 |
+
"""
|
| 436 |
+
Parse multiple function call strings.
|
| 437 |
+
|
| 438 |
+
Args:
|
| 439 |
+
function_strings: List of function call strings
|
| 440 |
+
|
| 441 |
+
Returns:
|
| 442 |
+
List of FunctionCall objects
|
| 443 |
+
"""
|
| 444 |
+
results = []
|
| 445 |
+
for func_str in function_strings:
|
| 446 |
+
try:
|
| 447 |
+
result_list = parse_function_call(func_str)
|
| 448 |
+
results.extend(result_list)
|
| 449 |
+
except Exception as e:
|
| 450 |
+
print(f"Warning: Could not parse function call '{func_str}': {e}")
|
| 451 |
+
continue
|
| 452 |
+
|
| 453 |
+
return results
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def extract_function_calls_from_text(text: str) -> List[FunctionCall]:
|
| 457 |
+
"""
|
| 458 |
+
Extract and parse function calls from a text block.
|
| 459 |
+
|
| 460 |
+
Args:
|
| 461 |
+
text: Text containing function calls
|
| 462 |
+
|
| 463 |
+
Returns:
|
| 464 |
+
List of FunctionCall objects
|
| 465 |
+
"""
|
| 466 |
+
# Pattern to find function calls in text
|
| 467 |
+
# Matches: function_name(param1=value1, param2=value2)
|
| 468 |
+
pattern = r"[a-zA-Z_][a-zA-Z0-9_.]*\([^)]*\)"
|
| 469 |
+
|
| 470 |
+
matches = re.findall(pattern, text)
|
| 471 |
+
return parse_multiple_functions(matches)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
# Example usage and testing
|
| 475 |
+
if __name__ == "__main__":
|
| 476 |
+
test_cases = [
|
| 477 |
+
"mobile.home()",
|
| 478 |
+
"mobile.open_app(app_name='drupe')",
|
| 479 |
+
"mobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])",
|
| 480 |
+
"mobile.back()",
|
| 481 |
+
"mobile.long_press(x=0.799, y=0.911)",
|
| 482 |
+
"mobile.terminate(status='success')",
|
| 483 |
+
"answer('text')",
|
| 484 |
+
"pyautogui.hscroll(page=-0.1)",
|
| 485 |
+
"pyautogui.scroll(page=-0.1)",
|
| 486 |
+
"pyautogui.scroll(0.13)",
|
| 487 |
+
"pyautogui.click(x=0.8102, y=0.9463)",
|
| 488 |
+
"pyautogui.hotkey(keys=['ctrl', 'c'])",
|
| 489 |
+
"pyautogui.press(keys='enter')",
|
| 490 |
+
"pyautogui.press(keys=['enter'])",
|
| 491 |
+
"pyautogui.moveTo(x=0.04, y=0.405)",
|
| 492 |
+
"pyautogui.write(message='bread buns')",
|
| 493 |
+
"pyautogui.dragTo(x=0.8102, y=0.9463)",
|
| 494 |
+
"mobile.wait(seconds=3)\nmobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])",
|
| 495 |
+
# Additional test cases for multiple positional arguments
|
| 496 |
+
"function(arg1, arg2, arg3)",
|
| 497 |
+
"function('hello', 123, x=0.5)",
|
| 498 |
+
"function(arg1, arg2, named_param='value')",
|
| 499 |
+
"function(1, 2, 3, 4, 5)",
|
| 500 |
+
"function('a', 'b', 'c', x=1, y=2)",
|
| 501 |
+
]
|
| 502 |
+
|
| 503 |
+
print("Testing function parser:")
|
| 504 |
+
print("=" * 50)
|
| 505 |
+
|
| 506 |
+
for test_case in test_cases:
|
| 507 |
+
try:
|
| 508 |
+
results = parse_function_call(test_case)
|
| 509 |
+
print(f"✓ {test_case}")
|
| 510 |
+
for result in results:
|
| 511 |
+
print(f" Function: {result.function_name}")
|
| 512 |
+
print(f" Parameters: {result.parameters}")
|
| 513 |
+
print()
|
| 514 |
+
except Exception as e:
|
| 515 |
+
print(f"✗ {test_case}")
|
| 516 |
+
print(f" Error: {e}")
|
| 517 |
+
print()
|
| 518 |
+
|
| 519 |
+
# Test extracting from text
|
| 520 |
+
print("Testing text extraction:")
|
| 521 |
+
print("=" * 50)
|
| 522 |
+
|
| 523 |
+
sample_text = """
|
| 524 |
+
mobile.wait(seconds=3)
|
| 525 |
+
mobile.open_app(app_name='drupe')
|
| 526 |
+
pyautogui.click(x=0.8102, y=0.9463)
|
| 527 |
+
pyautogui.write(message='bread buns')
|
| 528 |
+
"""
|
| 529 |
+
|
| 530 |
+
extracted = extract_function_calls_from_text(sample_text)
|
| 531 |
+
for func_call in extracted:
|
| 532 |
+
print(f"Found: {func_call.function_name} with params: {func_call.parameters}")
|
| 533 |
+
|
| 534 |
+
# Test reconstruction
|
| 535 |
+
print("\nTesting function call reconstruction:")
|
| 536 |
+
print("=" * 50)
|
| 537 |
+
|
| 538 |
+
reconstruction_tests = [
|
| 539 |
+
"mobile.wait(seconds=3)",
|
| 540 |
+
"mobile.home()",
|
| 541 |
+
"mobile.open_app(app_name='drupe')",
|
| 542 |
+
"mobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])",
|
| 543 |
+
"answer('text')",
|
| 544 |
+
"pyautogui.scroll(0.13)",
|
| 545 |
+
"pyautogui.click(x=0.8102, y=0.9463)",
|
| 546 |
+
"pyautogui.hotkey(keys=['ctrl', 'c'])",
|
| 547 |
+
"function(1, 2, 3)",
|
| 548 |
+
"function('hello', 123, x=0.5, y=0.8)",
|
| 549 |
+
"function([1, 3], 'arg2', named_param='value')",
|
| 550 |
+
]
|
| 551 |
+
|
| 552 |
+
for test_case in reconstruction_tests:
|
| 553 |
+
parsed_list = parse_function_call(test_case)
|
| 554 |
+
for parsed in parsed_list:
|
| 555 |
+
reconstructed = parsed.to_string()
|
| 556 |
+
print(f"Original: {test_case}")
|
| 557 |
+
print(f"Reconstructed: {reconstructed}")
|
| 558 |
+
print(f"Match: {test_case == reconstructed}")
|
| 559 |
+
assert test_case == reconstructed
|
| 560 |
+
print()
|
cua2-core/src/cua2_core/services/agent_utils/get_model.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from smolagents import InferenceClientModel, Model
|
| 2 |
+
|
| 3 |
+
# Available model IDs
|
| 4 |
+
AVAILABLE_MODELS = [
|
| 5 |
+
"Qwen/Qwen3-VL-2B-Instruct",
|
| 6 |
+
"Qwen/Qwen3-VL-2B-Thinking",
|
| 7 |
+
"Qwen/Qwen3-VL-4B-Instruct",
|
| 8 |
+
"Qwen/Qwen3-VL-4B-Thinking",
|
| 9 |
+
"Qwen/Qwen3-VL-8B-Instruct",
|
| 10 |
+
"Qwen/Qwen3-VL-8B-Thinking",
|
| 11 |
+
"Qwen/Qwen3-VL-30B-A3B-Instruct",
|
| 12 |
+
"Qwen/Qwen3-VL-30B-A3B-Thinking",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_model(model_id: str) -> Model:
|
| 17 |
+
"""Get the model"""
|
| 18 |
+
return InferenceClientModel(model_id=model_id)
|
cua2-core/src/cua2_core/services/agent_utils/prompt.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
|
| 3 |
+
E2B_SYSTEM_PROMPT_TEMPLATE = """You are a computer-use automation assistant controlling a full desktop remotely.
|
| 4 |
+
The current date is <<current_date>>.
|
| 5 |
+
|
| 6 |
+
<mission>
|
| 7 |
+
Your objective is to complete a given task step-by-step by interacting with the desktop.
|
| 8 |
+
At every step, you:
|
| 9 |
+
1. Observe the latest screenshot (always analyze it carefully).
|
| 10 |
+
2. Reflect briefly on what you see and what to do next.
|
| 11 |
+
3. Produce **one precise action**, formatted exactly as Python code in a fenced block.
|
| 12 |
+
|
| 13 |
+
You will receive a new screenshot after each action.
|
| 14 |
+
Never skip the structure below.
|
| 15 |
+
</mission>
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
<action_process>
|
| 20 |
+
For every step, strictly follow this format:
|
| 21 |
+
|
| 22 |
+
Short term goal: what you’re trying to accomplish in this step.
|
| 23 |
+
What I see: describe key elements visible on the desktop.
|
| 24 |
+
Reflection: reasoning that justifies your next move (mention errors or corrections if needed).
|
| 25 |
+
**Action:**
|
| 26 |
+
```python
|
| 27 |
+
click(x, y)
|
| 28 |
+
```<end_code>
|
| 29 |
+
</action_process>
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
+
<environment>
|
| 34 |
+
The desktop resolution is <<resolution_x>>x<<resolution_y>> pixels.
|
| 35 |
+
You can only interact through the following tools:
|
| 36 |
+
|
| 37 |
+
{%- for tool in tools.values() %}
|
| 38 |
+
- **{{ tool.name }}**: {{ tool.description }}
|
| 39 |
+
- Inputs: {{ tool.inputs }}
|
| 40 |
+
- Returns: {{ tool.output_type }}
|
| 41 |
+
{%- endfor %}
|
| 42 |
+
|
| 43 |
+
If a task requires a specific application or website, **use**:
|
| 44 |
+
```python
|
| 45 |
+
open("app_or_url")
|
| 46 |
+
```
|
| 47 |
+
to launch it before interacting.
|
| 48 |
+
Never manually click the browser icon — use `open_url()` directly for web pages.
|
| 49 |
+
</environment>
|
| 50 |
+
|
| 51 |
+
---
|
| 52 |
+
|
| 53 |
+
<click_guidelines>
|
| 54 |
+
- Always click using **real, visible coordinates** based on the current screenshot.
|
| 55 |
+
- Click precisely **in the center** of the intended target (button, text, icon).
|
| 56 |
+
- Avoid random or approximate coordinates.
|
| 57 |
+
- If nothing changes after a click, check if you misclicked (green crosshair = last click position).
|
| 58 |
+
- If a menu item shows a ▶ (triangle), it means it expands—click directly on the text, not the icon.
|
| 59 |
+
- Use `scroll()` only within scrollable views (webpages, app lists, etc.).
|
| 60 |
+
</click_guidelines>
|
| 61 |
+
|
| 62 |
+
---
|
| 63 |
+
|
| 64 |
+
<workflow_guidelines>
|
| 65 |
+
- **ALWAYS START** by analyzing if the task requires opening an application or URL. If so, your **first action** must be:
|
| 66 |
+
- For websites: `open_url("https://google.com")`
|
| 67 |
+
- For applications: `open("app_name")`
|
| 68 |
+
- Never manually navigate to apps via clicking icons—use the open tools directly.
|
| 69 |
+
- Complete one atomic action per step: e.g., **click**, **type**, or **wait**.
|
| 70 |
+
- Never combine multiple tool calls in one step.
|
| 71 |
+
- Validate that your previous action succeeded before continuing.
|
| 72 |
+
- If the interface hasn't changed, adjust your strategy instead of repeating endlessly.
|
| 73 |
+
- Use `wait(seconds)` for short delays if the interface is loading.
|
| 74 |
+
- Always conclude with:
|
| 75 |
+
```python
|
| 76 |
+
final_answer("Answer the user's question or resume the task")
|
| 77 |
+
```
|
| 78 |
+
once the task is fully completed and verified. Answer the user's question or resume the task.
|
| 79 |
+
</workflow_guidelines>
|
| 80 |
+
|
| 81 |
+
---
|
| 82 |
+
|
| 83 |
+
<example>
|
| 84 |
+
Task: *Open a text editor and write “Hello World”*
|
| 85 |
+
|
| 86 |
+
Step 1
|
| 87 |
+
Short term goal: Launch the text editor.
|
| 88 |
+
What I see: “Text Editor” visible under Accessories.
|
| 89 |
+
Reflection: Clicking directly on “Text Editor”.
|
| 90 |
+
Action:
|
| 91 |
+
```python
|
| 92 |
+
open("text_editor")
|
| 93 |
+
```<end_code>
|
| 94 |
+
|
| 95 |
+
Step 2
|
| 96 |
+
Short term goal: click on the text editor page.
|
| 97 |
+
What I see: Text editor page.
|
| 98 |
+
Reflection: Click on the text editor page to write "Hello World".
|
| 99 |
+
Action:
|
| 100 |
+
```python
|
| 101 |
+
click(52, 10)
|
| 102 |
+
```<end_code>
|
| 103 |
+
|
| 104 |
+
Step 3
|
| 105 |
+
Short term goal: Type text.
|
| 106 |
+
What I see: Empty notepad open.
|
| 107 |
+
Reflection: Ready to type.
|
| 108 |
+
Action:
|
| 109 |
+
```python
|
| 110 |
+
write("Hello World")
|
| 111 |
+
```<end_code>
|
| 112 |
+
|
| 113 |
+
Step 3
|
| 114 |
+
Short term goal: Verify text and conclude.
|
| 115 |
+
What I see: “Hello World” visible in notepad.
|
| 116 |
+
Reflection: Task successful.
|
| 117 |
+
Action:
|
| 118 |
+
```python
|
| 119 |
+
final_answer("The task is complete and the text 'Hello World' is visible in the notepad.")
|
| 120 |
+
```<end_code>
|
| 121 |
+
</example>
|
| 122 |
+
|
| 123 |
+
---
|
| 124 |
+
|
| 125 |
+
<core_principles>
|
| 126 |
+
- Think visually and spatially.
|
| 127 |
+
- Always ground your reasoning in what’s visible in the screenshot.
|
| 128 |
+
- Never assume what’s on the next screen.
|
| 129 |
+
- Always check the result of your last action.
|
| 130 |
+
- Be deliberate, consistent, and patient.
|
| 131 |
+
- **ALWAYS START** by analyzing if the task requires opening an application or URL. If so, your **first action** must be:
|
| 132 |
+
- For websites: `open_url("https://google.com")`
|
| 133 |
+
- For applications: `open("app_name")`
|
| 134 |
+
- **NEVER** manually navigate to apps via clicking icons—use the open tools directly.
|
| 135 |
+
</core_principles>
|
| 136 |
+
""".replace("<<current_date>>", datetime.now().strftime("%A, %d-%B-%Y"))
|
cua2-core/src/cua2_core/services/sandbox_service.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
from e2b_desktop import Sandbox
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
SANDBOX_METADATA: dict[str, dict[str, Any]] = {}
|
| 13 |
+
SANDBOX_TIMEOUT = 300
|
| 14 |
+
WIDTH = 1280
|
| 15 |
+
HEIGHT = 960
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SandboxService:
|
| 19 |
+
def __init__(self, max_sandboxes: int = 50):
|
| 20 |
+
if not os.getenv("E2B_API_KEY"):
|
| 21 |
+
raise ValueError("E2B_API_KEY is not set")
|
| 22 |
+
self.max_sandboxes = max_sandboxes
|
| 23 |
+
self.sandboxes: dict[str, Sandbox] = {}
|
| 24 |
+
self.sandbox_metadata: dict[str, dict[str, Any]] = {}
|
| 25 |
+
self.sandbox_lock = asyncio.Lock()
|
| 26 |
+
|
| 27 |
+
async def acquire_sandbox(self, session_hash: str) -> Sandbox | None:
|
| 28 |
+
async with self.sandbox_lock:
|
| 29 |
+
current_time = datetime.now()
|
| 30 |
+
|
| 31 |
+
if (
|
| 32 |
+
session_hash in self.sandboxes
|
| 33 |
+
and session_hash in self.sandbox_metadata
|
| 34 |
+
and current_time - self.sandbox_metadata[session_hash]["created_at"]
|
| 35 |
+
< SANDBOX_TIMEOUT
|
| 36 |
+
):
|
| 37 |
+
print(f"Reusing Sandbox for session {session_hash}")
|
| 38 |
+
self.sandbox_metadata[session_hash]["last_accessed"] = current_time
|
| 39 |
+
return self.sandboxes[session_hash]
|
| 40 |
+
|
| 41 |
+
if session_hash in self.sandboxes:
|
| 42 |
+
try:
|
| 43 |
+
print(f"Closing expired sandbox for session {session_hash}")
|
| 44 |
+
await asyncio.to_thread(self.sandboxes[session_hash].kill)
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(f"Error closing expired sandbox: {str(e)}")
|
| 47 |
+
elif len(self.sandboxes) >= self.max_sandboxes:
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
print(f"Creating new sandbox for session {session_hash}")
|
| 51 |
+
|
| 52 |
+
def create_and_setup_sandbox():
|
| 53 |
+
desktop = Sandbox.create(
|
| 54 |
+
api_key=os.getenv("E2B_API_KEY"),
|
| 55 |
+
resolution=(WIDTH, HEIGHT),
|
| 56 |
+
dpi=96,
|
| 57 |
+
timeout=SANDBOX_TIMEOUT,
|
| 58 |
+
template="k0wmnzir0zuzye6dndlw",
|
| 59 |
+
)
|
| 60 |
+
desktop.stream.start(require_auth=True)
|
| 61 |
+
setup_cmd = """sudo mkdir -p /usr/lib/firefox-esr/distribution && echo '{"policies":{"OverrideFirstRunPage":"","OverridePostUpdatePage":"","DisableProfileImport":true,"DontCheckDefaultBrowser":true}}' | sudo tee /usr/lib/firefox-esr/distribution/policies.json > /dev/null"""
|
| 62 |
+
desktop.commands.run(setup_cmd)
|
| 63 |
+
time.sleep(3)
|
| 64 |
+
return desktop
|
| 65 |
+
|
| 66 |
+
desktop = await asyncio.to_thread(create_and_setup_sandbox)
|
| 67 |
+
|
| 68 |
+
print(f"Sandbox ID for session {session_hash} is {desktop.sandbox_id}.")
|
| 69 |
+
|
| 70 |
+
self.sandboxes[session_hash] = desktop
|
| 71 |
+
self.sandbox_metadata[session_hash] = {
|
| 72 |
+
"created_at": current_time,
|
| 73 |
+
"last_accessed": current_time,
|
| 74 |
+
}
|
| 75 |
+
return desktop
|
| 76 |
+
|
| 77 |
+
async def release_sandbox(self, session_hash: str):
|
| 78 |
+
async with self.sandbox_lock:
|
| 79 |
+
if session_hash in self.sandboxes:
|
| 80 |
+
print(f"Releasing sandbox for session {session_hash}")
|
| 81 |
+
await asyncio.to_thread(self.sandboxes[session_hash].kill)
|
| 82 |
+
del self.sandboxes[session_hash]
|
| 83 |
+
del self.sandbox_metadata[session_hash]
|
| 84 |
+
|
| 85 |
+
async def cleanup_sandboxes(self):
|
| 86 |
+
async with self.sandbox_lock:
|
| 87 |
+
for session_hash in list(self.sandboxes.keys()):
|
| 88 |
+
await asyncio.to_thread(self.sandboxes[session_hash].kill)
|
| 89 |
+
del self.sandboxes[session_hash]
|
| 90 |
+
del self.sandbox_metadata[session_hash]
|
cua2-core/src/cua2_core/websocket/websocket_manager.py
CHANGED
|
@@ -1,11 +1,29 @@
|
|
| 1 |
import asyncio
|
| 2 |
import json
|
| 3 |
-
from typing import Dict,
|
| 4 |
-
|
| 5 |
-
from cua2_core.models.models import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from fastapi import WebSocket
|
| 7 |
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
class WebSocketManager:
|
| 10 |
"""Manages WebSocket connections and broadcasting"""
|
| 11 |
|
|
@@ -29,90 +47,72 @@ class WebSocketManager:
|
|
| 29 |
f"WebSocket disconnected. Total connections: {len(self.active_connections)}"
|
| 30 |
)
|
| 31 |
|
| 32 |
-
async def
|
| 33 |
-
self, message: WebSocketEvent, websocket: WebSocket
|
| 34 |
-
):
|
| 35 |
"""Send a message to a specific WebSocket connection"""
|
| 36 |
try:
|
| 37 |
-
await websocket.send_text(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
except Exception as e:
|
| 39 |
print(f"Error sending personal message: {e}")
|
| 40 |
# Only disconnect if the connection is still in our set
|
| 41 |
if websocket in self.active_connections:
|
| 42 |
self.disconnect(websocket)
|
|
|
|
| 43 |
|
| 44 |
-
async def
|
| 45 |
-
"""Broadcast a message to all connected WebSockets"""
|
| 46 |
-
if not self.active_connections:
|
| 47 |
-
return
|
| 48 |
-
|
| 49 |
-
# Create a list of connections to remove if they fail
|
| 50 |
-
disconnected = []
|
| 51 |
-
|
| 52 |
-
for connection in self.active_connections.copy():
|
| 53 |
-
try:
|
| 54 |
-
await connection.send_text(json.dumps(message.model_dump(mode="json")))
|
| 55 |
-
except Exception as e:
|
| 56 |
-
print(f"Error broadcasting to connection: {e}")
|
| 57 |
-
disconnected.append(connection)
|
| 58 |
-
|
| 59 |
-
# Remove failed connections
|
| 60 |
-
for connection in disconnected:
|
| 61 |
-
if connection in self.active_connections:
|
| 62 |
-
self.disconnect(connection)
|
| 63 |
-
|
| 64 |
-
async def send_agent_start(self, content: str, message_id: str):
|
| 65 |
"""Send agent start event"""
|
| 66 |
-
event =
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
)
|
| 69 |
-
await self.
|
| 70 |
|
| 71 |
-
async def send_agent_progress(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
"""Send agent progress event"""
|
| 73 |
-
event =
|
| 74 |
-
|
|
|
|
| 75 |
)
|
| 76 |
-
await self.
|
| 77 |
|
| 78 |
async def send_agent_complete(
|
| 79 |
-
self,
|
| 80 |
-
content: str,
|
| 81 |
-
message_id: str,
|
| 82 |
-
metadata: Optional[AgentTraceMetadata] = None,
|
| 83 |
):
|
| 84 |
"""Send agent complete event"""
|
| 85 |
-
event =
|
| 86 |
-
|
| 87 |
-
content=content,
|
| 88 |
-
messageId=message_id,
|
| 89 |
-
metadata=metadata,
|
| 90 |
-
)
|
| 91 |
-
await self.broadcast(event)
|
| 92 |
|
| 93 |
-
async def send_agent_error(self,
|
| 94 |
"""Send agent error event"""
|
| 95 |
-
event =
|
| 96 |
-
|
| 97 |
-
)
|
| 98 |
-
await self.broadcast(event)
|
| 99 |
|
| 100 |
-
async def send_vnc_url_set(self, vnc_url: str,
|
| 101 |
"""Send VNC URL set event"""
|
| 102 |
-
event =
|
| 103 |
-
type="vnc_url_set",
|
| 104 |
-
content=content or f"VNC stream available at: {vnc_url}",
|
| 105 |
vncUrl=vnc_url,
|
| 106 |
)
|
| 107 |
-
await self.
|
| 108 |
|
| 109 |
-
async def send_vnc_url_unset(self,
|
| 110 |
"""Send VNC URL unset event (reset to default display)"""
|
| 111 |
-
event =
|
| 112 |
-
|
| 113 |
-
content=content or "VNC stream disconnected, showing default display",
|
| 114 |
-
)
|
| 115 |
-
await self.broadcast(event)
|
| 116 |
|
| 117 |
def get_connection_count(self) -> int:
|
| 118 |
"""Get the number of active connections"""
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import json
|
| 3 |
+
from typing import Dict, Set
|
| 4 |
+
|
| 5 |
+
from cua2_core.models.models import (
|
| 6 |
+
ActiveTask,
|
| 7 |
+
AgentCompleteEvent,
|
| 8 |
+
AgentErrorEvent,
|
| 9 |
+
AgentProgressEvent,
|
| 10 |
+
AgentStartEvent,
|
| 11 |
+
AgentStep,
|
| 12 |
+
AgentTrace,
|
| 13 |
+
AgentTraceMetadata,
|
| 14 |
+
VncUrlSetEvent,
|
| 15 |
+
VncUrlUnsetEvent,
|
| 16 |
+
WebSocketEvent,
|
| 17 |
+
)
|
| 18 |
from fastapi import WebSocket
|
| 19 |
|
| 20 |
|
| 21 |
+
class WebSocketException(Exception):
|
| 22 |
+
"""Exception for WebSocket errors"""
|
| 23 |
+
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
|
| 27 |
class WebSocketManager:
|
| 28 |
"""Manages WebSocket connections and broadcasting"""
|
| 29 |
|
|
|
|
| 47 |
f"WebSocket disconnected. Total connections: {len(self.active_connections)}"
|
| 48 |
)
|
| 49 |
|
| 50 |
+
async def send_message(self, message: WebSocketEvent, websocket: WebSocket):
|
|
|
|
|
|
|
| 51 |
"""Send a message to a specific WebSocket connection"""
|
| 52 |
try:
|
| 53 |
+
await websocket.send_text(
|
| 54 |
+
json.dumps(
|
| 55 |
+
message.model_dump(mode="json", context={"actions_as_json": False})
|
| 56 |
+
)
|
| 57 |
+
)
|
| 58 |
except Exception as e:
|
| 59 |
print(f"Error sending personal message: {e}")
|
| 60 |
# Only disconnect if the connection is still in our set
|
| 61 |
if websocket in self.active_connections:
|
| 62 |
self.disconnect(websocket)
|
| 63 |
+
raise WebSocketException()
|
| 64 |
|
| 65 |
+
async def send_agent_start(self, active_task: ActiveTask, websocket: WebSocket):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
"""Send agent start event"""
|
| 67 |
+
event = AgentStartEvent(
|
| 68 |
+
agentTrace=AgentTrace(
|
| 69 |
+
id=active_task.message_id,
|
| 70 |
+
timestamp=active_task.timestamp,
|
| 71 |
+
instruction=active_task.instruction,
|
| 72 |
+
modelId=active_task.model_id,
|
| 73 |
+
steps=active_task.steps,
|
| 74 |
+
traceMetadata=active_task.traceMetadata,
|
| 75 |
+
isRunning=True,
|
| 76 |
+
),
|
| 77 |
)
|
| 78 |
+
await self.send_message(event, websocket)
|
| 79 |
|
| 80 |
+
async def send_agent_progress(
|
| 81 |
+
self,
|
| 82 |
+
step: AgentStep,
|
| 83 |
+
metadata: AgentTraceMetadata,
|
| 84 |
+
websocket: WebSocket,
|
| 85 |
+
):
|
| 86 |
"""Send agent progress event"""
|
| 87 |
+
event = AgentProgressEvent(
|
| 88 |
+
agentStep=step,
|
| 89 |
+
traceMetadata=metadata,
|
| 90 |
)
|
| 91 |
+
await self.send_message(event, websocket)
|
| 92 |
|
| 93 |
async def send_agent_complete(
|
| 94 |
+
self, metadata: AgentTraceMetadata, websocket: WebSocket
|
|
|
|
|
|
|
|
|
|
| 95 |
):
|
| 96 |
"""Send agent complete event"""
|
| 97 |
+
event = AgentCompleteEvent(traceMetadata=metadata)
|
| 98 |
+
await self.send_message(event, websocket)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
+
async def send_agent_error(self, error: str, websocket: WebSocket):
|
| 101 |
"""Send agent error event"""
|
| 102 |
+
event = AgentErrorEvent(error=error)
|
| 103 |
+
await self.send_message(event, websocket)
|
|
|
|
|
|
|
| 104 |
|
| 105 |
+
async def send_vnc_url_set(self, vnc_url: str, websocket: WebSocket):
|
| 106 |
"""Send VNC URL set event"""
|
| 107 |
+
event = VncUrlSetEvent(
|
|
|
|
|
|
|
| 108 |
vncUrl=vnc_url,
|
| 109 |
)
|
| 110 |
+
await self.send_message(event, websocket)
|
| 111 |
|
| 112 |
+
async def send_vnc_url_unset(self, websocket: WebSocket):
|
| 113 |
"""Send VNC URL unset event (reset to default display)"""
|
| 114 |
+
event = VncUrlUnsetEvent()
|
| 115 |
+
await self.send_message(event, websocket)
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
def get_connection_count(self) -> int:
|
| 118 |
"""Get the number of active connections"""
|
cua2-core/tests/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Tests for cua2-core"""
|
cua2-core/tests/test_routes.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from unittest.mock import Mock
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from cua2_core.models.models import AvailableModelsResponse, UpdateStepResponse
|
| 5 |
+
from cua2_core.routes.routes import router
|
| 6 |
+
from cua2_core.services.agent_service import AgentService
|
| 7 |
+
from cua2_core.services.agent_utils.get_model import AVAILABLE_MODELS
|
| 8 |
+
from fastapi import FastAPI
|
| 9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
+
from fastapi.testclient import TestClient
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@pytest.fixture
|
| 14 |
+
def mock_agent_service():
|
| 15 |
+
"""Fixture to create a mocked AgentService"""
|
| 16 |
+
service = Mock(spec=AgentService)
|
| 17 |
+
service.active_tasks = {}
|
| 18 |
+
service.update_trace_step = Mock()
|
| 19 |
+
return service
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@pytest.fixture
|
| 23 |
+
def mock_websocket_manager():
|
| 24 |
+
"""Fixture to create a mocked WebSocketManager"""
|
| 25 |
+
manager = Mock()
|
| 26 |
+
manager.get_connection_count = Mock(return_value=0)
|
| 27 |
+
return manager
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@pytest.fixture
|
| 31 |
+
def app(mock_agent_service, mock_websocket_manager):
|
| 32 |
+
"""Fixture to create FastAPI app with mocked services"""
|
| 33 |
+
# Create a test FastAPI app
|
| 34 |
+
test_app = FastAPI(title="Test App")
|
| 35 |
+
|
| 36 |
+
# Add CORS middleware
|
| 37 |
+
test_app.add_middleware(
|
| 38 |
+
CORSMiddleware,
|
| 39 |
+
allow_origins=["*"],
|
| 40 |
+
allow_credentials=True,
|
| 41 |
+
allow_methods=["*"],
|
| 42 |
+
allow_headers=["*"],
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Include the router
|
| 46 |
+
test_app.include_router(router)
|
| 47 |
+
|
| 48 |
+
# Mock the services in app state
|
| 49 |
+
test_app.state.agent_service = mock_agent_service
|
| 50 |
+
test_app.state.websocket_manager = mock_websocket_manager
|
| 51 |
+
|
| 52 |
+
return test_app
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@pytest.fixture
|
| 56 |
+
def client(app):
|
| 57 |
+
"""Fixture to create test client"""
|
| 58 |
+
return TestClient(app)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class TestGetAvailableModels:
|
| 62 |
+
"""Test suite for GET /models endpoint"""
|
| 63 |
+
|
| 64 |
+
def test_get_available_models_success(self, client):
|
| 65 |
+
"""Test successful retrieval of available models"""
|
| 66 |
+
response = client.get("/models")
|
| 67 |
+
|
| 68 |
+
assert response.status_code == 200
|
| 69 |
+
data = response.json()
|
| 70 |
+
|
| 71 |
+
assert "models" in data
|
| 72 |
+
assert isinstance(data["models"], list)
|
| 73 |
+
assert len(data["models"]) > 0
|
| 74 |
+
|
| 75 |
+
# Verify models match AVAILABLE_MODELS
|
| 76 |
+
assert data["models"] == AVAILABLE_MODELS
|
| 77 |
+
|
| 78 |
+
def test_get_available_models_structure(self, client):
|
| 79 |
+
"""Test that response matches AvailableModelsResponse schema"""
|
| 80 |
+
response = client.get("/models")
|
| 81 |
+
|
| 82 |
+
assert response.status_code == 200
|
| 83 |
+
data = response.json()
|
| 84 |
+
|
| 85 |
+
# Validate against Pydantic model
|
| 86 |
+
models_response = AvailableModelsResponse(**data)
|
| 87 |
+
assert models_response.models == AVAILABLE_MODELS
|
| 88 |
+
|
| 89 |
+
def test_get_available_models_content(self, client):
|
| 90 |
+
"""Test that specific expected models are included"""
|
| 91 |
+
response = client.get("/models")
|
| 92 |
+
|
| 93 |
+
assert response.status_code == 200
|
| 94 |
+
data = response.json()
|
| 95 |
+
|
| 96 |
+
# Check for some specific models
|
| 97 |
+
expected_models = [
|
| 98 |
+
"Qwen/Qwen3-VL-2B-Instruct",
|
| 99 |
+
"Qwen/Qwen3-VL-30B-A3B-Instruct",
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
for model in expected_models:
|
| 103 |
+
assert model in data["models"]
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class TestUpdateTraceStep:
|
| 107 |
+
"""Test suite for PATCH /traces/{trace_id}/steps/{step_id} endpoint"""
|
| 108 |
+
|
| 109 |
+
def test_update_trace_step_success(self, client, mock_agent_service):
|
| 110 |
+
"""Test successful step update"""
|
| 111 |
+
trace_id = "test-trace-123"
|
| 112 |
+
step_id = "1"
|
| 113 |
+
request_data = {"step_evaluation": "like"}
|
| 114 |
+
|
| 115 |
+
# Mock the service method to succeed
|
| 116 |
+
mock_agent_service.update_trace_step.return_value = None
|
| 117 |
+
|
| 118 |
+
response = client.patch(
|
| 119 |
+
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
assert response.status_code == 200
|
| 123 |
+
data = response.json()
|
| 124 |
+
|
| 125 |
+
assert data["success"] is True
|
| 126 |
+
assert data["message"] == "Step updated successfully"
|
| 127 |
+
|
| 128 |
+
# Verify the service was called correctly
|
| 129 |
+
mock_agent_service.update_trace_step.assert_called_once_with(
|
| 130 |
+
trace_id=trace_id, step_id=step_id, step_evaluation="like"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def test_update_trace_step_with_dislike(self, client, mock_agent_service):
|
| 134 |
+
"""Test step update with 'dislike' evaluation"""
|
| 135 |
+
trace_id = "test-trace-456"
|
| 136 |
+
step_id = "2"
|
| 137 |
+
request_data = {"step_evaluation": "dislike"}
|
| 138 |
+
|
| 139 |
+
mock_agent_service.update_trace_step.return_value = None
|
| 140 |
+
|
| 141 |
+
response = client.patch(
|
| 142 |
+
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
assert response.status_code == 200
|
| 146 |
+
|
| 147 |
+
mock_agent_service.update_trace_step.assert_called_once_with(
|
| 148 |
+
trace_id=trace_id, step_id=step_id, step_evaluation="dislike"
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def test_update_trace_step_with_neutral(self, client, mock_agent_service):
|
| 152 |
+
"""Test step update with 'neutral' evaluation"""
|
| 153 |
+
trace_id = "test-trace-789"
|
| 154 |
+
step_id = "3"
|
| 155 |
+
request_data = {"step_evaluation": "neutral"}
|
| 156 |
+
|
| 157 |
+
mock_agent_service.update_trace_step.return_value = None
|
| 158 |
+
|
| 159 |
+
response = client.patch(
|
| 160 |
+
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
assert response.status_code == 200
|
| 164 |
+
|
| 165 |
+
mock_agent_service.update_trace_step.assert_called_once_with(
|
| 166 |
+
trace_id=trace_id, step_id=step_id, step_evaluation="neutral"
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
def test_update_trace_step_invalid_evaluation(self, client, mock_agent_service):
|
| 170 |
+
"""Test step update with invalid evaluation value"""
|
| 171 |
+
trace_id = "test-trace-123"
|
| 172 |
+
step_id = "1"
|
| 173 |
+
request_data = {"step_evaluation": "invalid"}
|
| 174 |
+
|
| 175 |
+
response = client.patch(
|
| 176 |
+
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# Should fail validation
|
| 180 |
+
assert response.status_code == 422
|
| 181 |
+
|
| 182 |
+
def test_update_trace_step_value_error(self, client, mock_agent_service):
|
| 183 |
+
"""Test step update when service raises ValueError"""
|
| 184 |
+
trace_id = "test-trace-123"
|
| 185 |
+
step_id = "invalid"
|
| 186 |
+
request_data = {"step_evaluation": "like"}
|
| 187 |
+
|
| 188 |
+
# Mock the service to raise ValueError
|
| 189 |
+
mock_agent_service.update_trace_step.side_effect = ValueError(
|
| 190 |
+
"Invalid step_id format"
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
response = client.patch(
|
| 194 |
+
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
assert response.status_code == 400
|
| 198 |
+
assert "Invalid step_id format" in response.json()["detail"]
|
| 199 |
+
|
| 200 |
+
def test_update_trace_step_not_found(self, client, mock_agent_service):
|
| 201 |
+
"""Test step update when trace is not found"""
|
| 202 |
+
trace_id = "nonexistent-trace"
|
| 203 |
+
step_id = "1"
|
| 204 |
+
request_data = {"step_evaluation": "like"}
|
| 205 |
+
|
| 206 |
+
# Mock the service to raise FileNotFoundError
|
| 207 |
+
mock_agent_service.update_trace_step.side_effect = FileNotFoundError(
|
| 208 |
+
"Trace not found"
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
response = client.patch(
|
| 212 |
+
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
assert response.status_code == 404
|
| 216 |
+
assert "Trace not found" in response.json()["detail"]
|
| 217 |
+
|
| 218 |
+
def test_update_trace_step_step_not_found(self, client, mock_agent_service):
|
| 219 |
+
"""Test step update when step doesn't exist in trace"""
|
| 220 |
+
trace_id = "test-trace-123"
|
| 221 |
+
step_id = "999"
|
| 222 |
+
request_data = {"step_evaluation": "like"}
|
| 223 |
+
|
| 224 |
+
# Mock the service to raise ValueError for step not found
|
| 225 |
+
mock_agent_service.update_trace_step.side_effect = ValueError(
|
| 226 |
+
"Step 999 not found in trace"
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
response = client.patch(
|
| 230 |
+
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
assert response.status_code == 400
|
| 234 |
+
assert "Step 999 not found in trace" in response.json()["detail"]
|
| 235 |
+
|
| 236 |
+
def test_update_trace_step_missing_request_body(self, client, mock_agent_service):
|
| 237 |
+
"""Test step update with missing request body"""
|
| 238 |
+
trace_id = "test-trace-123"
|
| 239 |
+
step_id = "1"
|
| 240 |
+
|
| 241 |
+
response = client.patch(f"/traces/{trace_id}/steps/{step_id}", json={})
|
| 242 |
+
|
| 243 |
+
# Should fail validation
|
| 244 |
+
assert response.status_code == 422
|
| 245 |
+
|
| 246 |
+
def test_update_trace_step_with_special_characters(
|
| 247 |
+
self, client, mock_agent_service
|
| 248 |
+
):
|
| 249 |
+
"""Test step update with trace_id containing special characters"""
|
| 250 |
+
trace_id = "trace-01K960P4FA2BVC058EZDXQEB5E-Qwen-Qwen3-VL-30B-A3B-Instruct"
|
| 251 |
+
step_id = "1"
|
| 252 |
+
request_data = {"step_evaluation": "like"}
|
| 253 |
+
|
| 254 |
+
mock_agent_service.update_trace_step.return_value = None
|
| 255 |
+
|
| 256 |
+
response = client.patch(
|
| 257 |
+
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
assert response.status_code == 200
|
| 261 |
+
|
| 262 |
+
mock_agent_service.update_trace_step.assert_called_once_with(
|
| 263 |
+
trace_id=trace_id, step_id=step_id, step_evaluation="like"
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
def test_update_trace_step_response_structure(self, client, mock_agent_service):
|
| 267 |
+
"""Test that response matches UpdateStepResponse schema"""
|
| 268 |
+
trace_id = "test-trace-123"
|
| 269 |
+
step_id = "1"
|
| 270 |
+
request_data = {"step_evaluation": "like"}
|
| 271 |
+
|
| 272 |
+
mock_agent_service.update_trace_step.return_value = None
|
| 273 |
+
|
| 274 |
+
response = client.patch(
|
| 275 |
+
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
assert response.status_code == 200
|
| 279 |
+
data = response.json()
|
| 280 |
+
|
| 281 |
+
# Validate against Pydantic model
|
| 282 |
+
update_response = UpdateStepResponse(**data)
|
| 283 |
+
assert update_response.success is True
|
| 284 |
+
assert update_response.message == "Step updated successfully"
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class TestRoutesIntegration:
|
| 288 |
+
"""Integration tests for multiple routes"""
|
| 289 |
+
|
| 290 |
+
def test_models_endpoint_available(self, client):
|
| 291 |
+
"""Test that models endpoint is available"""
|
| 292 |
+
response = client.get("/models")
|
| 293 |
+
assert response.status_code == 200
|
| 294 |
+
|
| 295 |
+
def test_update_step_endpoint_available(self, client, mock_agent_service):
|
| 296 |
+
"""Test that update step endpoint is available"""
|
| 297 |
+
mock_agent_service.update_trace_step.return_value = None
|
| 298 |
+
|
| 299 |
+
response = client.patch(
|
| 300 |
+
"/traces/test/steps/1", json={"step_evaluation": "like"}
|
| 301 |
+
)
|
| 302 |
+
assert response.status_code == 200
|
| 303 |
+
|
| 304 |
+
def test_invalid_route(self, client):
|
| 305 |
+
"""Test accessing an invalid route"""
|
| 306 |
+
response = client.get("/invalid-route")
|
| 307 |
+
assert response.status_code == 404
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
if __name__ == "__main__":
|
| 311 |
+
pytest.main([__file__, "-v"])
|
cua2-front/src/components/mock/TaskButton.tsx
CHANGED
|
@@ -12,8 +12,8 @@ export const TaskButton: React.FC<TaskButtonProps> = ({ isAgentProcessing, isCon
|
|
| 12 |
onClick={() => {
|
| 13 |
if (!isAgentProcessing && isConnected) {
|
| 14 |
onSendTask(
|
| 15 |
-
"
|
| 16 |
-
"
|
| 17 |
);
|
| 18 |
}
|
| 19 |
}}
|
|
@@ -56,7 +56,7 @@ export const TaskButton: React.FC<TaskButtonProps> = ({ isAgentProcessing, isCon
|
|
| 56 |
)}
|
| 57 |
</div>
|
| 58 |
<p style={{ fontSize: '15px', fontWeight: 500, color: '#1f2937' }}>
|
| 59 |
-
|
| 60 |
</p>
|
| 61 |
</div>
|
| 62 |
<div style={{
|
|
@@ -67,7 +67,7 @@ export const TaskButton: React.FC<TaskButtonProps> = ({ isAgentProcessing, isCon
|
|
| 67 |
}}>
|
| 68 |
<span style={{ fontSize: '11px', fontWeight: 600, color: 'rgba(0, 0, 0, 0.6)', textTransform: 'uppercase', letterSpacing: '1px' }}>Model</span>
|
| 69 |
<p style={{ fontSize: '12px', fontWeight: 600, color: '#1f2937', marginTop: '2px', whiteSpace: 'nowrap' }}>
|
| 70 |
-
|
| 71 |
</p>
|
| 72 |
</div>
|
| 73 |
</div>
|
|
|
|
| 12 |
onClick={() => {
|
| 13 |
if (!isAgentProcessing && isConnected) {
|
| 14 |
onSendTask(
|
| 15 |
+
"Find the price of a NVIDIA RTX 4090 GPU",
|
| 16 |
+
"Qwen/Qwen3-VL-30B-A3B-Instruct"
|
| 17 |
);
|
| 18 |
}
|
| 19 |
}}
|
|
|
|
| 56 |
)}
|
| 57 |
</div>
|
| 58 |
<p style={{ fontSize: '15px', fontWeight: 500, color: '#1f2937' }}>
|
| 59 |
+
Find the price of a NVIDIA RTX 4090 GPU
|
| 60 |
</p>
|
| 61 |
</div>
|
| 62 |
<div style={{
|
|
|
|
| 67 |
}}>
|
| 68 |
<span style={{ fontSize: '11px', fontWeight: 600, color: 'rgba(0, 0, 0, 0.6)', textTransform: 'uppercase', letterSpacing: '1px' }}>Model</span>
|
| 69 |
<p style={{ fontSize: '12px', fontWeight: 600, color: '#1f2937', marginTop: '2px', whiteSpace: 'nowrap' }}>
|
| 70 |
+
Qwen/Qwen3-VL-30B-A3B-Instruct
|
| 71 |
</p>
|
| 72 |
</div>
|
| 73 |
</div>
|
cua2-front/src/pages/Index.tsx
CHANGED
|
@@ -1,16 +1,14 @@
|
|
| 1 |
-
import
|
| 2 |
import { useWebSocket } from '@/hooks/useWebSocket';
|
| 3 |
-
import { WebSocketEvent } from '@/types/agent';
|
| 4 |
import { useState } from 'react';
|
| 5 |
-
import { AgentTrace, AgentStep } from '@/types/agent';
|
| 6 |
import { ulid } from 'ulid';
|
| 7 |
-
import { Header, VNCStream, Metadata, StackSteps } from '@/components/mock';
|
| 8 |
|
| 9 |
const Index = () => {
|
| 10 |
const [trace, setTrace] = useState<AgentTrace>();
|
| 11 |
const [isAgentProcessing, setIsAgentProcessing] = useState(false);
|
| 12 |
const [vncUrl, setVncUrl] = useState<string>('');
|
| 13 |
-
const [selectedModelId, setSelectedModelId] = useState<string>("
|
| 14 |
|
| 15 |
// #################### WebSocket Connection ########################
|
| 16 |
|
|
@@ -51,12 +49,12 @@ const Index = () => {
|
|
| 51 |
setIsAgentProcessing(false);
|
| 52 |
setTrace(trace => {
|
| 53 |
return trace.id === event.traceMetadata.traceId
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
});
|
| 61 |
console.log('Agent complete received:', event.traceMetadata);
|
| 62 |
break;
|
|
|
|
| 1 |
+
import { Header, Metadata, StackSteps, VNCStream } from '@/components/mock';
|
| 2 |
import { useWebSocket } from '@/hooks/useWebSocket';
|
| 3 |
+
import { AgentStep, AgentTrace, WebSocketEvent } from '@/types/agent';
|
| 4 |
import { useState } from 'react';
|
|
|
|
| 5 |
import { ulid } from 'ulid';
|
|
|
|
| 6 |
|
| 7 |
const Index = () => {
|
| 8 |
const [trace, setTrace] = useState<AgentTrace>();
|
| 9 |
const [isAgentProcessing, setIsAgentProcessing] = useState(false);
|
| 10 |
const [vncUrl, setVncUrl] = useState<string>('');
|
| 11 |
+
const [selectedModelId, setSelectedModelId] = useState<string>("Qwen/Qwen3-VL-30B-A3B-Instruct");
|
| 12 |
|
| 13 |
// #################### WebSocket Connection ########################
|
| 14 |
|
|
|
|
| 49 |
setIsAgentProcessing(false);
|
| 50 |
setTrace(trace => {
|
| 51 |
return trace.id === event.traceMetadata.traceId
|
| 52 |
+
? {
|
| 53 |
+
...trace,
|
| 54 |
+
isRunning: false,
|
| 55 |
+
metadata: event.traceMetadata,
|
| 56 |
+
}
|
| 57 |
+
: trace;
|
| 58 |
});
|
| 59 |
console.log('Agent complete received:', event.traceMetadata);
|
| 60 |
break;
|
cua2-front/src/types/agent.ts
CHANGED
|
@@ -82,3 +82,18 @@ export interface UserTaskMessage {
|
|
| 82 |
type: 'user_task';
|
| 83 |
trace: AgentTrace;
|
| 84 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
type: 'user_task';
|
| 83 |
trace: AgentTrace;
|
| 84 |
}
|
| 85 |
+
|
| 86 |
+
// #################### API Routes Types ########################
|
| 87 |
+
|
| 88 |
+
export interface AvailableModelsResponse {
|
| 89 |
+
models: string[];
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
export interface UpdateStepRequest {
|
| 93 |
+
step_evaluation: 'like' | 'dislike' | 'neutral';
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
export interface UpdateStepResponse {
|
| 97 |
+
success: boolean;
|
| 98 |
+
message: string;
|
| 99 |
+
}
|