Spaces:
Sleeping
feat(integration): integrate hospital customization pipeline with main RAG system
Browse files- Move hospital customization to Step 1.5 for early execution
- Add parallel retrieval: general medical guidelines + hospital-specific docs
- Rename customization/src/retrieval to custom_retrieval to resolve import conflicts
- Fix field name mismatch in app.py fallback flow (medical_advice vs advice)
- Add enhanced keyword extraction for better hospital document matching
- Update generation.py to handle hospital_custom chunk classification
- Ensure conditional return values based on DEBUG_MODE to fix Gradio warnings
Major architectural changes:
- Hospital docs now retrieved alongside general guidelines
- LLM-based keyword extraction improves hospital document relevance
- Graceful fallback when no medical condition found but hospital docs available
- All components properly integrated with error handling
- app.py +110 -15
- customization/customization_pipeline.py +25 -6
- customization/src/{retrieval → custom_retrieval}/__init__.py +0 -0
- customization/src/{retrieval → custom_retrieval}/chunk_retriever.py +0 -0
- customization/src/{retrieval → custom_retrieval}/document_retriever.py +0 -0
- customization/src/demos/demo_runner.py +2 -2
- customization/src/rag/medical_rag_pipeline.py +2 -2
- src/generation.py +42 -7
- src/llm_clients.py +81 -1
- test_retrieval_pipeline.py +0 -223
|
@@ -31,6 +31,9 @@ current_dir = Path(__file__).parent
|
|
| 31 |
src_dir = current_dir / "src"
|
| 32 |
sys.path.insert(0, str(src_dir))
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
# Import OnCall.ai modules
|
| 35 |
try:
|
| 36 |
from user_prompt import UserPromptProcessor
|
|
@@ -141,14 +144,84 @@ class OnCallAIInterface:
|
|
| 141 |
processing_steps.append(" 🚫 Query identified as non-medical")
|
| 142 |
return non_medical_msg, '\n'.join(processing_steps), "{}", "{}"
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
# STEP 2: User Confirmation (Auto-simulated)
|
| 145 |
processing_steps.append("\n🤝 Step 2: User confirmation (auto-confirmed for demo)")
|
| 146 |
confirmation = self.user_prompt_processor.handle_user_confirmation(condition_result)
|
| 147 |
|
| 148 |
if not condition_result.get('condition'):
|
| 149 |
-
no_condition_msg = "Unable to identify a specific medical condition. Please rephrase your query with more specific medical terms."
|
| 150 |
processing_steps.append(" ⚠️ No medical condition identified")
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
processing_steps.append(f" ✅ Confirmed condition: {condition_result.get('condition')}")
|
| 154 |
|
|
@@ -161,9 +234,13 @@ class OnCallAIInterface:
|
|
| 161 |
if not search_query:
|
| 162 |
search_query = condition_result.get('condition', user_query)
|
| 163 |
|
| 164 |
-
|
|
|
|
| 165 |
step3_time = (datetime.now() - step3_start).total_seconds()
|
| 166 |
|
|
|
|
|
|
|
|
|
|
| 167 |
processed_results = retrieval_results.get('processed_results', [])
|
| 168 |
emergency_count = len([r for r in processed_results if r.get('type') == 'emergency'])
|
| 169 |
treatment_count = len([r for r in processed_results if r.get('type') == 'treatment'])
|
|
@@ -179,6 +256,8 @@ class OnCallAIInterface:
|
|
| 179 |
else:
|
| 180 |
guidelines_display = self._format_user_friendly_sources(processed_results)
|
| 181 |
|
|
|
|
|
|
|
| 182 |
# STEP 4: Medical Advice Generation
|
| 183 |
processing_steps.append("\n🧠 Step 4: Generating evidence-based medical advice...")
|
| 184 |
step4_start = datetime.now()
|
|
@@ -235,12 +314,20 @@ class OnCallAIInterface:
|
|
| 235 |
if not DEBUG_MODE:
|
| 236 |
technical_details = self._sanitize_technical_details(technical_details)
|
| 237 |
|
| 238 |
-
return
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
except Exception as e:
|
| 246 |
error_msg = f"❌ System error: {str(e)}"
|
|
@@ -252,12 +339,20 @@ class OnCallAIInterface:
|
|
| 252 |
"query": user_query
|
| 253 |
}
|
| 254 |
|
| 255 |
-
return
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
def _format_guidelines_display(self, processed_results: List[Dict]) -> str:
|
| 263 |
"""Format retrieved guidelines for user-friendly display"""
|
|
|
|
| 31 |
src_dir = current_dir / "src"
|
| 32 |
sys.path.insert(0, str(src_dir))
|
| 33 |
|
| 34 |
+
# Also add project root to ensure customization module can be imported
|
| 35 |
+
sys.path.insert(0, str(current_dir))
|
| 36 |
+
|
| 37 |
# Import OnCall.ai modules
|
| 38 |
try:
|
| 39 |
from user_prompt import UserPromptProcessor
|
|
|
|
| 144 |
processing_steps.append(" 🚫 Query identified as non-medical")
|
| 145 |
return non_medical_msg, '\n'.join(processing_steps), "{}", "{}"
|
| 146 |
|
| 147 |
+
# STEP 1.5: Hospital-Specific Customization (Early retrieval)
|
| 148 |
+
# Run this early since it has its own keyword extraction
|
| 149 |
+
customization_results = []
|
| 150 |
+
retrieval_results = {} # Initialize early for hospital results
|
| 151 |
+
try:
|
| 152 |
+
from customization.customization_pipeline import retrieve_document_chunks
|
| 153 |
+
|
| 154 |
+
processing_steps.append("\n🏥 Step 1.5: Checking hospital-specific guidelines...")
|
| 155 |
+
custom_start = datetime.now()
|
| 156 |
+
|
| 157 |
+
# Use original user query since hospital module has its own keyword extraction
|
| 158 |
+
custom_results = retrieve_document_chunks(user_query, top_k=3, llm_client=self.llm_client)
|
| 159 |
+
custom_time = (datetime.now() - custom_start).total_seconds()
|
| 160 |
+
|
| 161 |
+
if custom_results:
|
| 162 |
+
processing_steps.append(f" 📋 Found {len(custom_results)} hospital-specific guidelines")
|
| 163 |
+
processing_steps.append(f" ⏱️ Customization time: {custom_time:.3f}s")
|
| 164 |
+
|
| 165 |
+
# Store customization results for later use
|
| 166 |
+
customization_results = custom_results
|
| 167 |
+
|
| 168 |
+
# Add custom results to retrieval_results for the generator
|
| 169 |
+
retrieval_results['customization_results'] = custom_results
|
| 170 |
+
else:
|
| 171 |
+
processing_steps.append(" ℹ️ No hospital-specific guidelines found")
|
| 172 |
+
except ImportError as e:
|
| 173 |
+
processing_steps.append(f" ⚠️ Hospital customization module not available: {str(e)}")
|
| 174 |
+
if DEBUG_MODE:
|
| 175 |
+
print(f"Import error: {traceback.format_exc()}")
|
| 176 |
+
except Exception as e:
|
| 177 |
+
processing_steps.append(f" ⚠️ Customization search skipped: {str(e)}")
|
| 178 |
+
if DEBUG_MODE:
|
| 179 |
+
print(f"Customization error: {traceback.format_exc()}")
|
| 180 |
+
|
| 181 |
# STEP 2: User Confirmation (Auto-simulated)
|
| 182 |
processing_steps.append("\n🤝 Step 2: User confirmation (auto-confirmed for demo)")
|
| 183 |
confirmation = self.user_prompt_processor.handle_user_confirmation(condition_result)
|
| 184 |
|
| 185 |
if not condition_result.get('condition'):
|
|
|
|
| 186 |
processing_steps.append(" ⚠️ No medical condition identified")
|
| 187 |
+
|
| 188 |
+
# If we have hospital customization results, we can still try to provide help
|
| 189 |
+
if customization_results:
|
| 190 |
+
processing_steps.append(" ℹ️ Using hospital-specific guidelines to assist...")
|
| 191 |
+
|
| 192 |
+
# Create a minimal retrieval_results structure for generation
|
| 193 |
+
retrieval_results['processed_results'] = []
|
| 194 |
+
|
| 195 |
+
# Skip to generation with hospital results only
|
| 196 |
+
processing_steps.append("\n🧠 Step 4: Generating advice based on hospital guidelines...")
|
| 197 |
+
gen_start = datetime.now()
|
| 198 |
+
|
| 199 |
+
medical_advice_result = self.medical_generator.generate_medical_advice(
|
| 200 |
+
condition_result.get('condition', user_query),
|
| 201 |
+
retrieval_results,
|
| 202 |
+
intention="general"
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
gen_time = (datetime.now() - gen_start).total_seconds()
|
| 206 |
+
medical_advice = medical_advice_result.get('medical_advice', 'Unable to generate advice')
|
| 207 |
+
|
| 208 |
+
processing_steps.append(f" ⏱️ Generation time: {gen_time:.3f}s")
|
| 209 |
+
|
| 210 |
+
# Format guidelines display
|
| 211 |
+
guidelines_display = f"Hospital Guidelines Found: {len(customization_results)}"
|
| 212 |
+
|
| 213 |
+
# Conditional return based on DEBUG_MODE
|
| 214 |
+
if DEBUG_MODE:
|
| 215 |
+
return (medical_advice, '\n'.join(processing_steps), guidelines_display, "{}")
|
| 216 |
+
else:
|
| 217 |
+
return (medical_advice, '\n'.join(processing_steps), guidelines_display)
|
| 218 |
+
else:
|
| 219 |
+
# No condition and no hospital results
|
| 220 |
+
no_condition_msg = "Unable to identify a specific medical condition. Please rephrase your query with more specific medical terms."
|
| 221 |
+
if DEBUG_MODE:
|
| 222 |
+
return no_condition_msg, '\n'.join(processing_steps), "{}", "{}"
|
| 223 |
+
else:
|
| 224 |
+
return no_condition_msg, '\n'.join(processing_steps), "{}"
|
| 225 |
|
| 226 |
processing_steps.append(f" ✅ Confirmed condition: {condition_result.get('condition')}")
|
| 227 |
|
|
|
|
| 234 |
if not search_query:
|
| 235 |
search_query = condition_result.get('condition', user_query)
|
| 236 |
|
| 237 |
+
# Search for general medical guidelines
|
| 238 |
+
general_results = self.retrieval_system.search(search_query, top_k=5)
|
| 239 |
step3_time = (datetime.now() - step3_start).total_seconds()
|
| 240 |
|
| 241 |
+
# Merge with existing retrieval_results (which contains hospital customization)
|
| 242 |
+
retrieval_results.update(general_results)
|
| 243 |
+
|
| 244 |
processed_results = retrieval_results.get('processed_results', [])
|
| 245 |
emergency_count = len([r for r in processed_results if r.get('type') == 'emergency'])
|
| 246 |
treatment_count = len([r for r in processed_results if r.get('type') == 'treatment'])
|
|
|
|
| 256 |
else:
|
| 257 |
guidelines_display = self._format_user_friendly_sources(processed_results)
|
| 258 |
|
| 259 |
+
# Hospital customization already done in Step 1.5
|
| 260 |
+
|
| 261 |
# STEP 4: Medical Advice Generation
|
| 262 |
processing_steps.append("\n🧠 Step 4: Generating evidence-based medical advice...")
|
| 263 |
step4_start = datetime.now()
|
|
|
|
| 314 |
if not DEBUG_MODE:
|
| 315 |
technical_details = self._sanitize_technical_details(technical_details)
|
| 316 |
|
| 317 |
+
# Conditional return based on DEBUG_MODE
|
| 318 |
+
if DEBUG_MODE:
|
| 319 |
+
return (
|
| 320 |
+
medical_advice,
|
| 321 |
+
'\n'.join(processing_steps),
|
| 322 |
+
guidelines_display,
|
| 323 |
+
json.dumps(technical_details, indent=2)
|
| 324 |
+
)
|
| 325 |
+
else:
|
| 326 |
+
return (
|
| 327 |
+
medical_advice,
|
| 328 |
+
'\n'.join(processing_steps),
|
| 329 |
+
guidelines_display
|
| 330 |
+
)
|
| 331 |
|
| 332 |
except Exception as e:
|
| 333 |
error_msg = f"❌ System error: {str(e)}"
|
|
|
|
| 339 |
"query": user_query
|
| 340 |
}
|
| 341 |
|
| 342 |
+
# Conditional return based on DEBUG_MODE
|
| 343 |
+
if DEBUG_MODE:
|
| 344 |
+
return (
|
| 345 |
+
"I apologize, but I encountered an error while processing your medical query. Please try rephrasing your question or contact technical support.",
|
| 346 |
+
'\n'.join(processing_steps),
|
| 347 |
+
"{}",
|
| 348 |
+
json.dumps(error_details, indent=2)
|
| 349 |
+
)
|
| 350 |
+
else:
|
| 351 |
+
return (
|
| 352 |
+
"I apologize, but I encountered an error while processing your medical query. Please try rephrasing your question or contact technical support.",
|
| 353 |
+
'\n'.join(processing_steps),
|
| 354 |
+
"{}"
|
| 355 |
+
)
|
| 356 |
|
| 357 |
def _format_guidelines_display(self, processed_results: List[Dict]) -> str:
|
| 358 |
"""Format retrieved guidelines for user-friendly display"""
|
|
@@ -9,7 +9,9 @@ from pathlib import Path
|
|
| 9 |
from typing import List, Dict
|
| 10 |
|
| 11 |
# Add src directory to Python path
|
| 12 |
-
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# Import necessary modules
|
| 15 |
from models.embedding_models import load_biomedbert_model
|
|
@@ -17,8 +19,8 @@ from data.loaders import load_annotations
|
|
| 17 |
from indexing.document_indexer import build_document_index
|
| 18 |
from indexing.embedding_creator import create_tag_embeddings, create_chunk_embeddings
|
| 19 |
from indexing.storage import save_document_system, load_document_system_with_annoy
|
| 20 |
-
from
|
| 21 |
-
from
|
| 22 |
|
| 23 |
|
| 24 |
def build_customization_embeddings():
|
|
@@ -68,7 +70,7 @@ def build_customization_embeddings():
|
|
| 68 |
return True
|
| 69 |
|
| 70 |
|
| 71 |
-
def retrieve_document_chunks(query: str, top_k: int = 5) -> List[Dict]:
|
| 72 |
"""Retrieve relevant document chunks using two-stage ANNOY retrieval.
|
| 73 |
|
| 74 |
Stage 1: Find relevant documents using tag embeddings (medical concepts)
|
|
@@ -77,6 +79,7 @@ def retrieve_document_chunks(query: str, top_k: int = 5) -> List[Dict]:
|
|
| 77 |
Args:
|
| 78 |
query: The search query
|
| 79 |
top_k: Number of chunks to retrieve
|
|
|
|
| 80 |
|
| 81 |
Returns:
|
| 82 |
List of dictionaries containing chunk information
|
|
@@ -98,8 +101,24 @@ def retrieve_document_chunks(query: str, top_k: int = 5) -> List[Dict]:
|
|
| 98 |
print("❌ Failed to load ANNOY manager")
|
| 99 |
return []
|
| 100 |
|
| 101 |
-
#
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
# Stage 1: Find relevant documents using tag ANNOY index
|
| 105 |
print(f"🔍 Stage 1: Finding relevant documents for query: '{query}'")
|
|
|
|
| 9 |
from typing import List, Dict
|
| 10 |
|
| 11 |
# Add src directory to Python path
|
| 12 |
+
src_path = Path(__file__).parent / 'src'
|
| 13 |
+
if str(src_path) not in sys.path:
|
| 14 |
+
sys.path.insert(0, str(src_path))
|
| 15 |
|
| 16 |
# Import necessary modules
|
| 17 |
from models.embedding_models import load_biomedbert_model
|
|
|
|
| 19 |
from indexing.document_indexer import build_document_index
|
| 20 |
from indexing.embedding_creator import create_tag_embeddings, create_chunk_embeddings
|
| 21 |
from indexing.storage import save_document_system, load_document_system_with_annoy
|
| 22 |
+
from custom_retrieval.document_retriever import create_document_tag_mapping
|
| 23 |
+
from custom_retrieval.chunk_retriever import find_relevant_chunks_with_fallback
|
| 24 |
|
| 25 |
|
| 26 |
def build_customization_embeddings():
|
|
|
|
| 70 |
return True
|
| 71 |
|
| 72 |
|
| 73 |
+
def retrieve_document_chunks(query: str, top_k: int = 5, llm_client=None) -> List[Dict]:
|
| 74 |
"""Retrieve relevant document chunks using two-stage ANNOY retrieval.
|
| 75 |
|
| 76 |
Stage 1: Find relevant documents using tag embeddings (medical concepts)
|
|
|
|
| 79 |
Args:
|
| 80 |
query: The search query
|
| 81 |
top_k: Number of chunks to retrieve
|
| 82 |
+
llm_client: Optional LLM client for keyword extraction
|
| 83 |
|
| 84 |
Returns:
|
| 85 |
List of dictionaries containing chunk information
|
|
|
|
| 101 |
print("❌ Failed to load ANNOY manager")
|
| 102 |
return []
|
| 103 |
|
| 104 |
+
# Extract medical keywords for better matching
|
| 105 |
+
search_query = query
|
| 106 |
+
if llm_client:
|
| 107 |
+
try:
|
| 108 |
+
print(f"🔍 Extracting medical keywords from: '{query}'")
|
| 109 |
+
keywords = llm_client.extract_medical_keywords_for_customization(query)
|
| 110 |
+
if keywords:
|
| 111 |
+
search_query = " ".join(keywords)
|
| 112 |
+
print(f"✅ Using keywords for search: '{search_query}'")
|
| 113 |
+
else:
|
| 114 |
+
print("ℹ️ No keywords extracted, using original query")
|
| 115 |
+
except Exception as e:
|
| 116 |
+
print(f"⚠️ Keyword extraction failed, using original query: {e}")
|
| 117 |
+
else:
|
| 118 |
+
print("ℹ️ No LLM client provided, using original query")
|
| 119 |
+
|
| 120 |
+
# Create query embedding using processed search query
|
| 121 |
+
query_embedding = embedding_model.encode(search_query)
|
| 122 |
|
| 123 |
# Stage 1: Find relevant documents using tag ANNOY index
|
| 124 |
print(f"🔍 Stage 1: Finding relevant documents for query: '{query}'")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -7,11 +7,11 @@ from data.loaders import load_annotations
|
|
| 7 |
from indexing.document_indexer import build_document_index
|
| 8 |
from indexing.embedding_creator import create_tag_embeddings, create_chunk_embeddings
|
| 9 |
from indexing.storage import save_document_system, load_document_system, load_document_system_with_annoy
|
| 10 |
-
from
|
| 11 |
create_document_tag_mapping, find_relevant_documents,
|
| 12 |
find_relevant_documents_with_fallback
|
| 13 |
)
|
| 14 |
-
from
|
| 15 |
find_relevant_chunks, get_documents_for_rag, get_chunks_for_rag,
|
| 16 |
find_relevant_chunks_with_fallback
|
| 17 |
)
|
|
|
|
| 7 |
from indexing.document_indexer import build_document_index
|
| 8 |
from indexing.embedding_creator import create_tag_embeddings, create_chunk_embeddings
|
| 9 |
from indexing.storage import save_document_system, load_document_system, load_document_system_with_annoy
|
| 10 |
+
from custom_retrieval.document_retriever import (
|
| 11 |
create_document_tag_mapping, find_relevant_documents,
|
| 12 |
find_relevant_documents_with_fallback
|
| 13 |
)
|
| 14 |
+
from custom_retrieval.chunk_retriever import (
|
| 15 |
find_relevant_chunks, get_documents_for_rag, get_chunks_for_rag,
|
| 16 |
find_relevant_chunks_with_fallback
|
| 17 |
)
|
|
@@ -7,8 +7,8 @@ from typing import Dict, List, Optional, Tuple
|
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
|
| 9 |
# Import existing retrieval components
|
| 10 |
-
from
|
| 11 |
-
from
|
| 12 |
from models.embedding_models import load_biomedbert_model
|
| 13 |
|
| 14 |
|
|
|
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
|
| 9 |
# Import existing retrieval components
|
| 10 |
+
from custom_retrieval.document_retriever import find_relevant_documents
|
| 11 |
+
from custom_retrieval.chunk_retriever import find_relevant_chunks, get_chunks_for_rag
|
| 12 |
from models.embedding_models import load_biomedbert_model
|
| 13 |
|
| 14 |
|
|
@@ -128,6 +128,7 @@ class MedicalAdviceGenerator:
|
|
| 128 |
treatment_chunks = classified_chunks.get("treatment_subset", [])
|
| 129 |
symptom_chunks = classified_chunks.get("symptom_subset", []) # Dataset B (future)
|
| 130 |
diagnosis_chunks = classified_chunks.get("diagnosis_subset", []) # Dataset B (future)
|
|
|
|
| 131 |
|
| 132 |
# Select chunks based on intention or intelligent defaults
|
| 133 |
selected_chunks = self._select_chunks_by_intention(
|
|
@@ -135,7 +136,8 @@ class MedicalAdviceGenerator:
|
|
| 135 |
emergency_chunks=emergency_chunks,
|
| 136 |
treatment_chunks=treatment_chunks,
|
| 137 |
symptom_chunks=symptom_chunks,
|
| 138 |
-
diagnosis_chunks=diagnosis_chunks
|
|
|
|
| 139 |
)
|
| 140 |
|
| 141 |
# Build context block from selected chunks
|
|
@@ -161,7 +163,8 @@ class MedicalAdviceGenerator:
|
|
| 161 |
"emergency_subset": [],
|
| 162 |
"treatment_subset": [],
|
| 163 |
"symptom_subset": [], # Reserved for Dataset B
|
| 164 |
-
"diagnosis_subset": [] # Reserved for Dataset B
|
|
|
|
| 165 |
}
|
| 166 |
|
| 167 |
# Process results from current dual-index system
|
|
@@ -180,29 +183,49 @@ class MedicalAdviceGenerator:
|
|
| 180 |
logger.warning(f"Unknown chunk type: {chunk_type}, defaulting to STAT (tentative)")
|
| 181 |
classified["emergency_subset"].append(chunk)
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
# TODO: Future integration point for Dataset B
|
| 184 |
# When Dataset B team provides symptom/diagnosis data:
|
| 185 |
# classified["symptom_subset"] = process_dataset_b_symptoms(retrieval_results)
|
| 186 |
# classified["diagnosis_subset"] = process_dataset_b_diagnosis(retrieval_results)
|
| 187 |
|
| 188 |
logger.info(f"Classified chunks: Emergency={len(classified['emergency_subset'])}, "
|
| 189 |
-
f"Treatment={len(classified['treatment_subset'])}"
|
|
|
|
| 190 |
|
| 191 |
return classified
|
| 192 |
|
| 193 |
def _select_chunks_by_intention(self, intention: Optional[str],
|
| 194 |
emergency_chunks: List, treatment_chunks: List,
|
| 195 |
-
symptom_chunks: List, diagnosis_chunks: List
|
|
|
|
| 196 |
"""
|
| 197 |
Select optimal chunk combination based on query intention
|
| 198 |
|
| 199 |
Args:
|
| 200 |
intention: Detected or specified intention
|
| 201 |
*_chunks: Chunks from different dataset sources
|
|
|
|
| 202 |
|
| 203 |
Returns:
|
| 204 |
List of selected chunks for prompt construction
|
| 205 |
"""
|
|
|
|
|
|
|
| 206 |
if intention and intention in self.dataset_priorities:
|
| 207 |
# Use predefined priorities for known intentions
|
| 208 |
priorities = self.dataset_priorities[intention]
|
|
@@ -212,6 +235,9 @@ class MedicalAdviceGenerator:
|
|
| 212 |
selected_chunks.extend(emergency_chunks[:priorities["emergency_subset"]])
|
| 213 |
selected_chunks.extend(treatment_chunks[:priorities["treatment_subset"]])
|
| 214 |
|
|
|
|
|
|
|
|
|
|
| 215 |
# TODO: Future Dataset B integration
|
| 216 |
# selected_chunks.extend(symptom_chunks[:priorities["symptom_subset"]])
|
| 217 |
# selected_chunks.extend(diagnosis_chunks[:priorities["diagnosis_subset"]])
|
|
@@ -220,7 +246,7 @@ class MedicalAdviceGenerator:
|
|
| 220 |
|
| 221 |
else:
|
| 222 |
# No specific intention - let LLM judge from best available chunks
|
| 223 |
-
all_chunks = emergency_chunks + treatment_chunks + symptom_chunks + diagnosis_chunks
|
| 224 |
|
| 225 |
# Sort by relevance (distance) and take top 6
|
| 226 |
all_chunks_sorted = sorted(all_chunks, key=lambda x: x.get("distance", 999))
|
|
@@ -251,10 +277,19 @@ class MedicalAdviceGenerator:
|
|
| 251 |
distance = chunk.get("distance", 0)
|
| 252 |
|
| 253 |
# Format each chunk with metadata
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
[Guideline {i}] (Source: {chunk_type.title()}, Relevance: {1-distance:.3f})
|
| 256 |
{chunk_text}
|
| 257 |
-
|
| 258 |
|
| 259 |
context_parts.append(context_part)
|
| 260 |
|
|
|
|
| 128 |
treatment_chunks = classified_chunks.get("treatment_subset", [])
|
| 129 |
symptom_chunks = classified_chunks.get("symptom_subset", []) # Dataset B (future)
|
| 130 |
diagnosis_chunks = classified_chunks.get("diagnosis_subset", []) # Dataset B (future)
|
| 131 |
+
hospital_custom_chunks = classified_chunks.get("hospital_custom", []) # Hospital customization
|
| 132 |
|
| 133 |
# Select chunks based on intention or intelligent defaults
|
| 134 |
selected_chunks = self._select_chunks_by_intention(
|
|
|
|
| 136 |
emergency_chunks=emergency_chunks,
|
| 137 |
treatment_chunks=treatment_chunks,
|
| 138 |
symptom_chunks=symptom_chunks,
|
| 139 |
+
diagnosis_chunks=diagnosis_chunks,
|
| 140 |
+
hospital_custom_chunks=hospital_custom_chunks
|
| 141 |
)
|
| 142 |
|
| 143 |
# Build context block from selected chunks
|
|
|
|
| 163 |
"emergency_subset": [],
|
| 164 |
"treatment_subset": [],
|
| 165 |
"symptom_subset": [], # Reserved for Dataset B
|
| 166 |
+
"diagnosis_subset": [], # Reserved for Dataset B
|
| 167 |
+
"hospital_custom": [] # Hospital-specific customization
|
| 168 |
}
|
| 169 |
|
| 170 |
# Process results from current dual-index system
|
|
|
|
| 183 |
logger.warning(f"Unknown chunk type: {chunk_type}, defaulting to STAT (tentative)")
|
| 184 |
classified["emergency_subset"].append(chunk)
|
| 185 |
|
| 186 |
+
# Process hospital customization results if available
|
| 187 |
+
customization_results = retrieval_results.get('customization_results', [])
|
| 188 |
+
if customization_results:
|
| 189 |
+
for custom_chunk in customization_results:
|
| 190 |
+
# Convert customization format to standard chunk format
|
| 191 |
+
standardized_chunk = {
|
| 192 |
+
'type': 'hospital_custom',
|
| 193 |
+
'text': custom_chunk.get('chunk_text', ''),
|
| 194 |
+
'distance': 1 - custom_chunk.get('score', 0), # Convert score to distance
|
| 195 |
+
'matched': f"Hospital Doc: {custom_chunk.get('document', 'Unknown')}",
|
| 196 |
+
'metadata': custom_chunk.get('metadata', {})
|
| 197 |
+
}
|
| 198 |
+
classified["hospital_custom"].append(standardized_chunk)
|
| 199 |
+
logger.info(f"Added {len(customization_results)} hospital-specific chunks")
|
| 200 |
+
|
| 201 |
# TODO: Future integration point for Dataset B
|
| 202 |
# When Dataset B team provides symptom/diagnosis data:
|
| 203 |
# classified["symptom_subset"] = process_dataset_b_symptoms(retrieval_results)
|
| 204 |
# classified["diagnosis_subset"] = process_dataset_b_diagnosis(retrieval_results)
|
| 205 |
|
| 206 |
logger.info(f"Classified chunks: Emergency={len(classified['emergency_subset'])}, "
|
| 207 |
+
f"Treatment={len(classified['treatment_subset'])}, "
|
| 208 |
+
f"Hospital Custom={len(classified['hospital_custom'])}")
|
| 209 |
|
| 210 |
return classified
|
| 211 |
|
| 212 |
def _select_chunks_by_intention(self, intention: Optional[str],
|
| 213 |
emergency_chunks: List, treatment_chunks: List,
|
| 214 |
+
symptom_chunks: List, diagnosis_chunks: List,
|
| 215 |
+
hospital_custom_chunks: List = None) -> List:
|
| 216 |
"""
|
| 217 |
Select optimal chunk combination based on query intention
|
| 218 |
|
| 219 |
Args:
|
| 220 |
intention: Detected or specified intention
|
| 221 |
*_chunks: Chunks from different dataset sources
|
| 222 |
+
hospital_custom_chunks: Hospital-specific customization chunks
|
| 223 |
|
| 224 |
Returns:
|
| 225 |
List of selected chunks for prompt construction
|
| 226 |
"""
|
| 227 |
+
hospital_custom_chunks = hospital_custom_chunks or []
|
| 228 |
+
|
| 229 |
if intention and intention in self.dataset_priorities:
|
| 230 |
# Use predefined priorities for known intentions
|
| 231 |
priorities = self.dataset_priorities[intention]
|
|
|
|
| 235 |
selected_chunks.extend(emergency_chunks[:priorities["emergency_subset"]])
|
| 236 |
selected_chunks.extend(treatment_chunks[:priorities["treatment_subset"]])
|
| 237 |
|
| 238 |
+
# Add hospital custom chunks alongside
|
| 239 |
+
selected_chunks.extend(hospital_custom_chunks)
|
| 240 |
+
|
| 241 |
# TODO: Future Dataset B integration
|
| 242 |
# selected_chunks.extend(symptom_chunks[:priorities["symptom_subset"]])
|
| 243 |
# selected_chunks.extend(diagnosis_chunks[:priorities["diagnosis_subset"]])
|
|
|
|
| 246 |
|
| 247 |
else:
|
| 248 |
# No specific intention - let LLM judge from best available chunks
|
| 249 |
+
all_chunks = emergency_chunks + treatment_chunks + symptom_chunks + diagnosis_chunks + hospital_custom_chunks
|
| 250 |
|
| 251 |
# Sort by relevance (distance) and take top 6
|
| 252 |
all_chunks_sorted = sorted(all_chunks, key=lambda x: x.get("distance", 999))
|
|
|
|
| 277 |
distance = chunk.get("distance", 0)
|
| 278 |
|
| 279 |
# Format each chunk with metadata
|
| 280 |
+
if chunk_type == 'hospital_custom':
|
| 281 |
+
# Special formatting for hospital-specific guidelines
|
| 282 |
+
source_label = "Hospital Protocol"
|
| 283 |
+
context_part = f"""
|
| 284 |
+
[Guideline {i}] (Source: {source_label}, Relevance: {1-distance:.3f})
|
| 285 |
+
📋 {chunk.get('matched', 'Hospital Document')}
|
| 286 |
+
{chunk_text}
|
| 287 |
+
""".strip()
|
| 288 |
+
else:
|
| 289 |
+
context_part = f"""
|
| 290 |
[Guideline {i}] (Source: {chunk_type.title()}, Relevance: {1-distance:.3f})
|
| 291 |
{chunk_text}
|
| 292 |
+
""".strip()
|
| 293 |
|
| 294 |
context_parts.append(context_part)
|
| 295 |
|
|
@@ -9,7 +9,7 @@ Date: 2025-07-29
|
|
| 9 |
|
| 10 |
import logging
|
| 11 |
import os
|
| 12 |
-
from typing import Dict, Optional, Union
|
| 13 |
from huggingface_hub import InferenceClient
|
| 14 |
from dotenv import load_dotenv
|
| 15 |
|
|
@@ -162,6 +162,86 @@ DO NOT provide medical advice."""
|
|
| 162 |
'latency': latency # Include latency even for error cases
|
| 163 |
}
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
def _extract_condition(self, response: str) -> str:
|
| 166 |
"""
|
| 167 |
Extract medical condition from model response.
|
|
|
|
| 9 |
|
| 10 |
import logging
|
| 11 |
import os
|
| 12 |
+
from typing import Dict, Optional, Union, List
|
| 13 |
from huggingface_hub import InferenceClient
|
| 14 |
from dotenv import load_dotenv
|
| 15 |
|
|
|
|
| 162 |
'latency': latency # Include latency even for error cases
|
| 163 |
}
|
| 164 |
|
| 165 |
+
def extract_medical_keywords_for_customization(
|
| 166 |
+
self,
|
| 167 |
+
query: str,
|
| 168 |
+
max_tokens: int = 50,
|
| 169 |
+
timeout: Optional[float] = None
|
| 170 |
+
) -> List[str]:
|
| 171 |
+
"""
|
| 172 |
+
Extract key medical concepts for hospital customization matching.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
query: Medical query text
|
| 176 |
+
max_tokens: Maximum tokens to generate
|
| 177 |
+
timeout: Specific API call timeout
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
List of key medical keywords/concepts
|
| 181 |
+
"""
|
| 182 |
+
import time
|
| 183 |
+
|
| 184 |
+
# Start timing
|
| 185 |
+
start_time = time.time()
|
| 186 |
+
|
| 187 |
+
try:
|
| 188 |
+
self.logger.info(f"Extracting medical keywords for: {query}")
|
| 189 |
+
|
| 190 |
+
# Prepare chat completion request for keyword extraction
|
| 191 |
+
response = self.client.chat.completions.create(
|
| 192 |
+
model="m42-health/Llama3-Med42-70B",
|
| 193 |
+
messages=[
|
| 194 |
+
{
|
| 195 |
+
"role": "system",
|
| 196 |
+
"content": """You are a medical keyword extractor. Extract 2-4 key medical concepts from queries for hospital document matching.
|
| 197 |
+
|
| 198 |
+
Return ONLY the key medical terms/concepts, separated by commas.
|
| 199 |
+
|
| 200 |
+
Examples:
|
| 201 |
+
- "Patient with severe chest pain and shortness of breath" → "chest pain, dyspnea, cardiac"
|
| 202 |
+
- "How to manage atrial fibrillation in emergency?" → "atrial fibrillation, arrhythmia, emergency"
|
| 203 |
+
- "Stroke protocol for elderly patient" → "stroke, cerebrovascular, elderly"
|
| 204 |
+
|
| 205 |
+
Focus on: conditions, symptoms, procedures, body systems."""
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"role": "user",
|
| 209 |
+
"content": query
|
| 210 |
+
}
|
| 211 |
+
],
|
| 212 |
+
max_tokens=max_tokens
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Calculate latency
|
| 216 |
+
end_time = time.time()
|
| 217 |
+
latency = end_time - start_time
|
| 218 |
+
|
| 219 |
+
# Extract keywords from response
|
| 220 |
+
keywords_text = response.choices[0].message.content or ""
|
| 221 |
+
|
| 222 |
+
# Log response and latency
|
| 223 |
+
self.logger.info(f"Keywords extracted: {keywords_text}")
|
| 224 |
+
self.logger.info(f"Keyword extraction latency: {latency:.4f} seconds")
|
| 225 |
+
|
| 226 |
+
# Parse keywords
|
| 227 |
+
keywords = [k.strip() for k in keywords_text.split(',') if k.strip()]
|
| 228 |
+
|
| 229 |
+
# Filter out empty or very short keywords
|
| 230 |
+
keywords = [k for k in keywords if len(k) > 2]
|
| 231 |
+
|
| 232 |
+
return keywords
|
| 233 |
+
|
| 234 |
+
except Exception as e:
|
| 235 |
+
# Calculate latency even for failed requests
|
| 236 |
+
end_time = time.time()
|
| 237 |
+
latency = end_time - start_time
|
| 238 |
+
|
| 239 |
+
self.logger.error(f"Medical keyword extraction error: {str(e)}")
|
| 240 |
+
self.logger.error(f"Query that caused error: {query}")
|
| 241 |
+
|
| 242 |
+
# Return empty list on error
|
| 243 |
+
return []
|
| 244 |
+
|
| 245 |
def _extract_condition(self, response: str) -> str:
|
| 246 |
"""
|
| 247 |
Extract medical condition from model response.
|
|
@@ -1,223 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Test script for OnCall.ai retrieval pipeline
|
| 4 |
-
|
| 5 |
-
This script tests the complete flow:
|
| 6 |
-
user_input → user_prompt.py → retrieval.py
|
| 7 |
-
|
| 8 |
-
Author: OnCall.ai Team
|
| 9 |
-
Date: 2025-07-30
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
import sys
|
| 13 |
-
import os
|
| 14 |
-
from pathlib import Path
|
| 15 |
-
import logging
|
| 16 |
-
import json
|
| 17 |
-
from datetime import datetime
|
| 18 |
-
|
| 19 |
-
# Add src directory to Python path
|
| 20 |
-
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
|
| 21 |
-
|
| 22 |
-
# Import our modules
|
| 23 |
-
from user_prompt import UserPromptProcessor
|
| 24 |
-
from retrieval import BasicRetrievalSystem
|
| 25 |
-
from llm_clients import llm_Med42_70BClient
|
| 26 |
-
|
| 27 |
-
# Configure logging
|
| 28 |
-
logging.basicConfig(
|
| 29 |
-
level=logging.INFO,
|
| 30 |
-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 31 |
-
handlers=[
|
| 32 |
-
logging.StreamHandler(),
|
| 33 |
-
logging.FileHandler('test_retrieval_pipeline.log')
|
| 34 |
-
]
|
| 35 |
-
)
|
| 36 |
-
logger = logging.getLogger(__name__)
|
| 37 |
-
|
| 38 |
-
def test_retrieval_pipeline():
|
| 39 |
-
"""
|
| 40 |
-
Test the complete retrieval pipeline
|
| 41 |
-
"""
|
| 42 |
-
print("="*60)
|
| 43 |
-
print("OnCall.ai Retrieval Pipeline Test")
|
| 44 |
-
print("="*60)
|
| 45 |
-
print(f"Test started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 46 |
-
print()
|
| 47 |
-
|
| 48 |
-
try:
|
| 49 |
-
# Initialize components
|
| 50 |
-
print("🔧 Initializing components...")
|
| 51 |
-
|
| 52 |
-
# Initialize LLM client
|
| 53 |
-
llm_client = llm_Med42_70BClient()
|
| 54 |
-
print("✅ LLM client initialized")
|
| 55 |
-
|
| 56 |
-
# Initialize retrieval system
|
| 57 |
-
retrieval_system = BasicRetrievalSystem()
|
| 58 |
-
print("✅ Retrieval system initialized")
|
| 59 |
-
|
| 60 |
-
# Initialize user prompt processor
|
| 61 |
-
user_prompt_processor = UserPromptProcessor(
|
| 62 |
-
llm_client=llm_client,
|
| 63 |
-
retrieval_system=retrieval_system
|
| 64 |
-
)
|
| 65 |
-
print("✅ User prompt processor initialized")
|
| 66 |
-
print()
|
| 67 |
-
|
| 68 |
-
# Test queries
|
| 69 |
-
test_queries = [
|
| 70 |
-
"how to treat acute MI?",
|
| 71 |
-
"patient with chest pain and shortness of breath",
|
| 72 |
-
"sudden neurological symptoms suggesting stroke",
|
| 73 |
-
"acute stroke management protocol"
|
| 74 |
-
]
|
| 75 |
-
|
| 76 |
-
results = []
|
| 77 |
-
|
| 78 |
-
for i, query in enumerate(test_queries, 1):
|
| 79 |
-
print(f"🔍 Test {i}/{len(test_queries)}: Testing query: '{query}'")
|
| 80 |
-
print("-" * 50)
|
| 81 |
-
|
| 82 |
-
try:
|
| 83 |
-
# Step 1: Extract condition keywords
|
| 84 |
-
print("Step 1: Extracting condition keywords...")
|
| 85 |
-
condition_result = user_prompt_processor.extract_condition_keywords(query)
|
| 86 |
-
|
| 87 |
-
print(f" Condition: {condition_result.get('condition', 'None')}")
|
| 88 |
-
print(f" Emergency keywords: {condition_result.get('emergency_keywords', 'None')}")
|
| 89 |
-
print(f" Treatment keywords: {condition_result.get('treatment_keywords', 'None')}")
|
| 90 |
-
|
| 91 |
-
if not condition_result.get('condition'):
|
| 92 |
-
print(" ⚠️ No condition extracted, skipping retrieval")
|
| 93 |
-
continue
|
| 94 |
-
|
| 95 |
-
# Step 2: User confirmation (simulated)
|
| 96 |
-
print("\nStep 2: User confirmation (simulated as 'yes')")
|
| 97 |
-
confirmation = user_prompt_processor.handle_user_confirmation(condition_result)
|
| 98 |
-
print(f" Confirmation type: {confirmation.get('type', 'Unknown')}")
|
| 99 |
-
|
| 100 |
-
# Step 3: Perform retrieval
|
| 101 |
-
print("\nStep 3: Performing retrieval...")
|
| 102 |
-
search_query = f"{condition_result.get('emergency_keywords', '')} {condition_result.get('treatment_keywords', '')}".strip()
|
| 103 |
-
|
| 104 |
-
if not search_query:
|
| 105 |
-
search_query = condition_result.get('condition', query)
|
| 106 |
-
|
| 107 |
-
print(f" Search query: '{search_query}'")
|
| 108 |
-
|
| 109 |
-
retrieval_results = retrieval_system.search(search_query, top_k=5)
|
| 110 |
-
|
| 111 |
-
# Display results
|
| 112 |
-
print(f"\n📊 Retrieval Results:")
|
| 113 |
-
print(f" Total results: {retrieval_results.get('total_results', 0)}")
|
| 114 |
-
|
| 115 |
-
emergency_results = retrieval_results.get('emergency_results', [])
|
| 116 |
-
treatment_results = retrieval_results.get('treatment_results', [])
|
| 117 |
-
|
| 118 |
-
print(f" Emergency results: {len(emergency_results)}")
|
| 119 |
-
print(f" Treatment results: {len(treatment_results)}")
|
| 120 |
-
|
| 121 |
-
# Show top results
|
| 122 |
-
if 'processed_results' in retrieval_results:
|
| 123 |
-
processed_results = retrieval_results['processed_results'][:3] # Show top 3
|
| 124 |
-
print(f"\n Top {len(processed_results)} results:")
|
| 125 |
-
for j, result in enumerate(processed_results, 1):
|
| 126 |
-
print(f" {j}. Type: {result.get('type', 'Unknown')}")
|
| 127 |
-
print(f" Distance: {result.get('distance', 'Unknown'):.4f}")
|
| 128 |
-
print(f" Text preview: {result.get('text', '')[:100]}...")
|
| 129 |
-
print(f" Matched: {result.get('matched', 'None')}")
|
| 130 |
-
print(f" Treatment matched: {result.get('matched_treatment', 'None')}")
|
| 131 |
-
print()
|
| 132 |
-
|
| 133 |
-
# Store results for summary
|
| 134 |
-
test_result = {
|
| 135 |
-
'query': query,
|
| 136 |
-
'condition_extracted': condition_result.get('condition', ''),
|
| 137 |
-
'emergency_keywords': condition_result.get('emergency_keywords', ''),
|
| 138 |
-
'treatment_keywords': condition_result.get('treatment_keywords', ''),
|
| 139 |
-
'search_query': search_query,
|
| 140 |
-
'total_results': retrieval_results.get('total_results', 0),
|
| 141 |
-
'emergency_count': len(emergency_results),
|
| 142 |
-
'treatment_count': len(treatment_results),
|
| 143 |
-
'success': True
|
| 144 |
-
}
|
| 145 |
-
results.append(test_result)
|
| 146 |
-
|
| 147 |
-
print("✅ Test completed successfully")
|
| 148 |
-
|
| 149 |
-
except Exception as e:
|
| 150 |
-
logger.error(f"Error in test {i}: {e}", exc_info=True)
|
| 151 |
-
test_result = {
|
| 152 |
-
'query': query,
|
| 153 |
-
'error': str(e),
|
| 154 |
-
'success': False
|
| 155 |
-
}
|
| 156 |
-
results.append(test_result)
|
| 157 |
-
print(f"❌ Test failed: {e}")
|
| 158 |
-
|
| 159 |
-
print("\n" + "="*60 + "\n")
|
| 160 |
-
|
| 161 |
-
# Print summary
|
| 162 |
-
print_test_summary(results)
|
| 163 |
-
|
| 164 |
-
# Save results to file
|
| 165 |
-
save_test_results(results)
|
| 166 |
-
|
| 167 |
-
return results
|
| 168 |
-
|
| 169 |
-
except Exception as e:
|
| 170 |
-
logger.error(f"Critical error in pipeline test: {e}", exc_info=True)
|
| 171 |
-
print(f"❌ Critical error: {e}")
|
| 172 |
-
return []
|
| 173 |
-
|
| 174 |
-
def print_test_summary(results):
|
| 175 |
-
"""Print test summary"""
|
| 176 |
-
print("📋 TEST SUMMARY")
|
| 177 |
-
print("="*60)
|
| 178 |
-
|
| 179 |
-
successful_tests = [r for r in results if r.get('success', False)]
|
| 180 |
-
failed_tests = [r for r in results if not r.get('success', False)]
|
| 181 |
-
|
| 182 |
-
print(f"Total tests: {len(results)}")
|
| 183 |
-
print(f"Successful: {len(successful_tests)}")
|
| 184 |
-
print(f"Failed: {len(failed_tests)}")
|
| 185 |
-
print(f"Success rate: {len(successful_tests)/len(results)*100:.1f}%")
|
| 186 |
-
print()
|
| 187 |
-
|
| 188 |
-
if successful_tests:
|
| 189 |
-
print("✅ Successful tests:")
|
| 190 |
-
for result in successful_tests:
|
| 191 |
-
print(f" - '{result['query']}'")
|
| 192 |
-
print(f" Condition: {result.get('condition_extracted', 'None')}")
|
| 193 |
-
print(f" Results: {result.get('total_results', 0)} total "
|
| 194 |
-
f"({result.get('emergency_count', 0)} emergency, "
|
| 195 |
-
f"{result.get('treatment_count', 0)} treatment)")
|
| 196 |
-
print()
|
| 197 |
-
|
| 198 |
-
if failed_tests:
|
| 199 |
-
print("❌ Failed tests:")
|
| 200 |
-
for result in failed_tests:
|
| 201 |
-
print(f" - '{result['query']}': {result.get('error', 'Unknown error')}")
|
| 202 |
-
print()
|
| 203 |
-
|
| 204 |
-
def save_test_results(results):
|
| 205 |
-
"""Save test results to JSON file"""
|
| 206 |
-
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 207 |
-
filename = f"test_results_{timestamp}.json"
|
| 208 |
-
|
| 209 |
-
try:
|
| 210 |
-
with open(filename, 'w', encoding='utf-8') as f:
|
| 211 |
-
json.dump({
|
| 212 |
-
'timestamp': datetime.now().isoformat(),
|
| 213 |
-
'test_results': results
|
| 214 |
-
}, f, indent=2, ensure_ascii=False)
|
| 215 |
-
|
| 216 |
-
print(f"📁 Test results saved to: {filename}")
|
| 217 |
-
|
| 218 |
-
except Exception as e:
|
| 219 |
-
logger.error(f"Failed to save test results: {e}")
|
| 220 |
-
print(f"⚠️ Failed to save test results: {e}")
|
| 221 |
-
|
| 222 |
-
if __name__ == "__main__":
|
| 223 |
-
test_retrieval_pipeline()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|