Spaces:
Sleeping
Sleeping
| """ | |
| Medical Query Router for RAG AI Advisor | |
| """ | |
| import asyncio | |
| import logging | |
| from fastapi import APIRouter, HTTPException, status | |
| from fastapi.responses import StreamingResponse | |
| import sys | |
| import os | |
| import json | |
| # Add src to path for imports | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) | |
| from core.agent import safe_run_agent, safe_run_agent_streaming, clear_session_memory, get_active_sessions | |
| from api.models import ChatRequest, ChatResponse, HBVPatientInput, HBVAssessmentResponse | |
| from typing import Optional | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter(tags=["medical"]) | |
| def _build_contextual_query( | |
| query: str, | |
| patient_context: Optional[HBVPatientInput] = None, | |
| assessment_result: Optional[HBVAssessmentResponse] = None | |
| ) -> str: | |
| """ | |
| Build an enhanced query that includes patient context and assessment results. | |
| This helps the agent provide more relevant answers by understanding the specific | |
| patient case being discussed. | |
| Args: | |
| query: The doctor's original question | |
| patient_context: Optional patient data from assessment | |
| assessment_result: Optional assessment result with eligibility and recommendations | |
| Returns: | |
| Enhanced query string with context | |
| """ | |
| if not patient_context and not assessment_result: | |
| # No context, return original query | |
| return query | |
| context_parts = [query] | |
| # Add patient context if available | |
| if patient_context: | |
| context_parts.append("\n\n[PATIENT CONTEXT FOR THIS QUESTION]") | |
| context_parts.append(f"- Age: {patient_context.age}, Sex: {patient_context.sex}") | |
| context_parts.append(f"- HBsAg: {patient_context.hbsag_status}, HBeAg: {patient_context.hbeag_status}") | |
| context_parts.append(f"- HBV DNA: {patient_context.hbv_dna_level:,.0f} IU/mL") | |
| context_parts.append(f"- ALT: {patient_context.alt_level} U/L") | |
| context_parts.append(f"- Fibrosis: {patient_context.fibrosis_stage}") | |
| if patient_context.pregnancy_status == "Pregnant": | |
| context_parts.append(f"- Pregnancy: {patient_context.pregnancy_status}") | |
| if patient_context.immunosuppression_status and patient_context.immunosuppression_status != "None": | |
| context_parts.append(f"- Immunosuppression: {patient_context.immunosuppression_status}") | |
| if patient_context.coinfections: | |
| context_parts.append(f"- Coinfections: {', '.join(patient_context.coinfections)}") | |
| # Add assessment result if available | |
| if assessment_result: | |
| context_parts.append("\n[PRIOR ASSESSMENT RESULT]") | |
| context_parts.append(f"- Eligible for treatment: {assessment_result.eligible}") | |
| # Include brief summary of recommendations (first 200 chars) | |
| rec_summary = assessment_result.recommendations[:200] + "..." if len(assessment_result.recommendations) > 200 else assessment_result.recommendations | |
| context_parts.append(f"- Assessment summary: {rec_summary}") | |
| return "\n".join(context_parts) | |
| async def ask(request: ChatRequest): | |
| """ | |
| Interactive chat endpoint for doctors to ask questions about HBV guidelines. | |
| This endpoint: | |
| 1. Accepts doctor's questions about HBV treatment guidelines | |
| 2. Maintains conversation context via session_id | |
| 3. Optionally includes patient context from prior assessment | |
| 4. Uses the same SASLT 2021 guidelines vector store as /assess | |
| 5. Returns evidence-based answers with guideline citations | |
| Args: | |
| request: ChatRequest containing query, session_id, and optional patient/assessment context | |
| Returns: | |
| ChatResponse with AI answer and session_id | |
| """ | |
| try: | |
| # Validate input | |
| if not request.query or not request.query.strip(): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Query cannot be empty" | |
| ) | |
| if len(request.query) > 2000: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Query is too long. Maximum length is 2000 characters." | |
| ) | |
| logger.info(f"Processing chat request - Session: {request.session_id}, Query length: {len(request.query)}") | |
| # Build enhanced query with context if provided | |
| enhanced_query = _build_contextual_query( | |
| query=request.query, | |
| patient_context=request.patient_context, | |
| assessment_result=request.assessment_result | |
| ) | |
| # Process through agent with session context | |
| response = await safe_run_agent( | |
| user_input=enhanced_query, | |
| session_id=request.session_id | |
| ) | |
| if not response or not response.strip(): | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Received empty response from AI agent" | |
| ) | |
| logger.info(f"Chat request completed - Session: {request.session_id}") | |
| return ChatResponse( | |
| response=response, | |
| session_id=request.session_id | |
| ) | |
| except HTTPException: | |
| # Re-raise HTTP exceptions as-is | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error processing chat request: {str(e)}", exc_info=True) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Error processing medical query: {str(e)}" | |
| ) | |
| async def ask_stream(request: ChatRequest): | |
| """ | |
| Interactive streaming chat endpoint for doctors to ask questions about HBV guidelines. | |
| This endpoint: | |
| 1. Streams AI responses in real-time for better UX | |
| 2. Accepts doctor's questions about HBV treatment guidelines | |
| 3. Maintains conversation context via session_id | |
| 4. Optionally includes patient context from prior assessment | |
| 5. Uses the same SASLT 2021 guidelines vector store as /assess | |
| 6. Returns evidence-based answers with guideline citations | |
| Args: | |
| request: ChatRequest containing query, session_id, and optional patient/assessment context | |
| Returns: | |
| StreamingResponse with markdown-formatted AI answer | |
| """ | |
| # Validate input before starting stream | |
| try: | |
| if not request.query or not request.query.strip(): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Query cannot be empty" | |
| ) | |
| if len(request.query) > 2000: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Query is too long. Maximum length is 2000 characters." | |
| ) | |
| logger.info(f"Processing streaming chat request - Session: {request.session_id}, Query length: {len(request.query)}") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Validation error in streaming chat: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Invalid request: {str(e)}" | |
| ) | |
| async def event_stream(): | |
| try: | |
| # Build enhanced query with context if provided | |
| enhanced_query = _build_contextual_query( | |
| query=request.query, | |
| patient_context=request.patient_context, | |
| assessment_result=request.assessment_result | |
| ) | |
| chunk_buffer = "" | |
| chunk_count = 0 | |
| async for chunk in safe_run_agent_streaming( | |
| user_input=enhanced_query, | |
| session_id=request.session_id | |
| ): | |
| chunk_buffer += chunk | |
| chunk_count += 1 | |
| # Send chunks in reasonable sizes for smoother streaming | |
| if len(chunk_buffer) >= 10: | |
| yield chunk_buffer | |
| chunk_buffer = "" | |
| await asyncio.sleep(0.01) | |
| # Send any remaining content | |
| if chunk_buffer: | |
| yield chunk_buffer | |
| logger.info(f"Streaming chat completed - Session: {request.session_id}, Chunks: {chunk_count}") | |
| except Exception as e: | |
| error_msg = f"\n\n**Error**: An error occurred while processing your request. Please try again or contact support if the issue persists." | |
| logger.error(f"Error in streaming chat: {str(e)}", exc_info=True) | |
| yield error_msg | |
| return StreamingResponse(event_stream(), media_type="text/markdown") | |
| async def clear_session(session_id: str): | |
| """ | |
| Clear conversation history for a specific session. | |
| This is useful when: | |
| - Starting a new patient case | |
| - Switching between different patient discussions | |
| - Resetting the conversation context | |
| Args: | |
| session_id: The session identifier to clear | |
| Returns: | |
| Success message with session status | |
| """ | |
| try: | |
| logger.info(f"Clearing session: {session_id}") | |
| success = clear_session_memory(session_id) | |
| if success: | |
| return { | |
| "status": "success", | |
| "message": f"Session '{session_id}' cleared successfully", | |
| "session_id": session_id | |
| } | |
| else: | |
| return { | |
| "status": "not_found", | |
| "message": f"Session '{session_id}' not found or already cleared", | |
| "session_id": session_id | |
| } | |
| except Exception as e: | |
| logger.error(f"Error clearing session {session_id}: {str(e)}", exc_info=True) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Error clearing session: {str(e)}" | |
| ) | |
| async def list_sessions(): | |
| """ | |
| List all active chat sessions. | |
| Returns: | |
| List of active session IDs | |
| """ | |
| try: | |
| sessions = get_active_sessions() | |
| return { | |
| "status": "success", | |
| "active_sessions": sessions, | |
| "count": len(sessions) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error listing sessions: {str(e)}", exc_info=True) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Error listing sessions: {str(e)}" | |
| ) | |