Skier8402 commited on
Commit
a66e3ec
·
verified ·
1 Parent(s): aafda8b

Delete detr_and_interp.py

Browse files
Files changed (1) hide show
  1. detr_and_interp.py +0 -442
detr_and_interp.py DELETED
@@ -1,442 +0,0 @@
1
- '''
2
- this is a combined script that implements DETR object detection with interpretability methods
3
- using Grad-CAM, Grad-CAM++, Integrated Gradients, and Monte Carlo Dropout for uncertainty estimation.
4
- It provides a Gradio-based web interface for users to upload images, select detected objects
5
- and visualize explanations and uncertainty maps.
6
-
7
- How to run it:
8
-
9
- ```python
10
- python detr_and_interp.py
11
- ```
12
-
13
- '''
14
-
15
- import torch, requests, numpy as np
16
- import matplotlib.pyplot as plt
17
- import matplotlib.patches as patches
18
- from PIL import Image, ImageFilter
19
- import gradio as gr
20
- from transformers import DetrImageProcessor, DetrForObjectDetection
21
- from torchvision.transforms.functional import resize
22
- from captum.attr import IntegratedGradients
23
- import torch.nn.functional as F
24
- import logging
25
- import os
26
- from datetime import datetime
27
-
28
- # ---------- Logging Setup ----------
29
- log_dir = os.path.join(os.path.dirname(__file__), "logs")
30
- os.makedirs(log_dir, exist_ok=True)
31
- log_file = os.path.join(log_dir, f"detr_interp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
32
-
33
- logging.basicConfig(
34
- level=logging.INFO,
35
- format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s',
36
- handlers=[
37
- logging.FileHandler(log_file),
38
- logging.StreamHandler()
39
- ]
40
- )
41
- logger = logging.getLogger(__name__)
42
-
43
- logger.info("Starting DETR Interpretability Dashboard")
44
-
45
- device = "cuda" if torch.cuda.is_available() else "cpu"
46
- logger.info(f"Using device: {device}")
47
-
48
- model_name = "facebook/detr-resnet-50"
49
- logger.info(f"Loading model: {model_name}")
50
- model = DetrForObjectDetection.from_pretrained(model_name).to(device)
51
- extractor = DetrImageProcessor.from_pretrained(model_name)
52
- model.eval()
53
- logger.info("Model loaded and set to evaluation mode")
54
-
55
- # ---------- Grad-CAM / Grad-CAM++ ----------
56
- def gradcam(img, det_idx, keep, pixel_values, use_pp=False):
57
- """
58
- Compute Grad-CAM (or Grad-CAM++) heatmap for a selected detection.
59
-
60
- What it computes:
61
- - Captures feature-map activations from a late conv layer and the gradients of the
62
- detection score w.r.t. those activations. Channel-wise weights are computed from
63
- gradients and used to combine feature maps into a spatial heatmap.
64
-
65
- Why this matters:
66
- - Highlights which spatial regions the model used to make the prediction. Useful to
67
- check whether the detector is attending to the object vs irrelevant background.
68
-
69
- How to interpret results:
70
- - High values in the returned heatmap indicate regions that contributed positively to
71
- the detection score. Grad-CAM++ (use_pp=True) computes a refined weighting that often
72
- yields sharper, better-localized maps when multiple instances overlap.
73
-
74
- Caveats & tips:
75
- - Choosing a layer too early will give fine-grained but semantically weak maps; too late
76
- will be coarse. We pick a late backbone conv block (layer4[-1]) as a sensible default.
77
- - Hooks must be removed after use to avoid memory leaks; we do that below.
78
-
79
- References:
80
- - Selvaraju et al., Grad-CAM (2017): https://arxiv.org/abs/1610.02391
81
- """
82
- logger.info(f"Running {'Grad-CAM++' if use_pp else 'Grad-CAM'} for detection {det_idx}")
83
- try:
84
- # pick a late conv layer that still retains spatial info
85
- conv_layer = model.model.backbone.conv_encoder.model.layer4[-1]
86
- activations, gradients = {}, {}
87
-
88
- def fwd(m, i, o):
89
- activations["v"] = o.detach()
90
-
91
- def bwd(m, gi, go):
92
- gradients["v"] = go[0].detach()
93
-
94
- h1 = conv_layer.register_forward_hook(fwd)
95
- h2 = conv_layer.register_full_backward_hook(bwd) if hasattr(conv_layer, "register_full_backward_hook") else conv_layer.register_backward_hook(bwd)
96
- logger.debug("Hooks registered for Grad-CAM")
97
-
98
- outputs_for_attr = model(pixel_values)
99
- logits = outputs_for_attr.logits
100
- labels = logits.argmax(-1).squeeze(0)
101
- label_id = labels[keep.nonzero()[det_idx]].item()
102
- score = logits[0, keep.nonzero()[det_idx], label_id]
103
- logger.debug(f"Target label_id: {label_id}, score: {score.item():.4f}")
104
-
105
- model.zero_grad()
106
- score.backward()
107
-
108
- acts = activations["v"].squeeze(0)
109
- grads = gradients["v"].squeeze(0)
110
- logger.debug(f"Activations shape: {acts.shape}, Gradients shape: {grads.shape}")
111
-
112
- if use_pp: # Grad-CAM++
113
- weights = (grads ** 2).mean(dim=(1, 2)) / (2 * (grads ** 2).mean(dim=(1, 2)) + (acts * grads ** 3).mean(dim=(1, 2)) + 1e-8)
114
- else: # vanilla Grad-CAM
115
- weights = grads.mean(dim=(1, 2))
116
-
117
- cam = torch.relu((weights[:, None, None] * acts).sum(0))
118
- cam = cam / (cam.max() + 1e-8)
119
- cam_resized = resize(cam.unsqueeze(0).unsqueeze(0), img.size[::-1])[0, 0].cpu().numpy()
120
-
121
- h1.remove(); h2.remove()
122
- logger.info(f"{'Grad-CAM++' if use_pp else 'Grad-CAM'} completed successfully")
123
- return cam_resized
124
- except Exception as e:
125
- logger.error(f"Error in gradcam: {str(e)}", exc_info=True)
126
- raise
127
-
128
- # ---------- Integrated Gradients ----------
129
- def integrated_grad(img, det_idx, keep, outputs_for_attr, pixel_values, baseline="black"):
130
- """
131
- Compute Integrated Gradients attribution map for a detection's logit.
132
-
133
- What it computes:
134
- - Integrates gradients along a path from a baseline input to the real input in embedding
135
- space, producing per-pixel (or per-channel) attributions.
136
-
137
- Why baseline choice matters:
138
- - The baseline defines what the model should consider as 'no signal'. Common choices:
139
- black (zeros), a blurred version of the image, or a neutral/mean image. Different
140
- baselines highlight different aspects of the input.
141
-
142
- How to read the output:
143
- - Values > 0 indicate pixels that increase the detection logit vs baseline; values < 0
144
- reduce it. We normalize the result to [0,1] for visualization convenience.
145
-
146
- Tips:
147
- - Increase n_steps for smoother attributions (costlier). Check convergence_delta to
148
- validate IG's completeness property.
149
-
150
- References:
151
- - Distill article on baselines: https://distill.pub/2020/attribution-baselines
152
- - Captum IntegratedGradients docs: https://captum.ai/api/integrated_gradients.html
153
- """
154
- logger.info(f"Running Integrated Gradients with {baseline} baseline for detection {det_idx}")
155
- try:
156
- logits = outputs_for_attr.logits
157
- labels = logits.argmax(-1).squeeze(0)
158
- label_id = labels[keep.nonzero()[det_idx]].item()
159
- logger.debug(f"IG target label_id: {label_id}")
160
-
161
- # Baselines
162
- if baseline == "black":
163
- base = torch.zeros_like(pixel_values)
164
- logger.debug("Using black baseline")
165
- elif baseline == "blur":
166
- blur = img.filter(ImageFilter.GaussianBlur(radius=15))
167
- base = extractor(images=blur, return_tensors="pt")["pixel_values"].to(device)
168
- logger.debug("Using blurred baseline")
169
- else:
170
- base = torch.zeros_like(pixel_values)
171
- logger.debug("Defaulting to black baseline")
172
-
173
- def forward_func(pix):
174
- return model(pix).logits[:, keep.nonzero()[det_idx], label_id]
175
-
176
- ig = IntegratedGradients(forward_func)
177
- attr, _ = ig.attribute(pixel_values, baselines=base, n_steps=25, return_convergence_delta=True)
178
- arr = attr.squeeze().mean(0).cpu().detach().numpy()
179
- logger.info(f"Integrated Gradients with {baseline} baseline completed")
180
- return (arr - arr.min()) / (arr.max() - arr.min() + 1e-8)
181
- except Exception as e:
182
- logger.error(f"Error in integrated_grad: {str(e)}", exc_info=True)
183
- raise
184
-
185
- # ---------- Monte Carlo Dropout Uncertainty ----------
186
- def mc_dropout_uncertainty(img, det_idx, keep, pixel_values, n_samples=20, dropout_p=0.1):
187
- """
188
- Estimate uncertainty by running multiple stochastic forward passes with dropout active.
189
-
190
- What it computes:
191
- - Runs the model multiple times with dropout enabled and computes a CAM per run.
192
- - Returns the per-pixel mean and standard deviation across CAMs. High std indicates
193
- the model's focus is unstable across stochastic perturbations.
194
-
195
- Why this helps:
196
- - If heatmaps vary a lot, the interpretability output is less reliable. Use this to flag
197
- detections where explanations may not be trustworthy.
198
-
199
- Practical tips:
200
- - Increasing n_samples reduces variance in the estimate but increases runtime.
201
- - Temporarily sets the model to train mode to activate dropout modules; restores eval mode.
202
- """
203
- logger.info(f"Running MC Dropout uncertainty: samples={n_samples}, p={dropout_p}, detection={det_idx}")
204
- try:
205
- def enable_dropout(m):
206
- if isinstance(m, torch.nn.Dropout):
207
- m.train()
208
-
209
- model.train()
210
- model.apply(enable_dropout)
211
-
212
- cams = []
213
- conv_layer = model.model.backbone.conv_encoder.model.layer4[-1]
214
-
215
- for i in range(n_samples):
216
- outputs = model(pixel_values)
217
- logits = outputs.logits
218
- labels = logits.argmax(-1).squeeze(0)
219
- label_id = labels[keep.nonzero()[det_idx]].item()
220
- score = logits[0, keep.nonzero()[det_idx], label_id]
221
-
222
- acts, grads = {}, {}
223
-
224
- def fwd(m, i, o):
225
- acts['v'] = o.detach()
226
-
227
- def bwd(m, gi, go):
228
- grads['v'] = go[0].detach()
229
-
230
- h1 = conv_layer.register_forward_hook(fwd)
231
- h2 = (conv_layer.register_full_backward_hook(bwd)
232
- if hasattr(conv_layer, 'register_full_backward_hook')
233
- else conv_layer.register_backward_hook(bwd))
234
-
235
- model.zero_grad()
236
- score.backward(retain_graph=False)
237
-
238
- if 'v' not in acts:
239
- logger.warning(f"No activations captured in sample {i}, using fallback zero map")
240
- cam_resized = np.zeros((img.size[1], img.size[0]))
241
- else:
242
- act = acts['v'].squeeze(0)
243
- grad = grads['v'].squeeze(0)
244
- weights = grad.mean(dim=(1, 2))
245
- cam = torch.relu((weights[:, None, None] * act).sum(0))
246
- cam = cam / (cam.max() + 1e-8)
247
- cam_resized = resize(cam.unsqueeze(0).unsqueeze(0), img.size[::-1])[0, 0].cpu().numpy()
248
-
249
- cams.append(cam_resized)
250
- h1.remove(); h2.remove()
251
-
252
- model.eval()
253
-
254
- if len(cams) == 0:
255
- logger.error("No valid CAM maps generated")
256
- return np.zeros((img.size[1], img.size[0])), np.zeros((img.size[1], img.size[0]))
257
-
258
- cams_arr = np.stack(cams, axis=0)
259
- mean_map = cams_arr.mean(0)
260
- std_map = cams_arr.std(0)
261
-
262
- mean_map = (mean_map - mean_map.min()) / (mean_map.max() - mean_map.min() + 1e-8)
263
- std_map = (std_map - std_map.min()) / (std_map.max() - std_map.min() + 1e-8)
264
-
265
- logger.info("MC Dropout uncertainty completed")
266
- return mean_map, std_map
267
- except Exception as e:
268
- logger.error(f"Error in mc_dropout_uncertainty: {str(e)}", exc_info=True)
269
- model.eval()
270
- raise
271
-
272
- # ---------- Full pipeline ----------
273
- def interpret(img, det_choice, conf_thresh, cam_variant, mc_samples, dropout_p):
274
- logger.info(f"Starting interpretation - detection: {det_choice}, threshold: {conf_thresh}, cam: {cam_variant}, mc_samples: {mc_samples}, dropout_p: {dropout_p}")
275
- try:
276
- inputs = extractor(images=img, return_tensors="pt").to(device)
277
- with torch.no_grad(): outputs = model(**inputs)
278
- pixel_values_attr = inputs["pixel_values"].clone().requires_grad_(True)
279
- target_sizes = [img.size[::-1]]
280
- results = extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)[0]
281
- keep = results["scores"] > conf_thresh
282
- labels, scores = results["labels"][keep], results["scores"][keep]
283
-
284
- logger.info(f"Found {len(labels)} detections above threshold {conf_thresh}")
285
-
286
- if len(labels) == 0:
287
- logger.warning("No detections found above threshold")
288
- return None, "No detections above threshold", None, ""
289
-
290
- if det_choice is None:
291
- det_idx = 0
292
- else:
293
- try: det_idx = int(str(det_choice).split(":")[0])
294
- except: det_idx = 0
295
-
296
- label = model.config.id2label[labels[det_idx].item()]
297
- logger.info(f"Selected detection {det_idx}: {label}")
298
-
299
- # Grad-CAM / Grad-CAM++ (single deterministic pass)
300
- cam = gradcam(img, det_idx, keep, pixel_values_attr, use_pp=(cam_variant=="Grad-CAM++"))
301
- fig1, ax1 = plt.subplots(); ax1.imshow(img); ax1.imshow(cam, cmap="jet", alpha=0.5); ax1.axis("off")
302
- ax1.set_title(f"{cam_variant}: {label}"); plt.close(fig1)
303
- logger.debug(f"{cam_variant} visualization created")
304
-
305
- # MC Dropout Uncertainty analysis
306
- mean_map, std_map = mc_dropout_uncertainty(img, det_idx, keep, pixel_values_attr, n_samples=int(mc_samples), dropout_p=float(dropout_p))
307
- # Create a composite figure: mean map and std map side-by-side
308
- fig2, axes = plt.subplots(1,2, figsize=(8,4))
309
- axes[0].imshow(img); axes[0].imshow(mean_map, cmap='hot', alpha=0.5); axes[0].axis('off'); axes[0].set_title('Predictive Mean')
310
- axes[1].imshow(img); axes[1].imshow(std_map, cmap='viridis', alpha=0.5); axes[1].axis('off'); axes[1].set_title('Predictive Std (Uncertainty)')
311
- plt.close(fig2)
312
- logger.debug("MC Dropout uncertainty visualization created")
313
-
314
- exp1 = f"🔎 {cam_variant}:\nGradient-weighted feature maps → highlights where DETR focused."
315
- exp2 = f"🔎 MC Dropout Uncertainty:\nSamples={mc_samples}, dropout={dropout_p}. Shows predictive mean and per-pixel std as uncertainty."
316
-
317
- logger.info("Interpretation completed successfully")
318
- return fig1, exp1, fig2, exp2
319
- except Exception as e:
320
- logger.error(f"Error in interpret function: {str(e)}", exc_info=True)
321
- return None, f"Error: {str(e)}", None, ""
322
-
323
- # ---------- Gradio UI ----------
324
- with gr.Blocks() as demo:
325
- gr.Markdown("## 🧠 DETR Interpretability Dashboard with Controls")
326
- gr.Markdown(
327
- """
328
- **How to use this dashboard**
329
-
330
- - Upload an image using the left panel. The model will run object detection and list detected objects.
331
- - Use the "Confidence Threshold" slider to filter detections by score. Detections below the threshold are hidden.
332
- - Pick a detection from the dropdown to generate explanations for that object.
333
- - Choose between `Grad-CAM` and `Grad-CAM++` (Grad-CAM++ often gives sharper, more localized maps).
334
- - `MC Dropout Samples` controls how many stochastic forward passes are used to estimate prediction uncertainty. More samples give smoother estimates but take longer.
335
- - `Dropout Probability` sets the dropout rate used during MC Dropout; higher values typically increase predicted uncertainty.
336
-
337
- Tooltips are provided on each control (hover or focus) for quick hints.
338
- """
339
- )
340
-
341
- with gr.Row():
342
- img_in = gr.Image(type="pil", label="Upload an image")
343
- det_out = gr.Label(label="Detections")
344
- det_fig = gr.Plot(label="Detections visualization")
345
- det_choice = gr.Dropdown(label="Pick a detection for explanation")
346
-
347
- with gr.Row():
348
- conf_thresh = gr.Slider(0, 1, value=0.7, step=0.05, label="Confidence Threshold")
349
- cam_variant = gr.Radio(["Grad-CAM", "Grad-CAM++"], value="Grad-CAM", label="Grad-CAM Variant")
350
- mc_samples = gr.Slider(1, 100, value=20, step=1, label="MC Dropout Samples")
351
- dropout_p = gr.Slider(0.0, 0.9, value=0.1, step=0.05, label="Dropout Probability")
352
-
353
- btn = gr.Button("Explain")
354
-
355
- gc_fig = gr.Plot(label="Grad-CAM / Grad-CAM++")
356
- gc_txt = gr.Textbox(label="Explanation (Grad-CAM)")
357
- unc_fig = gr.Plot(label="Uncertainty (MC Dropout)")
358
- unc_txt = gr.Textbox(label="Explanation (Uncertainty)")
359
-
360
- # Visible control tooltips section (for environments where hovering tooltips are not available)
361
- gr.Markdown(
362
- """
363
- **Control tooltips (quick reference)**
364
-
365
- - Confidence Threshold: Filter out detections with confidence below this value.
366
- - Grad-CAM Variant: Choose the gradient-based visualization method. Grad-CAM++ may highlight smaller regions more precisely.
367
- - MC Dropout Samples: Number of stochastic forward passes for uncertainty estimation. Increase for more stable results.
368
- - Dropout Probability: Dropout rate used during MC Dropout sampling. Higher values typically increase predictive variance.
369
- - Pick a detection: Select which detected object to explain. Format shown as 'index: label (score)'.
370
- """
371
- )
372
-
373
- # ---------- Key interpretability choices (Feynman-style) ----------
374
- gr.Markdown(
375
- """
376
- **Key interpretability choices & why they matter**
377
-
378
- - **Baseline (Integrated Gradients)**: Defines what 'no signal' looks like. Black (zeros) is simple, but blurred or neutral baselines may give more meaningful attributions.
379
- - **Which conv layer for Grad-CAM**: Early layers give fine texture but low semantics; very late layers are coarse. A late backbone conv (default used) is a good compromise.
380
- - **Number of MC Dropout samples**: More samples = smoother, more stable uncertainty estimates, but higher compute cost.
381
- - **Grad-CAM vs Grad-CAM++**: Grad-CAM++ can be sharper and better for overlapping instances; vanilla Grad-CAM is faster and simpler.
382
- """
383
- )
384
-
385
- # ---------- Further reading / Feynman-style references ----------
386
- # Add short, clickable references so users can read the original papers and deep-dive articles.
387
- gr.Markdown(
388
- """
389
- **Further reading (recommended)**
390
-
391
- - [Grad-CAM — Selvaraju et al., 2017 (arXiv)](https://arxiv.org/abs/1610.02391) — the original Grad-CAM paper; explains the core idea of gradient-weighted localization.
392
- - [Grad-CAM++ — Chattopadhay et al.](https://arxiv.org/abs/1710.11063) — an improved variant that often produces sharper maps and handles multiple instances better.
393
- - [Visualizing the Impact of Feature Attribution Baselines (Distill)](https://distill.pub/2020/attribution-baselines) — an accessible deep dive on baseline choices for Integrated Gradients.
394
- - [Captum docs — IntegratedGradients](https://captum.ai/api/integrated_gradients.html) — practical API notes for baseline, n_steps, and convergence delta.
395
- - [Constructing sensible baselines for Integrated Gradients](https://arxiv.org/abs/2004.09627) — discussion and techniques for choosing baselines beyond a black image.
396
- - [A New Baseline Assumption of Integrated Gradients Based on Shapley Values](https://arxiv.org/html/2310.04821v3) — recent research on improved baselines.
397
- """
398
- )
399
-
400
- # Helper: safe label getter in case model.config.id2label is missing or not a dict
401
- def safe_label_lookup(idx):
402
- try:
403
- id2label = getattr(model.config, 'id2label', None)
404
- if id2label is None:
405
- return f"Class {idx}"
406
- return id2label.get(int(idx), f"Class {idx}")
407
- except Exception:
408
- return f"Class {idx}"
409
-
410
- def run_detect(img, conf_thresh):
411
- logger.info(f"Running detection with confidence threshold: {conf_thresh}")
412
- try:
413
- inputs = extractor(images=img, return_tensors="pt").to(device)
414
- with torch.no_grad(): outputs = model(**inputs)
415
- target_sizes = [img.size[::-1]]
416
- results = extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)[0]
417
- keep = results["scores"] > conf_thresh
418
- boxes, labels, scores = results["boxes"][keep], results["labels"][keep], results["scores"][keep]
419
-
420
- logger.info(f"Detection found {len(labels)} objects above threshold")
421
-
422
- det_list = [f"{i}: {safe_label_lookup(l.item())} ({s:.2f})" for i,(l,s) in enumerate(zip(labels,scores))]
423
- fig, ax = plt.subplots(); ax.imshow(img); ax.axis("off")
424
- for box,label,score in zip(boxes,labels,scores):
425
- xmin,ymin,xmax,ymax = box
426
- ax.add_patch(patches.Rectangle((xmin,ymin),xmax-xmin,ymax-ymin,fill=False,color="red",lw=2))
427
- ax.text(xmin,ymin,f"{safe_label_lookup(label.item())}:{score:.2f}",color="black",
428
- bbox=dict(facecolor="yellow",alpha=0.5))
429
- plt.close(fig)
430
- default_val = det_list[0] if len(det_list) > 0 else None
431
- logger.debug("Detection visualization created")
432
- return {det_out: str(det_list), det_fig: fig, det_choice: gr.update(choices=det_list, value=default_val)}
433
- except Exception as e:
434
- logger.error(f"Error in run_detect: {str(e)}", exc_info=True)
435
- return {det_out: "Error in detection", det_fig: None, det_choice: gr.update(choices=[], value=None)}
436
-
437
- img_in.change(run_detect, inputs=[img_in, conf_thresh], outputs=[det_out, det_fig, det_choice])
438
- btn.click(interpret, inputs=[img_in, det_choice, conf_thresh, cam_variant, mc_samples, dropout_p],
439
- outputs=[gc_fig, gc_txt, unc_fig, unc_txt])
440
-
441
- logger.info("Gradio interface configured, launching demo")
442
- demo.launch()