Spaces:
Running
Running
File size: 12,111 Bytes
ccd68a1 e0d4a07 ccd68a1 e0d4a07 ccd68a1 e0d4a07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 |
from unittest.mock import Mock, patch
import pytest
from cua2_core.models.models import AvailableModelsResponse, UpdateStepResponse
from cua2_core.routes.routes import router
from cua2_core.services.agent_service import AgentService
from cua2_core.services.agent_utils.get_model import AVAILABLE_MODELS
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.testclient import TestClient
@pytest.fixture
def mock_agent_service():
"""Fixture to create a mocked AgentService"""
service = Mock(spec=AgentService)
service.active_tasks = {}
service.update_trace_step = Mock()
return service
@pytest.fixture
def mock_websocket_manager():
"""Fixture to create a mocked WebSocketManager"""
manager = Mock()
manager.get_connection_count = Mock(return_value=0)
return manager
@pytest.fixture
def app(mock_agent_service, mock_websocket_manager):
"""Fixture to create FastAPI app with mocked services"""
# Create a test FastAPI app
test_app = FastAPI(title="Test App")
# Add CORS middleware
test_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include the router
test_app.include_router(router)
# Mock the services in app state
test_app.state.agent_service = mock_agent_service
test_app.state.websocket_manager = mock_websocket_manager
return test_app
@pytest.fixture
def client(app):
"""Fixture to create test client"""
return TestClient(app)
class TestGetAvailableModels:
"""Test suite for GET /models endpoint"""
def test_get_available_models_success(self, client):
"""Test successful retrieval of available models"""
response = client.get("/models")
assert response.status_code == 200
data = response.json()
assert "models" in data
assert isinstance(data["models"], list)
assert len(data["models"]) > 0
# Verify models match AVAILABLE_MODELS
assert data["models"] == AVAILABLE_MODELS
def test_get_available_models_structure(self, client):
"""Test that response matches AvailableModelsResponse schema"""
response = client.get("/models")
assert response.status_code == 200
data = response.json()
# Validate against Pydantic model
models_response = AvailableModelsResponse(**data)
assert models_response.models == AVAILABLE_MODELS
def test_get_available_models_content(self, client):
"""Test that specific expected models are included"""
response = client.get("/models")
assert response.status_code == 200
data = response.json()
# Check for some specific models
expected_models = [
"Qwen/Qwen3-VL-8B-Instruct",
"Qwen/Qwen3-VL-30B-A3B-Instruct",
]
for model in expected_models:
assert model in data["models"]
class TestUpdateTraceStep:
"""Test suite for PATCH /traces/{trace_id}/steps/{step_id} endpoint"""
def test_update_trace_step_success(self, client, mock_agent_service):
"""Test successful step update"""
trace_id = "test-trace-123"
step_id = "1"
request_data = {"step_evaluation": "like"}
# Mock the service method to succeed
mock_agent_service.update_trace_step.return_value = None
response = client.patch(
f"/traces/{trace_id}/steps/{step_id}", json=request_data
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["message"] == "Step updated successfully"
# Verify the service was called correctly
mock_agent_service.update_trace_step.assert_called_once_with(
trace_id=trace_id, step_id=step_id, step_evaluation="like"
)
def test_update_trace_step_with_dislike(self, client, mock_agent_service):
"""Test step update with 'dislike' evaluation"""
trace_id = "test-trace-456"
step_id = "2"
request_data = {"step_evaluation": "dislike"}
mock_agent_service.update_trace_step.return_value = None
response = client.patch(
f"/traces/{trace_id}/steps/{step_id}", json=request_data
)
assert response.status_code == 200
mock_agent_service.update_trace_step.assert_called_once_with(
trace_id=trace_id, step_id=step_id, step_evaluation="dislike"
)
def test_update_trace_step_with_neutral(self, client, mock_agent_service):
"""Test step update with 'neutral' evaluation"""
trace_id = "test-trace-789"
step_id = "3"
request_data = {"step_evaluation": "neutral"}
mock_agent_service.update_trace_step.return_value = None
response = client.patch(
f"/traces/{trace_id}/steps/{step_id}", json=request_data
)
assert response.status_code == 200
mock_agent_service.update_trace_step.assert_called_once_with(
trace_id=trace_id, step_id=step_id, step_evaluation="neutral"
)
def test_update_trace_step_invalid_evaluation(self, client, mock_agent_service):
"""Test step update with invalid evaluation value"""
trace_id = "test-trace-123"
step_id = "1"
request_data = {"step_evaluation": "invalid"}
response = client.patch(
f"/traces/{trace_id}/steps/{step_id}", json=request_data
)
# Should fail validation
assert response.status_code == 422
def test_update_trace_step_value_error(self, client, mock_agent_service):
"""Test step update when service raises ValueError"""
trace_id = "test-trace-123"
step_id = "invalid"
request_data = {"step_evaluation": "like"}
# Mock the service to raise ValueError
mock_agent_service.update_trace_step.side_effect = ValueError(
"Invalid step_id format"
)
response = client.patch(
f"/traces/{trace_id}/steps/{step_id}", json=request_data
)
assert response.status_code == 400
assert "Invalid step_id format" in response.json()["detail"]
def test_update_trace_step_not_found(self, client, mock_agent_service):
"""Test step update when trace is not found"""
trace_id = "nonexistent-trace"
step_id = "1"
request_data = {"step_evaluation": "like"}
# Mock the service to raise FileNotFoundError
mock_agent_service.update_trace_step.side_effect = FileNotFoundError(
"Trace not found"
)
response = client.patch(
f"/traces/{trace_id}/steps/{step_id}", json=request_data
)
assert response.status_code == 404
assert "Trace not found" in response.json()["detail"]
def test_update_trace_step_step_not_found(self, client, mock_agent_service):
"""Test step update when step doesn't exist in trace"""
trace_id = "test-trace-123"
step_id = "999"
request_data = {"step_evaluation": "like"}
# Mock the service to raise ValueError for step not found
mock_agent_service.update_trace_step.side_effect = ValueError(
"Step 999 not found in trace"
)
response = client.patch(
f"/traces/{trace_id}/steps/{step_id}", json=request_data
)
assert response.status_code == 400
assert "Step 999 not found in trace" in response.json()["detail"]
def test_update_trace_step_missing_request_body(self, client, mock_agent_service):
"""Test step update with missing request body"""
trace_id = "test-trace-123"
step_id = "1"
response = client.patch(f"/traces/{trace_id}/steps/{step_id}", json={})
# Should fail validation
assert response.status_code == 422
def test_update_trace_step_with_special_characters(
self, client, mock_agent_service
):
"""Test step update with trace_id containing special characters"""
trace_id = "trace-01K960P4FA2BVC058EZDXQEB5E-Qwen-Qwen3-VL-30B-A3B-Instruct"
step_id = "1"
request_data = {"step_evaluation": "like"}
mock_agent_service.update_trace_step.return_value = None
response = client.patch(
f"/traces/{trace_id}/steps/{step_id}", json=request_data
)
assert response.status_code == 200
mock_agent_service.update_trace_step.assert_called_once_with(
trace_id=trace_id, step_id=step_id, step_evaluation="like"
)
def test_update_trace_step_response_structure(self, client, mock_agent_service):
"""Test that response matches UpdateStepResponse schema"""
trace_id = "test-trace-123"
step_id = "1"
request_data = {"step_evaluation": "like"}
mock_agent_service.update_trace_step.return_value = None
response = client.patch(
f"/traces/{trace_id}/steps/{step_id}", json=request_data
)
assert response.status_code == 200
data = response.json()
# Validate against Pydantic model
update_response = UpdateStepResponse(**data)
assert update_response.success is True
assert update_response.message == "Step updated successfully"
class TestGenerateInstruction:
"""Test suite for POST /generate-instruction endpoint"""
@patch("cua2_core.routes.routes.InstructionService.generate_instruction")
def test_generate_instruction_success(self, mock_generate, client):
"""Test successful instruction generation with mocked model"""
# Mock the instruction generation
mock_instruction = "Open Google Chrome and navigate to example.com"
mock_generate.return_value = mock_instruction
request_data = {
"model_id": "Qwen/Qwen3-VL-8B-Instruct",
"prompt": "Generate a web browsing task",
}
response = client.post("/generate-instruction", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["instruction"] == mock_instruction
assert data["model_id"] == request_data["model_id"]
# Verify the service was called correctly
mock_generate.assert_called_once_with(
model_id=request_data["model_id"], prompt=request_data["prompt"]
)
@patch("cua2_core.routes.routes.InstructionService.generate_instruction")
def test_generate_instruction_invalid_model(self, mock_generate, client):
"""Test instruction generation with invalid model_id"""
# Mock the service to raise ValueError for invalid model
mock_generate.side_effect = ValueError(
"Invalid model_id 'invalid-model'. Must be one of: Qwen/Qwen3-VL-2B-Instruct, ..."
)
request_data = {
"model_id": "invalid-model",
"prompt": "Generate a task",
}
response = client.post("/generate-instruction", json=request_data)
assert response.status_code == 400
assert "Invalid model_id" in response.json()["detail"]
class TestRoutesIntegration:
"""Integration tests for multiple routes"""
def test_models_endpoint_available(self, client):
"""Test that models endpoint is available"""
response = client.get("/models")
assert response.status_code == 200
def test_update_step_endpoint_available(self, client, mock_agent_service):
"""Test that update step endpoint is available"""
mock_agent_service.update_trace_step.return_value = None
response = client.patch(
"/traces/test/steps/1", json={"step_evaluation": "like"}
)
assert response.status_code == 200
def test_invalid_route(self, client):
"""Test accessing an invalid route"""
response = client.get("/invalid-route")
assert response.status_code == 404
if __name__ == "__main__":
pytest.main([__file__, "-v"])
|