Spaces:
Sleeping
Sleeping
| """ | |
| OnCall.ai Medical Advice Generation Module | |
| This module handles: | |
| 1. RAG prompt construction from retrieval results | |
| 2. Medical advice generation using Med42-70B | |
| 3. Response formatting and confidence assessment | |
| 4. Integration with multi-dataset architecture | |
| Author: OnCall.ai Team | |
| Date: 2025-07-31 | |
| """ | |
| import logging | |
| from typing import Dict, List, Optional, Any, Union | |
| from datetime import datetime | |
| import json | |
| # Import existing LLM client | |
| from llm_clients import llm_Med42_70BClient | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class MedicalAdviceGenerator: | |
| """ | |
| Core generation module for medical advice using RAG approach | |
| """ | |
| def __init__(self, llm_client: Optional[llm_Med42_70BClient] = None): | |
| """ | |
| Initialize medical advice generator | |
| Args: | |
| llm_client: Optional Med42-70B client, creates new if None | |
| """ | |
| self.llm_client = llm_client or llm_Med42_70BClient() | |
| # Dataset source priorities for different intentions | |
| self.dataset_priorities = { | |
| "treatment": { | |
| "emergency_subset": 2, | |
| "treatment_subset": 4, | |
| "symptom_subset": 0, # Reserved for Dataset B | |
| "diagnosis_subset": 0 # Reserved for Dataset B | |
| }, | |
| "diagnosis": { | |
| "emergency_subset": 4, | |
| "treatment_subset": 2, | |
| "symptom_subset": 0, # Reserved for Dataset B | |
| "diagnosis_subset": 0 # Reserved for Dataset B | |
| }, | |
| # "STAT": { | |
| # # NOTE: Use when query contains urgent indicators like "NOW", "STAT", "critical" | |
| # "emergency_subset": 5, | |
| # "treatment_subset": 1, | |
| # "symptom_subset": 0, # Reserved for Dataset B | |
| # "diagnosis_subset": 0 # Reserved for Dataset B | |
| # } | |
| } | |
| logger.info("MedicalAdviceGenerator initialized") | |
| def generate_medical_advice(self, user_query: str, retrieval_results: Dict[str, Any], | |
| intention: Optional[str] = None) -> Dict[str, Any]: | |
| """ | |
| Complete pipeline: construct prompt → generate advice → format response | |
| Args: | |
| user_query: Original user medical query | |
| retrieval_results: Results from BasicRetrievalSystem.search() | |
| intention: Optional query intention ('treatment', 'diagnosis', 'STAT'(tentative)) | |
| Returns: | |
| Dict containing formatted medical advice and metadata | |
| """ | |
| try: | |
| logger.info(f"Generating medical advice for query: '{user_query[:50]}...'") | |
| start_time = datetime.now() | |
| # Step 1: Extract and classify chunks from retrieval results | |
| classified_chunks = self._classify_retrieval_chunks(retrieval_results) | |
| # Step 2: Build RAG prompt based on intention and chunk classification | |
| rag_prompt = self.generate_prompt(user_query, classified_chunks, intention) | |
| # Step 3: Generate medical advice using Med42-70B | |
| generation_result = self._generate_with_med42(rag_prompt) | |
| # Step 4: Format structured response | |
| formatted_response = self._format_medical_response( | |
| user_query=user_query, | |
| generated_advice=generation_result, | |
| chunks_used=classified_chunks, | |
| intention=intention, | |
| processing_time=(datetime.now() - start_time).total_seconds() | |
| ) | |
| processing_duration = formatted_response.get('query_metadata', {}).get('processing_time_seconds', 0) | |
| logger.info(f"Medical advice generated successfully in {processing_duration:.3f}s") | |
| return formatted_response | |
| except Exception as e: | |
| logger.error(f"Medical advice generation failed: {e}") | |
| return self._generate_error_response(user_query, str(e)) | |
| def generate_prompt(self, user_query: str, classified_chunks: Dict[str, List], | |
| intention: Optional[str] = None) -> str: | |
| """ | |
| Enhanced prompt generator with flexible dataset integration | |
| Args: | |
| user_query: User's medical query | |
| classified_chunks: Chunks classified by dataset source | |
| intention: Query intention if detected | |
| Returns: | |
| Structured RAG prompt for Med42-70B | |
| """ | |
| logger.info(f"Generating prompt with intention: {intention}") | |
| # Extract chunks by dataset source | |
| emergency_chunks = classified_chunks.get("emergency_subset", []) | |
| treatment_chunks = classified_chunks.get("treatment_subset", []) | |
| symptom_chunks = classified_chunks.get("symptom_subset", []) # Dataset B (future) | |
| diagnosis_chunks = classified_chunks.get("diagnosis_subset", []) # Dataset B (future) | |
| # Select chunks based on intention or intelligent defaults | |
| selected_chunks = self._select_chunks_by_intention( | |
| intention=intention, | |
| emergency_chunks=emergency_chunks, | |
| treatment_chunks=treatment_chunks, | |
| symptom_chunks=symptom_chunks, | |
| diagnosis_chunks=diagnosis_chunks | |
| ) | |
| # Build context block from selected chunks | |
| context_block = self._build_context_block(selected_chunks) | |
| # Construct medical RAG prompt | |
| prompt = self._construct_medical_prompt(user_query, context_block, intention) | |
| logger.info(f"Generated prompt with {len(selected_chunks)} chunks, {len(context_block)} chars") | |
| return prompt | |
| def _classify_retrieval_chunks(self, retrieval_results: Dict[str, Any]) -> Dict[str, List]: | |
| """ | |
| Classify retrieval chunks by dataset source | |
| Args: | |
| retrieval_results: Results from BasicRetrievalSystem.search() | |
| Returns: | |
| Dict mapping dataset sources to chunk lists | |
| """ | |
| classified = { | |
| "emergency_subset": [], | |
| "treatment_subset": [], | |
| "symptom_subset": [], # Reserved for Dataset B | |
| "diagnosis_subset": [] # Reserved for Dataset B | |
| } | |
| # Process results from current dual-index system | |
| processed_results = retrieval_results.get('processed_results', []) | |
| for chunk in processed_results: | |
| chunk_type = chunk.get('type', 'unknown') | |
| # Map current system types to dataset sources | |
| if chunk_type == 'emergency': | |
| classified["emergency_subset"].append(chunk) | |
| elif chunk_type == 'treatment': | |
| classified["treatment_subset"].append(chunk) | |
| else: | |
| # Unknown type, classify by content analysis or default to STAT (tentative) | |
| logger.warning(f"Unknown chunk type: {chunk_type}, defaulting to STAT (tentative)") | |
| classified["emergency_subset"].append(chunk) | |
| # TODO: Future integration point for Dataset B | |
| # When Dataset B team provides symptom/diagnosis data: | |
| # classified["symptom_subset"] = process_dataset_b_symptoms(retrieval_results) | |
| # classified["diagnosis_subset"] = process_dataset_b_diagnosis(retrieval_results) | |
| logger.info(f"Classified chunks: Emergency={len(classified['emergency_subset'])}, " | |
| f"Treatment={len(classified['treatment_subset'])}") | |
| return classified | |
| def _select_chunks_by_intention(self, intention: Optional[str], | |
| emergency_chunks: List, treatment_chunks: List, | |
| symptom_chunks: List, diagnosis_chunks: List) -> List: | |
| """ | |
| Select optimal chunk combination based on query intention | |
| Args: | |
| intention: Detected or specified intention | |
| *_chunks: Chunks from different dataset sources | |
| Returns: | |
| List of selected chunks for prompt construction | |
| """ | |
| if intention and intention in self.dataset_priorities: | |
| # Use predefined priorities for known intentions | |
| priorities = self.dataset_priorities[intention] | |
| selected_chunks = [] | |
| # Add chunks according to priority allocation | |
| selected_chunks.extend(emergency_chunks[:priorities["emergency_subset"]]) | |
| selected_chunks.extend(treatment_chunks[:priorities["treatment_subset"]]) | |
| # TODO: Future Dataset B integration | |
| # selected_chunks.extend(symptom_chunks[:priorities["symptom_subset"]]) | |
| # selected_chunks.extend(diagnosis_chunks[:priorities["diagnosis_subset"]]) | |
| logger.info(f"Selected chunks by intention '{intention}': {len(selected_chunks)} total") | |
| else: | |
| # No specific intention - let LLM judge from best available chunks | |
| all_chunks = emergency_chunks + treatment_chunks + symptom_chunks + diagnosis_chunks | |
| # Sort by relevance (distance) and take top 6 | |
| all_chunks_sorted = sorted(all_chunks, key=lambda x: x.get("distance", 999)) | |
| selected_chunks = all_chunks_sorted[:6] | |
| logger.info(f"Selected chunks by relevance (no intention): {len(selected_chunks)} total") | |
| return selected_chunks | |
| def _build_context_block(self, selected_chunks: List) -> str: | |
| """ | |
| Build formatted context block from selected chunks | |
| Args: | |
| selected_chunks: List of selected chunks | |
| Returns: | |
| Formatted context string for prompt | |
| """ | |
| if not selected_chunks: | |
| return "No relevant medical guidelines found." | |
| context_parts = [] | |
| for i, chunk in enumerate(selected_chunks, 1): | |
| chunk_text = chunk.get("text", "").strip() | |
| chunk_type = chunk.get("type", "unknown") | |
| distance = chunk.get("distance", 0) | |
| # Format each chunk with metadata | |
| context_part = f""" | |
| [Guideline {i}] (Source: {chunk_type.title()}, Relevance: {1-distance:.3f}) | |
| {chunk_text} | |
| """.strip() | |
| context_parts.append(context_part) | |
| return "\n\n".join(context_parts) | |
| def _construct_medical_prompt(self, user_query: str, context_block: str, | |
| intention: Optional[str]) -> str: | |
| """ | |
| Construct final medical RAG prompt with appropriate framing | |
| Args: | |
| user_query: Original user query | |
| context_block: Formatted context from selected chunks | |
| intention: Query intention if detected | |
| Returns: | |
| Complete RAG prompt for Med42-70B | |
| """ | |
| # Customize prompt based on intention | |
| if intention == "treatment": | |
| focus_guidance = "Focus on providing specific treatment protocols, management steps, and therapeutic interventions." | |
| elif intention == "diagnosis": | |
| focus_guidance = "Focus on differential diagnosis, diagnostic criteria, and assessment approaches." | |
| elif intention == "STAT(tentative)": | |
| focus_guidance = "Focus on immediate emergency interventions and critical decision-making steps." | |
| else: | |
| focus_guidance = "Provide comprehensive medical guidance covering both diagnostic and treatment aspects as appropriate." | |
| prompt = f"""You are an experienced attending physician providing guidance to a junior clinician in an emergency setting. A colleague is asking for your expert medical opinion. | |
| Clinical Question: | |
| {user_query} | |
| Relevant Medical Guidelines: | |
| {context_block} | |
| Instructions: | |
| {focus_guidance} | |
| Please provide a clear, actionable response that: | |
| 1. Addresses the specific clinical question asked | |
| 2. References relevant evidence from the provided guidelines | |
| 3. Offers practical, step-by-step guidance when appropriate | |
| 4. Maintains appropriate medical caution and emphasizes the need for clinical judgment | |
| Your response should be concise but comprehensive, suitable for immediate clinical application.""" | |
| return prompt | |
| def _generate_with_med42(self, prompt: str) -> Dict[str, Any]: | |
| """ | |
| Generate medical advice using Med42-70B | |
| Args: | |
| prompt: Complete RAG prompt | |
| Returns: | |
| Generation result with metadata | |
| """ | |
| try: | |
| logger.info("Calling Med42-70B for medical advice generation") | |
| result = self.llm_client.analyze_medical_query( | |
| query=prompt, | |
| max_tokens=500, # Adjust based on needs | |
| timeout=30.0 # Allow more time for complex medical advice | |
| ) | |
| if result.get('error'): | |
| raise Exception(f"Med42-70B generation error: {result['error']}") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Med42-70B generation failed: {e}") | |
| raise | |
| def _format_medical_response(self, user_query: str, generated_advice: Dict[str, Any], | |
| chunks_used: Dict[str, List], intention: Optional[str], | |
| processing_time: float) -> Dict[str, Any]: | |
| """ | |
| Format final medical response with metadata and confidence assessment | |
| Args: | |
| user_query: Original query | |
| generated_advice: Result from Med42-70B | |
| chunks_used: Classification of chunks used | |
| intention: Detected intention | |
| processing_time: Total processing time | |
| Returns: | |
| Structured medical advice response | |
| """ | |
| # Extract generated content - use raw_response for complete medical advice | |
| advice_content = generated_advice.get('raw_response', '') | |
| if not advice_content: | |
| advice_content = generated_advice.get('extracted_condition', 'Unable to generate medical advice.') | |
| # Calculate confidence based on available factors | |
| confidence_score = self._calculate_confidence_score(generated_advice, chunks_used) | |
| # Count chunks used by source | |
| chunk_counts = {source: len(chunks) for source, chunks in chunks_used.items()} | |
| total_chunks = sum(chunk_counts.values()) | |
| formatted_response = { | |
| "medical_advice": advice_content, | |
| "confidence_score": confidence_score, | |
| "query_metadata": { | |
| "original_query": user_query, | |
| "detected_intention": intention, | |
| "processing_time_seconds": processing_time, | |
| "total_chunks_used": total_chunks, | |
| "chunks_by_source": chunk_counts | |
| }, | |
| "generation_metadata": { | |
| "model_used": "m42-health/Llama3-Med42-70B", | |
| "generation_time": generated_advice.get('latency', 0), | |
| "model_confidence": generated_advice.get('confidence', 'unknown'), | |
| "timestamp": datetime.now().isoformat() | |
| }, | |
| "sources": { | |
| "emergency_sources": len(chunks_used.get("emergency_subset", [])), | |
| "treatment_sources": len(chunks_used.get("treatment_subset", [])), | |
| "total_sources": total_chunks | |
| }, | |
| "disclaimer": "This advice is for informational purposes only and should not replace professional medical consultation. Always consult with qualified healthcare providers for medical decisions." | |
| } | |
| return formatted_response | |
| def _calculate_confidence_score(self, generated_advice: Dict[str, Any], | |
| chunks_used: Dict[str, List]) -> float: | |
| """ | |
| Calculate confidence score based on generation quality and source reliability | |
| Args: | |
| generated_advice: Result from Med42-70B | |
| chunks_used: Chunks used in generation | |
| Returns: | |
| Confidence score between 0.0 and 1.0 | |
| """ | |
| confidence_factors = [] | |
| # Factor 1: Model confidence if available | |
| model_confidence = generated_advice.get('confidence', '0.5') | |
| try: | |
| model_conf_value = float(model_confidence) | |
| confidence_factors.append(model_conf_value) | |
| except (ValueError, TypeError): | |
| confidence_factors.append(0.5) # Default neutral confidence | |
| # Factor 2: Number of sources used (more sources = higher confidence) | |
| total_chunks = sum(len(chunks) for chunks in chunks_used.values()) | |
| source_confidence = min(total_chunks / 6.0, 1.0) # Normalize to max 6 chunks | |
| confidence_factors.append(source_confidence) | |
| # Factor 3: Response length (reasonable length indicates comprehensive advice) | |
| response_length = len(generated_advice.get('raw_response', '')) | |
| length_confidence = min(response_length / 500.0, 1.0) # Normalize to ~500 chars | |
| confidence_factors.append(length_confidence) | |
| # Factor 4: Processing success (no errors = higher confidence) | |
| if generated_advice.get('error'): | |
| confidence_factors.append(0.3) # Lower confidence if errors occurred | |
| else: | |
| confidence_factors.append(0.8) # Higher confidence for clean generation | |
| # Calculate weighted average | |
| final_confidence = sum(confidence_factors) / len(confidence_factors) | |
| # Ensure confidence is within valid range | |
| return max(0.1, min(0.95, final_confidence)) | |
| def _generate_error_response(self, user_query: str, error_message: str) -> Dict[str, Any]: | |
| """ | |
| Generate error response when generation fails | |
| Args: | |
| user_query: Original query | |
| error_message: Error details | |
| Returns: | |
| Error response in standard format | |
| """ | |
| return { | |
| "medical_advice": "I apologize, but I encountered an error while processing your medical query. Please try rephrasing your question or contact technical support if the issue persists.", | |
| "confidence_score": 0.0, | |
| "query_metadata": { | |
| "original_query": user_query, | |
| "detected_intention": None, | |
| "processing_time_seconds": 0.0, | |
| "total_chunks_used": 0, | |
| "chunks_by_source": {} | |
| }, | |
| "generation_metadata": { | |
| "model_used": "m42-health/Llama3-Med42-70B", | |
| "error": error_message, | |
| "timestamp": datetime.now().isoformat() | |
| }, | |
| "sources": { | |
| "emergency_sources": 0, | |
| "treatment_sources": 0, | |
| "total_sources": 0 | |
| }, | |
| "disclaimer": "This system experienced a technical error. Please consult with qualified healthcare providers for medical decisions." | |
| } | |
| # Example usage and testing | |
| def main(): | |
| """ | |
| Test the medical advice generation system | |
| """ | |
| # Initialize generator | |
| generator = MedicalAdviceGenerator() | |
| # Example retrieval results (simulated) | |
| example_retrieval_results = { | |
| "processed_results": [ | |
| { | |
| "type": "emergency", | |
| "distance": 0.3, | |
| "text": "Acute myocardial infarction requires immediate assessment including ECG, cardiac enzymes, and chest X-ray. Time-sensitive condition requiring rapid intervention.", | |
| "matched": "MI|chest pain" | |
| }, | |
| { | |
| "type": "treatment", | |
| "distance": 0.25, | |
| "text": "Treatment protocol for STEMI includes aspirin 325mg, clopidogrel loading dose, and urgent PCI within 90 minutes when available.", | |
| "matched_treatment": "aspirin|PCI|thrombolytic" | |
| } | |
| ] | |
| } | |
| # Test queries | |
| test_queries = [ | |
| ("How should I treat a patient with chest pain?", "treatment"), | |
| ("What are the signs of acute MI?", "diagnosis"), | |
| # ("Emergency management of cardiac arrest", "STAT(tentative)") | |
| ] | |
| for query, intention in test_queries: | |
| print(f"\n{'='*60}") | |
| print(f"Testing: {query}") | |
| print(f"Intention: {intention}") | |
| try: | |
| result = generator.generate_medical_advice( | |
| user_query=query, | |
| retrieval_results=example_retrieval_results, | |
| intention=intention | |
| ) | |
| print(f"✅ Success: {result['confidence_score']:.2f} confidence") | |
| print(f"Advice: {result['medical_advice'][:200]}...") | |
| except Exception as e: | |
| print(f"❌ Error: {e}") | |
| if __name__ == "__main__": | |
| main() | |