"""MedGemma AI Assistant - Gradio UI for medical image analysis and chat""" import os import warnings # Import spaces BEFORE any torch/CUDA imports in HF Spaces IS_HF_SPACE = os.getenv("SPACE_ID") is not None if IS_HF_SPACE: import spaces # Now safe to import torch-dependent modules os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' warnings.filterwarnings('ignore', message="Can't initialize NVML") import gradio as gr import logging import base64 import json from PIL import Image from src.logger import setup_logger from src.models_client import classify_image, segment_case, detect_objects from src.config import MAX_CONCURRENT_USERS from src.styles import CUSTOM_CSS, MODAL_JS from src.server import respond_stream # Import server_hf early in HF Spaces to register @spaces.GPU decorator if IS_HF_SPACE: from src import server_hf from src.image_utils import ( get_images_from_folder, get_seg3d_folders, base64_to_image, save_image_to_temp, process_cam_visualizations, save_images_to_temp ) from src.ui_helpers import ( hide_all_results, hide_viewer_components, show_viewer_with_image, create_empty_state, format_classification_result, format_detection_result, format_segmentation_result ) # Setup logger logger = setup_logger(__name__) # Suppress verbose logging logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("gradio").setLevel(logging.WARNING) if IS_HF_SPACE: logger.info("🌐 Running in HF Spaces mode (transformers)") else: logger.info("🔧 Running in local GGUF mode (llama-cpp)") # Load prompts from info.json with open("info.json", "r") as f: INFO_CONFIG = json.load(f) SYSTEM_PROMPT = INFO_CONFIG["system_prompt"] # Helper functions are now imported from src.image_utils and src.ui_helpers async def run_classification(image_path, model_name): """Run classification with multiple CAM visualizations""" try: result = await classify_image(image_path, model_name) if not result or not result.get('top_prediction'): return [], "❌ Classification failed", "" # Format result text info_text = format_classification_result(result) # Process CAM visualizations viz_images, viz_labels = process_cam_visualizations(result.get('cam_visualizations', {})) # Create carousel label carousel_label = f"CAM Visualizations ({len(viz_images)} methods)" if viz_images else "No visualizations" return viz_images, info_text, carousel_label except Exception as e: logger.error(f"Classification error: {e}") return [], f"❌ Error: {str(e)}", "" async def run_detection(image_path, model_name): """Run YOLO object detection""" try: result = await detect_objects(image_path, model_name) if not result or not result.get('annotated_image'): return None, "❌ Detection failed" viz_img = base64_to_image(result['annotated_image']) info_text = format_detection_result(result) return viz_img, info_text except Exception as e: logger.error(f"Detection error: {e}") return None, f"❌ Error: {str(e)}" async def run_segmentation(case_folder): """Run 3D brain tumor segmentation""" try: case_path = f"Files_Seg3D/{case_folder}" result = await segment_case(case_path, "brats") if result and result.get('visualization'): single_slice = base64_to_image(result['additional_visualizations']['single_slice']) three_slices = base64_to_image(result['additional_visualizations']['three_slices']) five_slices = base64_to_image(result['additional_visualizations']['five_slices']) seg_images = [single_slice, three_slices, five_slices] html_3d = result.get('3d_html_visualization', '') # Check if 3D HTML is available if not html_3d: logger.warning("⚠️ No 3D HTML in result") logger.debug(f"Result keys: {result.keys()}") dice_scores = result['dice_scores'] avg_dice = result['average_dice'] vol_analysis = result.get('volumetric_analysis', {}) spatial = result.get('spatial_analysis', {}) info_text = f"🎯 SEGMENTATION RESULT\n\n" info_text += f"📊 Dice Scores:\n Average: {avg_dice:.4f}\n" for k, v in dice_scores.items(): info_text += f" • {k}: {v:.4f}\n" if vol_analysis: info_text += f"\n📐 Volumetric Analysis:\n" info_text += f" Whole Tumor: {vol_analysis['whole_tumor_volume_cm3']:.2f} cm³\n" info_text += f" Tumor Core: {vol_analysis['tumor_core_volume_cm3']:.2f} cm³\n" info_text += f" Enhancing Tumor: {vol_analysis['enhancing_tumor_volume_cm3']:.2f} cm³\n" if spatial and spatial.get('center_of_mass'): com = spatial['center_of_mass'] info_text += f"\n📍 Tumor Location:\n" info_text += f" Sagittal: {com['sagittal']:.1f}, Coronal: {com['coronal']:.1f}, Axial: {com['axial']:.1f}\n" # Start with single slice view (index 0) return seg_images, 0, single_slice, "View 1 of 3 (1 Slice)", info_text, html_3d else: return [], 0, None, "No images", "❌ Segmentation failed", "" except Exception as e: logger.error(f"Segmentation error: {e}") return [], 0, None, "No images", f"❌ Error: {str(e)}", "" async def respond_with_context_control(message, history, system_message, max_tokens, temperature, top_p, model_choice): """Wrapper for VLM streaming with session management""" from src.config import IS_HF_SPACE # Extract session ID try: request: gr.Request = gr.context.LocalContext.request.get() session_hash = request.session_hash if request and hasattr(request, 'session_hash') else None session_id = session_hash[:8] if session_hash and len(session_hash) > 8 else "default" except: session_id = "default" # Clear session on new conversation (only for local mode with session management) if not IS_HF_SPACE and (not history or len(history) == 0): logger.info(f"🔄 NEW CONVERSATION | Session: {session_id}") from src.session_manager import session_manager session_manager.clear_session(session_id) # Log model selection model_type = "FT" if model_choice == "Fine-Tuned (BraTS)" else "Base" logger.info(f"🤖 MODEL | Session: {session_id} | Type: {model_type}") if IS_HF_SPACE: # HF Spaces: Use transformers inference async for response in respond_stream(message, history, system_message, max_tokens, temperature, top_p, session_id, model_choice=model_choice): yield response else: # Local: Use GGUF llama-cpp servers from src.config import FT_SERVER_URL, BASE_SERVER_URL server_url = FT_SERVER_URL if model_choice == "Fine-Tuned (BraTS)" else BASE_SERVER_URL async for response in respond_stream(message, history, system_message, max_tokens, temperature, top_p, session_id, server_url, model_choice): yield response # Helper functions def get_explain_message(category, model_name, view_type=None): """Get the appropriate explanation prompt from info.json""" if category == "Classification": if model_name == "Brain_Tumor": return INFO_CONFIG["classification_brain_tumor_message"] elif model_name == "Chest_X-Ray": return INFO_CONFIG["classification_chest_xray_message"] elif model_name == "Lung_Cancer": return INFO_CONFIG["classification_lung_histopathology_message"] elif category == "Detection": if model_name == "Blood_Cell": return INFO_CONFIG["detection_blood_message"] elif model_name == "Breast_Cancer": return INFO_CONFIG["detection_breast_cancer_message"] elif model_name == "Fracture": return INFO_CONFIG["detection_fracture_message"] elif category == "Segmentation": if view_type == "single": return INFO_CONFIG["segmentation_brats_single_message"] elif view_type == "three": return INFO_CONFIG["segmentation_brats_three_message"] elif view_type == "five": return INFO_CONFIG["segmentation_brats_five_message"] return "Analyze this medical image and provide a detailed report." # Create chat interface - clean and simple chat_interface = gr.ChatInterface( respond_with_context_control, type="messages", multimodal=True, chatbot=gr.Chatbot( type="messages", scale=4, height=600, show_copy_button=True, ), textbox=gr.MultimodalTextbox( file_types=["image"], file_count="multiple", placeholder="💬 Type your medical question or upload images for analysis...", show_label=False, ), additional_inputs=[ gr.Textbox( value=SYSTEM_PROMPT, label="System Prompt", lines=6, max_lines=10, info="Customize the AI assistant's behavior and medical expertise" ), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"), gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="Temperature"), gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-p"), gr.Radio( choices=["Base (General)", "Fine-Tuned (BraTS)"], value="Base (General)", label="Model Selection", ), ], stop_btn=True, cache_examples=False, ) # Build full interface with Tabs with gr.Blocks(title="MedGemma AI", css=CUSTOM_CSS, head=MODAL_JS, fill_height=True) as demo: gr.Markdown("# 🏥 MedGemma AI Assistant") with gr.Tabs() as tabs: # Tab 1: Chat Interface with gr.Tab("💬 Chat", id=0): chat_interface.render() chat_textbox = chat_interface.textbox chat_chatbot = chat_interface.chatbot # Tab 2: Image Analysis (Left: Selection + Images, Right: Predictions) with gr.Tab("🔬 Image Analysis", id=1): with gr.Row(): # Left Column: Model Selection + Image Viewer with gr.Column(scale=1): gr.Markdown("### 🤖 Medical AI Models") # Classification section with gr.Accordion("🧠 Classification", open=True): brain_tumor_btn = gr.Button("🔬 Brain Tumor", size="sm") chest_xray_btn = gr.Button("🫁 Chest X-Ray", size="sm") lung_cancer_btn = gr.Button("💨 Lung Cancer", size="sm") # Detection section with gr.Accordion("🔍 Detection", open=False): blood_cell_btn = gr.Button("🩸 Blood Cell", size="sm") breast_cancer_btn = gr.Button("🎗️ Breast Cancer", size="sm") fracture_btn = gr.Button("🦴 Fracture", size="sm") # Segmentation section with gr.Accordion("📊 Segmentation", open=False): folders = get_seg3d_folders() seg3d_dropdown = gr.Dropdown( choices=folders, value=None, # Don't auto-select, let user choose label="Select Case", interactive=True, ) gr.Markdown("---") # Image Viewer gr.Markdown("### 🖼️ Image Viewer") dataset_info = gr.Markdown("", visible=False) # Hidden, not used anymore image_filename = gr.Markdown("", visible=False, elem_classes="image-filename") viewer_image = gr.Image(label="", show_label=False, height=400, visible=False, elem_classes="image-preview") with gr.Row(visible=False, elem_classes="nav-buttons") as nav_controls: prev_btn = gr.Button("◀", size="sm") image_counter = gr.Markdown("", elem_classes="counter-text") next_btn = gr.Button("▶", size="sm") action_btn = gr.Button("🔬 Analyze", size="lg", visible=False, elem_classes="action-predict-btn") # Right Column: Predictions/Results with gr.Column(scale=1): gr.Markdown("### 📊 Analysis Results") result_display = gr.Image(label="", show_label=False, height=400, visible=False, elem_classes="image-preview") html_display = gr.HTML(visible=False, elem_classes="html-3d-preview") result_info = gr.Textbox(label="", show_label=False, lines=12, visible=False) # CAM/Segmentation navigation (for multiple views) with gr.Row(visible=False, elem_classes="nav-buttons") as seg_nav_controls: seg_prev_btn = gr.Button("◀", size="sm") seg_counter = gr.Markdown("", elem_classes="counter-text") seg_next_btn = gr.Button("▶", size="sm") with gr.Column(visible=False) as action_buttons: explain_btn = gr.Button("🤖 Explain with AI", size="lg", elem_classes="action-explain-btn") view_3d_btn = gr.Button("🌐 View 3D", size="lg", visible=False, elem_classes="action-view3d-btn") # Hidden textbox to pass HTML to JavaScript html_storage_bridge = gr.Textbox(visible=False, elem_id="html-storage-bridge") # State variables current_images = gr.State([]) current_image_idx = gr.State(0) current_dataset = gr.State("") current_category = gr.State("") current_result_image = gr.State(None) current_result_text = gr.State("") seg3d_images = gr.State([]) seg3d_image_idx = gr.State(0) seg3d_html_storage = gr.State("") # Helper functions def show_images(folder_name, category): """Load and display images automatically - resets to first image""" images = get_images_from_folder(f"Images/{category}/{folder_name}") action_text = "🔬 Classify" if category == "Classification" else "🔍 Detect" if images: # Get filename from path filename = os.path.basename(images[0]) return ( "", # Remove dataset info text gr.update(value=f"📄 {filename}", visible=True), gr.update(value=images[0], visible=True), gr.update(visible=True), f"Image 1 of {len(images)}", gr.update(value=action_text, visible=True), gr.update(visible=False), # Hide result display gr.update(visible=False), # Hide HTML display gr.update(visible=False), # Hide result info gr.update(visible=False), # Hide seg nav controls gr.update(visible=False), # Hide action buttons images, 0, folder_name, category, None, "" ) return ( f"**{folder_name}** - No images found", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), [], 0, folder_name, category, None, "" ) def show_seg3d_case(case_name): """Load segmentation case info - auto shows segment button, hides image viewer""" if not case_name: return ( "", gr.update(visible=False), # Hide image filename gr.update(visible=False), # Hide viewer image gr.update(visible=False), # Hide nav controls "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), [], 0, case_name, "Segmentation", None, "" ) folder_path = f"Files_Seg3D/{case_name}" if os.path.exists(folder_path): return ( "", # Remove dataset info text gr.update(visible=False), # Hide image filename - not needed for segmentation gr.update(visible=False), # Hide viewer image - not needed for segmentation gr.update(visible=False), # Hide nav controls - not needed for segmentation "", gr.update(value="📊 Segment", visible=True), # Show segment button gr.update(visible=False), # Hide result display gr.update(visible=False), # Hide HTML display gr.update(visible=False), # Hide result info gr.update(visible=False), # Hide seg nav controls gr.update(visible=False), # Hide action buttons [], 0, case_name, "Segmentation", None, "" ) return ( f"**{case_name}** - Not found", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), [], 0, "", "", None, "" ) def navigate_image(images, current_idx, direction): """Navigate and hide previous results""" if not images: return current_idx, gr.update(), gr.update(), "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) new_idx = (current_idx + direction) % len(images) filename = os.path.basename(images[new_idx]) return (new_idx, gr.update(value=f"📄 {filename}"), gr.update(value=images[new_idx]), f"Image {new_idx + 1} of {len(images)}", gr.update(visible=False), # Hide result display gr.update(visible=False), # Hide HTML display gr.update(visible=False), # Hide result info gr.update(visible=False), # Hide action buttons gr.update(visible=False)) # Hide seg nav controls async def handle_action(images, seg_images_state, idx, seg_idx, dataset_name, category): """Handle analyze button""" if category in ["Classification", "Detection"] and images: if idx >= len(images): return (gr.update(), "No image", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), "", None, "", [], 0, "") image_path = images[idx] if category == "Classification": viz_images, text, carousel_label = await run_classification(image_path, dataset_name) else: img, text = await run_detection(image_path, dataset_name) viz_images = [img] if img else [] carousel_label = "" result_path = None result_images = [] if viz_images: # Save first image for "Explain with AI" result_path = save_image_to_temp(viz_images[0]) # Save all images for carousel result_images = save_images_to_temp(viz_images) # Show navigation if multiple visualizations (Classification CAMs) show_nav = len(viz_images) > 1 counter_text = f"View 1 of {len(viz_images)} ({carousel_label})" if show_nav else "" return (gr.update(value=viz_images[0] if viz_images else None, visible=bool(viz_images)), gr.update(visible=False), # html_display - not used for classification gr.update(value=text, visible=True), gr.update(visible=show_nav), # seg_nav_controls (reused for CAM navigation) gr.update(value=counter_text), # seg_counter (reused for CAM counter) gr.update(visible=True), # action_buttons gr.update(visible=False), # view_3d_btn - hide for classification result_path, text, result_images, 0, "") # result_path, text, seg images, idx, html elif category == "Segmentation" and dataset_name: seg_imgs, new_idx, img, counter, text, html = await run_segmentation(dataset_name) result_path = None if img: result_path = save_image_to_temp(img) counter_text = f"View 1 of 3 (1 Slice)" # Decode HTML for inline display html_preview = "" if html: try: decoded_html = base64.b64decode(html).decode('utf-8') # Create iframe wrapper for inline display html_preview = f'''