Spaces:
Sleeping
Sleeping
YanBoChen
commited on
Commit
·
30fc9ee
1
Parent(s):
a1e2d00
feat(user_prompt): update UserPromptProcessor to integrate Llama3-Med42-70B and enhance query validation; add unit tests for condition extraction and matching mechanisms
Browse files- .gitignore +1 -0
- src/__init__.py +12 -2
- src/llm_clients.py +155 -138
- src/medical_conditions.py +1 -1
- src/retrieval.py +18 -0
- src/user_prompt.py +181 -102
- tests/test_user_prompt.py +92 -0
.gitignore
CHANGED
|
@@ -37,6 +37,7 @@ __pycache__/
|
|
| 37 |
*.pem
|
| 38 |
credentials.json
|
| 39 |
token.json
|
|
|
|
| 40 |
|
| 41 |
# 🚫 Large files - models
|
| 42 |
models/cache/
|
|
|
|
| 37 |
*.pem
|
| 38 |
credentials.json
|
| 39 |
token.json
|
| 40 |
+
*.mdc
|
| 41 |
|
| 42 |
# 🚫 Large files - models
|
| 43 |
models/cache/
|
src/__init__.py
CHANGED
|
@@ -3,6 +3,16 @@ OnCall.ai src package
|
|
| 3 |
|
| 4 |
This package contains the core implementation of the OnCall.ai system.
|
| 5 |
"""
|
| 6 |
-
|
| 7 |
# Version
|
| 8 |
-
__version__ = '0.1.0'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
This package contains the core implementation of the OnCall.ai system.
|
| 5 |
"""
|
| 6 |
+
|
| 7 |
# Version
|
| 8 |
+
__version__ = '0.1.0'
|
| 9 |
+
|
| 10 |
+
# import key modules
|
| 11 |
+
from .llm_clients import llm_Med42_70BClient
|
| 12 |
+
from .user_prompt import UserPromptProcessor
|
| 13 |
+
from .retrieval import BasicRetrievalSystem
|
| 14 |
+
from .medical_conditions import (
|
| 15 |
+
CONDITION_KEYWORD_MAPPING,
|
| 16 |
+
get_condition_keywords,
|
| 17 |
+
validate_condition
|
| 18 |
+
)
|
src/llm_clients.py
CHANGED
|
@@ -9,91 +9,71 @@ Date: 2025-07-29
|
|
| 9 |
|
| 10 |
import logging
|
| 11 |
import os
|
| 12 |
-
from typing import Dict, Optional
|
| 13 |
-
import torch
|
| 14 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 15 |
from huggingface_hub import InferenceClient
|
| 16 |
from dotenv import load_dotenv
|
| 17 |
|
| 18 |
# Load environment variables from .env file
|
| 19 |
load_dotenv()
|
| 20 |
|
| 21 |
-
class
|
| 22 |
def __init__(
|
| 23 |
self,
|
| 24 |
-
model_name: str = "
|
| 25 |
-
local_model_path: Optional[str] = None,
|
| 26 |
-
use_local: bool = False,
|
| 27 |
timeout: float = 30.0
|
| 28 |
):
|
| 29 |
"""
|
| 30 |
-
Initialize
|
| 31 |
|
| 32 |
Args:
|
| 33 |
model_name: Hugging Face model name
|
| 34 |
-
local_model_path: Path to local model files
|
| 35 |
-
use_local: Flag to use local model
|
| 36 |
timeout: API call timeout duration
|
| 37 |
|
| 38 |
Warning: This model should not be used for professional medical advice.
|
| 39 |
"""
|
| 40 |
self.logger = logging.getLogger(__name__)
|
| 41 |
self.timeout = timeout
|
| 42 |
-
self.use_local = use_local
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
"Not for professional medical diagnosis."
|
| 67 |
-
)
|
| 68 |
-
except Exception as e:
|
| 69 |
-
self.logger.error(f"Failed to load local model: {str(e)}")
|
| 70 |
-
raise ValueError(f"Failed to initialize local Meditron client: {str(e)}")
|
| 71 |
-
else:
|
| 72 |
-
# Existing InferenceClient logic
|
| 73 |
-
hf_token = os.getenv('HF_TOKEN')
|
| 74 |
-
if not hf_token:
|
| 75 |
-
raise ValueError(
|
| 76 |
-
"HF_TOKEN not found in environment variables. "
|
| 77 |
-
"Please set HF_TOKEN in your .env file or environment."
|
| 78 |
-
)
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
|
| 91 |
def analyze_medical_query(
|
| 92 |
self,
|
| 93 |
query: str,
|
| 94 |
max_tokens: int = 100,
|
| 95 |
timeout: Optional[float] = None
|
| 96 |
-
) -> Dict[str, str]:
|
| 97 |
"""
|
| 98 |
Analyze medical query and extract condition.
|
| 99 |
|
|
@@ -103,82 +83,74 @@ class MeditronClient:
|
|
| 103 |
timeout: Specific API call timeout
|
| 104 |
|
| 105 |
Returns:
|
| 106 |
-
Extracted medical condition information
|
| 107 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
try:
|
| 109 |
-
|
| 110 |
-
prompt = f"""<|im_start|>system
|
| 111 |
-
You are a professional medical assistant trained to extract medical conditions.
|
| 112 |
-
Provide only the most representative condition name.
|
| 113 |
-
DO NOT provide medical advice.
|
| 114 |
-
<|im_end|>
|
| 115 |
-
<|im_start|>user
|
| 116 |
-
{query}
|
| 117 |
-
<|im_end|>
|
| 118 |
-
<|im_start|>assistant
|
| 119 |
-
"""
|
| 120 |
-
|
| 121 |
-
self.logger.info(f"Calling Meditron with query: {query}")
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
response_text = self.tokenizer.decode(response[0], skip_special_tokens=True)
|
| 136 |
-
self.logger.info(f"Local model response: {response_text}")
|
| 137 |
-
else:
|
| 138 |
-
# InferenceClient inference
|
| 139 |
-
self.logger.info(f"Using model: {self.client.model}")
|
| 140 |
-
|
| 141 |
-
# Test API connection first
|
| 142 |
-
try:
|
| 143 |
-
test_response = self.client.text_generation(
|
| 144 |
-
"Hello",
|
| 145 |
-
max_new_tokens=5,
|
| 146 |
-
temperature=0.7,
|
| 147 |
-
top_k=50
|
| 148 |
-
)
|
| 149 |
-
self.logger.info("API connection test successful")
|
| 150 |
-
except Exception as test_error:
|
| 151 |
-
self.logger.error(f"API connection test failed: {str(test_error)}")
|
| 152 |
-
return {
|
| 153 |
-
'extracted_condition': '',
|
| 154 |
-
'confidence': 0,
|
| 155 |
-
'error': f"API connection failed: {str(test_error)}"
|
| 156 |
}
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
# Extract condition from response
|
| 166 |
extracted_condition = self._extract_condition(response_text)
|
| 167 |
|
|
|
|
|
|
|
|
|
|
| 168 |
return {
|
| 169 |
'extracted_condition': extracted_condition,
|
| 170 |
-
'confidence': 0.8,
|
| 171 |
-
'raw_response': response_text
|
|
|
|
| 172 |
}
|
| 173 |
|
| 174 |
except Exception as e:
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
return {
|
| 179 |
'extracted_condition': '',
|
| 180 |
-
'confidence': 0,
|
| 181 |
-
'error':
|
|
|
|
| 182 |
}
|
| 183 |
|
| 184 |
def _extract_condition(self, response: str) -> str:
|
|
@@ -193,26 +165,29 @@ DO NOT provide medical advice.
|
|
| 193 |
"""
|
| 194 |
from medical_conditions import CONDITION_KEYWORD_MAPPING
|
| 195 |
|
| 196 |
-
# Remove prompt parts, keep only generated content
|
| 197 |
-
generated_text = response.split('<|im_start|>assistant\n')[-1].strip()
|
| 198 |
-
|
| 199 |
# Search in known medical conditions
|
| 200 |
for condition in CONDITION_KEYWORD_MAPPING.keys():
|
| 201 |
-
if condition.lower() in
|
| 202 |
return condition
|
| 203 |
|
| 204 |
-
return
|
| 205 |
|
| 206 |
def main():
|
| 207 |
"""
|
| 208 |
-
Test
|
| 209 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
try:
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
)
|
| 216 |
|
| 217 |
test_queries = [
|
| 218 |
"patient experiencing chest pain",
|
|
@@ -220,24 +195,66 @@ def main():
|
|
| 220 |
"severe headache with neurological symptoms"
|
| 221 |
]
|
| 222 |
|
|
|
|
|
|
|
|
|
|
| 223 |
for query in test_queries:
|
| 224 |
print(f"\nTesting query: {query}")
|
| 225 |
result = client.analyze_medical_query(query)
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
if 'error' in result:
|
| 229 |
print("Error:", result['error'])
|
| 230 |
print("---")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
except Exception as e:
|
| 233 |
print(f"Client initialization error: {str(e)}")
|
| 234 |
-
print("
|
| 235 |
-
print("1.
|
| 236 |
-
print("2.
|
| 237 |
-
print("3.
|
| 238 |
-
print("\
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
if __name__ == "__main__":
|
| 243 |
main()
|
|
|
|
| 9 |
|
| 10 |
import logging
|
| 11 |
import os
|
| 12 |
+
from typing import Dict, Optional, Union
|
|
|
|
|
|
|
| 13 |
from huggingface_hub import InferenceClient
|
| 14 |
from dotenv import load_dotenv
|
| 15 |
|
| 16 |
# Load environment variables from .env file
|
| 17 |
load_dotenv()
|
| 18 |
|
| 19 |
+
class llm_Med42_70BClient:
|
| 20 |
def __init__(
|
| 21 |
self,
|
| 22 |
+
model_name: str = "m42-health/Llama3-Med42-70B",
|
|
|
|
|
|
|
| 23 |
timeout: float = 30.0
|
| 24 |
):
|
| 25 |
"""
|
| 26 |
+
Initialize Medical LLM client for query processing.
|
| 27 |
|
| 28 |
Args:
|
| 29 |
model_name: Hugging Face model name
|
|
|
|
|
|
|
| 30 |
timeout: API call timeout duration
|
| 31 |
|
| 32 |
Warning: This model should not be used for professional medical advice.
|
| 33 |
"""
|
| 34 |
self.logger = logging.getLogger(__name__)
|
| 35 |
self.timeout = timeout
|
|
|
|
| 36 |
|
| 37 |
+
# Configure logging to show detailed information
|
| 38 |
+
logging.basicConfig(
|
| 39 |
+
level=logging.INFO,
|
| 40 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Get Hugging Face token from environment
|
| 44 |
+
hf_token = os.getenv('HF_TOKEN')
|
| 45 |
+
if not hf_token:
|
| 46 |
+
self.logger.error("HF_TOKEN is missing from environment variables.")
|
| 47 |
+
raise ValueError(
|
| 48 |
+
"HF_TOKEN not found in environment variables. "
|
| 49 |
+
"Please set HF_TOKEN in your .env file or environment. "
|
| 50 |
+
"Ensure the token is not empty and is correctly set."
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
# Initialize InferenceClient with the new model
|
| 55 |
+
self.client = InferenceClient(
|
| 56 |
+
provider="featherless-ai",
|
| 57 |
+
api_key=hf_token
|
| 58 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
self.logger.info(f"Medical LLM client initialized with model: {model_name}")
|
| 61 |
+
self.logger.warning(
|
| 62 |
+
"Medical LLM Model: Research tool only. "
|
| 63 |
+
"Not for professional medical diagnosis."
|
| 64 |
+
)
|
| 65 |
+
except Exception as e:
|
| 66 |
+
self.logger.error(f"Failed to initialize InferenceClient: {str(e)}")
|
| 67 |
+
self.logger.error(f"Error Type: {type(e).__name__}")
|
| 68 |
+
self.logger.error(f"Detailed Error: {repr(e)}")
|
| 69 |
+
raise ValueError(f"Failed to initialize Medical LLM client: {str(e)}") from e
|
| 70 |
|
| 71 |
def analyze_medical_query(
|
| 72 |
self,
|
| 73 |
query: str,
|
| 74 |
max_tokens: int = 100,
|
| 75 |
timeout: Optional[float] = None
|
| 76 |
+
) -> Dict[str, Union[str, float]]:
|
| 77 |
"""
|
| 78 |
Analyze medical query and extract condition.
|
| 79 |
|
|
|
|
| 83 |
timeout: Specific API call timeout
|
| 84 |
|
| 85 |
Returns:
|
| 86 |
+
Extracted medical condition information with latency
|
| 87 |
"""
|
| 88 |
+
import time
|
| 89 |
+
|
| 90 |
+
# Start timing
|
| 91 |
+
start_time = time.time()
|
| 92 |
+
|
| 93 |
try:
|
| 94 |
+
self.logger.info(f"Calling Medical LLM with query: {query}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
+
# Prepare chat completion request
|
| 97 |
+
response = self.client.chat.completions.create(
|
| 98 |
+
model="m42-health/Llama3-Med42-70B",
|
| 99 |
+
messages=[
|
| 100 |
+
{
|
| 101 |
+
"role": "system",
|
| 102 |
+
"content": "You are a professional medical assistant trained to extract medical conditions. Provide only the most representative condition name. DO NOT provide medical advice."
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"role": "user",
|
| 106 |
+
"content": query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
}
|
| 108 |
+
],
|
| 109 |
+
max_tokens=max_tokens
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Calculate latency
|
| 113 |
+
end_time = time.time()
|
| 114 |
+
latency = end_time - start_time
|
| 115 |
+
|
| 116 |
+
# Extract the response text
|
| 117 |
+
response_text = response.choices[0].message.content or ""
|
| 118 |
+
|
| 119 |
+
# Log raw response and latency
|
| 120 |
+
self.logger.info(f"Raw LLM Response: {response_text}")
|
| 121 |
+
self.logger.info(f"Query Latency: {latency:.4f} seconds")
|
| 122 |
|
| 123 |
# Extract condition from response
|
| 124 |
extracted_condition = self._extract_condition(response_text)
|
| 125 |
|
| 126 |
+
# Log the extracted condition
|
| 127 |
+
self.logger.info(f"Extracted Condition: {extracted_condition}")
|
| 128 |
+
|
| 129 |
return {
|
| 130 |
'extracted_condition': extracted_condition,
|
| 131 |
+
'confidence': '0.8',
|
| 132 |
+
'raw_response': response_text,
|
| 133 |
+
'latency': latency # Add latency to the return dictionary
|
| 134 |
}
|
| 135 |
|
| 136 |
except Exception as e:
|
| 137 |
+
# Calculate latency even for failed requests
|
| 138 |
+
end_time = time.time()
|
| 139 |
+
latency = end_time - start_time
|
| 140 |
+
|
| 141 |
+
self.logger.error(f"Medical LLM query error: {str(e)}")
|
| 142 |
+
self.logger.error(f"Error Type: {type(e).__name__}")
|
| 143 |
+
self.logger.error(f"Detailed Error: {repr(e)}")
|
| 144 |
+
self.logger.error(f"Query Latency (on error): {latency:.4f} seconds")
|
| 145 |
+
|
| 146 |
+
# Additional context logging
|
| 147 |
+
self.logger.error(f"Query that caused error: {query}")
|
| 148 |
+
|
| 149 |
return {
|
| 150 |
'extracted_condition': '',
|
| 151 |
+
'confidence': '0',
|
| 152 |
+
'error': str(e),
|
| 153 |
+
'latency': latency # Include latency even for error cases
|
| 154 |
}
|
| 155 |
|
| 156 |
def _extract_condition(self, response: str) -> str:
|
|
|
|
| 165 |
"""
|
| 166 |
from medical_conditions import CONDITION_KEYWORD_MAPPING
|
| 167 |
|
|
|
|
|
|
|
|
|
|
| 168 |
# Search in known medical conditions
|
| 169 |
for condition in CONDITION_KEYWORD_MAPPING.keys():
|
| 170 |
+
if condition.lower() in response.lower():
|
| 171 |
return condition
|
| 172 |
|
| 173 |
+
return response.split('\n')[0].strip() or ""
|
| 174 |
|
| 175 |
def main():
|
| 176 |
"""
|
| 177 |
+
Test Medical LLM client functionality
|
| 178 |
"""
|
| 179 |
+
import time
|
| 180 |
+
from datetime import datetime
|
| 181 |
+
|
| 182 |
+
# Record total execution start time
|
| 183 |
+
total_start_time = time.time()
|
| 184 |
+
execution_start_timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 185 |
+
|
| 186 |
try:
|
| 187 |
+
print(f"Execution Started at: {execution_start_timestamp}")
|
| 188 |
+
|
| 189 |
+
# Test client initialization
|
| 190 |
+
client = llm_Med42_70BClient()
|
|
|
|
| 191 |
|
| 192 |
test_queries = [
|
| 193 |
"patient experiencing chest pain",
|
|
|
|
| 195 |
"severe headache with neurological symptoms"
|
| 196 |
]
|
| 197 |
|
| 198 |
+
# Store individual query results
|
| 199 |
+
query_results = []
|
| 200 |
+
|
| 201 |
for query in test_queries:
|
| 202 |
print(f"\nTesting query: {query}")
|
| 203 |
result = client.analyze_medical_query(query)
|
| 204 |
+
|
| 205 |
+
# Store query result
|
| 206 |
+
query_result = {
|
| 207 |
+
'query': query,
|
| 208 |
+
'extracted_condition': result.get('extracted_condition', 'N/A'),
|
| 209 |
+
'confidence': result.get('confidence', 'N/A'),
|
| 210 |
+
'latency': result.get('latency', 'N/A')
|
| 211 |
+
}
|
| 212 |
+
query_results.append(query_result)
|
| 213 |
+
|
| 214 |
+
# Print individual query results
|
| 215 |
+
print("Extracted Condition:", query_result['extracted_condition'])
|
| 216 |
+
print("Confidence:", query_result['confidence'])
|
| 217 |
+
print(f"Latency: {query_result['latency']:.4f} seconds")
|
| 218 |
+
|
| 219 |
if 'error' in result:
|
| 220 |
print("Error:", result['error'])
|
| 221 |
print("---")
|
| 222 |
+
|
| 223 |
+
# Calculate total execution time
|
| 224 |
+
total_end_time = time.time()
|
| 225 |
+
total_execution_time = total_end_time - total_start_time
|
| 226 |
+
execution_end_timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 227 |
+
|
| 228 |
+
# Print summary
|
| 229 |
+
print("\n--- Execution Summary ---")
|
| 230 |
+
print(f"Execution Started at: {execution_start_timestamp}")
|
| 231 |
+
print(f"Execution Ended at: {execution_end_timestamp}")
|
| 232 |
+
print(f"Total Execution Time: {total_execution_time:.4f} seconds")
|
| 233 |
+
|
| 234 |
+
# Optional: Return results for potential further processing
|
| 235 |
+
return {
|
| 236 |
+
'start_time': execution_start_timestamp,
|
| 237 |
+
'end_time': execution_end_timestamp,
|
| 238 |
+
'total_execution_time': total_execution_time,
|
| 239 |
+
'query_results': query_results
|
| 240 |
+
}
|
| 241 |
|
| 242 |
except Exception as e:
|
| 243 |
print(f"Client initialization error: {str(e)}")
|
| 244 |
+
print("Possible issues:")
|
| 245 |
+
print("1. Invalid or missing Hugging Face token")
|
| 246 |
+
print("2. Network connectivity problems")
|
| 247 |
+
print("3. Model access restrictions")
|
| 248 |
+
print("\nPlease check your .env file and Hugging Face token.")
|
| 249 |
+
|
| 250 |
+
# Calculate total execution time even in case of error
|
| 251 |
+
total_end_time = time.time()
|
| 252 |
+
total_execution_time = total_end_time - total_start_time
|
| 253 |
+
|
| 254 |
+
return {
|
| 255 |
+
'error': str(e),
|
| 256 |
+
'total_execution_time': total_execution_time
|
| 257 |
+
}
|
| 258 |
|
| 259 |
if __name__ == "__main__":
|
| 260 |
main()
|
src/medical_conditions.py
CHANGED
|
@@ -26,7 +26,7 @@ CONDITION_KEYWORD_MAPPING: Dict[str, Dict[str, str]] = {
|
|
| 26 |
"emergency": "chest pain|shortness of breath|sudden dyspnea",
|
| 27 |
"treatment": "anticoagulation|heparin|embolectomy"
|
| 28 |
},
|
| 29 |
-
#
|
| 30 |
"acute_ischemic_stroke": {
|
| 31 |
"emergency": "ischemic stroke|neurological deficit",
|
| 32 |
"treatment": "tPA|stroke unit management"
|
|
|
|
| 26 |
"emergency": "chest pain|shortness of breath|sudden dyspnea",
|
| 27 |
"treatment": "anticoagulation|heparin|embolectomy"
|
| 28 |
},
|
| 29 |
+
# extended from @20250729Test_Retrieval.md
|
| 30 |
"acute_ischemic_stroke": {
|
| 31 |
"emergency": "ischemic stroke|neurological deficit",
|
| 32 |
"treatment": "tPA|stroke unit management"
|
src/retrieval.py
CHANGED
|
@@ -368,4 +368,22 @@ class BasicRetrievalSystem:
|
|
| 368 |
|
| 369 |
except Exception as e:
|
| 370 |
logger.error(f"Sliding window search failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
return []
|
|
|
|
| 368 |
|
| 369 |
except Exception as e:
|
| 370 |
logger.error(f"Sliding window search failed: {e}")
|
| 371 |
+
return []
|
| 372 |
+
|
| 373 |
+
def search_generic_medical_content(self, query: str, top_k: int = 5) -> List[Dict]:
|
| 374 |
+
"""
|
| 375 |
+
Perform generic medical content search
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
query: Search query
|
| 379 |
+
top_k: Number of top results to return
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
List of search results
|
| 383 |
+
"""
|
| 384 |
+
try:
|
| 385 |
+
# re-use search_sliding_window_chunks method
|
| 386 |
+
return self.search_sliding_window_chunks(query, top_k=top_k)
|
| 387 |
+
except Exception as e:
|
| 388 |
+
logger.error(f"Generic medical content search error: {e}")
|
| 389 |
return []
|
src/user_prompt.py
CHANGED
|
@@ -34,15 +34,15 @@ logging.basicConfig(
|
|
| 34 |
logger = logging.getLogger(__name__)
|
| 35 |
|
| 36 |
class UserPromptProcessor:
|
| 37 |
-
def __init__(self,
|
| 38 |
"""
|
| 39 |
-
Initialize UserPromptProcessor with optional
|
| 40 |
|
| 41 |
Args:
|
| 42 |
-
|
| 43 |
retrieval_system: Optional retrieval system for semantic search
|
| 44 |
"""
|
| 45 |
-
self.
|
| 46 |
self.retrieval_system = retrieval_system
|
| 47 |
self.embedding_model = SentenceTransformer("NeuML/pubmedbert-base-embeddings")
|
| 48 |
|
|
@@ -66,11 +66,11 @@ class UserPromptProcessor:
|
|
| 66 |
if predefined_result:
|
| 67 |
return predefined_result
|
| 68 |
|
| 69 |
-
# Level 2:
|
| 70 |
-
if self.
|
| 71 |
-
|
| 72 |
-
if
|
| 73 |
-
return
|
| 74 |
|
| 75 |
# Level 3: Semantic Search Fallback
|
| 76 |
semantic_result = self._semantic_search_fallback(user_query)
|
|
@@ -112,9 +112,9 @@ class UserPromptProcessor:
|
|
| 112 |
|
| 113 |
return None
|
| 114 |
|
| 115 |
-
def
|
| 116 |
"""
|
| 117 |
-
Use
|
| 118 |
|
| 119 |
Args:
|
| 120 |
user_query: User's medical query
|
|
@@ -122,17 +122,17 @@ class UserPromptProcessor:
|
|
| 122 |
Returns:
|
| 123 |
Dict with condition and keywords, or None
|
| 124 |
"""
|
| 125 |
-
if not self.
|
| 126 |
return None
|
| 127 |
|
| 128 |
try:
|
| 129 |
-
|
| 130 |
query=user_query,
|
| 131 |
max_tokens=100,
|
| 132 |
timeout=2.0
|
| 133 |
)
|
| 134 |
|
| 135 |
-
extracted_condition =
|
| 136 |
|
| 137 |
if extracted_condition and validate_condition(extracted_condition):
|
| 138 |
condition_details = get_condition_keywords(extracted_condition)
|
|
@@ -145,12 +145,12 @@ class UserPromptProcessor:
|
|
| 145 |
return None
|
| 146 |
|
| 147 |
except Exception as e:
|
| 148 |
-
logger.error(f"
|
| 149 |
return None
|
| 150 |
|
| 151 |
def _semantic_search_fallback(self, user_query: str) -> Optional[Dict[str, str]]:
|
| 152 |
"""
|
| 153 |
-
Perform semantic search for condition extraction
|
| 154 |
|
| 155 |
Args:
|
| 156 |
user_query: User's medical query
|
|
@@ -158,31 +158,45 @@ class UserPromptProcessor:
|
|
| 158 |
Returns:
|
| 159 |
Dict with condition and keywords, or None
|
| 160 |
"""
|
|
|
|
|
|
|
| 161 |
if not self.retrieval_system:
|
|
|
|
| 162 |
return None
|
| 163 |
|
| 164 |
try:
|
| 165 |
# Perform semantic search on sliding window chunks
|
| 166 |
semantic_results = self.retrieval_system.search_sliding_window_chunks(user_query)
|
| 167 |
|
|
|
|
|
|
|
| 168 |
if semantic_results:
|
| 169 |
# Extract condition from top semantic result
|
| 170 |
top_result = semantic_results[0]
|
| 171 |
condition = self._infer_condition_from_text(top_result['text'])
|
| 172 |
|
|
|
|
|
|
|
| 173 |
if condition and validate_condition(condition):
|
| 174 |
condition_details = get_condition_keywords(condition)
|
| 175 |
-
|
| 176 |
'condition': condition,
|
| 177 |
'emergency_keywords': condition_details.get('emergency', ''),
|
| 178 |
'treatment_keywords': condition_details.get('treatment', ''),
|
| 179 |
'semantic_confidence': top_result.get('distance', 0)
|
| 180 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
|
|
|
| 182 |
return None
|
| 183 |
|
| 184 |
except Exception as e:
|
| 185 |
-
logger.error(f"Semantic search fallback error: {e}")
|
| 186 |
return None
|
| 187 |
|
| 188 |
def _generic_medical_search(self, user_query: str) -> Optional[Dict[str, str]]:
|
|
@@ -369,97 +383,162 @@ Please confirm:
|
|
| 369 |
'extracted_info': extracted_info
|
| 370 |
}
|
| 371 |
|
| 372 |
-
def
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
#
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
)
|
| 422 |
-
|
| 423 |
-
# If Meditron successfully extracts a medical condition
|
| 424 |
-
if meditron_result.get('extracted_condition'):
|
| 425 |
-
return None # Validated by Meditron
|
| 426 |
-
|
| 427 |
-
except Exception as e:
|
| 428 |
-
# Log Meditron analysis failure without blocking the process
|
| 429 |
-
self.logger.warning(f"Meditron query validation failed: {e}")
|
| 430 |
-
|
| 431 |
-
# If no medical relevance is found
|
| 432 |
-
return self._generate_invalid_query_response()
|
| 433 |
|
| 434 |
-
def
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
|
| 452 |
def main():
|
| 453 |
"""
|
| 454 |
-
Example usage and testing of UserPromptProcessor
|
|
|
|
| 455 |
"""
|
| 456 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
|
| 458 |
-
#
|
| 459 |
test_queries = [
|
| 460 |
-
"
|
| 461 |
-
"
|
| 462 |
-
"
|
| 463 |
]
|
| 464 |
|
| 465 |
for query in test_queries:
|
|
|
|
| 34 |
logger = logging.getLogger(__name__)
|
| 35 |
|
| 36 |
class UserPromptProcessor:
|
| 37 |
+
def __init__(self, llm_client=None, retrieval_system=None):
|
| 38 |
"""
|
| 39 |
+
Initialize UserPromptProcessor with optional LLM and retrieval system
|
| 40 |
|
| 41 |
Args:
|
| 42 |
+
llm_client: Optional Llama3-Med42-70B client for advanced condition extraction
|
| 43 |
retrieval_system: Optional retrieval system for semantic search
|
| 44 |
"""
|
| 45 |
+
self.llm_client = llm_client
|
| 46 |
self.retrieval_system = retrieval_system
|
| 47 |
self.embedding_model = SentenceTransformer("NeuML/pubmedbert-base-embeddings")
|
| 48 |
|
|
|
|
| 66 |
if predefined_result:
|
| 67 |
return predefined_result
|
| 68 |
|
| 69 |
+
# Level 2: Llama3-Med42-70B Extraction (if available)
|
| 70 |
+
if self.llm_client:
|
| 71 |
+
llm_result = self._extract_with_llm(user_query)
|
| 72 |
+
if llm_result:
|
| 73 |
+
return llm_result
|
| 74 |
|
| 75 |
# Level 3: Semantic Search Fallback
|
| 76 |
semantic_result = self._semantic_search_fallback(user_query)
|
|
|
|
| 112 |
|
| 113 |
return None
|
| 114 |
|
| 115 |
+
def _extract_with_llm(self, user_query: str) -> Optional[Dict[str, str]]:
|
| 116 |
"""
|
| 117 |
+
Use Llama3-Med42-70B for advanced condition extraction
|
| 118 |
|
| 119 |
Args:
|
| 120 |
user_query: User's medical query
|
|
|
|
| 122 |
Returns:
|
| 123 |
Dict with condition and keywords, or None
|
| 124 |
"""
|
| 125 |
+
if not self.llm_client:
|
| 126 |
return None
|
| 127 |
|
| 128 |
try:
|
| 129 |
+
llama_response = self.llm_client.analyze_medical_query(
|
| 130 |
query=user_query,
|
| 131 |
max_tokens=100,
|
| 132 |
timeout=2.0
|
| 133 |
)
|
| 134 |
|
| 135 |
+
extracted_condition = llama_response.get('extracted_condition', '')
|
| 136 |
|
| 137 |
if extracted_condition and validate_condition(extracted_condition):
|
| 138 |
condition_details = get_condition_keywords(extracted_condition)
|
|
|
|
| 145 |
return None
|
| 146 |
|
| 147 |
except Exception as e:
|
| 148 |
+
logger.error(f"Llama3-Med42-70B condition extraction error: {e}")
|
| 149 |
return None
|
| 150 |
|
| 151 |
def _semantic_search_fallback(self, user_query: str) -> Optional[Dict[str, str]]:
|
| 152 |
"""
|
| 153 |
+
Perform semantic search for condition extraction using sliding window chunks
|
| 154 |
|
| 155 |
Args:
|
| 156 |
user_query: User's medical query
|
|
|
|
| 158 |
Returns:
|
| 159 |
Dict with condition and keywords, or None
|
| 160 |
"""
|
| 161 |
+
logger.info(f"Starting semantic search fallback for query: '{user_query}'")
|
| 162 |
+
|
| 163 |
if not self.retrieval_system:
|
| 164 |
+
logger.warning("No retrieval system available for semantic search")
|
| 165 |
return None
|
| 166 |
|
| 167 |
try:
|
| 168 |
# Perform semantic search on sliding window chunks
|
| 169 |
semantic_results = self.retrieval_system.search_sliding_window_chunks(user_query)
|
| 170 |
|
| 171 |
+
logger.info(f"Semantic search returned {len(semantic_results)} results")
|
| 172 |
+
|
| 173 |
if semantic_results:
|
| 174 |
# Extract condition from top semantic result
|
| 175 |
top_result = semantic_results[0]
|
| 176 |
condition = self._infer_condition_from_text(top_result['text'])
|
| 177 |
|
| 178 |
+
logger.info(f"Inferred condition: {condition}")
|
| 179 |
+
|
| 180 |
if condition and validate_condition(condition):
|
| 181 |
condition_details = get_condition_keywords(condition)
|
| 182 |
+
result = {
|
| 183 |
'condition': condition,
|
| 184 |
'emergency_keywords': condition_details.get('emergency', ''),
|
| 185 |
'treatment_keywords': condition_details.get('treatment', ''),
|
| 186 |
'semantic_confidence': top_result.get('distance', 0)
|
| 187 |
}
|
| 188 |
+
|
| 189 |
+
logger.info(f"Semantic search successful. Condition: {condition}, "
|
| 190 |
+
f"Confidence: {result['semantic_confidence']}")
|
| 191 |
+
return result
|
| 192 |
+
else:
|
| 193 |
+
logger.warning(f"Condition validation failed for: {condition}")
|
| 194 |
|
| 195 |
+
logger.info("No suitable condition found in semantic search")
|
| 196 |
return None
|
| 197 |
|
| 198 |
except Exception as e:
|
| 199 |
+
logger.error(f"Semantic search fallback error: {e}", exc_info=True)
|
| 200 |
return None
|
| 201 |
|
| 202 |
def _generic_medical_search(self, user_query: str) -> Optional[Dict[str, str]]:
|
|
|
|
| 383 |
'extracted_info': extracted_info
|
| 384 |
}
|
| 385 |
|
| 386 |
+
def _handle_matching_failure_level1(self, condition: str) -> Optional[Dict[str, Any]]:
|
| 387 |
+
"""
|
| 388 |
+
Level 1 Fallback: Loose keyword matching for medical conditions
|
| 389 |
+
|
| 390 |
+
Args:
|
| 391 |
+
condition: The condition to match loosely
|
| 392 |
+
|
| 393 |
+
Returns:
|
| 394 |
+
Dict with matched keywords or None
|
| 395 |
+
"""
|
| 396 |
+
# Predefined loose matching keywords for different medical domains
|
| 397 |
+
loose_medical_keywords = {
|
| 398 |
+
'emergency': [
|
| 399 |
+
'urgent', 'critical', 'severe', 'acute',
|
| 400 |
+
'immediate', 'life-threatening', 'emergency'
|
| 401 |
+
],
|
| 402 |
+
'treatment': [
|
| 403 |
+
'manage', 'cure', 'heal', 'recover',
|
| 404 |
+
'therapy', 'medication', 'intervention'
|
| 405 |
+
]
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
# Normalize condition
|
| 409 |
+
condition_lower = condition.lower().strip()
|
| 410 |
+
|
| 411 |
+
# Check emergency keywords
|
| 412 |
+
emergency_matches = [
|
| 413 |
+
kw for kw in loose_medical_keywords['emergency']
|
| 414 |
+
if kw in condition_lower
|
| 415 |
+
]
|
| 416 |
+
|
| 417 |
+
# Check treatment keywords
|
| 418 |
+
treatment_matches = [
|
| 419 |
+
kw for kw in loose_medical_keywords['treatment']
|
| 420 |
+
if kw in condition_lower
|
| 421 |
+
]
|
| 422 |
+
|
| 423 |
+
# If matches found, return result
|
| 424 |
+
if emergency_matches or treatment_matches:
|
| 425 |
+
logger.info(f"Loose keyword match for condition: {condition}")
|
| 426 |
+
return {
|
| 427 |
+
'type': 'loose_keyword_match',
|
| 428 |
+
'condition': condition,
|
| 429 |
+
'emergency_keywords': '|'.join(emergency_matches),
|
| 430 |
+
'treatment_keywords': '|'.join(treatment_matches),
|
| 431 |
+
'confidence': 0.5 # Lower confidence due to loose matching
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
# No loose matches found
|
| 435 |
+
logger.info(f"No loose keyword match for condition: {condition}")
|
| 436 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
+
def validate_medical_query(self, user_query: str) -> Dict[str, Any]:
|
| 439 |
+
"""
|
| 440 |
+
Validate if the query is a medical-related query using Llama3-Med42-70B multi-layer verification
|
| 441 |
+
|
| 442 |
+
Args:
|
| 443 |
+
user_query: User's input query
|
| 444 |
+
|
| 445 |
+
Returns:
|
| 446 |
+
Dict with validation result or None if medical query
|
| 447 |
+
"""
|
| 448 |
+
# Expanded medical keywords covering comprehensive medical terminology
|
| 449 |
+
predefined_medical_keywords = {
|
| 450 |
+
# Symptoms and signs
|
| 451 |
+
'pain', 'symptom', 'ache', 'fever', 'inflammation',
|
| 452 |
+
'bleeding', 'swelling', 'rash', 'bruise', 'wound',
|
| 453 |
+
|
| 454 |
+
# Medical professional terms
|
| 455 |
+
'disease', 'condition', 'syndrome', 'disorder',
|
| 456 |
+
'medical', 'health', 'diagnosis', 'treatment',
|
| 457 |
+
'therapy', 'medication', 'prescription',
|
| 458 |
+
|
| 459 |
+
# Body systems and organs
|
| 460 |
+
'heart', 'lung', 'brain', 'kidney', 'liver',
|
| 461 |
+
'blood', 'nerve', 'muscle', 'bone', 'joint',
|
| 462 |
+
|
| 463 |
+
# Medical actions
|
| 464 |
+
'examine', 'check', 'test', 'scan', 'surgery',
|
| 465 |
+
'operation', 'emergency', 'urgent', 'critical',
|
| 466 |
+
|
| 467 |
+
# Specific medical fields
|
| 468 |
+
'cardiology', 'neurology', 'oncology', 'pediatrics',
|
| 469 |
+
'psychiatry', 'dermatology', 'orthopedics'
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
# Check if query contains predefined medical keywords
|
| 473 |
+
query_lower = user_query.lower()
|
| 474 |
+
if any(kw in query_lower for kw in predefined_medical_keywords):
|
| 475 |
+
return None # Validated by predefined keywords
|
| 476 |
+
|
| 477 |
+
try:
|
| 478 |
+
# Ensure Llama3-Med42-70B client is properly initialized
|
| 479 |
+
if not hasattr(self, 'llm_client') or self.llm_client is None:
|
| 480 |
+
self.logger.warning("Llama3-Med42-70B client not initialized")
|
| 481 |
+
return self._generate_invalid_query_response()
|
| 482 |
+
|
| 483 |
+
# Use Llama3-Med42-70B for final medical query determination
|
| 484 |
+
llama_result = self.llm_client.analyze_medical_query(
|
| 485 |
+
query=user_query,
|
| 486 |
+
max_tokens=100 # Limit tokens for efficiency
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# If Llama3-Med42-70B successfully extracts a medical condition
|
| 490 |
+
if llama_result.get('extracted_condition'):
|
| 491 |
+
return None # Validated by Llama3-Med42-70B
|
| 492 |
+
|
| 493 |
+
except Exception as e:
|
| 494 |
+
# Log Llama3-Med42-70B analysis failure without blocking the process
|
| 495 |
+
self.logger.warning(f"Llama3-Med42-70B query validation failed: {e}")
|
| 496 |
+
|
| 497 |
+
# If no medical relevance is found
|
| 498 |
+
return self._generate_invalid_query_response()
|
| 499 |
+
|
| 500 |
+
def _generate_invalid_query_response(self) -> Dict[str, Any]:
|
| 501 |
+
"""
|
| 502 |
+
Generate response for non-medical queries
|
| 503 |
+
|
| 504 |
+
Returns:
|
| 505 |
+
Dict with invalid query guidance
|
| 506 |
+
"""
|
| 507 |
+
return {
|
| 508 |
+
'type': 'invalid_query',
|
| 509 |
+
'message': "This is OnCall.AI, a clinical medical assistance platform. "
|
| 510 |
+
"Please input a medical problem you need help resolving. "
|
| 511 |
+
"\n\nExamples:\n"
|
| 512 |
+
"- 'I'm experiencing chest pain'\n"
|
| 513 |
+
"- 'What are symptoms of stroke?'\n"
|
| 514 |
+
"- 'How to manage acute asthma?'\n"
|
| 515 |
+
"- 'I have a persistent headache'"
|
| 516 |
+
}
|
| 517 |
|
| 518 |
def main():
|
| 519 |
"""
|
| 520 |
+
Example usage and testing of UserPromptProcessor with Llama3-Med42-70B
|
| 521 |
+
Demonstrates condition extraction and query validation
|
| 522 |
"""
|
| 523 |
+
from .retrieval import BasicRetrievalSystem
|
| 524 |
+
|
| 525 |
+
# use relative import to avoid circular import
|
| 526 |
+
from .llm_clients import llm_Med42_70BClient
|
| 527 |
+
|
| 528 |
+
# Initialize LLM client
|
| 529 |
+
llm_client = llm_Med42_70BClient()
|
| 530 |
+
retrieval_system = BasicRetrievalSystem()
|
| 531 |
+
|
| 532 |
+
# Initialize UserPromptProcessor with the LLM client
|
| 533 |
+
processor = UserPromptProcessor(
|
| 534 |
+
llm_client=llm_client, retrieval_system=retrieval_system
|
| 535 |
+
)
|
| 536 |
|
| 537 |
+
# Update test cases with more representative medical queries
|
| 538 |
test_queries = [
|
| 539 |
+
"patient with severe chest pain and shortness of breath",
|
| 540 |
+
"sudden neurological symptoms suggesting stroke",
|
| 541 |
+
"persistent headache with vision changes"
|
| 542 |
]
|
| 543 |
|
| 544 |
for query in test_queries:
|
tests/test_user_prompt.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
User Prompt Processor Test Suite
|
| 3 |
+
|
| 4 |
+
Comprehensive unit tests for UserPromptProcessor class
|
| 5 |
+
Ensures robust functionality across medical query scenarios.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
# Dynamically add project root to Python path
|
| 13 |
+
project_root = Path(__file__).parent.parent
|
| 14 |
+
sys.path.insert(0, str(project_root / "src"))
|
| 15 |
+
|
| 16 |
+
from user_prompt import UserPromptProcessor
|
| 17 |
+
|
| 18 |
+
class TestUserPromptProcessor:
|
| 19 |
+
"""Test suite for UserPromptProcessor functionality"""
|
| 20 |
+
|
| 21 |
+
def setup_method(self):
|
| 22 |
+
"""Initialize test environment before each test method"""
|
| 23 |
+
self.processor = UserPromptProcessor()
|
| 24 |
+
|
| 25 |
+
def test_extract_condition_keywords_predefined(self):
|
| 26 |
+
"""Test predefined condition extraction"""
|
| 27 |
+
query = "heart attack symptoms"
|
| 28 |
+
result = self.processor.extract_condition_keywords(query)
|
| 29 |
+
|
| 30 |
+
assert result is not None
|
| 31 |
+
assert 'condition' in result
|
| 32 |
+
assert 'emergency_keywords' in result
|
| 33 |
+
assert 'treatment_keywords' in result
|
| 34 |
+
|
| 35 |
+
def test_handle_matching_failure_level1(self):
|
| 36 |
+
"""Test loose keyword matching mechanism"""
|
| 37 |
+
test_queries = [
|
| 38 |
+
"urgent medical help",
|
| 39 |
+
"critical condition",
|
| 40 |
+
"severe symptoms"
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
for query in test_queries:
|
| 44 |
+
result = self.processor._handle_matching_failure_level1(query)
|
| 45 |
+
|
| 46 |
+
assert result is not None
|
| 47 |
+
assert result['type'] == 'loose_keyword_match'
|
| 48 |
+
assert result['confidence'] == 0.5
|
| 49 |
+
|
| 50 |
+
def test_semantic_search_fallback(self):
|
| 51 |
+
"""Verify semantic search fallback mechanism"""
|
| 52 |
+
test_queries = [
|
| 53 |
+
"how to manage chest pain",
|
| 54 |
+
"treatment for acute stroke",
|
| 55 |
+
"emergency cardiac care"
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
for query in test_queries:
|
| 59 |
+
result = self.processor._semantic_search_fallback(query)
|
| 60 |
+
|
| 61 |
+
# Result can be None if no match found
|
| 62 |
+
if result is not None:
|
| 63 |
+
assert 'condition' in result
|
| 64 |
+
assert 'emergency_keywords' in result
|
| 65 |
+
assert 'treatment_keywords' in result
|
| 66 |
+
|
| 67 |
+
def test_validate_keywords(self):
|
| 68 |
+
"""Test keyword validation functionality"""
|
| 69 |
+
valid_keywords = {
|
| 70 |
+
'emergency_keywords': 'urgent|critical',
|
| 71 |
+
'treatment_keywords': 'medication|therapy'
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
invalid_keywords = {
|
| 75 |
+
'emergency_keywords': '',
|
| 76 |
+
'treatment_keywords': ''
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
assert self.processor.validate_keywords(valid_keywords) is True
|
| 80 |
+
assert self.processor.validate_keywords(invalid_keywords) is False
|
| 81 |
+
|
| 82 |
+
def main():
|
| 83 |
+
"""Run comprehensive test suite with detailed reporting"""
|
| 84 |
+
print("\n" + "="*60)
|
| 85 |
+
print("OnCall.ai: User Prompt Processor Test Suite")
|
| 86 |
+
print("="*60)
|
| 87 |
+
|
| 88 |
+
# Run pytest with verbose output
|
| 89 |
+
pytest.main([__file__, '-v', '--tb=short'])
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
main()
|