File size: 32,032 Bytes
d5f2660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
"""MedGemma AI Assistant - Gradio UI for medical image analysis and chat"""

import os
import warnings

# Import spaces BEFORE any torch/CUDA imports in HF Spaces
IS_HF_SPACE = os.getenv("SPACE_ID") is not None
if IS_HF_SPACE:
    import spaces

# Now safe to import torch-dependent modules
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
warnings.filterwarnings('ignore', message="Can't initialize NVML")

import gradio as gr
import logging
import base64
import json
from PIL import Image

from src.logger import setup_logger
from src.models_client import classify_image, segment_case, detect_objects
from src.config import MAX_CONCURRENT_USERS
from src.styles import CUSTOM_CSS, MODAL_JS
from src.server import respond_stream

# Import server_hf early in HF Spaces to register @spaces.GPU decorator
if IS_HF_SPACE:
    from src import server_hf
from src.image_utils import (
    get_images_from_folder, 
    get_seg3d_folders, 
    base64_to_image,
    save_image_to_temp,
    process_cam_visualizations,
    save_images_to_temp
)
from src.ui_helpers import (
    hide_all_results,
    hide_viewer_components,
    show_viewer_with_image,
    create_empty_state,
    format_classification_result,
    format_detection_result,
    format_segmentation_result
)

# Setup logger
logger = setup_logger(__name__)

# Suppress verbose logging
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("gradio").setLevel(logging.WARNING)

if IS_HF_SPACE:
    logger.info("🌐 Running in HF Spaces mode (transformers)")
else:
    logger.info("πŸ”§ Running in local GGUF mode (llama-cpp)")

# Load prompts from info.json
with open("info.json", "r") as f:
    INFO_CONFIG = json.load(f)
    SYSTEM_PROMPT = INFO_CONFIG["system_prompt"]


# Helper functions are now imported from src.image_utils and src.ui_helpers


async def run_classification(image_path, model_name):
    """Run classification with multiple CAM visualizations"""
    try:
        result = await classify_image(image_path, model_name)
        
        if not result or not result.get('top_prediction'):
            return [], "❌ Classification failed", ""
        
        # Format result text
        info_text = format_classification_result(result)
        
        # Process CAM visualizations
        viz_images, viz_labels = process_cam_visualizations(result.get('cam_visualizations', {}))
        
        # Create carousel label
        carousel_label = f"CAM Visualizations ({len(viz_images)} methods)" if viz_images else "No visualizations"
        
        return viz_images, info_text, carousel_label
        
    except Exception as e:
        logger.error(f"Classification error: {e}")
        return [], f"❌ Error: {str(e)}", ""


async def run_detection(image_path, model_name):
    """Run YOLO object detection"""
    try:
        result = await detect_objects(image_path, model_name)
        
        if not result or not result.get('annotated_image'):
            return None, "❌ Detection failed"
        
        viz_img = base64_to_image(result['annotated_image'])
        info_text = format_detection_result(result)
        
        return viz_img, info_text
        
    except Exception as e:
        logger.error(f"Detection error: {e}")
        return None, f"❌ Error: {str(e)}"


async def run_segmentation(case_folder):
    """Run 3D brain tumor segmentation"""
    try:
        case_path = f"Files_Seg3D/{case_folder}"
        result = await segment_case(case_path, "brats")
        
        if result and result.get('visualization'):
            single_slice = base64_to_image(result['additional_visualizations']['single_slice'])
            three_slices = base64_to_image(result['additional_visualizations']['three_slices'])
            five_slices = base64_to_image(result['additional_visualizations']['five_slices'])
            
            seg_images = [single_slice, three_slices, five_slices]
            html_3d = result.get('3d_html_visualization', '')
            
            # Check if 3D HTML is available
            if not html_3d:
                logger.warning("⚠️ No 3D HTML in result")
                logger.debug(f"Result keys: {result.keys()}")
            
            dice_scores = result['dice_scores']
            avg_dice = result['average_dice']
            vol_analysis = result.get('volumetric_analysis', {})
            spatial = result.get('spatial_analysis', {})
            
            info_text = f"🎯 SEGMENTATION RESULT\n\n"
            info_text += f"πŸ“Š Dice Scores:\n  Average: {avg_dice:.4f}\n"
            for k, v in dice_scores.items():
                info_text += f"  β€’ {k}: {v:.4f}\n"
            
            if vol_analysis:
                info_text += f"\nπŸ“ Volumetric Analysis:\n"
                info_text += f"  Whole Tumor: {vol_analysis['whole_tumor_volume_cm3']:.2f} cmΒ³\n"
                info_text += f"  Tumor Core: {vol_analysis['tumor_core_volume_cm3']:.2f} cmΒ³\n"
                info_text += f"  Enhancing Tumor: {vol_analysis['enhancing_tumor_volume_cm3']:.2f} cmΒ³\n"
            
            if spatial and spatial.get('center_of_mass'):
                com = spatial['center_of_mass']
                info_text += f"\nπŸ“ Tumor Location:\n"
                info_text += f"  Sagittal: {com['sagittal']:.1f}, Coronal: {com['coronal']:.1f}, Axial: {com['axial']:.1f}\n"
            
            # Start with single slice view (index 0)
            return seg_images, 0, single_slice, "View 1 of 3 (1 Slice)", info_text, html_3d
        else:
            return [], 0, None, "No images", "❌ Segmentation failed", ""
    except Exception as e:
        logger.error(f"Segmentation error: {e}")
        return [], 0, None, "No images", f"❌ Error: {str(e)}", ""


async def respond_with_context_control(message, history, system_message, max_tokens, temperature, top_p, model_choice):
    """Wrapper for VLM streaming with session management"""
    from src.config import IS_HF_SPACE
    
    # Extract session ID
    try:
        request: gr.Request = gr.context.LocalContext.request.get()
        session_hash = request.session_hash if request and hasattr(request, 'session_hash') else None
        session_id = session_hash[:8] if session_hash and len(session_hash) > 8 else "default"
    except:
        session_id = "default"
    
    # Clear session on new conversation (only for local mode with session management)
    if not IS_HF_SPACE and (not history or len(history) == 0):
        logger.info(f"πŸ”„ NEW CONVERSATION | Session: {session_id}")
        from src.session_manager import session_manager
        session_manager.clear_session(session_id)
    
    # Log model selection
    model_type = "FT" if model_choice == "Fine-Tuned (BraTS)" else "Base"
    logger.info(f"πŸ€– MODEL | Session: {session_id} | Type: {model_type}")
    
    if IS_HF_SPACE:
        # HF Spaces: Use transformers inference
        async for response in respond_stream(message, history, system_message, max_tokens, temperature, top_p, session_id, model_choice=model_choice):
            yield response
    else:
        # Local: Use GGUF llama-cpp servers
        from src.config import FT_SERVER_URL, BASE_SERVER_URL
        server_url = FT_SERVER_URL if model_choice == "Fine-Tuned (BraTS)" else BASE_SERVER_URL
        async for response in respond_stream(message, history, system_message, max_tokens, temperature, top_p, session_id, server_url, model_choice):
            yield response


# Helper functions
def get_explain_message(category, model_name, view_type=None):
    """Get the appropriate explanation prompt from info.json"""
    if category == "Classification":
        if model_name == "Brain_Tumor":
            return INFO_CONFIG["classification_brain_tumor_message"]
        elif model_name == "Chest_X-Ray":
            return INFO_CONFIG["classification_chest_xray_message"]
        elif model_name == "Lung_Cancer":
            return INFO_CONFIG["classification_lung_histopathology_message"]
    elif category == "Detection":
        if model_name == "Blood_Cell":
            return INFO_CONFIG["detection_blood_message"]
        elif model_name == "Breast_Cancer":
            return INFO_CONFIG["detection_breast_cancer_message"]
        elif model_name == "Fracture":
            return INFO_CONFIG["detection_fracture_message"]
    elif category == "Segmentation":
        if view_type == "single":
            return INFO_CONFIG["segmentation_brats_single_message"]
        elif view_type == "three":
            return INFO_CONFIG["segmentation_brats_three_message"]
        elif view_type == "five":
            return INFO_CONFIG["segmentation_brats_five_message"]
    return "Analyze this medical image and provide a detailed report."


# Create chat interface - clean and simple
chat_interface = gr.ChatInterface(
    respond_with_context_control,
    type="messages",
    multimodal=True,
    chatbot=gr.Chatbot(
        type="messages",
        scale=4,
        height=600,
        show_copy_button=True,
    ),
    textbox=gr.MultimodalTextbox(
        file_types=["image"],
        file_count="multiple",
        placeholder="πŸ’¬ Type your medical question or upload images for analysis...",
        show_label=False,
    ),
    additional_inputs=[
        gr.Textbox(
            value=SYSTEM_PROMPT, 
            label="System Prompt",
            lines=6,
            max_lines=10,
            info="Customize the AI assistant's behavior and medical expertise"
        ),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
        gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
        gr.Radio(
            choices=["Base (General)", "Fine-Tuned (BraTS)"],
            value="Base (General)",
            label="Model Selection",
        ),
    ],
    stop_btn=True,
    cache_examples=False,
)


# Build full interface with Tabs
with gr.Blocks(title="MedGemma AI", css=CUSTOM_CSS, head=MODAL_JS, fill_height=True) as demo:
    gr.Markdown("# πŸ₯ MedGemma AI Assistant")
    
    with gr.Tabs() as tabs:
        # Tab 1: Chat Interface
        with gr.Tab("πŸ’¬ Chat", id=0):
            chat_interface.render()
            chat_textbox = chat_interface.textbox
            chat_chatbot = chat_interface.chatbot
        
        # Tab 2: Image Analysis (Left: Selection + Images, Right: Predictions)
        with gr.Tab("πŸ”¬ Image Analysis", id=1):
            with gr.Row():
                # Left Column: Model Selection + Image Viewer
                with gr.Column(scale=1):
                    gr.Markdown("### πŸ€– Medical AI Models")
                    
                    # Classification section
                    with gr.Accordion("🧠 Classification", open=True):
                        brain_tumor_btn = gr.Button("πŸ”¬ Brain Tumor", size="sm")
                        chest_xray_btn = gr.Button("🫁 Chest X-Ray", size="sm")
                        lung_cancer_btn = gr.Button("πŸ’¨ Lung Cancer", size="sm")
                    
                    # Detection section
                    with gr.Accordion("πŸ” Detection", open=False):
                        blood_cell_btn = gr.Button("🩸 Blood Cell", size="sm")
                        breast_cancer_btn = gr.Button("πŸŽ—οΈ Breast Cancer", size="sm")
                        fracture_btn = gr.Button("🦴 Fracture", size="sm")
                    
                    # Segmentation section
                    with gr.Accordion("πŸ“Š Segmentation", open=False):
                        folders = get_seg3d_folders()
                        seg3d_dropdown = gr.Dropdown(
                            choices=folders,
                            value=None,  # Don't auto-select, let user choose
                            label="Select Case",
                            interactive=True,
                        )
                    
                    gr.Markdown("---")
                    
                    # Image Viewer
                    gr.Markdown("### πŸ–ΌοΈ Image Viewer")
                    dataset_info = gr.Markdown("", visible=False)  # Hidden, not used anymore
                    image_filename = gr.Markdown("", visible=False, elem_classes="image-filename")
                    viewer_image = gr.Image(label="", show_label=False, height=400, visible=False, elem_classes="image-preview")
                    
                    with gr.Row(visible=False, elem_classes="nav-buttons") as nav_controls:
                        prev_btn = gr.Button("β—€", size="sm")
                        image_counter = gr.Markdown("", elem_classes="counter-text")
                        next_btn = gr.Button("β–Ά", size="sm")
                    
                    action_btn = gr.Button("πŸ”¬ Analyze", size="lg", visible=False, elem_classes="action-predict-btn")
                
                # Right Column: Predictions/Results
                with gr.Column(scale=1):
                    gr.Markdown("### πŸ“Š Analysis Results")
                    result_display = gr.Image(label="", show_label=False, height=400, visible=False, elem_classes="image-preview")
                    html_display = gr.HTML(visible=False, elem_classes="html-3d-preview")
                    result_info = gr.Textbox(label="", show_label=False, lines=12, visible=False)
                    
                    # CAM/Segmentation navigation (for multiple views)
                    with gr.Row(visible=False, elem_classes="nav-buttons") as seg_nav_controls:
                        seg_prev_btn = gr.Button("β—€", size="sm")
                        seg_counter = gr.Markdown("", elem_classes="counter-text")
                        seg_next_btn = gr.Button("β–Ά", size="sm")
                    
                    with gr.Column(visible=False) as action_buttons:
                        explain_btn = gr.Button("πŸ€– Explain with AI", size="lg", elem_classes="action-explain-btn")
                        view_3d_btn = gr.Button("🌐 View 3D", size="lg", visible=False, elem_classes="action-view3d-btn")
                    
                    # Hidden textbox to pass HTML to JavaScript
                    html_storage_bridge = gr.Textbox(visible=False, elem_id="html-storage-bridge")
    
    # State variables
    current_images = gr.State([])
    current_image_idx = gr.State(0)
    current_dataset = gr.State("")
    current_category = gr.State("")
    current_result_image = gr.State(None)
    current_result_text = gr.State("")
    seg3d_images = gr.State([])
    seg3d_image_idx = gr.State(0)
    seg3d_html_storage = gr.State("")
    
    # Helper functions
    def show_images(folder_name, category):
        """Load and display images automatically - resets to first image"""
        images = get_images_from_folder(f"Images/{category}/{folder_name}")
        action_text = "πŸ”¬ Classify" if category == "Classification" else "πŸ” Detect"
        
        if images:
            # Get filename from path
            filename = os.path.basename(images[0])
            return (
                "",  # Remove dataset info text
                gr.update(value=f"πŸ“„ {filename}", visible=True),
                gr.update(value=images[0], visible=True),
                gr.update(visible=True),
                f"Image 1 of {len(images)}",
                gr.update(value=action_text, visible=True),
                gr.update(visible=False),  # Hide result display
                gr.update(visible=False),  # Hide HTML display
                gr.update(visible=False),  # Hide result info
                gr.update(visible=False),  # Hide seg nav controls
                gr.update(visible=False),  # Hide action buttons
                images, 0, folder_name, category, None, ""
            )
        return (
            f"**{folder_name}** - No images found",
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            "",
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            [], 0, folder_name, category, None, ""
        )
    
    def show_seg3d_case(case_name):
        """Load segmentation case info - auto shows segment button, hides image viewer"""
        if not case_name:
            return (
                "",
                gr.update(visible=False),  # Hide image filename
                gr.update(visible=False),  # Hide viewer image
                gr.update(visible=False),  # Hide nav controls
                "",
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(visible=False),
                [], 0, case_name, "Segmentation", None, ""
            )
        
        folder_path = f"Files_Seg3D/{case_name}"
        if os.path.exists(folder_path):
            return (
                "",  # Remove dataset info text
                gr.update(visible=False),  # Hide image filename - not needed for segmentation
                gr.update(visible=False),  # Hide viewer image - not needed for segmentation
                gr.update(visible=False),  # Hide nav controls - not needed for segmentation
                "",
                gr.update(value="πŸ“Š Segment", visible=True),  # Show segment button
                gr.update(visible=False),  # Hide result display
                gr.update(visible=False),  # Hide HTML display
                gr.update(visible=False),  # Hide result info
                gr.update(visible=False),  # Hide seg nav controls
                gr.update(visible=False),  # Hide action buttons
                [], 0, case_name, "Segmentation", None, ""
            )
        return (
            f"**{case_name}** - Not found",
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            "",
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            [], 0, "", "", None, ""
        )
    
    def navigate_image(images, current_idx, direction):
        """Navigate and hide previous results"""
        if not images:
            return current_idx, gr.update(), gr.update(), "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
        new_idx = (current_idx + direction) % len(images)
        filename = os.path.basename(images[new_idx])
        return (new_idx, gr.update(value=f"πŸ“„ {filename}"), gr.update(value=images[new_idx]), 
                f"Image {new_idx + 1} of {len(images)}",
                gr.update(visible=False),  # Hide result display
                gr.update(visible=False),  # Hide HTML display
                gr.update(visible=False),  # Hide result info
                gr.update(visible=False),  # Hide action buttons
                gr.update(visible=False))  # Hide seg nav controls
    
    async def handle_action(images, seg_images_state, idx, seg_idx, dataset_name, category):
        """Handle analyze button"""
        if category in ["Classification", "Detection"] and images:
            if idx >= len(images):
                return (gr.update(), "No image", gr.update(visible=False), gr.update(visible=False), 
                        gr.update(visible=False), gr.update(visible=False), "", None, "", [], 0, "")
            
            image_path = images[idx]
            if category == "Classification":
                viz_images, text, carousel_label = await run_classification(image_path, dataset_name)
            else:
                img, text = await run_detection(image_path, dataset_name)
                viz_images = [img] if img else []
                carousel_label = ""
            
            result_path = None
            result_images = []
            
            if viz_images:
                # Save first image for "Explain with AI"
                result_path = save_image_to_temp(viz_images[0])
                
                # Save all images for carousel
                result_images = save_images_to_temp(viz_images)
            
            # Show navigation if multiple visualizations (Classification CAMs)
            show_nav = len(viz_images) > 1
            counter_text = f"View 1 of {len(viz_images)} ({carousel_label})" if show_nav else ""
            
            return (gr.update(value=viz_images[0] if viz_images else None, visible=bool(viz_images)), 
                    gr.update(visible=False),  # html_display - not used for classification
                    gr.update(value=text, visible=True),
                    gr.update(visible=show_nav),  # seg_nav_controls (reused for CAM navigation)
                    gr.update(value=counter_text),  # seg_counter (reused for CAM counter)
                    gr.update(visible=True),  # action_buttons
                    gr.update(visible=False),  # view_3d_btn - hide for classification
                    result_path, text, result_images, 0, "")  # result_path, text, seg images, idx, html
        
        elif category == "Segmentation" and dataset_name:
            seg_imgs, new_idx, img, counter, text, html = await run_segmentation(dataset_name)
            
            result_path = None
            if img:
                result_path = save_image_to_temp(img)
            
            counter_text = f"View 1 of 3 (1 Slice)"
            
            # Decode HTML for inline display
            html_preview = ""
            if html:
                try:
                    decoded_html = base64.b64decode(html).decode('utf-8')
                    # Create iframe wrapper for inline display
                    html_preview = f'''
                    <div style="width: 100%; height: 700px; background: #000;">
                        <iframe srcdoc="{decoded_html.replace('"', '&quot;')}" 
                                style="width: 100%; height: 100%; border: none; background: white;">
                        </iframe>
                    </div>
                    '''
                except Exception as e:
                    logger.error(f"❌ Failed to decode HTML: {e}")
            
            return (gr.update(value=img, visible=True),
                    gr.update(value=html_preview, visible=False),  # HTML hidden initially
                    gr.update(value=text, visible=True),
                    gr.update(visible=True),  # seg_nav_controls
                    gr.update(value=counter_text),  # seg_counter
                    gr.update(visible=True),  # action_buttons
                    gr.update(visible=True),  # view_3d_btn - show for segmentation
                    result_path, text, seg_imgs, new_idx, html)  # result_path, text, seg images, idx, html
        
        return (gr.update(), "Please select a dataset", gr.update(visible=False), gr.update(visible=False),
                gr.update(visible=False), gr.update(visible=False), "", None, "", [], 0, "")
    
    def navigate_seg_image(images, current_idx, direction, category):
        """Navigate through segmentation views or CAM visualizations"""
        if not images:
            return current_idx, gr.update(), ""
        new_idx = (current_idx + direction) % len(images)
        
        # Load image (handle both PIL Image and file paths)
        img_value = images[new_idx]
        if isinstance(img_value, str):
            img_value = Image.open(img_value)
        
        # Different labels for segmentation vs classification
        if category == "Segmentation":
            view_names = ["1 Slice", "3 Slices", "5 Slices"]
            heights = [500, 600, 700]
            height = heights[new_idx] if new_idx < len(heights) else 700
            counter_text = f"View {new_idx + 1} of 3 ({view_names[new_idx]})"
        else:
            # Classification CAM visualizations
            height = 400
            counter_text = f"View {new_idx + 1} of {len(images)}"
        
        return new_idx, gr.update(value=img_value, height=height), counter_text
    
    def show_3d_view(html_storage):
        """Show 3D HTML visualization in popup modal"""
        if not html_storage:
            gr.Warning("No 3D visualization available")
            return gr.update(), ""
        try:
            html_content = base64.b64decode(html_storage).decode('utf-8')
            # Escape for HTML attribute
            escaped_html = html_content.replace('"', '&quot;').replace("'", '&#39;')
            # Create an iframe with the HTML content
            iframe_html = f'''
            <div style="width: 100%; height: 700px; background: #000;">
                <iframe srcdoc="{escaped_html}" 
                        style="width: 100%; height: 100%; border: none; background: white;">
                </iframe>
            </div>
            '''
            gr.Info("3D visualization loaded! Click to expand.")
            return gr.update(), iframe_html
        except Exception as e:
            logger.error(f"Error decoding 3D HTML: {e}")
            gr.Warning("Failed to load 3D visualization")
            return gr.update(), ""
    
    def explain_with_ai(seg_images, seg_idx, result_text, model_name, category):
        """Send CURRENT view to chat - uses the actual image being displayed"""
        # Get the current image being viewed
        if not seg_images or seg_idx >= len(seg_images):
            gr.Warning("No result image available")
            return {"text": "", "files": []}, gr.update(selected=0)
        
        # Get the current image path (it's already saved as a file path)
        current_image_path = seg_images[seg_idx]
        
        # If it's a PIL Image, save it
        if not isinstance(current_image_path, str):
            current_image_path = save_image_to_temp(current_image_path)
        
        if not os.path.exists(current_image_path):
            gr.Warning("Image file not found")
            return {"text": "", "files": []}, gr.update(selected=0)
        
        message = get_explain_message(category, model_name)
        full_message = f"Here are the prediction results:\n\n{result_text}\n\n---\n\n{message}" if result_text else message
        gr.Info("πŸ“ Switching to chat with analysis loaded!")
        return {"text": full_message, "files": [current_image_path]}, gr.update(selected=0)
    
    # Event handlers
    outputs_list = [dataset_info, image_filename, viewer_image, nav_controls, image_counter, action_btn,
                   result_display, html_display, result_info, seg_nav_controls, action_buttons,
                   current_images, current_image_idx, current_dataset, current_category,
                   current_result_image, current_result_text]
    
    brain_tumor_btn.click(lambda: show_images("Brain_Tumor", "Classification"), outputs=outputs_list)
    chest_xray_btn.click(lambda: show_images("Chest_X-Ray", "Classification"), outputs=outputs_list)
    lung_cancer_btn.click(lambda: show_images("Lung_Cancer", "Classification"), outputs=outputs_list)
    
    blood_cell_btn.click(lambda: show_images("Blood_Cell", "Detection"), outputs=outputs_list)
    breast_cancer_btn.click(lambda: show_images("Breast_Cancer", "Detection"), outputs=outputs_list)
    fracture_btn.click(lambda: show_images("Fracture", "Detection"), outputs=outputs_list)
    
    seg3d_dropdown.change(show_seg3d_case, inputs=[seg3d_dropdown], outputs=outputs_list)
    
    prev_btn.click(lambda imgs, idx: navigate_image(imgs, idx, -1),
                  inputs=[current_images, current_image_idx],
                  outputs=[current_image_idx, image_filename, viewer_image, image_counter, result_display, html_display, result_info, action_buttons, seg_nav_controls])
    
    next_btn.click(lambda imgs, idx: navigate_image(imgs, idx, 1),
                  inputs=[current_images, current_image_idx],
                  outputs=[current_image_idx, image_filename, viewer_image, image_counter, result_display, html_display, result_info, action_buttons, seg_nav_controls])
    
    action_btn.click(
        fn=lambda cat: gr.update(value=f"⏳ {'Classifying' if cat == 'Classification' else 'Detecting' if cat == 'Detection' else 'Segmenting'}...", interactive=False),
        inputs=[current_category],
        outputs=[action_btn]
    ).then(
        handle_action,
        inputs=[current_images, seg3d_images, current_image_idx, seg3d_image_idx, current_dataset, current_category],
        outputs=[result_display, html_display, result_info, seg_nav_controls, seg_counter, action_buttons, view_3d_btn,
                current_result_image, current_result_text, seg3d_images, seg3d_image_idx, seg3d_html_storage]
    ).then(
        fn=lambda cat: gr.update(value=f"πŸ”¬ {'Classify' if cat == 'Classification' else 'Detect' if cat == 'Detection' else 'Segment'}", interactive=True),
        inputs=[current_category],
        outputs=[action_btn]
    )
    
    # CAM/Segmentation navigation (reused for both)
    seg_prev_btn.click(
        lambda imgs, idx, cat: navigate_seg_image(imgs, idx, -1, cat),
        inputs=[seg3d_images, seg3d_image_idx, current_category],
        outputs=[seg3d_image_idx, result_display, seg_counter]
    )
    
    seg_next_btn.click(
        lambda imgs, idx, cat: navigate_seg_image(imgs, idx, 1, cat),
        inputs=[seg3d_images, seg3d_image_idx, current_category],
        outputs=[seg3d_image_idx, result_display, seg_counter]
    )
    
    # First update the bridge, then trigger the modal
    view_3d_btn.click(
        fn=lambda html: html,
        inputs=[seg3d_html_storage],
        outputs=[html_storage_bridge]
    ).then(
        fn=None,
        inputs=[html_storage_bridge],
        outputs=None,
        js="""
        (html_storage) => {
            console.log('View 3D button clicked, html_storage length:', html_storage ? html_storage.length : 0);
            if (html_storage && html_storage.length > 0) {
                try {
                    const decoded = atob(html_storage);
                    console.log('Decoded HTML length:', decoded.length);
                    if (typeof window.showHTMLModal === 'function') {
                        console.log('Calling showHTMLModal...');
                        window.showHTMLModal(decoded);
                    } else {
                        console.error('showHTMLModal function not found on window object');
                        alert('Modal function not available. Please refresh the page.');
                    }
                } catch (e) {
                    console.error('Failed to decode or show HTML:', e);
                    alert('Failed to show 3D visualization: ' + e.message);
                }
            } else {
                console.error('No HTML storage provided');
                alert('No 3D visualization available');
            }
        }
        """
    )
    
    explain_btn.click(explain_with_ai,
                     inputs=[seg3d_images, seg3d_image_idx, current_result_text, current_dataset, current_category],
                     outputs=[chat_textbox, tabs])


if __name__ == "__main__":
    logger.info("Starting MedGemma AI Assistant")
    if IS_HF_SPACE:
        logger.info("Mode: HF Spaces (transformers)")
        demo.queue(default_concurrency_limit=MAX_CONCURRENT_USERS)
        demo.launch()  # HF Spaces handles server config
    else:
        logger.info("Mode: Local GGUF (llama-cpp)")
        demo.queue(default_concurrency_limit=MAX_CONCURRENT_USERS)
        demo.launch(share=True, server_name="127.0.0.1", server_port=7860)