File size: 6,719 Bytes
0ccf2f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""
Mock tests for OpenAI Embedding Provider.

These tests mock OpenAI API calls to ensure coverage without requiring network access.
"""

from unittest.mock import MagicMock, patch
import sys
from pathlib import Path
import pytest
from warbler_cda.embeddings.openai_provider import OpenAIEmbeddingProvider

sys.path.insert(0, str(Path(__file__).parent.parent))


class TestOpenAIEmbeddingProvider:
    """Mock tests for OpenAI embedding provider."""

    def setup_method(self):
        """Setup for each test."""
        self.provider = OpenAIEmbeddingProvider({ # pylint: disable=W0201
            "api_key": "test_key",
            "model": "text-embedding-ada-002",
            "dimension": 1536
        })

    def test_init(self):
        """Test provider initialization."""
        config = {
            "api_key": "test_api_key",
            "model": "text-embedding-3-small",
            "dimension": 1536
        }
        provider = OpenAIEmbeddingProvider(config)

        assert provider.api_key == "test_api_key"
        assert provider.model == "text-embedding-3-small"
        assert provider.dimension == 1536
        assert provider._client is None # pylint: disable=W0212

    def test_init_default_values(self):
        """Test initialization with default values."""
        provider = OpenAIEmbeddingProvider()

        assert provider.api_key is None
        assert provider.model == "text-embedding-ada-002"
        assert provider.dimension == 1536
        assert provider._client is None # pylint: disable=W0212

    @patch("builtins.__import__")
    def test_get_client_lazy_initialization(self, mock_import):
        """Test lazy initialization of OpenAI client."""
        mock_openai = MagicMock()
        mock_import.return_value = mock_openai

        client = self.provider._get_client() # pylint: disable=W0212

        assert client == mock_openai
        assert self.provider._client == mock_openai # pylint: disable=W0212
        # Test that API key is set
        assert mock_openai.api_key == "test_key"

        # Test that second call returns cached client (client stays the same)
        client2 = self.provider._get_client() # pylint: disable=W0212
        assert client2 == mock_openai

    @patch("builtins.__import__")
    def test_get_client_import_error(self, mock_import):
        """Test ImportError when openai is not installed."""
        mock_import.side_effect = ImportError("OpenAI package not installed")

        with pytest.raises(ImportError, match="OpenAI package not installed"):
            self.provider._get_client() # pylint: disable=W0212

    @patch("builtins.__import__")
    def test_embed_text_success(self, mock_import):
        """Test successful text embedding."""
        mock_openai = MagicMock()
        mock_openai.Embedding.create.return_value = {
            "data": [{"embedding": [0.1, 0.2, 0.3]}]
        }
        mock_import.return_value = mock_openai

        result = self.provider.embed_text("test text")

        assert result == [0.1, 0.2, 0.3]
        mock_openai.Embedding.create.assert_called_once_with(
            model="text-embedding-ada-002",
            input="test text"
        )

    @patch("builtins.__import__")
    def test_embed_text_fallback_on_failure(self, mock_import):
        """Test fallback to mock embedding when API fails."""
        mock_openai = MagicMock()
        mock_openai.Embedding.create.side_effect = Exception("API error")
        mock_import.return_value = mock_openai

        result = self.provider.embed_text("test text")

        assert isinstance(result, list)
        assert len(result) == 1536
        assert all(isinstance(x, float) for x in result)

    @patch("builtins.__import__")
    def test_embed_batch_success(self, mock_import):
        """Test successful batch embedding."""
        mock_openai = MagicMock()
        mock_openai.Embedding.create.return_value = {
            "data": [
                {"embedding": [0.1, 0.2, 0.3]},
                {"embedding": [0.4, 0.5, 0.6]},
                {"embedding": [0.7, 0.8, 0.9]}
            ]
        }
        mock_import.return_value = mock_openai

        texts = ["text1", "text2", "text3"]
        result = self.provider.embed_batch(texts)

        assert result == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
        mock_openai.Embedding.create.assert_called_once_with(
            model="text-embedding-ada-002",
            input=texts
        )

    @patch("builtins.__import__")
    def test_embed_batch_fallback_on_failure(self, mock_import):
        """Test fallback to mock embeddings when batch API fails."""
        mock_openai = MagicMock()
        mock_openai.Embedding.create.side_effect = Exception("API error")
        mock_import.return_value = mock_openai

        texts = ["text1", "text2", "text3"]
        result = self.provider.embed_batch(texts)

        assert isinstance(result, list)
        assert len(result) == 3
        assert all(len(emb) == 1536 for emb in result)
        assert all(all(isinstance(x, float) for x in emb) for emb in result)

    def test_get_dimension(self):
        """Test getting embedding dimension."""
        assert self.provider.get_dimension() == 1536

    def test_create_mock_embedding(self):
        """Test creation of mock embedding."""
        result = self.provider._create_mock_embedding("test text") # pylint: disable=W0212

        assert isinstance(result, list)
        assert len(result) == 1536
        assert all(isinstance(x, float) for x in result)

        # Test that same text produces same embedding
        result2 = self.provider._create_mock_embedding("test text") # pylint: disable=W0212
        assert result == result2

        # Test that different text produces different embedding
        result3 = self.provider._create_mock_embedding("different text") # pylint: disable=W0212
        assert result != result3

    def test_create_mock_embedding_normalization(self):
        """Test that mock embeddings are normalized."""
        result = self.provider._create_mock_embedding("test") # pylint: disable=W0212

        # Calculate magnitude
        magnitude = sum(x * x for x in result) ** 0.5
        # Should be close to 1.0 (normalized)
        assert abs(magnitude - 1.0) < 0.01

    def test_provider_id(self):
        """Test provider ID."""
        assert self.provider.provider_id == "OpenAIEmbeddingProvider"

    def test_provider_info(self):
        """Test provider info method."""
        info = self.provider.get_provider_info()

        assert info["provider_id"] == "OpenAIEmbeddingProvider"
        assert info["dimension"] == 1536
        assert "created_at" in info
        assert "config_keys" in info