Spaces:
Build error
Build error
| import torch | |
| def get_grad_norm(model, l=2): | |
| num_para = 0 | |
| accu_grad = 0 | |
| if isinstance(model, torch.nn.Module): | |
| params = model.parameters() | |
| else: | |
| params = model | |
| for p in params: | |
| if p.grad is None: | |
| continue | |
| num_para += p.numel() | |
| if l == 1: | |
| accu_grad += p.grad.abs(1).sum() | |
| elif l == 2: | |
| accu_grad += p.grad.pow(2).sum() | |
| else: | |
| raise ValueError("Now we only implement l1/l2 norm !") | |
| if l == 2: | |
| accu_grad = accu_grad ** 0.5 | |
| if isinstance(accu_grad, float): | |
| return accu_grad | |
| return accu_grad.item() | |
| class GradBuffer: | |
| def __init__(self): | |
| self.buffer = {} | |
| def add(self, model): | |
| for item in model.named_parameters(): | |
| name, param = item | |
| if param.grad is None: | |
| continue | |
| self.buffer[name] = self.buffer.get(name, 0) + param.grad.data | |
| def apply(self, model): | |
| for item in model.named_parameters(): | |
| name, param = item | |
| if param.grad is None: | |
| continue | |
| if name in self.buffer.keys(): | |
| param.grad.data += self.buffer[name] | |
| self.buffer = {} |