sergio-sanz-rodriguez commited on
Commit
b109b29
·
1 Parent(s): 2bf4af2

updated vision_transformer.py

Browse files
Files changed (1) hide show
  1. vision_transformer.py +2 -356
vision_transformer.py CHANGED
@@ -1,362 +1,8 @@
1
- import os
2
- import random
3
  import torch
4
  import torchvision
5
- import torch._dynamo
6
- import matplotlib.pyplot as plt
7
- from typing import List
8
  from torch import nn
9
- from torch.utils.data import DataLoader
10
- from torch.nn.init import trunc_normal_, xavier_normal_, zeros_, orthogonal_, kaiming_normal_
11
- from torchvision import datasets
12
- from torchvision.transforms import v2
13
-
14
- def display_random_images(dataset: torch.utils.data.dataset.Dataset, # or torchvision.datasets.ImageFolder?
15
- classes: List[str] = None,
16
- n: int = 10,
17
- display_shape: bool = True,
18
- rows: int = 5,
19
- cols: int = 5,
20
- seed: int = None):
21
-
22
-
23
- """Displays a number of random images from a given dataset.
24
-
25
- Args:
26
- dataset (torch.utils.data.dataset.Dataset): Dataset to select random images from.
27
- classes (List[str], optional): Names of the classes. Defaults to None.
28
- n (int, optional): Number of images to display. Defaults to 10.
29
- display_shape (bool, optional): Whether to display the shape of the image tensors. Defaults to True.
30
- rows: number of rows of the subplot
31
- cols: number of columns of the subplot
32
- seed (int, optional): The seed to set before drawing random images. Defaults to None.
33
-
34
- Usage:
35
- display_random_images(train_data,
36
- n=16,
37
- classes=class_names,
38
- rows=4,
39
- cols=4,
40
- display_shape=False,
41
- seed=None)
42
- """
43
-
44
- # Setup the range to select images
45
- n = min(n, len(dataset))
46
- # Adjust display if n too high
47
- if n > rows*cols:
48
- n = rows*cols
49
- #display_shape = False
50
- print(f"For display purposes, n shouldn't be larger than {rows*cols}, setting to {n} and removing shape display.")
51
-
52
- # Set random seed
53
- if seed:
54
- random.seed(seed)
55
-
56
- # Get random sample indexes
57
- random_samples_idx = random.sample(range(len(dataset)), k=n)
58
-
59
- # Setup plot
60
- plt.figure(figsize=(cols*4, rows*4))
61
-
62
- #Loop through samples and display random samples
63
- for i, targ_sample in enumerate(random_samples_idx):
64
- targ_image, targ_label = dataset[targ_sample][0], dataset[targ_sample][1]
65
-
66
- # 7. Adjust image tensor shape for plotting: [color_channels, height, width] -> [color_channels, height, width]
67
- targ_image_adjust = targ_image.permute(1, 2, 0)
68
-
69
- # Plot adjusted samples
70
- plt.subplot(rows, cols, i+1)
71
- plt.imshow(targ_image_adjust)
72
- plt.axis("off")
73
- if classes:
74
- title = f"class: {classes[targ_label]}"
75
- if display_shape:
76
- title = title + f"\nshape: {targ_image_adjust.shape}"
77
- plt.title(title)
78
-
79
- def create_dataloaders(
80
- train_dir: str,
81
- test_dir: str,
82
- train_transform: v2.Compose,
83
- test_transform: v2.Compose,
84
- batch_size: int,
85
- num_workers: int=os.cpu_count()
86
- ):
87
- """Creates training and testing DataLoaders.
88
-
89
- Takes in a training directory and testing directory path and turns
90
- them into PyTorch Datasets and then into PyTorch DataLoaders.
91
-
92
- Args:
93
- train_dir: Path to training directory.
94
- test_dir: Path to testing directory.
95
- train_transform: torchvision transforms to perform on training data.
96
- test_transform: torchvision transforms to perform on test data.
97
- batch_size: Number of samples per batch in each of the DataLoaders.
98
- num_workers: An integer for number of workers per DataLoader.
99
-
100
- Returns:
101
- A tuple of (train_dataloader, test_dataloader, class_names).
102
- Where class_names is a list of the target classes.
103
- Example usage:
104
- train_dataloader, test_dataloader, class_names = \
105
- = create_dataloaders(train_dir=path/to/train_dir,
106
- test_dir=path/to/test_dir,
107
- transform=some_transform,
108
- batch_size=32,
109
- num_workers=4)
110
- """
111
- # Use ImageFolder to create dataset(s)
112
- train_data = datasets.ImageFolder(train_dir, transform=train_transform)
113
- test_data = datasets.ImageFolder(test_dir, transform=test_transform)
114
-
115
- # Get class names
116
- class_names = train_data.classes
117
-
118
- # Turn images into data loaders
119
- train_dataloader = DataLoader(
120
- train_data,
121
- batch_size=batch_size,
122
- shuffle=True,
123
- num_workers=num_workers,
124
- pin_memory=True, #enables fast data transfers to CUDA-enabled GPU
125
- )
126
- test_dataloader = DataLoader(
127
- test_data,
128
- batch_size=batch_size,
129
- shuffle=False,
130
- num_workers=num_workers,
131
- pin_memory=True, #enables fast data transfers to CUDA-enabled GPU
132
- )
133
-
134
- return train_dataloader, test_dataloader, class_names
135
-
136
- def create_dataloader_for_vit(
137
- vit_model: str="bitbase16",
138
- train_dir: str="./",
139
- test_dir: str="./",
140
- batch_size: int=64,
141
- aug: bool=True,
142
- display_imgs: bool=True,
143
- num_workers: int=os.cpu_count()
144
- ):
145
-
146
- """
147
- Creates data loaders for the training and test datasets to be used to traing visiton transformers.
148
-
149
- Args:
150
- vit_model (str): The name of the ViT model to use. Default is "bitbase16".
151
- train_dir (str): The path to the training dataset directory. Default is TRAIN_DIR.
152
- test_dir (str): The path to the test dataset directory. Default is TEST_DIR.
153
- batch_size (int): The batch size for the data loaders. Default is BATCH_SIZE.
154
- aug (bool): Whether to apply data augmentation or not. Default is True.
155
- display_imgs (bool): Whether to display sample images or not. Default is True.
156
-
157
- Returns:
158
- train_dataloader (torch.utils.data.DataLoader): The data loader for the training dataset.
159
- test_dataloader (torch.utils.data.DataLoader): The data loader for the test dataset.
160
- class_names (list): A list of class names.
161
- """
162
-
163
- IMG_SIZE = 224
164
- IMG_SIZE_2 = 384
165
-
166
- # Manual transforms for the training dataset
167
- manual_transforms = v2.Compose([
168
- v2.RandomCrop((IMG_SIZE, IMG_SIZE)),
169
- v2.ToImage(),
170
- v2.ToDtype(torch.float32, scale=True),
171
- ])
172
-
173
- # ViT-Base/16 transforms
174
- if vit_model == "vitbase16":
175
-
176
- # Manual transforms for the training dataset
177
- if aug:
178
- manual_transforms_train_vitb = v2.Compose([
179
- v2.TrivialAugmentWide(),
180
- v2.Resize((256, 256)),
181
- v2.RandomCrop((IMG_SIZE, IMG_SIZE)),
182
- v2.ToImage(),
183
- v2.ToDtype(torch.float32, scale=True),
184
- v2.Normalize(mean=[0.485, 0.456, 0.406],
185
- std=[0.229, 0.224, 0.225])
186
- ])
187
- else:
188
- manual_transforms_train_vitb = v2.Compose([
189
- v2.Resize((256, 256)),
190
- v2.CenterCrop((IMG_SIZE, IMG_SIZE)),
191
- v2.ToImage(),
192
- v2.ToDtype(torch.float32, scale=True),
193
- v2.Normalize(mean=[0.485, 0.456, 0.406],
194
- std=[0.229, 0.224, 0.225])
195
- ])
196
-
197
- # Manual transforms for the test dataset
198
- manual_transforms_test_vitb = v2.Compose([
199
- v2.Resize((256, 256)),
200
- v2.CenterCrop((IMG_SIZE, IMG_SIZE)),
201
- v2.ToImage(),
202
- v2.ToDtype(torch.float32, scale=True),
203
- v2.Normalize(mean=[0.485, 0.456, 0.406],
204
- std=[0.229, 0.224, 0.225])
205
- ])
206
-
207
- # Create data loaders for ViT-Base
208
- train_dataloader, test_dataloader, class_names = create_dataloaders(
209
- train_dir=train_dir,
210
- test_dir=test_dir,
211
- train_transform=manual_transforms_train_vitb,
212
- test_transform=manual_transforms_test_vitb,
213
- batch_size=batch_size,
214
- num_workers=num_workers
215
- )
216
-
217
- if vit_model == "vitbase16_2":
218
-
219
- # Manual transforms for the training dataset
220
- if aug:
221
- manual_transforms_train_vitb = v2.Compose([
222
- v2.TrivialAugmentWide(),
223
- v2.Resize((IMG_SIZE_2, IMG_SIZE_2)),
224
- v2.CenterCrop((IMG_SIZE_2, IMG_SIZE_2)),
225
- v2.ToImage(),
226
- v2.ToDtype(torch.float32, scale=True),
227
- v2.Normalize(mean=[0.485, 0.456, 0.406],
228
- std=[0.229, 0.224, 0.225])
229
- ])
230
- else:
231
- manual_transforms_train_vitb = v2.Compose([
232
- v2.Resize((IMG_SIZE_2, IMG_SIZE_2)),
233
- v2.CenterCrop((IMG_SIZE_2, IMG_SIZE_2)),
234
- v2.ToImage(),
235
- v2.ToDtype(torch.float32, scale=True),
236
- v2.Normalize(mean=[0.485, 0.456, 0.406],
237
- std=[0.229, 0.224, 0.225])
238
- ])
239
-
240
- # Manual transforms for the test dataset
241
- manual_transforms_test_vitb = v2.Compose([
242
- v2.Resize((IMG_SIZE_2, IMG_SIZE_2)),
243
- v2.CenterCrop((IMG_SIZE_2, IMG_SIZE_2)),
244
- v2.ToImage(),
245
- v2.ToDtype(torch.float32, scale=True),
246
- v2.Normalize(mean=[0.485, 0.456, 0.406],
247
- std=[0.229, 0.224, 0.225])
248
- ])
249
-
250
- # Create data loaders for ViT-Base
251
- train_dataloader, test_dataloader, class_names = create_dataloaders(
252
- train_dir=train_dir,
253
- test_dir=test_dir,
254
- train_transform=manual_transforms_train_vitb,
255
- test_transform=manual_transforms_test_vitb,
256
- batch_size=batch_size,
257
- num_workers=num_workers
258
- )
259
-
260
- # ViT-Large/16 transforms
261
- elif vit_model == "vitlarge16":
262
-
263
- # Manual transforms for the training dataset
264
- if aug:
265
- manual_transforms_train_vitl = v2.Compose([
266
- v2.TrivialAugmentWide(),
267
- v2.Resize((242, 242)),
268
- v2.RandomCrop((IMG_SIZE, IMG_SIZE)),
269
- v2.ToImage(),
270
- v2.ToDtype(torch.float32, scale=True),
271
- v2.Normalize(mean=[0.485, 0.456, 0.406],
272
- std=[0.229, 0.224, 0.225])
273
- ])
274
- else:
275
- manual_transforms_train_vitl = v2.Compose([
276
- v2.Resize((242, 242)),
277
- v2.CenterCrop((IMG_SIZE, IMG_SIZE)),
278
- v2.ToImage(),
279
- v2.ToDtype(torch.float32, scale=True),
280
- v2.Normalize(mean=[0.485, 0.456, 0.406],
281
- std=[0.229, 0.224, 0.225])
282
- ])
283
-
284
- # Manual transforms for the test dataset
285
- manual_transforms_test_vitl = v2.Compose([
286
- v2.Resize((242, 242)),
287
- v2.CenterCrop((IMG_SIZE, IMG_SIZE)),
288
- v2.ToImage(),
289
- v2.ToDtype(torch.float32, scale=True),
290
- v2.Normalize(mean=[0.485, 0.456, 0.406],
291
- std=[0.229, 0.224, 0.225])
292
- ])
293
-
294
- # Create data loaders for ViT-Large/16
295
- train_dataloader, test_dataloader, class_names = create_dataloaders(
296
- train_dir=train_dir,
297
- test_dir=test_dir,
298
- train_transform=manual_transforms_train_vitl,
299
- test_transform=manual_transforms_test_vitl,
300
- batch_size=batch_size,
301
- num_workers=num_workers
302
- )
303
-
304
- # ViT-Large/32 transforms
305
- else:
306
- # Manual transforms for the training dataset
307
- if aug:
308
- manual_transforms_train_vitl = v2.Compose([
309
- v2.TrivialAugmentWide(),
310
- v2.Resize((256, 256)),
311
- v2.RandomCrop((IMG_SIZE, IMG_SIZE)),
312
- v2.ToImage(),
313
- v2.ToDtype(torch.float32, scale=True),
314
- v2.Normalize(mean=[0.485, 0.456, 0.406],
315
- std=[0.229, 0.224, 0.225])
316
- ])
317
- else:
318
- manual_transforms_train_vitl = v2.Compose([
319
- v2.Resize((256, 256)),
320
- v2.CenterCrop((IMG_SIZE, IMG_SIZE)),
321
- v2.ToImage(),
322
- v2.ToDtype(torch.float32, scale=True),
323
- v2.Normalize(mean=[0.485, 0.456, 0.406],
324
- std=[0.229, 0.224, 0.225])
325
- ])
326
-
327
- # Manual transforms for the test dataset
328
- manual_transforms_test_vitl = v2.Compose([
329
- v2.Resize((256, 256)),
330
- v2.CenterCrop((IMG_SIZE, IMG_SIZE)),
331
- v2.ToImage(),
332
- v2.ToDtype(torch.float32, scale=True),
333
- v2.Normalize(mean=[0.485, 0.456, 0.406],
334
- std=[0.229, 0.224, 0.225])
335
- ])
336
-
337
- # Create data loaders for ViT-Large/32
338
- train_dataloader, test_dataloader, class_names = create_dataloaders(
339
- train_dir=train_dir,
340
- test_dir=test_dir,
341
- train_transform=manual_transforms_train_vitl,
342
- test_transform=manual_transforms_test_vitl,
343
- batch_size=batch_size,
344
- num_workers=num_workers
345
- )
346
-
347
- # Display images
348
- if display_imgs:
349
- train_data = datasets.ImageFolder(train_dir, transform=manual_transforms)
350
- display_random_images(train_data,
351
- n=25,
352
- classes=class_names,
353
- rows=5,
354
- cols=5,
355
- display_shape=False,
356
- seed=None)
357
-
358
- return train_dataloader, test_dataloader, class_names
359
-
360
 
361
  # Create Pytorch's default ViT models
362
  def create_vit(
 
 
 
1
  import torch
2
  import torchvision
 
 
 
3
  from torch import nn
4
+ from torch.nn.init import trunc_normal_
5
+ #, xavier_normal_, zeros_, orthogonal_, kaiming_normal_
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Create Pytorch's default ViT models
8
  def create_vit(