finhdev commited on
Commit
775a719
·
verified ·
1 Parent(s): 589a581

Update reparam.py

Browse files
Files changed (1) hide show
  1. reparam.py +341 -0
reparam.py CHANGED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+ from typing import Union, Tuple
6
+
7
+ import copy
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ __all__ = ["MobileOneBlock", "reparameterize_model"]
13
+
14
+
15
+ class SEBlock(nn.Module):
16
+ """Squeeze and Excite module.
17
+
18
+ Pytorch implementation of `Squeeze-and-Excitation Networks` -
19
+ https://arxiv.org/pdf/1709.01507.pdf
20
+ """
21
+
22
+ def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None:
23
+ """Construct a Squeeze and Excite Module.
24
+
25
+ Args:
26
+ in_channels: Number of input channels.
27
+ rd_ratio: Input channel reduction ratio.
28
+ """
29
+ super(SEBlock, self).__init__()
30
+ self.reduce = nn.Conv2d(
31
+ in_channels=in_channels,
32
+ out_channels=int(in_channels * rd_ratio),
33
+ kernel_size=1,
34
+ stride=1,
35
+ bias=True,
36
+ )
37
+ self.expand = nn.Conv2d(
38
+ in_channels=int(in_channels * rd_ratio),
39
+ out_channels=in_channels,
40
+ kernel_size=1,
41
+ stride=1,
42
+ bias=True,
43
+ )
44
+
45
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
46
+ """Apply forward pass."""
47
+ b, c, h, w = inputs.size()
48
+ x = F.avg_pool2d(inputs, kernel_size=[h, w])
49
+ x = self.reduce(x)
50
+ x = F.relu(x)
51
+ x = self.expand(x)
52
+ x = torch.sigmoid(x)
53
+ x = x.view(-1, c, 1, 1)
54
+ return inputs * x
55
+
56
+
57
+ class MobileOneBlock(nn.Module):
58
+ """MobileOne building block.
59
+
60
+ This block has a multi-branched architecture at train-time
61
+ and plain-CNN style architecture at inference time
62
+ For more details, please refer to our paper:
63
+ `An Improved One millisecond Mobile Backbone` -
64
+ https://arxiv.org/pdf/2206.04040.pdf
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ in_channels: int,
70
+ out_channels: int,
71
+ kernel_size: int,
72
+ stride: int = 1,
73
+ padding: int = 0,
74
+ dilation: int = 1,
75
+ groups: int = 1,
76
+ inference_mode: bool = False,
77
+ use_se: bool = False,
78
+ use_act: bool = True,
79
+ use_scale_branch: bool = True,
80
+ num_conv_branches: int = 1,
81
+ activation: nn.Module = nn.GELU(),
82
+ ) -> None:
83
+ """Construct a MobileOneBlock module.
84
+
85
+ Args:
86
+ in_channels: Number of channels in the input.
87
+ out_channels: Number of channels produced by the block.
88
+ kernel_size: Size of the convolution kernel.
89
+ stride: Stride size.
90
+ padding: Zero-padding size.
91
+ dilation: Kernel dilation factor.
92
+ groups: Group number.
93
+ inference_mode: If True, instantiates model in inference mode.
94
+ use_se: Whether to use SE-ReLU activations.
95
+ use_act: Whether to use activation. Default: ``True``
96
+ use_scale_branch: Whether to use scale branch. Default: ``True``
97
+ num_conv_branches: Number of linear conv branches.
98
+ """
99
+ super(MobileOneBlock, self).__init__()
100
+ self.inference_mode = inference_mode
101
+ self.groups = groups
102
+ self.stride = stride
103
+ self.padding = padding
104
+ self.dilation = dilation
105
+ self.kernel_size = kernel_size
106
+ self.in_channels = in_channels
107
+ self.out_channels = out_channels
108
+ self.num_conv_branches = num_conv_branches
109
+
110
+ # Check if SE-ReLU is requested
111
+ if use_se:
112
+ self.se = SEBlock(out_channels)
113
+ else:
114
+ self.se = nn.Identity()
115
+
116
+ if use_act:
117
+ self.activation = activation
118
+ else:
119
+ self.activation = nn.Identity()
120
+
121
+ if inference_mode:
122
+ self.reparam_conv = nn.Conv2d(
123
+ in_channels=in_channels,
124
+ out_channels=out_channels,
125
+ kernel_size=kernel_size,
126
+ stride=stride,
127
+ padding=padding,
128
+ dilation=dilation,
129
+ groups=groups,
130
+ bias=True,
131
+ )
132
+ else:
133
+ # Re-parameterizable skip connection
134
+ self.rbr_skip = (
135
+ nn.BatchNorm2d(num_features=in_channels)
136
+ if out_channels == in_channels and stride == 1
137
+ else None
138
+ )
139
+
140
+ # Re-parameterizable conv branches
141
+ if num_conv_branches > 0:
142
+ rbr_conv = list()
143
+ for _ in range(self.num_conv_branches):
144
+ rbr_conv.append(
145
+ self._conv_bn(kernel_size=kernel_size, padding=padding)
146
+ )
147
+ self.rbr_conv = nn.ModuleList(rbr_conv)
148
+ else:
149
+ self.rbr_conv = None
150
+
151
+ # Re-parameterizable scale branch
152
+ self.rbr_scale = None
153
+ if not isinstance(kernel_size, int):
154
+ kernel_size = kernel_size[0]
155
+ if (kernel_size > 1) and use_scale_branch:
156
+ self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)
157
+
158
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
159
+ """Apply forward pass."""
160
+ # Inference mode forward pass.
161
+ if self.inference_mode:
162
+ return self.activation(self.se(self.reparam_conv(x)))
163
+
164
+ # Multi-branched train-time forward pass.
165
+ # Skip branch output
166
+ identity_out = 0
167
+ if self.rbr_skip is not None:
168
+ identity_out = self.rbr_skip(x)
169
+
170
+ # Scale branch output
171
+ scale_out = 0
172
+ if self.rbr_scale is not None:
173
+ scale_out = self.rbr_scale(x)
174
+
175
+ # Other branches
176
+ out = scale_out + identity_out
177
+ if self.rbr_conv is not None:
178
+ for ix in range(self.num_conv_branches):
179
+ out += self.rbr_conv[ix](x)
180
+
181
+ return self.activation(self.se(out))
182
+
183
+ def reparameterize(self):
184
+ """Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
185
+ https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
186
+ architecture used at training time to obtain a plain CNN-like structure
187
+ for inference.
188
+ """
189
+ if self.inference_mode:
190
+ return
191
+ kernel, bias = self._get_kernel_bias()
192
+ self.reparam_conv = nn.Conv2d(
193
+ in_channels=self.in_channels,
194
+ out_channels=self.out_channels,
195
+ kernel_size=self.kernel_size,
196
+ stride=self.stride,
197
+ padding=self.padding,
198
+ dilation=self.dilation,
199
+ groups=self.groups,
200
+ bias=True,
201
+ )
202
+ self.reparam_conv.weight.data = kernel
203
+ self.reparam_conv.bias.data = bias
204
+
205
+ # Delete un-used branches
206
+ for para in self.parameters():
207
+ para.detach_()
208
+ self.__delattr__("rbr_conv")
209
+ self.__delattr__("rbr_scale")
210
+ if hasattr(self, "rbr_skip"):
211
+ self.__delattr__("rbr_skip")
212
+
213
+ self.inference_mode = True
214
+
215
+ def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
216
+ """Method to obtain re-parameterized kernel and bias.
217
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
218
+
219
+ Returns:
220
+ Tuple of (kernel, bias) after fusing branches.
221
+ """
222
+ # get weights and bias of scale branch
223
+ kernel_scale = 0
224
+ bias_scale = 0
225
+ if self.rbr_scale is not None:
226
+ kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
227
+ # Pad scale branch kernel to match conv branch kernel size.
228
+ pad = self.kernel_size // 2
229
+ kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
230
+
231
+ # get weights and bias of skip branch
232
+ kernel_identity = 0
233
+ bias_identity = 0
234
+ if self.rbr_skip is not None:
235
+ kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
236
+
237
+ # get weights and bias of conv branches
238
+ kernel_conv = 0
239
+ bias_conv = 0
240
+ if self.rbr_conv is not None:
241
+ for ix in range(self.num_conv_branches):
242
+ _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
243
+ kernel_conv += _kernel
244
+ bias_conv += _bias
245
+
246
+ kernel_final = kernel_conv + kernel_scale + kernel_identity
247
+ bias_final = bias_conv + bias_scale + bias_identity
248
+ return kernel_final, bias_final
249
+
250
+ def _fuse_bn_tensor(
251
+ self, branch: Union[nn.Sequential, nn.BatchNorm2d]
252
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
253
+ """Method to fuse batchnorm layer with preceeding conv layer.
254
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
255
+
256
+ Args:
257
+ branch: Sequence of ops to be fused.
258
+
259
+ Returns:
260
+ Tuple of (kernel, bias) after fusing batchnorm.
261
+ """
262
+ if isinstance(branch, nn.Sequential):
263
+ kernel = branch.conv.weight
264
+ running_mean = branch.bn.running_mean
265
+ running_var = branch.bn.running_var
266
+ gamma = branch.bn.weight
267
+ beta = branch.bn.bias
268
+ eps = branch.bn.eps
269
+ else:
270
+ assert isinstance(branch, nn.BatchNorm2d)
271
+ if not hasattr(self, "id_tensor"):
272
+ input_dim = self.in_channels // self.groups
273
+
274
+ kernel_size = self.kernel_size
275
+ if isinstance(self.kernel_size, int):
276
+ kernel_size = (self.kernel_size, self.kernel_size)
277
+
278
+ kernel_value = torch.zeros(
279
+ (self.in_channels, input_dim, kernel_size[0], kernel_size[1]),
280
+ dtype=branch.weight.dtype,
281
+ device=branch.weight.device,
282
+ )
283
+ for i in range(self.in_channels):
284
+ kernel_value[
285
+ i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2
286
+ ] = 1
287
+ self.id_tensor = kernel_value
288
+ kernel = self.id_tensor
289
+ running_mean = branch.running_mean
290
+ running_var = branch.running_var
291
+ gamma = branch.weight
292
+ beta = branch.bias
293
+ eps = branch.eps
294
+ std = (running_var + eps).sqrt()
295
+ t = (gamma / std).reshape(-1, 1, 1, 1)
296
+ return kernel * t, beta - running_mean * gamma / std
297
+
298
+ def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
299
+ """Helper method to construct conv-batchnorm layers.
300
+
301
+ Args:
302
+ kernel_size: Size of the convolution kernel.
303
+ padding: Zero-padding size.
304
+
305
+ Returns:
306
+ Conv-BN module.
307
+ """
308
+ mod_list = nn.Sequential()
309
+ mod_list.add_module(
310
+ "conv",
311
+ nn.Conv2d(
312
+ in_channels=self.in_channels,
313
+ out_channels=self.out_channels,
314
+ kernel_size=kernel_size,
315
+ stride=self.stride,
316
+ padding=padding,
317
+ groups=self.groups,
318
+ bias=False,
319
+ ),
320
+ )
321
+ mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels))
322
+ return mod_list
323
+
324
+
325
+ def reparameterize_model(model: torch.nn.Module) -> nn.Module:
326
+ """Method returns a model where a multi-branched structure
327
+ used in training is re-parameterized into a single branch
328
+ for inference.
329
+
330
+ Args:
331
+ model: MobileOne model in train mode.
332
+
333
+ Returns:
334
+ MobileOne model in inference mode.
335
+ """
336
+ # Avoid editing original graph
337
+ model = copy.deepcopy(model)
338
+ for module in model.modules():
339
+ if hasattr(module, "reparameterize"):
340
+ module.reparameterize()
341
+ return model