Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Copyright (c) XiMing Xing. All rights reserved. | |
| # Author: XiMing Xing | |
| # Description: | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # Reference: https://arxiv.org/abs/1610.02391 | |
| def gradCAM( | |
| model: nn.Module, | |
| input: torch.Tensor, | |
| target: torch.Tensor, | |
| layer: nn.Module | |
| ) -> torch.Tensor: | |
| # Zero out any gradients at the input. | |
| if input.grad is not None: | |
| input.grad.data.zero_() | |
| # Disable gradient settings. | |
| requires_grad = {} | |
| for name, param in model.named_parameters(): | |
| requires_grad[name] = param.requires_grad | |
| param.requires_grad_(False) | |
| # Attach a hook to the model at the desired layer. | |
| assert isinstance(layer, nn.Module) | |
| with Hook(layer) as hook: | |
| # Do a forward and backward pass. | |
| output = model(input) | |
| output.backward(target) | |
| grad = hook.gradient.float() | |
| act = hook.activation.float() | |
| # Global average pool gradient across spatial dimension | |
| # to obtain importance weights. | |
| alpha = grad.mean(dim=(2, 3), keepdim=True) | |
| # Weighted combination of activation maps over channel | |
| # dimension. | |
| gradcam = torch.sum(act * alpha, dim=1, keepdim=True) | |
| # We only want neurons with positive influence so we | |
| # clamp any negative ones. | |
| gradcam = torch.clamp(gradcam, min=0) | |
| # Resize gradcam to input resolution. | |
| gradcam = F.interpolate(gradcam, input.shape[2:], mode='bicubic', align_corners=False) | |
| # Restore gradient settings. | |
| for name, param in model.named_parameters(): | |
| param.requires_grad_(requires_grad[name]) | |
| return gradcam | |
| class Hook: | |
| """Attaches to a module and records its activations and gradients.""" | |
| def __init__(self, module: nn.Module): | |
| self.data = None | |
| self.hook = module.register_forward_hook(self.save_grad) | |
| def save_grad(self, module, input, output): | |
| self.data = output | |
| output.requires_grad_(True) | |
| output.retain_grad() | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_value, exc_traceback): | |
| self.hook.remove() | |
| def activation(self) -> torch.Tensor: | |
| return self.data | |
| def gradient(self) -> torch.Tensor: | |
| return self.data.grad | |