Spaces:
Sleeping
Sleeping
Delete detr_and_interp.py
Browse files- detr_and_interp.py +0 -442
detr_and_interp.py
DELETED
|
@@ -1,442 +0,0 @@
|
|
| 1 |
-
'''
|
| 2 |
-
this is a combined script that implements DETR object detection with interpretability methods
|
| 3 |
-
using Grad-CAM, Grad-CAM++, Integrated Gradients, and Monte Carlo Dropout for uncertainty estimation.
|
| 4 |
-
It provides a Gradio-based web interface for users to upload images, select detected objects
|
| 5 |
-
and visualize explanations and uncertainty maps.
|
| 6 |
-
|
| 7 |
-
How to run it:
|
| 8 |
-
|
| 9 |
-
```python
|
| 10 |
-
python detr_and_interp.py
|
| 11 |
-
```
|
| 12 |
-
|
| 13 |
-
'''
|
| 14 |
-
|
| 15 |
-
import torch, requests, numpy as np
|
| 16 |
-
import matplotlib.pyplot as plt
|
| 17 |
-
import matplotlib.patches as patches
|
| 18 |
-
from PIL import Image, ImageFilter
|
| 19 |
-
import gradio as gr
|
| 20 |
-
from transformers import DetrImageProcessor, DetrForObjectDetection
|
| 21 |
-
from torchvision.transforms.functional import resize
|
| 22 |
-
from captum.attr import IntegratedGradients
|
| 23 |
-
import torch.nn.functional as F
|
| 24 |
-
import logging
|
| 25 |
-
import os
|
| 26 |
-
from datetime import datetime
|
| 27 |
-
|
| 28 |
-
# ---------- Logging Setup ----------
|
| 29 |
-
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
| 30 |
-
os.makedirs(log_dir, exist_ok=True)
|
| 31 |
-
log_file = os.path.join(log_dir, f"detr_interp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
|
| 32 |
-
|
| 33 |
-
logging.basicConfig(
|
| 34 |
-
level=logging.INFO,
|
| 35 |
-
format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s',
|
| 36 |
-
handlers=[
|
| 37 |
-
logging.FileHandler(log_file),
|
| 38 |
-
logging.StreamHandler()
|
| 39 |
-
]
|
| 40 |
-
)
|
| 41 |
-
logger = logging.getLogger(__name__)
|
| 42 |
-
|
| 43 |
-
logger.info("Starting DETR Interpretability Dashboard")
|
| 44 |
-
|
| 45 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 46 |
-
logger.info(f"Using device: {device}")
|
| 47 |
-
|
| 48 |
-
model_name = "facebook/detr-resnet-50"
|
| 49 |
-
logger.info(f"Loading model: {model_name}")
|
| 50 |
-
model = DetrForObjectDetection.from_pretrained(model_name).to(device)
|
| 51 |
-
extractor = DetrImageProcessor.from_pretrained(model_name)
|
| 52 |
-
model.eval()
|
| 53 |
-
logger.info("Model loaded and set to evaluation mode")
|
| 54 |
-
|
| 55 |
-
# ---------- Grad-CAM / Grad-CAM++ ----------
|
| 56 |
-
def gradcam(img, det_idx, keep, pixel_values, use_pp=False):
|
| 57 |
-
"""
|
| 58 |
-
Compute Grad-CAM (or Grad-CAM++) heatmap for a selected detection.
|
| 59 |
-
|
| 60 |
-
What it computes:
|
| 61 |
-
- Captures feature-map activations from a late conv layer and the gradients of the
|
| 62 |
-
detection score w.r.t. those activations. Channel-wise weights are computed from
|
| 63 |
-
gradients and used to combine feature maps into a spatial heatmap.
|
| 64 |
-
|
| 65 |
-
Why this matters:
|
| 66 |
-
- Highlights which spatial regions the model used to make the prediction. Useful to
|
| 67 |
-
check whether the detector is attending to the object vs irrelevant background.
|
| 68 |
-
|
| 69 |
-
How to interpret results:
|
| 70 |
-
- High values in the returned heatmap indicate regions that contributed positively to
|
| 71 |
-
the detection score. Grad-CAM++ (use_pp=True) computes a refined weighting that often
|
| 72 |
-
yields sharper, better-localized maps when multiple instances overlap.
|
| 73 |
-
|
| 74 |
-
Caveats & tips:
|
| 75 |
-
- Choosing a layer too early will give fine-grained but semantically weak maps; too late
|
| 76 |
-
will be coarse. We pick a late backbone conv block (layer4[-1]) as a sensible default.
|
| 77 |
-
- Hooks must be removed after use to avoid memory leaks; we do that below.
|
| 78 |
-
|
| 79 |
-
References:
|
| 80 |
-
- Selvaraju et al., Grad-CAM (2017): https://arxiv.org/abs/1610.02391
|
| 81 |
-
"""
|
| 82 |
-
logger.info(f"Running {'Grad-CAM++' if use_pp else 'Grad-CAM'} for detection {det_idx}")
|
| 83 |
-
try:
|
| 84 |
-
# pick a late conv layer that still retains spatial info
|
| 85 |
-
conv_layer = model.model.backbone.conv_encoder.model.layer4[-1]
|
| 86 |
-
activations, gradients = {}, {}
|
| 87 |
-
|
| 88 |
-
def fwd(m, i, o):
|
| 89 |
-
activations["v"] = o.detach()
|
| 90 |
-
|
| 91 |
-
def bwd(m, gi, go):
|
| 92 |
-
gradients["v"] = go[0].detach()
|
| 93 |
-
|
| 94 |
-
h1 = conv_layer.register_forward_hook(fwd)
|
| 95 |
-
h2 = conv_layer.register_full_backward_hook(bwd) if hasattr(conv_layer, "register_full_backward_hook") else conv_layer.register_backward_hook(bwd)
|
| 96 |
-
logger.debug("Hooks registered for Grad-CAM")
|
| 97 |
-
|
| 98 |
-
outputs_for_attr = model(pixel_values)
|
| 99 |
-
logits = outputs_for_attr.logits
|
| 100 |
-
labels = logits.argmax(-1).squeeze(0)
|
| 101 |
-
label_id = labels[keep.nonzero()[det_idx]].item()
|
| 102 |
-
score = logits[0, keep.nonzero()[det_idx], label_id]
|
| 103 |
-
logger.debug(f"Target label_id: {label_id}, score: {score.item():.4f}")
|
| 104 |
-
|
| 105 |
-
model.zero_grad()
|
| 106 |
-
score.backward()
|
| 107 |
-
|
| 108 |
-
acts = activations["v"].squeeze(0)
|
| 109 |
-
grads = gradients["v"].squeeze(0)
|
| 110 |
-
logger.debug(f"Activations shape: {acts.shape}, Gradients shape: {grads.shape}")
|
| 111 |
-
|
| 112 |
-
if use_pp: # Grad-CAM++
|
| 113 |
-
weights = (grads ** 2).mean(dim=(1, 2)) / (2 * (grads ** 2).mean(dim=(1, 2)) + (acts * grads ** 3).mean(dim=(1, 2)) + 1e-8)
|
| 114 |
-
else: # vanilla Grad-CAM
|
| 115 |
-
weights = grads.mean(dim=(1, 2))
|
| 116 |
-
|
| 117 |
-
cam = torch.relu((weights[:, None, None] * acts).sum(0))
|
| 118 |
-
cam = cam / (cam.max() + 1e-8)
|
| 119 |
-
cam_resized = resize(cam.unsqueeze(0).unsqueeze(0), img.size[::-1])[0, 0].cpu().numpy()
|
| 120 |
-
|
| 121 |
-
h1.remove(); h2.remove()
|
| 122 |
-
logger.info(f"{'Grad-CAM++' if use_pp else 'Grad-CAM'} completed successfully")
|
| 123 |
-
return cam_resized
|
| 124 |
-
except Exception as e:
|
| 125 |
-
logger.error(f"Error in gradcam: {str(e)}", exc_info=True)
|
| 126 |
-
raise
|
| 127 |
-
|
| 128 |
-
# ---------- Integrated Gradients ----------
|
| 129 |
-
def integrated_grad(img, det_idx, keep, outputs_for_attr, pixel_values, baseline="black"):
|
| 130 |
-
"""
|
| 131 |
-
Compute Integrated Gradients attribution map for a detection's logit.
|
| 132 |
-
|
| 133 |
-
What it computes:
|
| 134 |
-
- Integrates gradients along a path from a baseline input to the real input in embedding
|
| 135 |
-
space, producing per-pixel (or per-channel) attributions.
|
| 136 |
-
|
| 137 |
-
Why baseline choice matters:
|
| 138 |
-
- The baseline defines what the model should consider as 'no signal'. Common choices:
|
| 139 |
-
black (zeros), a blurred version of the image, or a neutral/mean image. Different
|
| 140 |
-
baselines highlight different aspects of the input.
|
| 141 |
-
|
| 142 |
-
How to read the output:
|
| 143 |
-
- Values > 0 indicate pixels that increase the detection logit vs baseline; values < 0
|
| 144 |
-
reduce it. We normalize the result to [0,1] for visualization convenience.
|
| 145 |
-
|
| 146 |
-
Tips:
|
| 147 |
-
- Increase n_steps for smoother attributions (costlier). Check convergence_delta to
|
| 148 |
-
validate IG's completeness property.
|
| 149 |
-
|
| 150 |
-
References:
|
| 151 |
-
- Distill article on baselines: https://distill.pub/2020/attribution-baselines
|
| 152 |
-
- Captum IntegratedGradients docs: https://captum.ai/api/integrated_gradients.html
|
| 153 |
-
"""
|
| 154 |
-
logger.info(f"Running Integrated Gradients with {baseline} baseline for detection {det_idx}")
|
| 155 |
-
try:
|
| 156 |
-
logits = outputs_for_attr.logits
|
| 157 |
-
labels = logits.argmax(-1).squeeze(0)
|
| 158 |
-
label_id = labels[keep.nonzero()[det_idx]].item()
|
| 159 |
-
logger.debug(f"IG target label_id: {label_id}")
|
| 160 |
-
|
| 161 |
-
# Baselines
|
| 162 |
-
if baseline == "black":
|
| 163 |
-
base = torch.zeros_like(pixel_values)
|
| 164 |
-
logger.debug("Using black baseline")
|
| 165 |
-
elif baseline == "blur":
|
| 166 |
-
blur = img.filter(ImageFilter.GaussianBlur(radius=15))
|
| 167 |
-
base = extractor(images=blur, return_tensors="pt")["pixel_values"].to(device)
|
| 168 |
-
logger.debug("Using blurred baseline")
|
| 169 |
-
else:
|
| 170 |
-
base = torch.zeros_like(pixel_values)
|
| 171 |
-
logger.debug("Defaulting to black baseline")
|
| 172 |
-
|
| 173 |
-
def forward_func(pix):
|
| 174 |
-
return model(pix).logits[:, keep.nonzero()[det_idx], label_id]
|
| 175 |
-
|
| 176 |
-
ig = IntegratedGradients(forward_func)
|
| 177 |
-
attr, _ = ig.attribute(pixel_values, baselines=base, n_steps=25, return_convergence_delta=True)
|
| 178 |
-
arr = attr.squeeze().mean(0).cpu().detach().numpy()
|
| 179 |
-
logger.info(f"Integrated Gradients with {baseline} baseline completed")
|
| 180 |
-
return (arr - arr.min()) / (arr.max() - arr.min() + 1e-8)
|
| 181 |
-
except Exception as e:
|
| 182 |
-
logger.error(f"Error in integrated_grad: {str(e)}", exc_info=True)
|
| 183 |
-
raise
|
| 184 |
-
|
| 185 |
-
# ---------- Monte Carlo Dropout Uncertainty ----------
|
| 186 |
-
def mc_dropout_uncertainty(img, det_idx, keep, pixel_values, n_samples=20, dropout_p=0.1):
|
| 187 |
-
"""
|
| 188 |
-
Estimate uncertainty by running multiple stochastic forward passes with dropout active.
|
| 189 |
-
|
| 190 |
-
What it computes:
|
| 191 |
-
- Runs the model multiple times with dropout enabled and computes a CAM per run.
|
| 192 |
-
- Returns the per-pixel mean and standard deviation across CAMs. High std indicates
|
| 193 |
-
the model's focus is unstable across stochastic perturbations.
|
| 194 |
-
|
| 195 |
-
Why this helps:
|
| 196 |
-
- If heatmaps vary a lot, the interpretability output is less reliable. Use this to flag
|
| 197 |
-
detections where explanations may not be trustworthy.
|
| 198 |
-
|
| 199 |
-
Practical tips:
|
| 200 |
-
- Increasing n_samples reduces variance in the estimate but increases runtime.
|
| 201 |
-
- Temporarily sets the model to train mode to activate dropout modules; restores eval mode.
|
| 202 |
-
"""
|
| 203 |
-
logger.info(f"Running MC Dropout uncertainty: samples={n_samples}, p={dropout_p}, detection={det_idx}")
|
| 204 |
-
try:
|
| 205 |
-
def enable_dropout(m):
|
| 206 |
-
if isinstance(m, torch.nn.Dropout):
|
| 207 |
-
m.train()
|
| 208 |
-
|
| 209 |
-
model.train()
|
| 210 |
-
model.apply(enable_dropout)
|
| 211 |
-
|
| 212 |
-
cams = []
|
| 213 |
-
conv_layer = model.model.backbone.conv_encoder.model.layer4[-1]
|
| 214 |
-
|
| 215 |
-
for i in range(n_samples):
|
| 216 |
-
outputs = model(pixel_values)
|
| 217 |
-
logits = outputs.logits
|
| 218 |
-
labels = logits.argmax(-1).squeeze(0)
|
| 219 |
-
label_id = labels[keep.nonzero()[det_idx]].item()
|
| 220 |
-
score = logits[0, keep.nonzero()[det_idx], label_id]
|
| 221 |
-
|
| 222 |
-
acts, grads = {}, {}
|
| 223 |
-
|
| 224 |
-
def fwd(m, i, o):
|
| 225 |
-
acts['v'] = o.detach()
|
| 226 |
-
|
| 227 |
-
def bwd(m, gi, go):
|
| 228 |
-
grads['v'] = go[0].detach()
|
| 229 |
-
|
| 230 |
-
h1 = conv_layer.register_forward_hook(fwd)
|
| 231 |
-
h2 = (conv_layer.register_full_backward_hook(bwd)
|
| 232 |
-
if hasattr(conv_layer, 'register_full_backward_hook')
|
| 233 |
-
else conv_layer.register_backward_hook(bwd))
|
| 234 |
-
|
| 235 |
-
model.zero_grad()
|
| 236 |
-
score.backward(retain_graph=False)
|
| 237 |
-
|
| 238 |
-
if 'v' not in acts:
|
| 239 |
-
logger.warning(f"No activations captured in sample {i}, using fallback zero map")
|
| 240 |
-
cam_resized = np.zeros((img.size[1], img.size[0]))
|
| 241 |
-
else:
|
| 242 |
-
act = acts['v'].squeeze(0)
|
| 243 |
-
grad = grads['v'].squeeze(0)
|
| 244 |
-
weights = grad.mean(dim=(1, 2))
|
| 245 |
-
cam = torch.relu((weights[:, None, None] * act).sum(0))
|
| 246 |
-
cam = cam / (cam.max() + 1e-8)
|
| 247 |
-
cam_resized = resize(cam.unsqueeze(0).unsqueeze(0), img.size[::-1])[0, 0].cpu().numpy()
|
| 248 |
-
|
| 249 |
-
cams.append(cam_resized)
|
| 250 |
-
h1.remove(); h2.remove()
|
| 251 |
-
|
| 252 |
-
model.eval()
|
| 253 |
-
|
| 254 |
-
if len(cams) == 0:
|
| 255 |
-
logger.error("No valid CAM maps generated")
|
| 256 |
-
return np.zeros((img.size[1], img.size[0])), np.zeros((img.size[1], img.size[0]))
|
| 257 |
-
|
| 258 |
-
cams_arr = np.stack(cams, axis=0)
|
| 259 |
-
mean_map = cams_arr.mean(0)
|
| 260 |
-
std_map = cams_arr.std(0)
|
| 261 |
-
|
| 262 |
-
mean_map = (mean_map - mean_map.min()) / (mean_map.max() - mean_map.min() + 1e-8)
|
| 263 |
-
std_map = (std_map - std_map.min()) / (std_map.max() - std_map.min() + 1e-8)
|
| 264 |
-
|
| 265 |
-
logger.info("MC Dropout uncertainty completed")
|
| 266 |
-
return mean_map, std_map
|
| 267 |
-
except Exception as e:
|
| 268 |
-
logger.error(f"Error in mc_dropout_uncertainty: {str(e)}", exc_info=True)
|
| 269 |
-
model.eval()
|
| 270 |
-
raise
|
| 271 |
-
|
| 272 |
-
# ---------- Full pipeline ----------
|
| 273 |
-
def interpret(img, det_choice, conf_thresh, cam_variant, mc_samples, dropout_p):
|
| 274 |
-
logger.info(f"Starting interpretation - detection: {det_choice}, threshold: {conf_thresh}, cam: {cam_variant}, mc_samples: {mc_samples}, dropout_p: {dropout_p}")
|
| 275 |
-
try:
|
| 276 |
-
inputs = extractor(images=img, return_tensors="pt").to(device)
|
| 277 |
-
with torch.no_grad(): outputs = model(**inputs)
|
| 278 |
-
pixel_values_attr = inputs["pixel_values"].clone().requires_grad_(True)
|
| 279 |
-
target_sizes = [img.size[::-1]]
|
| 280 |
-
results = extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)[0]
|
| 281 |
-
keep = results["scores"] > conf_thresh
|
| 282 |
-
labels, scores = results["labels"][keep], results["scores"][keep]
|
| 283 |
-
|
| 284 |
-
logger.info(f"Found {len(labels)} detections above threshold {conf_thresh}")
|
| 285 |
-
|
| 286 |
-
if len(labels) == 0:
|
| 287 |
-
logger.warning("No detections found above threshold")
|
| 288 |
-
return None, "No detections above threshold", None, ""
|
| 289 |
-
|
| 290 |
-
if det_choice is None:
|
| 291 |
-
det_idx = 0
|
| 292 |
-
else:
|
| 293 |
-
try: det_idx = int(str(det_choice).split(":")[0])
|
| 294 |
-
except: det_idx = 0
|
| 295 |
-
|
| 296 |
-
label = model.config.id2label[labels[det_idx].item()]
|
| 297 |
-
logger.info(f"Selected detection {det_idx}: {label}")
|
| 298 |
-
|
| 299 |
-
# Grad-CAM / Grad-CAM++ (single deterministic pass)
|
| 300 |
-
cam = gradcam(img, det_idx, keep, pixel_values_attr, use_pp=(cam_variant=="Grad-CAM++"))
|
| 301 |
-
fig1, ax1 = plt.subplots(); ax1.imshow(img); ax1.imshow(cam, cmap="jet", alpha=0.5); ax1.axis("off")
|
| 302 |
-
ax1.set_title(f"{cam_variant}: {label}"); plt.close(fig1)
|
| 303 |
-
logger.debug(f"{cam_variant} visualization created")
|
| 304 |
-
|
| 305 |
-
# MC Dropout Uncertainty analysis
|
| 306 |
-
mean_map, std_map = mc_dropout_uncertainty(img, det_idx, keep, pixel_values_attr, n_samples=int(mc_samples), dropout_p=float(dropout_p))
|
| 307 |
-
# Create a composite figure: mean map and std map side-by-side
|
| 308 |
-
fig2, axes = plt.subplots(1,2, figsize=(8,4))
|
| 309 |
-
axes[0].imshow(img); axes[0].imshow(mean_map, cmap='hot', alpha=0.5); axes[0].axis('off'); axes[0].set_title('Predictive Mean')
|
| 310 |
-
axes[1].imshow(img); axes[1].imshow(std_map, cmap='viridis', alpha=0.5); axes[1].axis('off'); axes[1].set_title('Predictive Std (Uncertainty)')
|
| 311 |
-
plt.close(fig2)
|
| 312 |
-
logger.debug("MC Dropout uncertainty visualization created")
|
| 313 |
-
|
| 314 |
-
exp1 = f"🔎 {cam_variant}:\nGradient-weighted feature maps → highlights where DETR focused."
|
| 315 |
-
exp2 = f"🔎 MC Dropout Uncertainty:\nSamples={mc_samples}, dropout={dropout_p}. Shows predictive mean and per-pixel std as uncertainty."
|
| 316 |
-
|
| 317 |
-
logger.info("Interpretation completed successfully")
|
| 318 |
-
return fig1, exp1, fig2, exp2
|
| 319 |
-
except Exception as e:
|
| 320 |
-
logger.error(f"Error in interpret function: {str(e)}", exc_info=True)
|
| 321 |
-
return None, f"Error: {str(e)}", None, ""
|
| 322 |
-
|
| 323 |
-
# ---------- Gradio UI ----------
|
| 324 |
-
with gr.Blocks() as demo:
|
| 325 |
-
gr.Markdown("## 🧠 DETR Interpretability Dashboard with Controls")
|
| 326 |
-
gr.Markdown(
|
| 327 |
-
"""
|
| 328 |
-
**How to use this dashboard**
|
| 329 |
-
|
| 330 |
-
- Upload an image using the left panel. The model will run object detection and list detected objects.
|
| 331 |
-
- Use the "Confidence Threshold" slider to filter detections by score. Detections below the threshold are hidden.
|
| 332 |
-
- Pick a detection from the dropdown to generate explanations for that object.
|
| 333 |
-
- Choose between `Grad-CAM` and `Grad-CAM++` (Grad-CAM++ often gives sharper, more localized maps).
|
| 334 |
-
- `MC Dropout Samples` controls how many stochastic forward passes are used to estimate prediction uncertainty. More samples give smoother estimates but take longer.
|
| 335 |
-
- `Dropout Probability` sets the dropout rate used during MC Dropout; higher values typically increase predicted uncertainty.
|
| 336 |
-
|
| 337 |
-
Tooltips are provided on each control (hover or focus) for quick hints.
|
| 338 |
-
"""
|
| 339 |
-
)
|
| 340 |
-
|
| 341 |
-
with gr.Row():
|
| 342 |
-
img_in = gr.Image(type="pil", label="Upload an image")
|
| 343 |
-
det_out = gr.Label(label="Detections")
|
| 344 |
-
det_fig = gr.Plot(label="Detections visualization")
|
| 345 |
-
det_choice = gr.Dropdown(label="Pick a detection for explanation")
|
| 346 |
-
|
| 347 |
-
with gr.Row():
|
| 348 |
-
conf_thresh = gr.Slider(0, 1, value=0.7, step=0.05, label="Confidence Threshold")
|
| 349 |
-
cam_variant = gr.Radio(["Grad-CAM", "Grad-CAM++"], value="Grad-CAM", label="Grad-CAM Variant")
|
| 350 |
-
mc_samples = gr.Slider(1, 100, value=20, step=1, label="MC Dropout Samples")
|
| 351 |
-
dropout_p = gr.Slider(0.0, 0.9, value=0.1, step=0.05, label="Dropout Probability")
|
| 352 |
-
|
| 353 |
-
btn = gr.Button("Explain")
|
| 354 |
-
|
| 355 |
-
gc_fig = gr.Plot(label="Grad-CAM / Grad-CAM++")
|
| 356 |
-
gc_txt = gr.Textbox(label="Explanation (Grad-CAM)")
|
| 357 |
-
unc_fig = gr.Plot(label="Uncertainty (MC Dropout)")
|
| 358 |
-
unc_txt = gr.Textbox(label="Explanation (Uncertainty)")
|
| 359 |
-
|
| 360 |
-
# Visible control tooltips section (for environments where hovering tooltips are not available)
|
| 361 |
-
gr.Markdown(
|
| 362 |
-
"""
|
| 363 |
-
**Control tooltips (quick reference)**
|
| 364 |
-
|
| 365 |
-
- Confidence Threshold: Filter out detections with confidence below this value.
|
| 366 |
-
- Grad-CAM Variant: Choose the gradient-based visualization method. Grad-CAM++ may highlight smaller regions more precisely.
|
| 367 |
-
- MC Dropout Samples: Number of stochastic forward passes for uncertainty estimation. Increase for more stable results.
|
| 368 |
-
- Dropout Probability: Dropout rate used during MC Dropout sampling. Higher values typically increase predictive variance.
|
| 369 |
-
- Pick a detection: Select which detected object to explain. Format shown as 'index: label (score)'.
|
| 370 |
-
"""
|
| 371 |
-
)
|
| 372 |
-
|
| 373 |
-
# ---------- Key interpretability choices (Feynman-style) ----------
|
| 374 |
-
gr.Markdown(
|
| 375 |
-
"""
|
| 376 |
-
**Key interpretability choices & why they matter**
|
| 377 |
-
|
| 378 |
-
- **Baseline (Integrated Gradients)**: Defines what 'no signal' looks like. Black (zeros) is simple, but blurred or neutral baselines may give more meaningful attributions.
|
| 379 |
-
- **Which conv layer for Grad-CAM**: Early layers give fine texture but low semantics; very late layers are coarse. A late backbone conv (default used) is a good compromise.
|
| 380 |
-
- **Number of MC Dropout samples**: More samples = smoother, more stable uncertainty estimates, but higher compute cost.
|
| 381 |
-
- **Grad-CAM vs Grad-CAM++**: Grad-CAM++ can be sharper and better for overlapping instances; vanilla Grad-CAM is faster and simpler.
|
| 382 |
-
"""
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
-
# ---------- Further reading / Feynman-style references ----------
|
| 386 |
-
# Add short, clickable references so users can read the original papers and deep-dive articles.
|
| 387 |
-
gr.Markdown(
|
| 388 |
-
"""
|
| 389 |
-
**Further reading (recommended)**
|
| 390 |
-
|
| 391 |
-
- [Grad-CAM — Selvaraju et al., 2017 (arXiv)](https://arxiv.org/abs/1610.02391) — the original Grad-CAM paper; explains the core idea of gradient-weighted localization.
|
| 392 |
-
- [Grad-CAM++ — Chattopadhay et al.](https://arxiv.org/abs/1710.11063) — an improved variant that often produces sharper maps and handles multiple instances better.
|
| 393 |
-
- [Visualizing the Impact of Feature Attribution Baselines (Distill)](https://distill.pub/2020/attribution-baselines) — an accessible deep dive on baseline choices for Integrated Gradients.
|
| 394 |
-
- [Captum docs — IntegratedGradients](https://captum.ai/api/integrated_gradients.html) — practical API notes for baseline, n_steps, and convergence delta.
|
| 395 |
-
- [Constructing sensible baselines for Integrated Gradients](https://arxiv.org/abs/2004.09627) — discussion and techniques for choosing baselines beyond a black image.
|
| 396 |
-
- [A New Baseline Assumption of Integrated Gradients Based on Shapley Values](https://arxiv.org/html/2310.04821v3) — recent research on improved baselines.
|
| 397 |
-
"""
|
| 398 |
-
)
|
| 399 |
-
|
| 400 |
-
# Helper: safe label getter in case model.config.id2label is missing or not a dict
|
| 401 |
-
def safe_label_lookup(idx):
|
| 402 |
-
try:
|
| 403 |
-
id2label = getattr(model.config, 'id2label', None)
|
| 404 |
-
if id2label is None:
|
| 405 |
-
return f"Class {idx}"
|
| 406 |
-
return id2label.get(int(idx), f"Class {idx}")
|
| 407 |
-
except Exception:
|
| 408 |
-
return f"Class {idx}"
|
| 409 |
-
|
| 410 |
-
def run_detect(img, conf_thresh):
|
| 411 |
-
logger.info(f"Running detection with confidence threshold: {conf_thresh}")
|
| 412 |
-
try:
|
| 413 |
-
inputs = extractor(images=img, return_tensors="pt").to(device)
|
| 414 |
-
with torch.no_grad(): outputs = model(**inputs)
|
| 415 |
-
target_sizes = [img.size[::-1]]
|
| 416 |
-
results = extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)[0]
|
| 417 |
-
keep = results["scores"] > conf_thresh
|
| 418 |
-
boxes, labels, scores = results["boxes"][keep], results["labels"][keep], results["scores"][keep]
|
| 419 |
-
|
| 420 |
-
logger.info(f"Detection found {len(labels)} objects above threshold")
|
| 421 |
-
|
| 422 |
-
det_list = [f"{i}: {safe_label_lookup(l.item())} ({s:.2f})" for i,(l,s) in enumerate(zip(labels,scores))]
|
| 423 |
-
fig, ax = plt.subplots(); ax.imshow(img); ax.axis("off")
|
| 424 |
-
for box,label,score in zip(boxes,labels,scores):
|
| 425 |
-
xmin,ymin,xmax,ymax = box
|
| 426 |
-
ax.add_patch(patches.Rectangle((xmin,ymin),xmax-xmin,ymax-ymin,fill=False,color="red",lw=2))
|
| 427 |
-
ax.text(xmin,ymin,f"{safe_label_lookup(label.item())}:{score:.2f}",color="black",
|
| 428 |
-
bbox=dict(facecolor="yellow",alpha=0.5))
|
| 429 |
-
plt.close(fig)
|
| 430 |
-
default_val = det_list[0] if len(det_list) > 0 else None
|
| 431 |
-
logger.debug("Detection visualization created")
|
| 432 |
-
return {det_out: str(det_list), det_fig: fig, det_choice: gr.update(choices=det_list, value=default_val)}
|
| 433 |
-
except Exception as e:
|
| 434 |
-
logger.error(f"Error in run_detect: {str(e)}", exc_info=True)
|
| 435 |
-
return {det_out: "Error in detection", det_fig: None, det_choice: gr.update(choices=[], value=None)}
|
| 436 |
-
|
| 437 |
-
img_in.change(run_detect, inputs=[img_in, conf_thresh], outputs=[det_out, det_fig, det_choice])
|
| 438 |
-
btn.click(interpret, inputs=[img_in, det_choice, conf_thresh, cam_variant, mc_samples, dropout_p],
|
| 439 |
-
outputs=[gc_fig, gc_txt, unc_fig, unc_txt])
|
| 440 |
-
|
| 441 |
-
logger.info("Gradio interface configured, launching demo")
|
| 442 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|