File size: 5,469 Bytes
95efa57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#!/usr/bin/env python3
"""

Test client for the ChatGPT Oasis Model Inference API

"""

import requests
import base64
import json
from PIL import Image
import io
import os

# API base URL
BASE_URL = "http://localhost:8000"

def test_health_check():
    """Test the health check endpoint"""
    print("Testing health check...")
    try:
        response = requests.get(f"{BASE_URL}/health")
        print(f"Status: {response.status_code}")
        print(f"Response: {json.dumps(response.json(), indent=2)}")
        return response.status_code == 200
    except Exception as e:
        print(f"Error: {e}")
        return False

def test_list_models():
    """Test the models list endpoint"""
    print("\nTesting models list...")
    try:
        response = requests.get(f"{BASE_URL}/models")
        print(f"Status: {response.status_code}")
        print(f"Response: {json.dumps(response.json(), indent=2)}")
        return response.status_code == 200
    except Exception as e:
        print(f"Error: {e}")
        return False

def create_test_image():
    """Create a simple test image"""
    # Create a simple colored rectangle
    img = Image.new('RGB', (224, 224), color='red')
    
    # Save to bytes
    buffer = io.BytesIO()
    img.save(buffer, format='JPEG')
    buffer.seek(0)
    
    return buffer.getvalue()

def test_base64_inference():
    """Test inference with base64 encoded image"""
    print("\nTesting base64 inference...")
    
    # Create test image
    image_data = create_test_image()
    image_base64 = base64.b64encode(image_data).decode()
    
    # Test both models
    for model_name in ["oasis500m", "vit-l-20"]:
        print(f"\nTesting {model_name}...")
        try:
            response = requests.post(
                f"{BASE_URL}/inference",
                json={
                    "image": image_base64,
                    "model_name": model_name
                },
                headers={"Content-Type": "application/json"}
            )
            print(f"Status: {response.status_code}")
            if response.status_code == 200:
                result = response.json()
                print(f"Model used: {result['model_used']}")
                print(f"Top prediction: {result['predictions'][0]}")
            else:
                print(f"Error: {response.text}")
        except Exception as e:
            print(f"Error: {e}")

def test_file_upload_inference():
    """Test inference with file upload"""
    print("\nTesting file upload inference...")
    
    # Create test image
    image_data = create_test_image()
    
    # Test both models
    for model_name in ["oasis500m", "vit-l-20"]:
        print(f"\nTesting {model_name} with file upload...")
        try:
            files = {'file': ('test_image.jpg', image_data, 'image/jpeg')}
            data = {'model_name': model_name}
            
            response = requests.post(
                f"{BASE_URL}/upload_inference",
                files=files,
                data=data
            )
            print(f"Status: {response.status_code}")
            if response.status_code == 200:
                result = response.json()
                print(f"Model used: {result['model_used']}")
                print(f"Top prediction: {result['predictions'][0]}")
            else:
                print(f"Error: {response.text}")
        except Exception as e:
            print(f"Error: {e}")

def test_with_real_image(image_path):
    """Test with a real image file"""
    if not os.path.exists(image_path):
        print(f"Image file not found: {image_path}")
        return
    
    print(f"\nTesting with real image: {image_path}")
    
    # Test file upload
    try:
        with open(image_path, 'rb') as f:
            files = {'file': (os.path.basename(image_path), f, 'image/jpeg')}
            data = {'model_name': 'oasis500m'}
            
            response = requests.post(
                f"{BASE_URL}/upload_inference",
                files=files,
                data=data
            )
            print(f"Status: {response.status_code}")
            if response.status_code == 200:
                result = response.json()
                print(f"Model used: {result['model_used']}")
                print("Top 3 predictions:")
                for i, pred in enumerate(result['predictions'][:3]):
                    print(f"  {i+1}. {pred['label']} ({pred['confidence']:.3f})")
            else:
                print(f"Error: {response.text}")
    except Exception as e:
        print(f"Error: {e}")

def main():
    """Run all tests"""
    print("ChatGPT Oasis Model Inference API - Test Client")
    print("=" * 50)
    
    # Test basic endpoints
    health_ok = test_health_check()
    models_ok = test_list_models()
    
    if not health_ok:
        print("Health check failed. Make sure the server is running!")
        return
    
    # Test inference endpoints
    test_base64_inference()
    test_file_upload_inference()
    
    # Test with real image if available
    test_images = ["test.jpg", "sample.jpg", "image.jpg"]
    for img in test_images:
        if os.path.exists(img):
            test_with_real_image(img)
            break
    
    print("\n" + "=" * 50)
    print("Test completed!")

if __name__ == "__main__":
    main()