ArchCoder commited on
Commit
cd271b0
Β·
verified Β·
1 Parent(s): 5ad7df6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -632
app.py CHANGED
@@ -10,14 +10,9 @@ from torchvision import transforms
10
  import torchvision.transforms.functional as TF
11
  import urllib.request
12
  import os
13
- import kagglehub
14
- import random
15
- from pathlib import Path
16
- import seaborn as sns
17
 
18
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
  model = None
20
- dataset_path = None
21
 
22
  # Define your Attention U-Net architecture (from your training code)
23
  class DoubleConv(nn.Module):
@@ -61,7 +56,7 @@ class AttentionBlock(nn.Module):
61
  x1 = self.W_x(x)
62
  psi = self.relu(g1 + x1)
63
  psi = self.psi(psi)
64
- return x * psi, psi # Return attention coefficients for visualization
65
 
66
  class AttentionUNET(nn.Module):
67
  def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
@@ -88,9 +83,8 @@ class AttentionUNET(nn.Module):
88
 
89
  self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
90
 
91
- def forward(self, x, return_attention=False):
92
  skip_connections = []
93
- attention_maps = []
94
 
95
  for down in self.downs:
96
  x = down(x)
@@ -98,39 +92,20 @@ class AttentionUNET(nn.Module):
98
  x = self.pool(x)
99
 
100
  x = self.bottleneck(x)
101
- skip_connections = skip_connections[::-1]
102
 
103
- for idx in range(0, len(self.ups), 2):
104
  x = self.ups[idx](x)
105
  skip_connection = skip_connections[idx//2]
106
 
107
  if x.shape != skip_connection.shape:
108
  x = TF.resize(x, size=skip_connection.shape[2:])
109
 
110
- skip_connection, attention_coeff = self.attentions[idx // 2](skip_connection, x)
111
- if return_attention:
112
- attention_maps.append(attention_coeff)
113
-
114
  concat_skip = torch.cat((skip_connection, x), dim=1)
115
  x = self.ups[idx+1](concat_skip)
116
 
117
- output = self.final_conv(x)
118
-
119
- if return_attention:
120
- return output, attention_maps
121
- return output
122
-
123
- def download_dataset():
124
- """Download Brain Tumor Segmentation dataset from Kaggle"""
125
- global dataset_path
126
- try:
127
- print("πŸ“₯ Downloading Brain Tumor Segmentation dataset...")
128
- dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation')
129
- print(f"βœ… Dataset downloaded to: {dataset_path}")
130
- return dataset_path
131
- except Exception as e:
132
- print(f"❌ Failed to download dataset: {e}")
133
- return None
134
 
135
  def download_model():
136
  """Download your trained model from HuggingFace"""
@@ -138,7 +113,7 @@ def download_model():
138
  model_path = "best_attention_model.pth.tar"
139
 
140
  if not os.path.exists(model_path):
141
- print("πŸ“₯ Downloading trained model...")
142
  try:
143
  urllib.request.urlretrieve(model_url, model_path)
144
  print("βœ… Model downloaded successfully!")
@@ -150,323 +125,88 @@ def download_model():
150
 
151
  return model_path
152
 
153
- def load_attention_model():
154
- """Load trained Attention U-Net model"""
155
  global model
156
  if model is None:
157
  try:
158
- print("πŸ”„ Loading Attention U-Net model...")
159
 
 
160
  model_path = download_model()
161
  if model_path is None:
162
  return None
163
 
 
164
  model = AttentionUNET(in_channels=1, out_channels=1).to(device)
 
 
165
  checkpoint = torch.load(model_path, map_location=device, weights_only=True)
166
  model.load_state_dict(checkpoint["state_dict"])
167
  model.eval()
168
 
169
- print("βœ… Attention U-Net model loaded successfully!")
170
  except Exception as e:
171
- print(f"❌ Error loading model: {e}")
172
  model = None
173
  return model
174
 
175
- def get_random_sample_from_dataset():
176
- """Get a random sample image and ground truth mask from the dataset"""
177
- global dataset_path
178
-
179
- if dataset_path is None:
180
- dataset_path = download_dataset()
181
- if dataset_path is None:
182
- return None, None
183
-
184
- try:
185
- images_path = Path(dataset_path) / "images"
186
- masks_path = Path(dataset_path) / "masks"
187
-
188
- if not images_path.exists() or not masks_path.exists():
189
- print("❌ Dataset structure not found")
190
- return None, None
191
-
192
- # Get all image files
193
- image_files = list(images_path.glob("*.jpg")) + list(images_path.glob("*.png")) + list(images_path.glob("*.tif"))
194
-
195
- if not image_files:
196
- print("❌ No image files found in dataset")
197
- return None, None
198
-
199
- # Select random image
200
- random_image_file = random.choice(image_files)
201
- image_name = random_image_file.stem
202
-
203
- # Find corresponding mask
204
- possible_mask_extensions = ['.jpg', '.png', '.tif', '.gif']
205
- mask_file = None
206
-
207
- for ext in possible_mask_extensions:
208
- potential_mask = masks_path / f"{image_name}{ext}"
209
- if potential_mask.exists():
210
- mask_file = potential_mask
211
- break
212
-
213
- if mask_file is None:
214
- print(f"❌ No corresponding mask found for {image_name}")
215
- return None, None
216
-
217
- # Load image and mask
218
- image = Image.open(random_image_file).convert('L')
219
- mask = Image.open(mask_file).convert('L')
220
-
221
- print(f"βœ… Loaded random sample: {image_name}")
222
- return image, mask
223
-
224
- except Exception as e:
225
- print(f"❌ Error loading random sample: {e}")
226
- return None, None
227
-
228
- def test_time_augmentation(model, image_tensor):
229
- """Apply Test-Time Augmentation (TTA) for robust predictions"""
230
- augmentations = [
231
- lambda x: x, # Original
232
- lambda x: torch.flip(x, dims=[3]), # Horizontal flip
233
- lambda x: torch.flip(x, dims=[2]), # Vertical flip
234
- lambda x: torch.flip(x, dims=[2, 3]), # Both flips
235
- lambda x: torch.rot90(x, k=1, dims=[2, 3]), # 90Β° rotation
236
- lambda x: torch.rot90(x, k=3, dims=[2, 3]), # 270Β° rotation
237
- ]
238
-
239
- reverse_augmentations = [
240
- lambda x: x, # Original
241
- lambda x: torch.flip(x, dims=[3]), # Reverse horizontal flip
242
- lambda x: torch.flip(x, dims=[2]), # Reverse vertical flip
243
- lambda x: torch.flip(x, dims=[2, 3]), # Reverse both flips
244
- lambda x: torch.rot90(x, k=3, dims=[2, 3]), # Reverse 90Β° rotation
245
- lambda x: torch.rot90(x, k=1, dims=[2, 3]), # Reverse 270Β° rotation
246
- ]
247
-
248
- predictions = []
249
-
250
- with torch.no_grad():
251
- for aug, rev_aug in zip(augmentations, reverse_augmentations):
252
- # Apply augmentation
253
- aug_input = aug(image_tensor)
254
-
255
- # Get prediction
256
- pred = torch.sigmoid(model(aug_input))
257
-
258
- # Reverse augmentation on prediction
259
- pred = rev_aug(pred)
260
-
261
- predictions.append(pred)
262
-
263
- # Average all predictions
264
- tta_prediction = torch.mean(torch.stack(predictions), dim=0)
265
-
266
- return tta_prediction
267
-
268
- def generate_attention_heatmaps(model, image_tensor):
269
- """Generate attention heatmaps for interpretability"""
270
- with torch.no_grad():
271
- pred, attention_maps = model(image_tensor, return_attention=True)
272
-
273
- # Convert attention maps to numpy for visualization
274
- heatmaps = []
275
- for i, att_map in enumerate(attention_maps):
276
- # Resize attention map to match input size
277
- att_map_resized = TF.resize(att_map, (256, 256))
278
- att_np = att_map_resized.cpu().squeeze().numpy()
279
- heatmaps.append(att_np)
280
-
281
- return heatmaps
282
-
283
- def preprocess_image(image):
284
- """Preprocessing exactly like training code"""
285
  if image.mode != 'L':
286
  image = image.convert('L')
287
 
 
288
  val_test_transform = transforms.Compose([
289
- transforms.Resize((256, 256)),
290
  transforms.ToTensor()
291
  ])
292
 
293
- return val_test_transform(image).unsqueeze(0)
294
 
295
- def calculate_metrics(pred_mask, ground_truth_mask):
296
- """Calculate Dice and IoU metrics"""
297
- pred_binary = (pred_mask > 0.5).float()
298
- gt_binary = (ground_truth_mask > 0.5).float()
299
-
300
- # Dice coefficient
301
- intersection = torch.sum(pred_binary * gt_binary)
302
- dice = (2.0 * intersection) / (torch.sum(pred_binary) + torch.sum(gt_binary) + 1e-8)
303
-
304
- # IoU
305
- union = torch.sum(pred_binary) + torch.sum(gt_binary) - intersection
306
- iou = intersection / (union + 1e-8)
307
-
308
- return dice.item(), iou.item()
309
-
310
- def predict_with_enhancements(image, ground_truth=None, use_tta=True, show_attention=True):
311
- """Enhanced prediction with TTA and attention visualization"""
312
- current_model = load_attention_model()
313
 
314
  if current_model is None:
315
- return None, "❌ Failed to load trained model."
316
 
317
  if image is None:
318
  return None, "⚠️ Please upload an image first."
319
 
320
  try:
321
- print("🧠 Processing with enhanced Attention U-Net...")
322
 
323
- input_tensor = preprocess_image(image).to(device)
 
324
 
325
- # Standard prediction
326
  with torch.no_grad():
327
- standard_pred = torch.sigmoid(current_model(input_tensor))
 
328
 
329
- # Test-Time Augmentation
330
- if use_tta:
331
- tta_pred = test_time_augmentation(current_model, input_tensor)
332
- final_pred = tta_pred
333
- else:
334
- final_pred = standard_pred
335
 
336
- # Generate attention heatmaps
337
- attention_heatmaps = []
338
- if show_attention:
339
- attention_heatmaps = generate_attention_heatmaps(current_model, input_tensor)
340
 
341
- # Convert predictions to binary
342
- pred_mask_binary = (final_pred > 0.5).float()
343
- pred_mask_np = pred_mask_binary.cpu().squeeze().numpy()
344
- standard_mask_np = (standard_pred > 0.5).float().cpu().squeeze().numpy()
345
 
346
- # Prepare images for visualization
347
- original_np = np.array(image.convert('L').resize((256, 256)))
 
348
 
349
- # Create comprehensive visualization
350
- if ground_truth is not None:
351
- # With ground truth comparison
352
- gt_np = np.array(ground_truth.convert('L').resize((256, 256)))
353
- gt_binary = (gt_np > 127).astype(np.float32) # Threshold ground truth
354
-
355
- # Calculate metrics
356
- gt_tensor = torch.tensor(gt_binary).unsqueeze(0).unsqueeze(0).to(device)
357
- dice_score, iou_score = calculate_metrics(final_pred, gt_tensor)
358
-
359
- # Create figure with ground truth comparison
360
- n_cols = 6 if show_attention and attention_heatmaps else 5
361
- fig, axes = plt.subplots(2, n_cols, figsize=(4*n_cols, 8))
362
- fig.suptitle('🧠 Enhanced Attention U-Net Analysis with Ground Truth Comparison', fontsize=16, weight='bold')
363
-
364
- # Top row - Standard analysis
365
- axes[0, 0].imshow(original_np, cmap='gray')
366
- axes[0, 0].set_title('Original Image', fontsize=12, weight='bold')
367
- axes[0, 0].axis('off')
368
-
369
- axes[0, 1].imshow(standard_mask_np * 255, cmap='hot')
370
- axes[0, 1].set_title('Standard Prediction', fontsize=12, weight='bold')
371
- axes[0, 1].axis('off')
372
-
373
- axes[0, 2].imshow(pred_mask_np * 255, cmap='hot')
374
- axes[0, 2].set_title(f'{"TTA Enhanced" if use_tta else "Final Prediction"}', fontsize=12, weight='bold')
375
- axes[0, 2].axis('off')
376
-
377
- axes[0, 3].imshow(gt_binary * 255, cmap='hot')
378
- axes[0, 3].set_title('Ground Truth', fontsize=12, weight='bold')
379
- axes[0, 3].axis('off')
380
-
381
- # Overlay comparison
382
- overlay = original_np.copy()
383
- overlay = np.stack([overlay, overlay, overlay], axis=-1)
384
- overlay[pred_mask_np > 0.5] = [255, 0, 0] # Red for prediction
385
- overlay[gt_binary > 0.5] = [0, 255, 0] # Green for ground truth
386
- overlap = (pred_mask_np > 0.5) & (gt_binary > 0.5)
387
- overlay[overlap] = [255, 255, 0] # Yellow for overlap
388
-
389
- axes[0, 4].imshow(overlay.astype(np.uint8))
390
- axes[0, 4].set_title('Overlay (Red:Pred, Green:GT, Yellow:Match)', fontsize=10, weight='bold')
391
- axes[0, 4].axis('off')
392
-
393
- if show_attention and attention_heatmaps:
394
- # Show combined attention
395
- combined_attention = np.mean(attention_heatmaps, axis=0)
396
- axes[0, 5].imshow(combined_attention, cmap='jet', alpha=0.7)
397
- axes[0, 5].imshow(original_np, cmap='gray', alpha=0.3)
398
- axes[0, 5].set_title('Attention Heatmap', fontsize=12, weight='bold')
399
- axes[0, 5].axis('off')
400
-
401
- # Bottom row - Individual attention maps or detailed analysis
402
- if show_attention and attention_heatmaps:
403
- for i, heatmap in enumerate(attention_heatmaps[:n_cols]):
404
- axes[1, i].imshow(heatmap, cmap='jet', alpha=0.7)
405
- axes[1, i].imshow(original_np, cmap='gray', alpha=0.3)
406
- axes[1, i].set_title(f'Attention Gate {i+1}', fontsize=10, weight='bold')
407
- axes[1, i].axis('off')
408
- else:
409
- # Show tumor extraction and analysis
410
- tumor_only = np.where(pred_mask_np == 1, original_np, 255)
411
- inv_mask = np.where(pred_mask_np == 1, 0, 255)
412
-
413
- axes[1, 0].imshow(tumor_only, cmap='gray')
414
- axes[1, 0].set_title('Tumor Extraction', fontsize=12, weight='bold')
415
- axes[1, 0].axis('off')
416
-
417
- axes[1, 1].imshow(inv_mask, cmap='gray')
418
- axes[1, 1].set_title('Inverted Mask', fontsize=12, weight='bold')
419
- axes[1, 1].axis('off')
420
-
421
- # Difference map
422
- diff_map = np.abs(pred_mask_np - gt_binary)
423
- axes[1, 2].imshow(diff_map, cmap='Reds')
424
- axes[1, 2].set_title('Difference Map', fontsize=12, weight='bold')
425
- axes[1, 2].axis('off')
426
-
427
- # Clear remaining axes
428
- for j in range(3, n_cols):
429
- axes[1, j].axis('off')
430
- else:
431
- # Without ground truth
432
- n_cols = 5 if show_attention and attention_heatmaps else 4
433
- fig, axes = plt.subplots(2, n_cols, figsize=(4*n_cols, 8))
434
- fig.suptitle('🧠 Enhanced Attention U-Net Analysis', fontsize=16, weight='bold')
435
-
436
- # Top row
437
- images = [original_np, standard_mask_np * 255, pred_mask_np * 255]
438
- titles = ["Original Image", "Standard Prediction", f'{"TTA Enhanced" if use_tta else "Final Prediction"}']
439
- cmaps = ['gray', 'hot', 'hot']
440
-
441
- for i in range(3):
442
- axes[0, i].imshow(images[i], cmap=cmaps[i])
443
- axes[0, i].set_title(titles[i], fontsize=12, weight='bold')
444
- axes[0, i].axis('off')
445
-
446
- # Tumor extraction
447
- tumor_only = np.where(pred_mask_np == 1, original_np, 255)
448
- axes[0, 3].imshow(tumor_only, cmap='gray')
449
- axes[0, 3].set_title('Tumor Extraction', fontsize=12, weight='bold')
450
- axes[0, 3].axis('off')
451
-
452
- if show_attention and attention_heatmaps:
453
- combined_attention = np.mean(attention_heatmaps, axis=0)
454
- axes[0, 4].imshow(combined_attention, cmap='jet', alpha=0.7)
455
- axes[0, 4].imshow(original_np, cmap='gray', alpha=0.3)
456
- axes[0, 4].set_title('Combined Attention', fontsize=12, weight='bold')
457
- axes[0, 4].axis('off')
458
-
459
- # Bottom row - Individual attention maps
460
- if show_attention and attention_heatmaps:
461
- for i, heatmap in enumerate(attention_heatmaps[:n_cols]):
462
- axes[1, i].imshow(heatmap, cmap='jet', alpha=0.7)
463
- axes[1, i].imshow(original_np, cmap='gray', alpha=0.3)
464
- axes[1, i].set_title(f'Attention Gate {i+1}', fontsize=10, weight='bold')
465
- axes[1, i].axis('off')
466
- else:
467
- # Clear bottom row
468
- for j in range(n_cols):
469
- axes[1, j].axis('off')
470
 
471
  plt.tight_layout()
472
 
@@ -478,418 +218,187 @@ def predict_with_enhancements(image, ground_truth=None, use_tta=True, show_atten
478
 
479
  result_image = Image.open(buf)
480
 
481
- # Calculate statistics
482
  tumor_pixels = np.sum(pred_mask_np)
483
  total_pixels = pred_mask_np.size
484
  tumor_percentage = (tumor_pixels / total_pixels) * 100
485
 
486
- max_confidence = torch.max(final_pred).item()
487
- mean_confidence = torch.mean(final_pred).item()
 
488
 
489
- # Enhanced analysis text
490
  analysis_text = f"""
491
- ## 🧠 Enhanced Attention U-Net Analysis Results
492
 
493
- ### πŸ“Š Detection Summary
494
  - **Status**: {'πŸ”΄ TUMOR DETECTED' if tumor_pixels > 50 else '🟒 NO SIGNIFICANT TUMOR'}
495
- - **Tumor Coverage**: {tumor_percentage:.2f}% of brain region
496
  - **Tumor Pixels**: {tumor_pixels:,} pixels
497
  - **Max Confidence**: {max_confidence:.4f}
498
  - **Mean Confidence**: {mean_confidence:.4f}
499
- """
500
 
501
- if ground_truth is not None:
502
- analysis_text += f"""
503
- ### 🎯 Ground Truth Comparison
504
- - **Dice Score**: {dice_score:.4f} {'βœ… Excellent' if dice_score > 0.8 else '⚠️ Good' if dice_score > 0.6 else '❌ Poor'}
505
- - **IoU Score**: {iou_score:.4f} {'βœ… Excellent' if iou_score > 0.7 else '⚠️ Good' if iou_score > 0.5 else '❌ Poor'}
506
- - **Model Accuracy**: {'High precision match' if dice_score > 0.8 else 'Reasonable match' if dice_score > 0.6 else 'Needs improvement'}
507
- """
508
-
509
- analysis_text += f"""
510
- ### πŸš€ Enhancement Features
511
- - **Test-Time Augmentation**: {'βœ… Applied (6 augmentations averaged)' if use_tta else '❌ Disabled'}
512
- - **Attention Visualization**: {'βœ… Generated attention heatmaps' if show_attention else '❌ Disabled'}
513
- - **Boundary Enhancement**: {'βœ… TTA improves edge detection' if use_tta else '⚠️ Standard prediction only'}
514
- - **Interpretability**: {'βœ… Attention gates show focus areas' if show_attention else '❌ Black box mode'}
515
-
516
- ### πŸ”¬ Model Architecture
517
- - **Base Model**: Attention U-Net with skip connections
518
- - **Training Performance**: Dice: 0.8420, IoU: 0.7297, Accuracy: 98.90%
519
- - **Attention Gates**: 4 levels with soft attention mechanism
520
- - **Features Channels**: [32, 64, 128, 256] progression
521
  - **Device**: {device.type.upper()}
522
 
523
- ### πŸ“ˆ Enhanced Processing Pipeline
524
- - **Preprocessing**: Resize(256Γ—256) + Normalization
525
- - **Augmentations**: Flips (H,V), Rotations (90Β°,270Β°), Combined
526
- - **Attention Fusion**: Multi-scale attention coefficient extraction
527
- - **Post-processing**: Ensemble averaging + Binary thresholding (0.5)
528
 
529
- ### ⚠️ Medical Disclaimer
530
- This enhanced AI model is for **research and educational purposes only**.
531
- Results include advanced features for better accuracy and interpretability.
532
- Always consult medical professionals for clinical applications.
 
533
 
534
- ### πŸ† Research Contributions
535
- βœ… **Attention Gates**: Enhanced boundary detection through selective feature passing
536
- βœ… **Test-Time Augmentation**: Robust predictions via ensemble averaging
537
- βœ… **Interpretability**: Attention heatmaps for clinical trust and validation
538
- βœ… **Efficiency**: No retraining required, minimal computational overhead
539
- """
 
540
 
541
- print(f"βœ… Enhanced analysis completed! Tumor coverage: {tumor_percentage:.2f}%")
542
  return result_image, analysis_text
543
 
544
  except Exception as e:
545
- error_msg = f"❌ Error during enhanced analysis: {str(e)}"
546
  print(error_msg)
547
  return None, error_msg
548
 
549
- def load_random_sample():
550
- """Load a random sample from the dataset"""
551
- image, mask = get_random_sample_from_dataset()
552
- if image is None:
553
- return None, None, "❌ Failed to load random sample from dataset"
554
- return image, mask, "βœ… Random sample loaded from dataset"
555
-
556
  def clear_all():
557
- return None, None, None, "Upload a brain MRI image or load a random sample to test the enhanced model"
558
 
559
- # Enhanced professional CSS
560
  css = """
561
  .gradio-container {
562
- max-width: 1600px !important;
563
  margin: auto !important;
564
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
565
  }
566
-
567
  #title {
568
  text-align: center;
569
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
570
  color: white;
571
- padding: 40px;
572
- border-radius: 20px;
573
- margin-bottom: 30px;
574
- box-shadow: 0 12px 24px rgba(102, 126, 234, 0.4);
575
- }
576
-
577
- .feature-box {
578
- background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
579
  border-radius: 15px;
580
- padding: 25px;
581
- margin: 15px 0;
582
- color: white;
583
- box-shadow: 0 8px 16px rgba(240, 147, 251, 0.3);
584
- }
585
-
586
- .metric-card {
587
- background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
588
- border-radius: 12px;
589
- padding: 20px;
590
- text-align: center;
591
- margin: 10px;
592
- box-shadow: 0 6px 12px rgba(79, 172, 254, 0.3);
593
- }
594
-
595
- .enhancement-badge {
596
- display: inline-block;
597
- background: linear-gradient(45deg, #fa709a 0%, #fee140 100%);
598
- color: white;
599
- padding: 8px 16px;
600
- border-radius: 25px;
601
- margin: 5px;
602
- font-weight: bold;
603
- box-shadow: 0 4px 8px rgba(250, 112, 154, 0.3);
604
  }
605
  """
606
 
607
- # Create enhanced Gradio interface
608
- with gr.Blocks(css=css, title="🧠 Enhanced Brain Tumor Segmentation", theme=gr.themes.Soft()) as app:
609
 
610
  gr.HTML("""
611
  <div id="title">
612
- <h1>🧠 Enhanced Attention U-Net Brain Tumor Segmentation</h1>
613
- <p style="font-size: 20px; margin-top: 20px; font-weight: 300;">
614
- πŸš€ Advanced Medical AI with Test-Time Augmentation & Attention Visualization
615
  </p>
616
- <p style="font-size: 16px; margin-top: 15px; opacity: 0.9;">
617
- πŸ“Š Performance: Dice 0.8420 β€’ IoU 0.7297 β€’ Accuracy 98.90% |
618
- πŸ”¬ Research-Grade Interpretability & Robustness
619
  </p>
620
  </div>
621
  """)
622
 
623
  with gr.Row():
624
  with gr.Column(scale=1):
625
- gr.Markdown("### πŸ“€ Input & Controls")
626
-
627
- with gr.Tab("πŸ“Έ Upload Image"):
628
- image_input = gr.Image(
629
- label="Brain MRI Scan",
630
- type="pil",
631
- sources=["upload", "webcam"],
632
- height=300
633
- )
634
-
635
- with gr.Tab("🎲 Random Sample"):
636
- random_image = gr.Image(
637
- label="Sample Image",
638
- type="pil",
639
- height=300,
640
- interactive=False
641
- )
642
- random_ground_truth = gr.Image(
643
- label="Ground Truth Mask",
644
- type="pil",
645
- height=300,
646
- interactive=False
647
- )
648
- load_sample_btn = gr.Button("🎲 Load Random Sample", variant="secondary", size="lg")
649
- sample_status = gr.Textbox(label="Sample Status", interactive=False)
650
-
651
- gr.Markdown("### βš™οΈ Enhancement Options")
652
-
653
- use_tta = gr.Checkbox(
654
- label="πŸ”„ Test-Time Augmentation",
655
- value=True,
656
- info="Apply multiple augmentations for robust predictions"
657
- )
658
 
659
- show_attention = gr.Checkbox(
660
- label="πŸ”₯ Attention Visualization",
661
- value=True,
662
- info="Generate attention heatmaps for interpretability"
 
663
  )
664
 
665
  with gr.Row():
666
- analyze_btn = gr.Button(
667
- "🧠 Analyze with Enhanced Model",
668
- variant="primary",
669
- scale=3,
670
- size="lg"
671
- )
672
- clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary", scale=1)
673
 
674
  gr.HTML("""
675
- <div class="feature-box">
676
- <h4 style="margin-bottom: 15px;">🎯 Research Innovations</h4>
677
- <div class="enhancement-badge">Attention Gates</div>
678
- <div class="enhancement-badge">Test-Time Augmentation</div>
679
- <div class="enhancement-badge">Interpretability</div>
680
- <div class="enhancement-badge">Ground Truth Comparison</div>
681
- <p style="margin-top: 15px; font-size: 14px; opacity: 0.9;">
682
- Advanced medical AI combining accuracy, robustness, and clinical interpretability
683
- </p>
684
  </div>
685
  """)
686
 
687
  with gr.Column(scale=2):
688
- gr.Markdown("### πŸ“Š Enhanced Analysis Results")
689
 
690
  output_image = gr.Image(
691
- label="Comprehensive Analysis Visualization",
692
  type="pil",
693
- height=600
694
  )
695
 
696
- with gr.Accordion("πŸ“ˆ Detailed Analysis Report", open=True):
697
- analysis_output = gr.Markdown(
698
- value="Upload a brain MRI image or load a random sample to test the enhanced Attention U-Net model.",
699
- elem_id="analysis"
700
- )
701
-
702
- # Performance metrics section
703
- gr.HTML("""
704
- <div style="margin-top: 40px;">
705
- <h3 style="text-align: center; color: #4a5568; margin-bottom: 25px;">πŸ“Š Model Performance & Research Contributions</h3>
706
- <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); gap: 20px; margin-bottom: 30px;">
707
-
708
- <div class="metric-card">
709
- <h4 style="color: white; margin-bottom: 10px;">🎯 Segmentation Accuracy</h4>
710
- <div style="font-size: 24px; font-weight: bold; margin: 10px 0;">98.90%</div>
711
- <p style="font-size: 14px; opacity: 0.9;">Training accuracy on brain tumor dataset</p>
712
- </div>
713
-
714
- <div class="metric-card">
715
- <h4 style="color: white; margin-bottom: 10px;">πŸ“ Dice Score</h4>
716
- <div style="font-size: 24px; font-weight: bold; margin: 10px 0;">0.8420</div>
717
- <p style="font-size: 14px; opacity: 0.9;">Overlap similarity coefficient</p>
718
- </div>
719
-
720
- <div class="metric-card">
721
- <h4 style="color: white; margin-bottom: 10px;">πŸ”² IoU Score</h4>
722
- <div style="font-size: 24px; font-weight: bold; margin: 10px 0;">0.7297</div>
723
- <p style="font-size: 14px; opacity: 0.9;">Intersection over Union metric</p>
724
- </div>
725
-
726
- <div class="metric-card">
727
- <h4 style="color: white; margin-bottom: 10px;">⚑ Enhancement Features</h4>
728
- <div style="font-size: 20px; font-weight: bold; margin: 10px 0;">TTA + Attention</div>
729
- <p style="font-size: 14px; opacity: 0.9;">Advanced robustness & interpretability</p>
730
- </div>
731
-
732
- </div>
733
- </div>
734
- """)
735
-
736
- # Research contributions section
737
- gr.HTML("""
738
- <div style="margin-top: 30px; padding: 30px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 20px; color: white;">
739
- <h3 style="text-align: center; margin-bottom: 25px; color: white;">πŸš€ Novel Research Contributions</h3>
740
-
741
- <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 30px; margin-bottom: 20px;">
742
-
743
- <div>
744
- <h4 style="margin-bottom: 15px; color: #ffd700;">πŸ” 1. Enhanced Boundary Detection</h4>
745
- <ul style="line-height: 1.8; margin-left: 20px;">
746
- <li><strong>Problem:</strong> Traditional U-Net passes noisy features through skip connections</li>
747
- <li><strong>Solution:</strong> Attention gates filter irrelevant encoder features</li>
748
- <li><strong>Impact:</strong> Cleaner boundaries, reduced false positives</li>
749
- </ul>
750
- </div>
751
-
752
- <div>
753
- <h4 style="margin-bottom: 15px; color: #ffd700;">πŸ”„ 2. Test-Time Augmentation</h4>
754
- <ul style="line-height: 1.8; margin-left: 20px;">
755
- <li><strong>Problem:</strong> Medical datasets are small, MRI scans vary across centers</li>
756
- <li><strong>Solution:</strong> Multiple augmentations averaged for robust predictions</li>
757
- <li><strong>Impact:</strong> Improved robustness without retraining</li>
758
- </ul>
759
- </div>
760
-
761
- <div>
762
- <h4 style="margin-bottom: 15px; color: #ffd700;">πŸ”₯ 3. Attention Visualization</h4>
763
- <ul style="line-height: 1.8; margin-left: 20px;">
764
- <li><strong>Problem:</strong> Deep networks are "black boxes" for clinicians</li>
765
- <li><strong>Solution:</strong> Extract attention coefficients as interpretable heatmaps</li>
766
- <li><strong>Impact:</strong> Build clinical trust through transparency</li>
767
- </ul>
768
- </div>
769
-
770
- <div>
771
- <h4 style="margin-bottom: 15px; color: #ffd700;">⚑ 4. Efficient Implementation</h4>
772
- <ul style="line-height: 1.8; margin-left: 20px;">
773
- <li><strong>Problem:</strong> Complex architectures are hard to deploy</li>
774
- <li><strong>Solution:</strong> Low-overhead enhancements within existing backbone</li>
775
- <li><strong>Impact:</strong> Practical for real-world medical workflows</li>
776
- </ul>
777
- </div>
778
-
779
- </div>
780
-
781
- <div style="text-align: center; padding-top: 20px; border-top: 2px solid rgba(255,255,255,0.3);">
782
- <p style="font-size: 16px; font-weight: 600; margin-bottom: 10px;">
783
- 🎯 Research Gap Addressed: Accuracy + Robustness + Interpretability
784
- </p>
785
- <p style="font-size: 14px; opacity: 0.9;">
786
- This combination tackles three major challenges in medical AI with minimal architectural changes
787
- </p>
788
- </div>
789
- </div>
790
- """)
791
 
792
- # Dataset and disclaimer section
793
  gr.HTML("""
794
- <div style="margin-top: 30px; padding: 25px; background-color: #f7fafc; border-radius: 15px; border-left: 5px solid #667eea;">
795
  <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 30px;">
796
-
797
  <div>
798
- <h4 style="color: #667eea; margin-bottom: 15px;">πŸ“š Dataset Information</h4>
799
- <p><strong>Source:</strong> Brain Tumor Segmentation (Kaggle)</p>
800
- <p><strong>Author:</strong> nikhilroxtomar</p>
801
- <p><strong>Structure:</strong> Images + Ground Truth Masks</p>
802
- <p><strong>Format:</strong> Grayscale MRI scans</p>
803
- <p><strong>Use Case:</strong> Medical image segmentation research</p>
804
- <p><strong>Ground Truth:</strong> Available for metric calculation</p>
805
  </div>
806
-
807
  <div>
808
- <h4 style="color: #dc2626; margin-bottom: 15px;">⚠️ Medical Disclaimer</h4>
809
- <p style="color: #dc2626; font-weight: 600; line-height: 1.5;">
810
- This enhanced AI system is designed for <strong>research and educational purposes only</strong>.<br><br>
811
-
812
- While the model includes advanced features like attention visualization and test-time augmentation
813
- for improved accuracy and interpretability, all results must be validated by qualified medical professionals.<br><br>
814
-
815
- <strong>Not approved for clinical diagnosis or medical decision making.</strong>
816
  </p>
817
  </div>
818
-
819
  </div>
820
-
821
- <hr style="margin: 25px 0; border: none; border-top: 2px solid #e2e8f0;">
822
-
823
- <p style="text-align: center; color: #4a5568; margin: 15px 0; font-weight: 600;">
824
- πŸ”¬ Research-Grade Medical AI β€’ Enhanced Interpretability β€’ Robust Predictions β€’ Ground Truth Validation
825
  </p>
826
  </div>
827
  """)
828
 
829
  # Event handlers
830
- def analyze_with_ground_truth(image, gt_mask, use_tta, show_attention):
831
- """Wrapper function to handle ground truth comparison"""
832
- return predict_with_enhancements(image, gt_mask, use_tta, show_attention)
833
-
834
- def analyze_uploaded_image(image, use_tta, show_attention):
835
- """Wrapper function for uploaded images without ground truth"""
836
- return predict_with_enhancements(image, None, use_tta, show_attention)
837
-
838
- # Button event handlers
839
  analyze_btn.click(
840
- fn=lambda img, rand_img, rand_gt, tta, attention: (
841
- analyze_with_ground_truth(rand_img, rand_gt, tta, attention)
842
- if rand_img is not None
843
- else analyze_uploaded_image(img, tta, attention)
844
- ),
845
- inputs=[image_input, random_image, random_ground_truth, use_tta, show_attention],
846
  outputs=[output_image, analysis_output],
847
  show_progress=True
848
  )
849
 
850
- load_sample_btn.click(
851
- fn=load_random_sample,
852
- inputs=[],
853
- outputs=[random_image, random_ground_truth, sample_status],
854
- show_progress=True
855
- )
856
-
857
  clear_btn.click(
858
  fn=clear_all,
859
  inputs=[],
860
- outputs=[image_input, random_image, random_ground_truth, analysis_output]
861
  )
862
 
863
- # Auto-load dataset on startup
864
- gr.HTML("""
865
- <script>
866
- document.addEventListener('DOMContentLoaded', function() {
867
- console.log('Enhanced Brain Tumor Segmentation App Loaded');
868
- console.log('Features: TTA + Attention Visualization + Ground Truth Comparison');
869
- });
870
- </script>
871
- """)
872
-
873
  if __name__ == "__main__":
874
- print("πŸš€ Starting Enhanced Brain Tumor Segmentation System...")
875
- print("πŸ“Š Model Performance: Dice 0.8420, IoU 0.7297, Accuracy 98.90%")
876
- print("πŸ”¬ Research Features: Attention Gates + TTA + Interpretability")
877
- print("πŸ“₯ Auto-downloading dataset and model...")
878
-
879
- # Initialize dataset download
880
- print("πŸ“š Initializing dataset...")
881
- try:
882
- dataset_path = download_dataset()
883
- if dataset_path:
884
- print(f"βœ… Dataset ready at: {dataset_path}")
885
- else:
886
- print("⚠️ Dataset download failed, random samples unavailable")
887
- except Exception as e:
888
- print(f"⚠️ Dataset initialization error: {e}")
889
 
890
  app.launch(
891
  server_name="0.0.0.0",
892
  server_port=7860,
893
  show_error=True,
894
  share=False
895
- )
 
10
  import torchvision.transforms.functional as TF
11
  import urllib.request
12
  import os
 
 
 
 
13
 
14
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  model = None
 
16
 
17
  # Define your Attention U-Net architecture (from your training code)
18
  class DoubleConv(nn.Module):
 
56
  x1 = self.W_x(x)
57
  psi = self.relu(g1 + x1)
58
  psi = self.psi(psi)
59
+ return x * psi
60
 
61
  class AttentionUNET(nn.Module):
62
  def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
 
83
 
84
  self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
85
 
86
+ def forward(self, x):
87
  skip_connections = []
 
88
 
89
  for down in self.downs:
90
  x = down(x)
 
92
  x = self.pool(x)
93
 
94
  x = self.bottleneck(x)
95
+ skip_connections = skip_connections[::-1] #reverse list
96
 
97
+ for idx in range(0, len(self.ups), 2): #do up and double_conv
98
  x = self.ups[idx](x)
99
  skip_connection = skip_connections[idx//2]
100
 
101
  if x.shape != skip_connection.shape:
102
  x = TF.resize(x, size=skip_connection.shape[2:])
103
 
104
+ skip_connection = self.attentions[idx // 2](skip_connection, x)
 
 
 
105
  concat_skip = torch.cat((skip_connection, x), dim=1)
106
  x = self.ups[idx+1](concat_skip)
107
 
108
+ return self.final_conv(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  def download_model():
111
  """Download your trained model from HuggingFace"""
 
113
  model_path = "best_attention_model.pth.tar"
114
 
115
  if not os.path.exists(model_path):
116
+ print("πŸ“₯ Downloading your trained model...")
117
  try:
118
  urllib.request.urlretrieve(model_url, model_path)
119
  print("βœ… Model downloaded successfully!")
 
125
 
126
  return model_path
127
 
128
+ def load_your_attention_model():
129
+ """Load YOUR trained Attention U-Net model"""
130
  global model
131
  if model is None:
132
  try:
133
+ print("πŸ”„ Loading your trained Attention U-Net model...")
134
 
135
+ # Download model if needed
136
  model_path = download_model()
137
  if model_path is None:
138
  return None
139
 
140
+ # Initialize your model architecture
141
  model = AttentionUNET(in_channels=1, out_channels=1).to(device)
142
+
143
+ # Load your trained weights
144
  checkpoint = torch.load(model_path, map_location=device, weights_only=True)
145
  model.load_state_dict(checkpoint["state_dict"])
146
  model.eval()
147
 
148
+ print("βœ… Your Attention U-Net model loaded successfully!")
149
  except Exception as e:
150
+ print(f"❌ Error loading your model: {e}")
151
  model = None
152
  return model
153
 
154
+ def preprocess_for_your_model(image):
155
+ """Preprocessing exactly like your Colab code"""
156
+ # Convert to grayscale (like your Colab code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  if image.mode != 'L':
158
  image = image.convert('L')
159
 
160
+ # Use the exact same transform as your Colab code
161
  val_test_transform = transforms.Compose([
162
+ transforms.Resize((256,256)),
163
  transforms.ToTensor()
164
  ])
165
 
166
+ return val_test_transform(image).unsqueeze(0) # Add batch dimension
167
 
168
+ def predict_tumor(image):
169
+ current_model = load_your_attention_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  if current_model is None:
172
+ return None, "❌ Failed to load your trained model."
173
 
174
  if image is None:
175
  return None, "⚠️ Please upload an image first."
176
 
177
  try:
178
+ print("🧠 Processing with YOUR trained Attention U-Net...")
179
 
180
+ # Use the exact preprocessing from your Colab code
181
+ input_tensor = preprocess_for_your_model(image).to(device)
182
 
183
+ # Predict using your model (exactly like your Colab code)
184
  with torch.no_grad():
185
+ pred_mask = torch.sigmoid(current_model(input_tensor))
186
+ pred_mask_binary = (pred_mask > 0.5).float()
187
 
188
+ # Convert to numpy (like your Colab code)
189
+ pred_mask_np = pred_mask_binary.cpu().squeeze().numpy()
190
+ original_np = np.array(image.convert('L').resize((256, 256)))
 
 
 
191
 
192
+ # Create inverted mask for visualization (like your Colab code)
193
+ inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255)
 
 
194
 
195
+ # Create tumor-only image (like your Colab code)
196
+ tumor_only = np.where(pred_mask_np == 1, original_np, 255)
 
 
197
 
198
+ # Create visualization (matching your Colab 4-panel layout)
199
+ fig, axes = plt.subplots(1, 4, figsize=(20, 5))
200
+ fig.suptitle('🧠 Your Attention U-Net Results', fontsize=16, fontweight='bold')
201
 
202
+ titles = ["Original Image", "Tumor Segmentation", "Inverted Mask", "Tumor Only"]
203
+ images = [original_np, pred_mask_np * 255, inv_pred_mask_np, tumor_only]
204
+ cmaps = ['gray', 'hot', 'gray', 'gray']
205
+
206
+ for i, ax in enumerate(axes):
207
+ ax.imshow(images[i], cmap=cmaps[i])
208
+ ax.set_title(titles[i], fontsize=12, fontweight='bold')
209
+ ax.axis('off')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  plt.tight_layout()
212
 
 
218
 
219
  result_image = Image.open(buf)
220
 
221
+ # Calculate statistics (like your Colab code)
222
  tumor_pixels = np.sum(pred_mask_np)
223
  total_pixels = pred_mask_np.size
224
  tumor_percentage = (tumor_pixels / total_pixels) * 100
225
 
226
+ # Calculate confidence metrics
227
+ max_confidence = torch.max(pred_mask).item()
228
+ mean_confidence = torch.mean(pred_mask).item()
229
 
 
230
  analysis_text = f"""
231
+ ## 🧠 Your Attention U-Net Analysis Results
232
 
233
+ ### πŸ“Š Detection Summary:
234
  - **Status**: {'πŸ”΄ TUMOR DETECTED' if tumor_pixels > 50 else '🟒 NO SIGNIFICANT TUMOR'}
235
+ - **Tumor Area**: {tumor_percentage:.2f}% of brain region
236
  - **Tumor Pixels**: {tumor_pixels:,} pixels
237
  - **Max Confidence**: {max_confidence:.4f}
238
  - **Mean Confidence**: {mean_confidence:.4f}
 
239
 
240
+ ### πŸ”¬ Your Model Information:
241
+ - **Architecture**: YOUR trained Attention U-Net
242
+ - **Training Performance**: Dice: 0.8420, IoU: 0.7297
243
+ - **Input**: Grayscale (single channel)
244
+ - **Output**: Binary segmentation mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  - **Device**: {device.type.upper()}
246
 
247
+ ### 🎯 Model Performance:
248
+ - **Training Accuracy**: 98.90%
249
+ - **Best Dice Score**: 0.8420
250
+ - **Best IoU Score**: 0.7297
251
+ - **Training Dataset**: Brain tumor segmentation dataset
252
 
253
+ ### πŸ“ˆ Processing Details:
254
+ - **Preprocessing**: Resize(256Γ—256) + ToTensor (your exact method)
255
+ - **Threshold**: 0.5 (sigmoid > 0.5)
256
+ - **Architecture**: Attention gates + Skip connections
257
+ - **Features**: [32, 64, 128, 256] channels
258
 
259
+ ### ⚠️ Medical Disclaimer:
260
+ This is YOUR trained AI model for **research and educational purposes only**.
261
+ Results should be validated by medical professionals. Not for clinical diagnosis.
262
+
263
+ ### πŸ† Model Quality:
264
+ βœ… This is your own trained model with proven {tumor_percentage:.2f}% detection capability!
265
+ """
266
 
267
+ print(f"βœ… Your model analysis completed! Tumor area: {tumor_percentage:.2f}%")
268
  return result_image, analysis_text
269
 
270
  except Exception as e:
271
+ error_msg = f"❌ Error with your model: {str(e)}"
272
  print(error_msg)
273
  return None, error_msg
274
 
 
 
 
 
 
 
 
275
  def clear_all():
276
+ return None, None, "Upload a brain MRI image to test YOUR trained Attention U-Net model"
277
 
278
+ # Enhanced CSS for your model
279
  css = """
280
  .gradio-container {
281
+ max-width: 1400px !important;
282
  margin: auto !important;
 
283
  }
 
284
  #title {
285
  text-align: center;
286
+ background: linear-gradient(135deg, #8B5CF6 0%, #7C3AED 100%);
287
  color: white;
288
+ padding: 30px;
 
 
 
 
 
 
 
289
  border-radius: 15px;
290
+ margin-bottom: 25px;
291
+ box-shadow: 0 8px 16px rgba(139, 92, 246, 0.3);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  }
293
  """
294
 
295
+ # Create Gradio interface for your model
296
+ with gr.Blocks(css=css, title="🧠 Your Attention U-Net Model", theme=gr.themes.Soft()) as app:
297
 
298
  gr.HTML("""
299
  <div id="title">
300
+ <h1>🧠 YOUR Attention U-Net Model</h1>
301
+ <p style="font-size: 18px; margin-top: 15px;">
302
+ Using Your Own Trained Model β€’ Dice: 0.8420 β€’ IoU: 0.7297
303
  </p>
304
+ <p style="font-size: 14px; margin-top: 10px; opacity: 0.9;">
305
+ Loaded from: ArchCoder/the-op-segmenter HuggingFace Space
 
306
  </p>
307
  </div>
308
  """)
309
 
310
  with gr.Row():
311
  with gr.Column(scale=1):
312
+ gr.Markdown("### πŸ“€ Upload Brain MRI")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
+ image_input = gr.Image(
315
+ label="Brain MRI Scan",
316
+ type="pil",
317
+ sources=["upload", "webcam"],
318
+ height=350
319
  )
320
 
321
  with gr.Row():
322
+ analyze_btn = gr.Button("πŸ” Analyze with YOUR Model", variant="primary", scale=2, size="lg")
323
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", scale=1)
 
 
 
 
 
324
 
325
  gr.HTML("""
326
+ <div style="margin-top: 20px; padding: 20px; background: linear-gradient(135deg, #F3E8FF 0%, #EDE9FE 100%); border-radius: 10px; border-left: 4px solid #8B5CF6;">
327
+ <h4 style="color: #8B5CF6; margin-bottom: 15px;">πŸ† Your Model Features:</h4>
328
+ <ul style="margin: 10px 0; padding-left: 20px; line-height: 1.6;">
329
+ <li><strong>Personal Model:</strong> Your own trained Attention U-Net</li>
330
+ <li><strong>Proven Performance:</strong> 84.2% Dice Score, 72.97% IoU</li>
331
+ <li><strong>Attention Gates:</strong> Advanced feature selection</li>
332
+ <li><strong>Clean Output:</strong> Binary segmentation masks</li>
333
+ <li><strong>4-Panel View:</strong> Complete analysis like your Colab</li>
334
+ </ul>
335
  </div>
336
  """)
337
 
338
  with gr.Column(scale=2):
339
+ gr.Markdown("### πŸ“Š Your Model Results")
340
 
341
  output_image = gr.Image(
342
+ label="Your Attention U-Net Analysis",
343
  type="pil",
344
+ height=500
345
  )
346
 
347
+ analysis_output = gr.Markdown(
348
+ value="Upload a brain MRI image to test YOUR trained Attention U-Net model.",
349
+ elem_id="analysis"
350
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
+ # Footer highlighting your model
353
  gr.HTML("""
354
+ <div style="margin-top: 30px; padding: 25px; background-color: #F8FAFC; border-radius: 15px; border: 2px solid #8B5CF6;">
355
  <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 30px;">
 
356
  <div>
357
+ <h4 style="color: #8B5CF6; margin-bottom: 15px;">πŸ† Your Personal AI Model</h4>
358
+ <p><strong>Architecture:</strong> Attention U-Net with skip connections</p>
359
+ <p><strong>Performance:</strong> Dice: 0.8420, IoU: 0.7297, Accuracy: 98.90%</p>
360
+ <p><strong>Training:</strong> Your own dataset-specific training</p>
361
+ <p><strong>Features:</strong> [32, 64, 128, 256] channel progression</p>
 
 
362
  </div>
 
363
  <div>
364
+ <h4 style="color: #DC2626; margin-bottom: 15px;">⚠️ Your Model Disclaimer</h4>
365
+ <p style="color: #DC2626; font-weight: 600; line-height: 1.4;">
366
+ This is YOUR personally trained AI model for <strong>research purposes only</strong>.<br>
367
+ Results reflect your model's training performance.<br>
368
+ Always validate with medical professionals for any clinical application.
 
 
 
369
  </p>
370
  </div>
 
371
  </div>
372
+ <hr style="margin: 20px 0; border: none; border-top: 2px solid #E5E7EB;">
373
+ <p style="text-align: center; color: #6B7280; margin: 10px 0; font-weight: 600;">
374
+ πŸš€ Your Personal Attention U-Net β€’ Downloaded from HuggingFace β€’ Research-Grade Performance
 
 
375
  </p>
376
  </div>
377
  """)
378
 
379
  # Event handlers
 
 
 
 
 
 
 
 
 
380
  analyze_btn.click(
381
+ fn=predict_tumor,
382
+ inputs=[image_input],
 
 
 
 
383
  outputs=[output_image, analysis_output],
384
  show_progress=True
385
  )
386
 
 
 
 
 
 
 
 
387
  clear_btn.click(
388
  fn=clear_all,
389
  inputs=[],
390
+ outputs=[image_input, output_image, analysis_output]
391
  )
392
 
 
 
 
 
 
 
 
 
 
 
393
  if __name__ == "__main__":
394
+ print("πŸš€ Starting YOUR Attention U-Net Model System...")
395
+ print("πŸ† Using your personally trained model")
396
+ print("πŸ“₯ Auto-downloading from HuggingFace...")
397
+ print("🎯 Expected performance: Dice 0.8420, IoU 0.7297")
 
 
 
 
 
 
 
 
 
 
 
398
 
399
  app.launch(
400
  server_name="0.0.0.0",
401
  server_port=7860,
402
  show_error=True,
403
  share=False
404
+ )