Stefano01 commited on
Commit
6344100
·
verified ·
1 Parent(s): 4fcd615

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -23
app.py CHANGED
@@ -32,7 +32,7 @@ app_state = {
32
 
33
  custom_theme = gr.themes.Soft(
34
  primary_hue="green", # main brand color
35
- secondary_hue="purple", # accent color
36
  neutral_hue="slate" # backgrounds/borders/text neutrals
37
  )
38
 
@@ -300,8 +300,8 @@ def load_checkpoint_from_url(url, preset_name):
300
  max_samples = len(dataset) - 1 if app_state["dataset"] else 0
301
 
302
  return (f"✅ Loaded: {ckpt_path}", json.dumps(meta_info, indent=2),
303
- gr.update(visible=True), gr.update(choices=class_choices, value="(any)"),
304
- gr.update(visible=True, maximum=max_samples, value=0))
305
 
306
  except Exception as e:
307
  return f"❌ Failed: {str(e)}", "", gr.update(visible=False), gr.update(choices=["(any)"], value="(any)"), gr.update(visible=False), gr.update(choices=["(any)"], value="(any)"), gr.update(visible=False)
@@ -362,29 +362,45 @@ def load_checkpoint_from_file(file):
362
  max_samples = len(dataset) - 1 if app_state["dataset"] else 0
363
 
364
  return (f"✅ Loaded: {local_path}", json.dumps(meta_info, indent=2),
365
- gr.update(visible=True), gr.update(choices=class_choices, value="(any)"),
366
- gr.update(visible=True, maximum=max_samples, value=0))
367
 
368
  except Exception as e:
369
  return f"❌ Failed: {str(e)}", "", gr.update(visible=False)
370
 
371
 
372
- def get_random_sample():
373
- """Get a random sample from the loaded dataset."""
374
  if app_state["dataset"] is None:
375
  return None, "No dataset loaded", gr.update(visible=False)
376
-
377
  dataset = app_state["dataset"]
378
- idx = random.randint(0, len(dataset) - 1)
379
- img_tensor, label = dataset[idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  sample_img = pil_from_tensor(img_tensor, grayscale_to_rgb=True)
381
-
382
- class_name = app_state["dataset_classes"][label] if app_state["dataset_classes"] else str(label)
383
- caption = f"Sample from {app_state['meta'].get('dataset', 'dataset')} • class: {class_name} • idx: {idx}"
384
-
385
- # Update slider maximum and current value
386
- max_idx = len(dataset) - 1
387
- return sample_img, caption, gr.update(visible=True, maximum=max_idx, value=idx)
388
 
389
 
390
  def get_sample_by_index(idx, class_filter):
@@ -414,7 +430,7 @@ def get_sample_by_index(idx, class_filter):
414
 
415
  img_tensor, label = dataset[actual_idx]
416
  sample_img = pil_from_tensor(img_tensor, grayscale_to_rgb=True)
417
-
418
  class_name = dataset_classes[label] if dataset_classes else str(label)
419
  caption = f"Sample {actual_idx} from {app_state['meta'].get('dataset', 'dataset')} • class: {class_name}"
420
 
@@ -440,6 +456,12 @@ def update_class_filter(class_filter):
440
  return gr.update(visible=True, maximum=max_idx, value=0)
441
 
442
 
 
 
 
 
 
 
443
  def process_image(image, method, topk, alpha):
444
  """Process image and generate Grad-CAM visualizations."""
445
  if app_state["model"] is None:
@@ -501,7 +523,17 @@ def create_interface():
501
  presets = load_release_presets()
502
  preset_choices = ["None"] + list(presets.keys()) if presets else ["None"]
503
 
504
- with gr.Blocks(title="🔍 Grad-CAM Demo", theme=custom_theme) as demo:
 
 
 
 
 
 
 
 
 
 
505
  gr.Markdown("# 🔍 Grad-CAM Demo — Upload an image, get top-k predictions + heatmaps")
506
 
507
  with gr.Row():
@@ -559,11 +591,22 @@ def create_interface():
559
 
560
  with gr.Column(scale=2):
561
  gr.Markdown("## Image Input")
 
 
 
 
 
 
 
 
 
562
 
563
  with gr.Group():
 
564
  image_input = gr.Image(
565
  label="Upload Image",
566
- type="pil"
 
567
  )
568
 
569
  with gr.Row():
@@ -615,7 +658,6 @@ def create_interface():
615
  show_label=False,
616
  elem_id="gallery",
617
  columns=3,
618
- rows=2,
619
  object_fit="contain",
620
  height="auto"
621
  )
@@ -624,17 +666,18 @@ def create_interface():
624
  url_button.click(
625
  fn=load_checkpoint_from_url,
626
  inputs=[url_input, preset_dropdown],
627
- outputs=[status_text, meta_display, sample_button, class_filter, sample_slider]
628
  )
629
 
630
  file_button.click(
631
  fn=load_checkpoint_from_file,
632
  inputs=[file_input],
633
- outputs=[status_text, meta_display, sample_button, class_filter, sample_slider]
634
  )
635
 
636
  sample_button.click(
637
  fn=get_random_sample,
 
638
  outputs=[image_input, sample_info, sample_slider]
639
  )
640
 
 
32
 
33
  custom_theme = gr.themes.Soft(
34
  primary_hue="green", # main brand color
35
+ secondary_hue="green", # accent color
36
  neutral_hue="slate" # backgrounds/borders/text neutrals
37
  )
38
 
 
300
  max_samples = len(dataset) - 1 if app_state["dataset"] else 0
301
 
302
  return (f"✅ Loaded: {ckpt_path}", json.dumps(meta_info, indent=2),
303
+ gr.update(visible=True), gr.update(choices=class_choices, value="(any)", visible=True),
304
+ gr.update(visible=True, maximum=max_samples, value=0), gr.update(visible=True, value=""))
305
 
306
  except Exception as e:
307
  return f"❌ Failed: {str(e)}", "", gr.update(visible=False), gr.update(choices=["(any)"], value="(any)"), gr.update(visible=False), gr.update(choices=["(any)"], value="(any)"), gr.update(visible=False)
 
362
  max_samples = len(dataset) - 1 if app_state["dataset"] else 0
363
 
364
  return (f"✅ Loaded: {local_path}", json.dumps(meta_info, indent=2),
365
+ gr.update(visible=True), gr.update(choices=class_choices, value="(any)", visible=True),
366
+ gr.update(visible=True, maximum=max_samples, value=0), gr.update(visible=True, value=""))
367
 
368
  except Exception as e:
369
  return f"❌ Failed: {str(e)}", "", gr.update(visible=False)
370
 
371
 
372
+ def get_random_sample(class_filter="(any)"):
373
+ """Get a random sample from the (optionally filtered) dataset."""
374
  if app_state["dataset"] is None:
375
  return None, "No dataset loaded", gr.update(visible=False)
376
+
377
  dataset = app_state["dataset"]
378
+ dataset_classes = app_state["dataset_classes"]
379
+
380
+ # Build candidate indices according to filter
381
+ if class_filter != "(any)":
382
+ targets = np.array([dataset[i][1] for i in range(len(dataset))])
383
+ class_id = dataset_classes.index(class_filter)
384
+ filtered_indices = np.where(targets == class_id)[0]
385
+ if len(filtered_indices) == 0:
386
+ return None, f"No samples found for class: {class_filter}", gr.update(visible=True, maximum=0, value=0)
387
+ actual_idx = int(random.choice(filtered_indices))
388
+ # slider index is relative to the filtered list length
389
+ slider_max = len(filtered_indices) - 1
390
+ slider_value = int(np.where(filtered_indices == actual_idx)[0][0])
391
+ else:
392
+ actual_idx = random.randint(0, len(dataset) - 1)
393
+ slider_max = len(dataset) - 1
394
+ slider_value = actual_idx
395
+
396
+ img_tensor, label = dataset[actual_idx]
397
  sample_img = pil_from_tensor(img_tensor, grayscale_to_rgb=True)
398
+ sample_img = double_height(sample_img)
399
+ class_name = dataset_classes[label] if dataset_classes else str(label)
400
+ caption = f"Sample {actual_idx} from {app_state['meta'].get('dataset', 'dataset')} • class: {class_name}"
401
+
402
+ # Update slider to the picked index inside the current filter's range
403
+ return sample_img, caption, gr.update(visible=True, maximum=slider_max, value=slider_value)
 
404
 
405
 
406
  def get_sample_by_index(idx, class_filter):
 
430
 
431
  img_tensor, label = dataset[actual_idx]
432
  sample_img = pil_from_tensor(img_tensor, grayscale_to_rgb=True)
433
+ sample_img = double_height(sample_img)
434
  class_name = dataset_classes[label] if dataset_classes else str(label)
435
  caption = f"Sample {actual_idx} from {app_state['meta'].get('dataset', 'dataset')} • class: {class_name}"
436
 
 
456
  return gr.update(visible=True, maximum=max_idx, value=0)
457
 
458
 
459
+ def double_height(img: Image.Image) -> Image.Image:
460
+ """Return a copy of the image with doubled height."""
461
+ w, h = img.size
462
+ return img.resize((w * 10, h * 10), Image.Resampling.NEAREST)
463
+
464
+
465
  def process_image(image, method, topk, alpha):
466
  """Process image and generate Grad-CAM visualizations."""
467
  if app_state["model"] is None:
 
523
  presets = load_release_presets()
524
  preset_choices = ["None"] + list(presets.keys()) if presets else ["None"]
525
 
526
+ with gr.Blocks(css="""
527
+ .alert {
528
+ padding: 10px 15px;
529
+ background-color: #FFF3CD;
530
+ color: #856404;
531
+ border: 1px solid #FFEEBA;
532
+ border-radius: 6px;
533
+ position: relative;
534
+ text-color: #856404;
535
+ }
536
+ """, theme=custom_theme) as demo:
537
  gr.Markdown("# 🔍 Grad-CAM Demo — Upload an image, get top-k predictions + heatmaps")
538
 
539
  with gr.Row():
 
591
 
592
  with gr.Column(scale=2):
593
  gr.Markdown("## Image Input")
594
+
595
+ size_alert = gr.Markdown(
596
+ value="""
597
+ <div class="alert">
598
+ ⚠️ Image was resized for better visualization — not equal to the dataset’s original size.
599
+ </div>
600
+ """,
601
+ elem_id="size-alert"
602
+ )
603
 
604
  with gr.Group():
605
+
606
  image_input = gr.Image(
607
  label="Upload Image",
608
+ type="pil",
609
+ height=400,
610
  )
611
 
612
  with gr.Row():
 
658
  show_label=False,
659
  elem_id="gallery",
660
  columns=3,
 
661
  object_fit="contain",
662
  height="auto"
663
  )
 
666
  url_button.click(
667
  fn=load_checkpoint_from_url,
668
  inputs=[url_input, preset_dropdown],
669
+ outputs=[status_text, meta_display, sample_button, class_filter, sample_slider, sample_info]
670
  )
671
 
672
  file_button.click(
673
  fn=load_checkpoint_from_file,
674
  inputs=[file_input],
675
+ outputs=[status_text, meta_display, sample_button, class_filter, sample_slider, sample_info]
676
  )
677
 
678
  sample_button.click(
679
  fn=get_random_sample,
680
+ inputs=[class_filter],
681
  outputs=[image_input, sample_info, sample_slider]
682
  )
683