Spaces:
Build error
Build error
| import torch | |
| from torch import nn | |
| from torchvision.models import densenet169 | |
| from config.finetune_config import set_args | |
| args = set_args() | |
| class Classifier(nn.Module): | |
| def __init__(self, num_classes): | |
| super(Classifier, self).__init__() | |
| self.GDConv1 = nn.Conv2d(1664 * 2, 1024, kernel_size=4, padding=0, dilation=2) | |
| self.GDConv2 = nn.Conv2d(1664 * 2, 1024, kernel_size=5, padding=1, dilation=2) | |
| self.GDConv3 = nn.Conv2d(1664 * 2, 1024, kernel_size=3, padding=0, dilation=3) | |
| self.LN1 = nn.LayerNorm([1024, 1, 1]) | |
| self.LN2 = nn.LayerNorm([1024, 1, 1]) | |
| self.LN3 = nn.LayerNorm([1024, 1, 1]) | |
| self.gelu = nn.GELU() | |
| self.fc_dropout = nn.Dropout(0.2) | |
| self.fc = nn.Linear(1024 * 3, num_classes) | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight) | |
| elif isinstance(m, nn.BatchNorm2d): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.Linear): | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x): | |
| x1 = self.GDConv1(x) | |
| x1 = self.LN1(x1) | |
| x1 = x1.view(x1.size(0), -1) | |
| x2 = self.GDConv2(x) | |
| x2 = self.LN2(x2) | |
| x2 = x2.view(x2.size(0), -1) | |
| x3 = self.GDConv3(x) | |
| x3 = self.LN3(x3) | |
| x3 = x3.view(x3.size(0), -1) | |
| X = torch.cat((x1, x2, x3), 1) | |
| X = self.gelu(X) | |
| output = self.fc(self.fc_dropout(X)) | |
| return output | |
| class M_DenseNet(nn.Module): | |
| def __init__(self, pretrain='IN', num_classes=8): | |
| super(M_DenseNet, self).__init__() | |
| # feature layer | |
| if pretrain == 'IN': | |
| model = densenet169(pretrained=True) # 此处的model参数是已经加载了预训练参数的模型 | |
| self.feature = nn.Sequential(*list(model.children())[:-1]) | |
| else: | |
| model = torch.load(args.finetune_path) | |
| self.feature = nn.Sequential(*list(model.children())[:-2]) | |
| self.classifier = Classifier(num_classes) | |
| def forward(self, left, right): | |
| left = self.feature(left) | |
| right = self.feature(right) | |
| x = torch.cat((left, right), 1) | |
| X = self.classifier(x) | |
| return X | |
| if __name__ == '__main__': | |
| model = M_DenseNet() | |
| input1 = torch.normal(0, 1, size=(4, 3, 224, 224)) | |
| input2 = torch.normal(0, 1, size=(4, 3, 224, 224)) | |
| output = model(input1, input2) | |
| print(output) | |