File size: 2,447 Bytes
f37a598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import pytest
from fastapi.testclient import TestClient

os.environ["DEMO_MODE"] = "1"
os.environ["RATE_LIMIT_REQUESTS"] = "100"  # high limit for tests

from api.ask import app  # import after setting env var


@pytest.fixture
def client():
    return TestClient(app)


def test_api_demo_mode_basic(client):
    """Test basic demo mode response."""
    payload = {"prompt": "Explain gravity in simple terms"}
    resp = client.post("/", json=payload)
    assert resp.status_code == 200
    data = resp.json()
    assert isinstance(data, dict)
    assert "result" in data
    assert data["source"] == "demo"
    assert "Demo" in data["result"] or "explain" in data["result"].lower()


def test_api_demo_mode_code_prompt(client):
    """Test demo mode with code-related prompt."""
    payload = {"prompt": "How to implement quicksort"}
    resp = client.post("/", json=payload)
    assert resp.status_code == 200
    data = resp.json()
    assert "result" in data
    assert "steps" in data["result"].lower() or "implement" in data["result"].lower()


def test_api_session_id_returned(client):
    """Test that session ID is returned."""
    payload = {"prompt": "Test prompt"}
    resp = client.post("/", json=payload)
    assert resp.status_code == 200
    data = resp.json()
    assert "session_id" in data
    assert len(data["session_id"]) > 0


def test_api_session_id_persistence(client):
    """Test that provided session ID is returned."""
    session_id = "test-session-123"
    payload = {"prompt": "Test prompt", "session_id": session_id}
    resp = client.post("/", json=payload)
    assert resp.status_code == 200
    data = resp.json()
    assert data["session_id"] == session_id


def test_api_empty_prompt(client):
    """Test API with empty prompt."""
    payload = {"prompt": ""}
    resp = client.post("/", json=payload)
    assert resp.status_code == 200
    data = resp.json()
    assert "result" in data
    assert "Please enter" in data["result"]


def test_api_history_endpoint(client):
    """Test history retrieval endpoint."""
    # First make a request
    session_id = "test-history-session"
    payload = {"prompt": "Test question", "session_id": session_id}
    client.post("/", json=payload)
    
    # Then retrieve history
    resp = client.get(f"/history/{session_id}")
    assert resp.status_code == 200
    data = resp.json()
    assert "history" in data
    assert isinstance(data["history"], list)