| def set_bn_eval(m): | |
| classname = m.__class__.__name__ | |
| if classname.find('BatchNorm') != -1: | |
| m.eval() | |
| def set_bn_non_trainable(m): | |
| classname = m.__class__.__name__ | |
| if classname.find('BatchNorm') != -1: | |
| m.weight.requires_grad = False | |
| m.bias.requires_grad = False | |
| def freeze_bn_statistics(model): | |
| """freeze the statistic mean and variance in BN | |
| Args: | |
| model (nn.Module): The model to be freezed statistics. | |
| """ | |
| model.apply(set_bn_eval) | |
| def freeze_bn_parameters(model): | |
| """ | |
| Args: | |
| model (nn.Module): The model to be freezed statistics. | |
| Returns: TODO | |
| """ | |
| model.apply(set_bn_non_trainable) | |