Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Fallback Mechanism Test for OnCall.ai | |
| Tests the 3-tier fallback system using rare disease cases that are likely | |
| to trigger fallback mechanisms due to limited training data coverage. | |
| Test Case: Suspected Whipple's disease with cognitive changes | |
| - Rare disease likely to challenge primary generation | |
| - Complex symptoms requiring specialized medical knowledge | |
| - Good test for fallback system reliability | |
| Author: OnCall.ai Team | |
| Date: 2025-08-03 | |
| """ | |
| import sys | |
| import os | |
| from pathlib import Path | |
| import logging | |
| import json | |
| import traceback | |
| from datetime import datetime | |
| from typing import Dict, List, Any, Optional | |
| from unittest.mock import MagicMock, patch | |
| # Add src directory to Python path | |
| current_dir = Path(__file__).parent | |
| project_root = current_dir.parent | |
| src_dir = project_root / "src" | |
| sys.path.insert(0, str(src_dir)) | |
| # Import pipeline modules | |
| try: | |
| from user_prompt import UserPromptProcessor | |
| from retrieval import BasicRetrievalSystem | |
| from llm_clients import llm_Med42_70BClient | |
| from generation import MedicalAdviceGenerator, FALLBACK_TIMEOUTS, FALLBACK_TOKEN_LIMITS | |
| from medical_conditions import CONDITION_KEYWORD_MAPPING | |
| except ImportError as e: | |
| print(f"β Import Error: {e}") | |
| print(f"Current working directory: {os.getcwd()}") | |
| print(f"Python path: {sys.path}") | |
| sys.exit(1) | |
| # Configure detailed logging for fallback testing | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.StreamHandler(), | |
| logging.FileHandler(project_root / 'tests' / 'fallback_test.log') | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class FallbackMechanismTester: | |
| """ | |
| Test class for validating fallback generation mechanisms | |
| """ | |
| def __init__(self): | |
| """Initialize test environment with mocked components for controlled testing""" | |
| self.test_query = "Suspected Whipple's disease with cognitive changes" | |
| self.setup_test_environment() | |
| def setup_test_environment(self): | |
| """Setup test components with proper initialization""" | |
| try: | |
| # Initialize components (will work with actual or mocked LLM) | |
| self.llm_client = llm_Med42_70BClient() | |
| self.retrieval_system = BasicRetrievalSystem() | |
| self.user_prompt_processor = UserPromptProcessor( | |
| llm_client=self.llm_client, | |
| retrieval_system=self.retrieval_system | |
| ) | |
| self.generator = MedicalAdviceGenerator(llm_client=self.llm_client) | |
| logger.info("β Test environment setup successful") | |
| except Exception as e: | |
| logger.error(f"β Test environment setup failed: {e}") | |
| # Create mock objects if real initialization fails | |
| self.setup_mock_environment() | |
| def setup_mock_environment(self): | |
| """Setup mock environment when real components fail""" | |
| logger.info("π§ Setting up mock test environment") | |
| # Mock LLM client | |
| self.llm_client = MagicMock() | |
| self.llm_client.analyze_medical_query.return_value = { | |
| 'extracted_condition': 'whipple disease', | |
| 'confidence': '0.3', | |
| 'raw_response': 'Whipple disease is rare...', | |
| 'latency': 15.0 | |
| } | |
| # Create generator with mock client | |
| self.generator = MedicalAdviceGenerator(llm_client=self.llm_client) | |
| logger.info("β Mock environment setup complete") | |
| def test_fallback_configuration(self): | |
| """Test 1: Verify fallback configuration constants are properly loaded""" | |
| logger.info("\n" + "="*60) | |
| logger.info("π§ͺ TEST 1: Fallback Configuration Validation") | |
| logger.info("="*60) | |
| try: | |
| # Test timeout configuration | |
| assert FALLBACK_TIMEOUTS["primary"] == 30.0 | |
| assert FALLBACK_TIMEOUTS["fallback_1"] == 15.0 | |
| assert FALLBACK_TIMEOUTS["fallback_2"] == 1.0 | |
| logger.info("β Timeout configuration correct") | |
| # Test token limits | |
| assert FALLBACK_TOKEN_LIMITS["primary"] == 800 | |
| assert FALLBACK_TOKEN_LIMITS["fallback_1"] == 300 | |
| assert FALLBACK_TOKEN_LIMITS["fallback_2"] == 0 | |
| logger.info("β Token limit configuration correct") | |
| # Test generator has access to fallback methods | |
| assert hasattr(self.generator, '_attempt_fallback_generation') | |
| assert hasattr(self.generator, '_attempt_simplified_med42') | |
| assert hasattr(self.generator, '_attempt_rag_template') | |
| logger.info("β Fallback methods are available") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Configuration test failed: {e}") | |
| return False | |
| def test_fallback_orchestration_logic(self): | |
| """Test 2: Test fallback orchestration with controlled error injection""" | |
| logger.info("\n" + "="*60) | |
| logger.info("π§ͺ TEST 2: Fallback Orchestration Logic") | |
| logger.info("="*60) | |
| try: | |
| # Create a sample RAG prompt that would fail | |
| test_prompt = """ | |
| You are an experienced attending physician providing guidance. | |
| Clinical Question: | |
| Suspected Whipple's disease with cognitive changes | |
| Relevant Medical Guidelines: | |
| [Large complex medical context that might cause timeout...] | |
| """ | |
| # Test the fallback orchestration | |
| logger.info("π Testing fallback orchestration with simulated primary failure") | |
| result = self.generator._attempt_fallback_generation( | |
| original_prompt=test_prompt, | |
| primary_error="Simulated timeout for testing" | |
| ) | |
| # Verify response structure | |
| assert isinstance(result, dict) | |
| assert 'fallback_method' in result | |
| logger.info(f"β Fallback orchestration returned: {result.get('fallback_method')}") | |
| # Since we have placeholder implementations, expect specific responses | |
| if result.get('fallback_method') == 'none': | |
| logger.info("β All fallbacks failed as expected (placeholder implementation)") | |
| else: | |
| logger.info(f"β Fallback method used: {result.get('fallback_method')}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Orchestration test failed: {e}") | |
| logger.error(traceback.format_exc()) | |
| return False | |
| def test_rare_disease_query_processing(self): | |
| """Test 3: Process rare disease query through complete pipeline""" | |
| logger.info("\n" + "="*60) | |
| logger.info("π§ͺ TEST 3: Rare Disease Query Processing") | |
| logger.info("="*60) | |
| try: | |
| logger.info(f"π Processing query: '{self.test_query}'") | |
| # Step 1: Test condition extraction | |
| logger.info("π Step 1: Condition extraction") | |
| extracted_keywords = self.user_prompt_processor.extract_condition_keywords(self.test_query) | |
| logger.info(f"Extracted keywords: {extracted_keywords}") | |
| # Step 2: Test retrieval (if available) | |
| if hasattr(self.retrieval_system, 'search_sliding_window_chunks'): | |
| logger.info("π Step 2: Retrieval system test") | |
| try: | |
| retrieval_results = self.retrieval_system.search_sliding_window_chunks(self.test_query) | |
| logger.info(f"Retrieved {len(retrieval_results)} results") | |
| except Exception as e: | |
| logger.warning(f"Retrieval failed (expected): {e}") | |
| # Step 3: Test generation pipeline with fallback | |
| logger.info("π Step 3: Generation with fallback testing") | |
| # Create mock retrieval results for generation | |
| mock_retrieval_results = { | |
| "emergency_subset": [ | |
| {"text": "Whipple disease emergency presentation guidelines", "chunk_id": 1}, | |
| {"text": "Cognitive changes in systemic diseases", "chunk_id": 2} | |
| ], | |
| "treatment_subset": [ | |
| {"text": "Antibiotic treatment for Whipple disease", "chunk_id": 3}, | |
| {"text": "Management of cognitive symptoms", "chunk_id": 4} | |
| ] | |
| } | |
| # Test generation (this should work even with mocked components) | |
| generation_result = self.generator.generate_medical_advice( | |
| user_query=self.test_query, | |
| retrieval_results=mock_retrieval_results, | |
| intention="diagnosis" | |
| ) | |
| logger.info("β Generation pipeline completed") | |
| logger.info(f"Confidence score: {generation_result.get('confidence_score', 'N/A')}") | |
| logger.info(f"Fallback method used: {generation_result.get('generation_metadata', {}).get('fallback_method', 'primary')}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Rare disease processing test failed: {e}") | |
| logger.error(traceback.format_exc()) | |
| return False | |
| def test_logging_format_validation(self): | |
| """Test 4: Validate logging format and emoji markers""" | |
| logger.info("\n" + "="*60) | |
| logger.info("π§ͺ TEST 4: Logging Format Validation") | |
| logger.info("="*60) | |
| try: | |
| # Capture log output during fallback | |
| with patch('builtins.print') as mock_print: | |
| # Test fallback with logging | |
| test_prompt = "Test prompt for logging validation" | |
| result = self.generator._attempt_fallback_generation( | |
| original_prompt=test_prompt, | |
| primary_error="Test error for logging" | |
| ) | |
| # Check that logging methods exist and can be called | |
| logger.info("π FALLBACK: Test logging message") | |
| logger.info("π FALLBACK 1: Test step message") | |
| logger.info("β FALLBACK 1: Test success message") | |
| logger.error("β FALLBACK 1: Test error message") | |
| logger.error("π« ALL FALLBACKS FAILED: Test final error") | |
| logger.info("β Logging format validation completed") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Logging validation failed: {e}") | |
| return False | |
| def run_all_tests(self): | |
| """Execute all fallback mechanism tests""" | |
| logger.info("\n" + "π STARTING FALLBACK MECHANISM TESTS") | |
| logger.info("=" * 80) | |
| logger.info(f"Test Query: '{self.test_query}'") | |
| logger.info(f"Test Time: {datetime.now().isoformat()}") | |
| logger.info("=" * 80) | |
| test_results = [] | |
| # Run individual tests | |
| tests = [ | |
| ("Configuration Validation", self.test_fallback_configuration), | |
| ("Orchestration Logic", self.test_fallback_orchestration_logic), | |
| ("Rare Disease Processing", self.test_rare_disease_query_processing), | |
| ("Logging Format", self.test_logging_format_validation) | |
| ] | |
| for test_name, test_func in tests: | |
| try: | |
| result = test_func() | |
| test_results.append((test_name, result)) | |
| status = "β PASSED" if result else "β FAILED" | |
| logger.info(f"\n{status}: {test_name}") | |
| except Exception as e: | |
| test_results.append((test_name, False)) | |
| logger.error(f"\nβ ERROR in {test_name}: {e}") | |
| # Summary | |
| logger.info("\n" + "=" * 80) | |
| logger.info("π TEST SUMMARY") | |
| logger.info("=" * 80) | |
| passed = sum(1 for _, result in test_results if result) | |
| total = len(test_results) | |
| for test_name, result in test_results: | |
| status = "β PASSED" if result else "β FAILED" | |
| logger.info(f"{status}: {test_name}") | |
| logger.info(f"\nOverall: {passed}/{total} tests passed") | |
| if passed == total: | |
| logger.info("π ALL TESTS PASSED - Fallback mechanism is working correctly!") | |
| else: | |
| logger.warning(f"β οΈ {total - passed} tests failed - Review implementation") | |
| return passed == total | |
| def main(): | |
| """Main test execution function""" | |
| print("π§ͺ OnCall.ai Fallback Mechanism Test") | |
| print("=" * 50) | |
| try: | |
| tester = FallbackMechanismTester() | |
| success = tester.run_all_tests() | |
| if success: | |
| print("\nπ All fallback tests completed successfully!") | |
| return 0 | |
| else: | |
| print("\nβ οΈ Some tests failed. Check logs for details.") | |
| return 1 | |
| except Exception as e: | |
| print(f"\nβ Test execution failed: {e}") | |
| traceback.print_exc() | |
| return 1 | |
| if __name__ == "__main__": | |
| exit_code = main() |