Spaces:
Sleeping
Sleeping
| """ | |
| Unit tests for interface components. | |
| """ | |
| import pytest | |
| from unittest.mock import MagicMock, patch | |
| import sys | |
| from pathlib import Path | |
| # Add project root to path | |
| project_root = Path(__file__).parent.parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| from app.interface import ExpertClassifier, ResponseGenerator, GradioInterface | |
| from app.model_loader import ModelLoader | |
| class TestExpertClassifier: | |
| """Test ExpertClassifier functionality.""" | |
| def test_classify_code_query(self): | |
| """Test classification of coding-related queries.""" | |
| queries = [ | |
| "How do I implement a binary search algorithm in Python?", | |
| "What's the best way to debug JavaScript code?", | |
| "Can you help me with this API integration?", | |
| "I need help with my Java class implementation" | |
| ] | |
| for query in queries: | |
| result = ExpertClassifier.classify_query(query) | |
| # Should be Code due to keywords like "algorithm", "Python", "JavaScript", "API", "Java" | |
| assert result == "Code", f"Query '{query}' should be classified as Code, got {result}" | |
| def test_classify_math_query(self): | |
| """Test classification of math-related queries.""" | |
| queries = [ | |
| "What is the derivative of xΒ² + 3x + 1?", | |
| "How do I solve this calculus problem?", | |
| "Calculate the probability of this event", | |
| "What's the integral of sin(x)?" | |
| ] | |
| for query in queries: | |
| result = ExpertClassifier.classify_query(query) | |
| assert result == "Math" | |
| def test_classify_reasoning_query(self): | |
| """Test classification of reasoning-related queries.""" | |
| queries = [ | |
| "Explain the logical reasoning behind the Monty Hall problem", | |
| "Why does this argument make sense?", | |
| "How should I analyze this situation?", | |
| "What's the critical thinking approach here?" | |
| ] | |
| for query in queries: | |
| result = ExpertClassifier.classify_query(query) | |
| assert result == "Reasoning" | |
| def test_classify_multilingual_query(self): | |
| """Test classification of multilingual queries.""" | |
| queries = [ | |
| "Translate 'Hello, how are you?' to Spanish", | |
| "What does this French phrase mean?", | |
| "Help me learn German vocabulary", | |
| "How do you say 'thank you' in Japanese?" | |
| ] | |
| for query in queries: | |
| result = ExpertClassifier.classify_query(query) | |
| assert result == "Multilingual" | |
| def test_classify_general_query(self): | |
| """Test classification of general queries.""" | |
| queries = [ | |
| "What are the benefits of renewable energy?", | |
| "Tell me about the history of computers", | |
| "Hello, how can you help me?", | |
| "What's the weather like?" | |
| ] | |
| for query in queries: | |
| result = ExpertClassifier.classify_query(query) | |
| assert result == "General" | |
| def test_classify_empty_query(self): | |
| """Test classification of empty query.""" | |
| result = ExpertClassifier.classify_query("") | |
| assert result == "General" | |
| def test_classify_none_query(self): | |
| """Test classification of None query.""" | |
| result = ExpertClassifier.classify_query(None) | |
| assert result == "General" | |
| def test_classify_ambiguous_query(self): | |
| """Test classification of ambiguous query with multiple keywords.""" | |
| # Query with both code and math keywords | |
| query = "How do I calculate the algorithm complexity in Python?" | |
| result = ExpertClassifier.classify_query(query) | |
| # Should pick the one with higher score (Code has 'algorithm' and 'Python') | |
| assert result in ["Code", "Math"] # Either is acceptable for ambiguous cases | |
| class TestResponseGenerator: | |
| """Test ResponseGenerator functionality.""" | |
| def setup_method(self): | |
| """Setup test fixtures.""" | |
| self.mock_model_loader = MagicMock(spec=ModelLoader) | |
| self.response_generator = ResponseGenerator(self.mock_model_loader) | |
| def test_generate_fallback_response_code(self): | |
| """Test fallback response generation for code queries.""" | |
| query = "How do I implement a function?" | |
| expert_type = "Code" | |
| response = self.response_generator.generate_fallback_response(query, expert_type) | |
| assert "Code Expert" in response | |
| assert query in response | |
| assert "code examples" in response | |
| def test_generate_fallback_response_math(self): | |
| """Test fallback response generation for math queries.""" | |
| query = "What is the derivative of xΒ²?" | |
| expert_type = "Math" | |
| response = self.response_generator.generate_fallback_response(query, expert_type) | |
| assert "Math Expert" in response | |
| assert query in response | |
| assert "step-by-step" in response | |
| def test_generate_fallback_response_unknown_expert(self): | |
| """Test fallback response for unknown expert type.""" | |
| query = "Test query" | |
| expert_type = "Unknown" | |
| response = self.response_generator.generate_fallback_response(query, expert_type) | |
| # Should default to General expert response | |
| assert "General Expert" in response | |
| def test_generate_response_model_not_loaded(self): | |
| """Test response generation when model is not loaded.""" | |
| self.mock_model_loader.is_loaded = False | |
| query = "Test query" | |
| response = self.response_generator.generate_response(query) | |
| assert "Expert Type:" in response | |
| assert "model is currently unavailable" in response | |
| def test_generate_response_model_loaded_success(self): | |
| """Test successful response generation when model is loaded.""" | |
| # Setup mock model loader | |
| self.mock_model_loader.is_loaded = True | |
| self.mock_model_loader.pipeline = MagicMock() | |
| self.mock_model_loader.tokenizer = MagicMock() | |
| self.mock_model_loader.tokenizer.eos_token_id = 2 | |
| # Mock pipeline response | |
| mock_response = [{ | |
| 'generated_text': 'System message\nUser: Test query\nAssistant: This is a test response' | |
| }] | |
| self.mock_model_loader.pipeline.return_value = mock_response | |
| query = "Test query" | |
| response = self.response_generator.generate_response(query) | |
| assert "Expert Type:" in response | |
| assert "This is a test response" in response | |
| # Verify pipeline was called with correct parameters | |
| self.mock_model_loader.pipeline.assert_called_once() | |
| call_args = self.mock_model_loader.pipeline.call_args | |
| assert call_args[1]['max_new_tokens'] == 500 | |
| assert call_args[1]['temperature'] == 0.7 | |
| assert call_args[1]['do_sample'] is True | |
| def test_generate_response_model_loaded_custom_params(self): | |
| """Test response generation with custom parameters.""" | |
| self.mock_model_loader.is_loaded = True | |
| self.mock_model_loader.pipeline = MagicMock() | |
| self.mock_model_loader.tokenizer = MagicMock() | |
| self.mock_model_loader.tokenizer.eos_token_id = 2 | |
| mock_response = [{'generated_text': 'Test response'}] | |
| self.mock_model_loader.pipeline.return_value = mock_response | |
| query = "Test query" | |
| response = self.response_generator.generate_response( | |
| query, | |
| max_tokens=200, | |
| temperature=0.5 | |
| ) | |
| # Verify custom parameters were used | |
| call_args = self.mock_model_loader.pipeline.call_args | |
| assert call_args[1]['max_new_tokens'] == 200 | |
| assert call_args[1]['temperature'] == 0.5 | |
| def test_generate_response_pipeline_error(self): | |
| """Test response generation when pipeline raises an error.""" | |
| self.mock_model_loader.is_loaded = True | |
| self.mock_model_loader.pipeline = MagicMock() | |
| self.mock_model_loader.pipeline.side_effect = Exception("Pipeline error") | |
| query = "Test query" | |
| response = self.response_generator.generate_response(query) | |
| assert "Error generating response" in response | |
| assert "Pipeline error" in response | |
| class TestGradioInterface: | |
| """Test GradioInterface functionality.""" | |
| def setup_method(self): | |
| """Setup test fixtures.""" | |
| self.mock_model_loader = MagicMock(spec=ModelLoader) | |
| self.mock_model_loader.get_model_info.return_value = { | |
| "status": "loaded", | |
| "model_id": "test/model", | |
| "revision": "abc123" | |
| } | |
| def test_gradio_interface_creation(self): | |
| """Test GradioInterface creation.""" | |
| interface = GradioInterface(self.mock_model_loader) | |
| assert interface.model_loader == self.mock_model_loader | |
| assert interface.response_generator is not None | |
| assert interface.demo is None | |
| def test_create_interface_model_loaded(self, mock_blocks): | |
| """Test interface creation when model is loaded.""" | |
| self.mock_model_loader.is_loaded = True | |
| interface = GradioInterface(self.mock_model_loader) | |
| demo = interface.create_interface() | |
| # Verify Blocks was called | |
| mock_blocks.assert_called_once() | |
| # Verify model info was requested | |
| self.mock_model_loader.get_model_info.assert_called() | |
| def test_create_interface_model_not_loaded(self, mock_blocks): | |
| """Test interface creation when model is not loaded.""" | |
| self.mock_model_loader.is_loaded = False | |
| self.mock_model_loader.get_model_info.return_value = {"status": "not_loaded"} | |
| interface = GradioInterface(self.mock_model_loader) | |
| demo = interface.create_interface() | |
| # Verify Blocks was called | |
| mock_blocks.assert_called_once() | |
| def test_launch_creates_interface_if_needed(self, mock_blocks): | |
| """Test that launch creates interface if it doesn't exist.""" | |
| mock_demo = MagicMock() | |
| mock_blocks.return_value.__enter__.return_value = mock_demo | |
| interface = GradioInterface(self.mock_model_loader) | |
| # Mock the launch method to avoid actual server startup | |
| with patch.object(mock_demo, 'launch'): | |
| interface.launch() | |
| # Verify interface was created | |
| assert interface.demo is not None | |
| mock_demo.launch.assert_called_once() | |
| def test_launch_uses_existing_interface(self): | |
| """Test that launch uses existing interface if available.""" | |
| interface = GradioInterface(self.mock_model_loader) | |
| mock_demo = MagicMock() | |
| interface.demo = mock_demo | |
| # Mock the launch method | |
| with patch.object(mock_demo, 'launch'): | |
| interface.launch(server_name="127.0.0.1", server_port=8080) | |
| # Verify existing demo was used | |
| mock_demo.launch.assert_called_once_with( | |
| server_name="127.0.0.1", | |
| server_port=8080 | |
| ) | |
| if __name__ == "__main__": | |
| pytest.main([__file__]) | |