File size: 6,927 Bytes
016b413
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""Unit tests for HuggingFaceChatClient."""

from unittest.mock import MagicMock, patch

import pytest

from src.utils.exceptions import ConfigurationError
from src.utils.huggingface_chat_client import HuggingFaceChatClient


@pytest.mark.unit
class TestHuggingFaceChatClient:
    """Unit tests for HuggingFaceChatClient."""

    def test_init_with_defaults(self):
        """Test initialization with default parameters."""
        with patch("src.utils.huggingface_chat_client.InferenceClient") as mock_client:
            client = HuggingFaceChatClient()
            assert client.model_name == "meta-llama/Llama-3.1-8B-Instruct"
            assert client.provider == "auto"
            mock_client.assert_called_once_with(
                model="meta-llama/Llama-3.1-8B-Instruct",
                api_key=None,
                provider="auto",
            )

    def test_init_with_custom_params(self):
        """Test initialization with custom parameters."""
        with patch("src.utils.huggingface_chat_client.InferenceClient") as mock_client:
            client = HuggingFaceChatClient(
                model_name="meta-llama/Llama-3.1-70B-Instruct",
                api_key="hf_test_token",
                provider="together",
            )
            assert client.model_name == "meta-llama/Llama-3.1-70B-Instruct"
            assert client.provider == "together"
            mock_client.assert_called_once_with(
                model="meta-llama/Llama-3.1-70B-Instruct",
                api_key="hf_test_token",
                provider="together",
            )

    def test_init_failure(self):
        """Test initialization failure handling."""
        with patch(
            "src.utils.huggingface_chat_client.InferenceClient",
            side_effect=Exception("Connection failed"),
        ):
            with pytest.raises(ConfigurationError, match="Failed to initialize"):
                HuggingFaceChatClient()

    @pytest.mark.asyncio
    async def test_chat_completion_basic(self):
        """Test basic chat completion without tools."""
        mock_response = MagicMock()
        mock_response.choices = [
            MagicMock(
                message=MagicMock(
                    role="assistant",
                    content="Hello! How can I help you?",
                    tool_calls=None,
                ),
            ),
        ]

        with patch("src.utils.huggingface_chat_client.InferenceClient") as mock_client_class:
            mock_client = MagicMock()
            mock_client.chat_completion.return_value = mock_response
            mock_client_class.return_value = mock_client

            client = HuggingFaceChatClient()
            messages = [{"role": "user", "content": "Hello"}]

            # Mock run_in_executor to call the lambda directly
            async def mock_run_in_executor(executor, func, *args):
                return func()

            with patch("asyncio.get_running_loop") as mock_loop:
                mock_loop.return_value.run_in_executor = mock_run_in_executor

                response = await client.chat_completion(messages=messages)

                assert response == mock_response
                mock_client.chat_completion.assert_called_once_with(
                    messages=messages,
                    tools=None,
                    tool_choice=None,
                    temperature=None,
                    max_tokens=None,
                )

    @pytest.mark.asyncio
    async def test_chat_completion_with_tools(self):
        """Test chat completion with function calling tools."""
        mock_tool_call = MagicMock()
        mock_tool_call.function.name = "search_pubmed"
        mock_tool_call.function.arguments = '{"query": "metformin", "max_results": 10}'

        mock_response = MagicMock()
        mock_response.choices = [
            MagicMock(
                message=MagicMock(
                    role="assistant",
                    content=None,
                    tool_calls=[mock_tool_call],
                ),
            ),
        ]

        with patch("src.utils.huggingface_chat_client.InferenceClient") as mock_client_class:
            mock_client = MagicMock()
            mock_client.chat_completion.return_value = mock_response
            mock_client_class.return_value = mock_client

            client = HuggingFaceChatClient()
            messages = [{"role": "user", "content": "Search for metformin"}]
            tools = [
                {
                    "type": "function",
                    "function": {
                        "name": "search_pubmed",
                        "description": "Search PubMed",
                        "parameters": {
                            "type": "object",
                            "properties": {
                                "query": {"type": "string"},
                                "max_results": {"type": "integer"},
                            },
                        },
                    },
                },
            ]

            # Mock run_in_executor to call the lambda directly
            async def mock_run_in_executor(executor, func, *args):
                return func()

            with patch("asyncio.get_running_loop") as mock_loop:
                mock_loop.return_value.run_in_executor = mock_run_in_executor

                response = await client.chat_completion(
                    messages=messages,
                    tools=tools,
                    tool_choice="auto",
                    temperature=0.3,
                    max_tokens=500,
                )

                assert response == mock_response
                mock_client.chat_completion.assert_called_once_with(
                    messages=messages,
                    tools=tools,  # ✅ Native support!
                    tool_choice="auto",
                    temperature=0.3,
                    max_tokens=500,
                )

    @pytest.mark.asyncio
    async def test_chat_completion_error_handling(self):
        """Test error handling in chat completion."""
        with patch("src.utils.huggingface_chat_client.InferenceClient") as mock_client_class:
            mock_client = MagicMock()
            mock_client.chat_completion.side_effect = Exception("API error")
            mock_client_class.return_value = mock_client

            client = HuggingFaceChatClient()
            messages = [{"role": "user", "content": "Hello"}]

            # Mock run_in_executor to propagate the exception
            async def mock_run_in_executor(executor, func, *args):
                return func()

            with patch("asyncio.get_running_loop") as mock_loop:
                mock_loop.return_value.run_in_executor = mock_run_in_executor

                with pytest.raises(ConfigurationError, match="HuggingFace chat completion failed"):
                    await client.chat_completion(messages=messages)