| | |
| | import torch |
| | import torch.nn as nn |
| | from transformers import PretrainedConfig, PreTrainedModel |
| | from transformers.modeling_outputs import ImageClassifierOutput |
| |
|
| |
|
| | class PrunedResNetConfig(PretrainedConfig): |
| | model_type = "resnet" |
| |
|
| | def __init__( |
| | self, channel_config: dict[str, int] | None = None, num_classes=1000, **kwargs |
| | ): |
| | super().__init__(**kwargs) |
| | self.channel_config = channel_config |
| | self.num_classes = num_classes |
| |
|
| |
|
| | class PrunedResNet50(PreTrainedModel): |
| | config_class = PrunedResNetConfig |
| | _tied_weights_keys = [] |
| |
|
| | def __init__(self, config: PrunedResNetConfig): |
| | super().__init__(config) |
| | self.config = config |
| | c = config.channel_config |
| | self.conv1 = nn.Conv2d( |
| | 3, c["conv1"], kernel_size=7, stride=2, padding=3, bias=False |
| | ) |
| | self.bn1 = nn.BatchNorm2d(c["conv1"]) |
| | self.relu = nn.ReLU(inplace=True) |
| | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
| | self.layer1 = self._make_layer(c, stage_idx=1, layers=3, stride=1) |
| | self.layer2 = self._make_layer(c, stage_idx=2, layers=4, stride=2) |
| | self.layer3 = self._make_layer(c, stage_idx=3, layers=6, stride=2) |
| | self.layer4 = self._make_layer(c, stage_idx=4, layers=3, stride=2) |
| | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
| | last_channel = c["layer4.2.conv3"] |
| | self.fc = nn.Linear(last_channel, config.num_classes) |
| | self.post_init() |
| |
|
| | def _make_layer(self, c, stage_idx, layers, stride): |
| | |
| | blocks = [] |
| |
|
| | |
| | blocks.append( |
| | Bottleneck( |
| | inplanes=c[f"layer{stage_idx}.0.in"], |
| | planes=[ |
| | c[f"layer{stage_idx}.0.conv1"], |
| | c[f"layer{stage_idx}.0.conv2"], |
| | c[f"layer{stage_idx}.0.conv3"], |
| | ], |
| | stride=stride, |
| | downsample_planes=c.get(f"layer{stage_idx}.0.downsample.0", None), |
| | ) |
| | ) |
| |
|
| | |
| | for i in range(1, layers): |
| | blocks.append( |
| | Bottleneck( |
| | inplanes=c[f"layer{stage_idx}.{i}.in"], |
| | planes=[ |
| | c[f"layer{stage_idx}.{i}.conv1"], |
| | c[f"layer{stage_idx}.{i}.conv2"], |
| | c[f"layer{stage_idx}.{i}.conv3"], |
| | ], |
| | ) |
| | ) |
| |
|
| | return nn.Sequential(*blocks) |
| |
|
| | def forward(self, pixel_values=None, labels=None, **kwargs): |
| | x = pixel_values |
| | x = self.conv1(x) |
| | x = self.bn1(x) |
| | x = self.relu(x) |
| | x = self.maxpool(x) |
| |
|
| | x = self.layer1(x) |
| | x = self.layer2(x) |
| | x = self.layer3(x) |
| | x = self.layer4(x) |
| |
|
| | x = self.avgpool(x) |
| | x = torch.flatten(x, 1) |
| | logits = self.fc(x) |
| | loss = None |
| | if labels is not None: |
| | |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1)) |
| | return ImageClassifierOutput(logits=logits, loss=loss) |
| |
|
| |
|
| | class Bottleneck(nn.Module): |
| | |
| | def __init__(self, inplanes, planes, stride=1, downsample_planes=None): |
| | super().__init__() |
| | c1, c2, c3 = planes |
| |
|
| | self.conv1 = nn.Conv2d(inplanes, c1, kernel_size=1, bias=False) |
| | self.bn1 = nn.BatchNorm2d(c1) |
| |
|
| | self.conv2 = nn.Conv2d( |
| | c1, c2, kernel_size=3, stride=stride, padding=1, bias=False |
| | ) |
| | self.bn2 = nn.BatchNorm2d(c2) |
| |
|
| | self.conv3 = nn.Conv2d(c2, c3, kernel_size=1, bias=False) |
| | self.bn3 = nn.BatchNorm2d(c3) |
| |
|
| | self.relu = nn.ReLU(inplace=True) |
| |
|
| | self.downsample = None |
| | if downsample_planes is not None: |
| | self.downsample = nn.Sequential( |
| | nn.Conv2d( |
| | inplanes, |
| | downsample_planes, |
| | kernel_size=1, |
| | stride=stride, |
| | bias=False, |
| | ), |
| | nn.BatchNorm2d(downsample_planes), |
| | ) |
| |
|
| | def forward(self, x): |
| | identity = x |
| | out = self.conv1(x) |
| | out = self.bn1(out) |
| | out = self.relu(out) |
| |
|
| | out = self.conv2(out) |
| | out = self.bn2(out) |
| | out = self.relu(out) |
| |
|
| | out = self.conv3(out) |
| | out = self.bn3(out) |
| |
|
| | if self.downsample is not None: |
| | identity = self.downsample(x) |
| |
|
| | out += identity |
| | out = self.relu(out) |
| | return out |
| |
|