buffaX commited on
Commit
a5ba72b
·
verified ·
1 Parent(s): 4f5a128

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +163 -0
model.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import lightning as L
6
+
7
+
8
+ class BasicBlock(nn.Module):
9
+ expansion = 1 # ResNet18/34 使用 expansion=1
10
+
11
+ def __init__(self, in_channels, out_channels, stride=1):
12
+ super().__init__()
13
+ self.conv1 = nn.Conv2d(
14
+ in_channels, out_channels, kernel_size=3,
15
+ stride=stride, padding=1, bias=False
16
+ )
17
+ self.bn1 = nn.BatchNorm2d(out_channels)
18
+
19
+ self.conv2 = nn.Conv2d(
20
+ out_channels, out_channels, kernel_size=3,
21
+ stride=1, padding=1, bias=False
22
+ )
23
+ self.bn2 = nn.BatchNorm2d(out_channels)
24
+
25
+ # Downsample for shape mismatch
26
+ self.shortcut = nn.Sequential()
27
+ if stride != 1 or in_channels != out_channels:
28
+ self.shortcut = nn.Sequential(
29
+ nn.Conv2d(
30
+ in_channels, out_channels, kernel_size=1,
31
+ stride=stride, bias=False
32
+ ),
33
+ nn.BatchNorm2d(out_channels)
34
+ )
35
+
36
+ def forward(self, x):
37
+ out = F.relu(self.bn1(self.conv1(x)))
38
+ out = self.bn2(self.conv2(out))
39
+ out += self.shortcut(x)
40
+ out = F.relu(out)
41
+ return out
42
+
43
+
44
+ class ResNet18_CIFAR10(nn.Module):
45
+ def __init__(self, num_classes=10):
46
+ super().__init__()
47
+
48
+ # 第一层换成 CIFAR10 友好的 3x3 conv,去掉 maxpool
49
+ self.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)
50
+ self.bn1 = nn.BatchNorm2d(64)
51
+
52
+ # ResNet stages
53
+ self.layer1 = self._make_layer(64, 64, num_blocks=2, stride=1)
54
+ self.layer2 = self._make_layer(64, 128, num_blocks=2, stride=2) # 32x32 -> 16x16
55
+ self.layer3 = self._make_layer(128, 256, num_blocks=2, stride=2) # 16x16 -> 8x8
56
+ self.layer4 = self._make_layer(256, 512, num_blocks=2, stride=2) # 8x8 -> 4x4
57
+
58
+ self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
59
+ self.fc = nn.Sequential(
60
+ nn.Dropout(0.2),
61
+ nn.Linear(512 * BasicBlock.expansion, num_classes)
62
+ )
63
+
64
+ def _make_layer(self, in_c, out_c, num_blocks, stride):
65
+ layers = []
66
+ layers.append(BasicBlock(in_c, out_c, stride))
67
+ for _ in range(1, num_blocks):
68
+ layers.append(BasicBlock(out_c, out_c, stride=1)) # 后续 block stride=1
69
+ return nn.Sequential(*layers)
70
+
71
+ def forward(self, x):
72
+ out = F.relu(self.bn1(self.conv1(x))) # 注意这里有relu
73
+
74
+ out = self.layer1(out)
75
+ out = self.layer2(out)
76
+ out = self.layer3(out)
77
+ out = self.layer4(out)
78
+
79
+ out = self.avg_pool(out) # [B, 512, 1, 1]
80
+ out = torch.flatten(out, 1) # [B, 512]
81
+ out = self.fc(out) # [B, num_classes]
82
+ return out
83
+
84
+
85
+
86
+ class CIFARCNN(L.LightningModule):
87
+ def __init__(self, lr=1e-3):
88
+ super().__init__()
89
+ self.save_hyperparameters()
90
+ self.example_input_array = torch.Tensor(64, 3, 32, 32)
91
+
92
+ self.net = ResNet18_CIFAR10(num_classes=10)
93
+
94
+ self.loss_fn = nn.CrossEntropyLoss()
95
+
96
+ def forward(self, x):
97
+ return self.net(x)
98
+
99
+ def training_step(self, batch, batch_idx): # _代表batch_idx,这里不需要用到
100
+ x, y = batch
101
+ logits = self(x)
102
+ loss = self.loss_fn(logits, y)
103
+
104
+ preds = torch.argmax(logits, dim=1)
105
+ acc = (preds == y).float().mean()
106
+
107
+ self.log("train_loss", loss, on_step=True, prog_bar=True) # 在每个step记录
108
+ self.log("train_acc", acc, on_step=True, prog_bar=True)
109
+ return loss
110
+
111
+
112
+ def validation_step(self, batch, batch_idx):
113
+ x, y = batch
114
+ logits = self(x)
115
+ loss = self.loss_fn(logits, y)
116
+
117
+ preds = torch.argmax(logits, dim=1)
118
+ acc = (preds == y).float().mean()
119
+
120
+ # log 专门给 validation 用:
121
+ self.log("val_loss", loss, prog_bar=True, sync_dist=True) # 把val_loss显示在lightning的progress bar上; sync_dist=True表示在分布式训练时同步各个设备上的指标
122
+ self.log("val_acc", acc, prog_bar=True, sync_dist=True)
123
+
124
+ return {"val_loss": loss, "val_acc": acc}
125
+
126
+ def test_step(self, batch, batch_idx):
127
+ x, y = batch
128
+ logits = self(x)
129
+ loss = self.loss_fn(logits, y)
130
+
131
+ preds = torch.argmax(logits, dim=1)
132
+ acc = (preds == y).float().mean()
133
+
134
+ self.log("test_loss", loss, prog_bar=True)
135
+ self.log("test_acc", acc, prog_bar=True)
136
+
137
+ return {"test_loss": loss, "test_acc": acc}
138
+
139
+ def predict_step(self, batch, batch_idx, dataloader_idx=0):
140
+ x, _ = batch
141
+ return self(x)
142
+
143
+ def configure_optimizers(self):
144
+ optimizer = torch.optim.SGD(
145
+ self.parameters(),
146
+ lr=self.hparams.lr,
147
+ momentum=0.9,
148
+ weight_decay=5e-4
149
+ )
150
+
151
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
152
+ optimizer, T_max=self.trainer.max_epochs
153
+ )
154
+ return {"optimizer": optimizer, "lr_scheduler": scheduler}
155
+
156
+
157
+
158
+ if __name__ == "__main__":
159
+ # 简单测试前向传播
160
+ model = CIFARCNN()
161
+ x = torch.randn(4, 3, 32, 32).to(model.device)
162
+ logits = model(x)
163
+ print(logits.shape) # [4, 10]