sergio-sanz-rodriguez commited on
Commit
2bf4af2
·
1 Parent(s): 95510f9

updated app with 101 + unknown trained classes

Browse files
Files changed (2) hide show
  1. app.py +7 -46
  2. vision_transformer.py +353 -0
app.py CHANGED
@@ -31,36 +31,17 @@ effnetb0_model = create_effnetb0(
31
  num_classes=2
32
  )
33
 
34
- # Load the ViT-Base/16 transformer with input image of 224x224 pixels
35
- vitbase_model_1 = create_vitbase_model(
36
  model_weights_dir=".",
37
- model_weights_name="vitbase16_5.pth",
38
- img_size=224,
39
- num_classes=num_classes,
40
- compile=False
41
- )
42
-
43
- # Specify manual transforms for model_1
44
- transforms_1 = v2.Compose([
45
- v2.Resize((242, 242)),
46
- v2.CenterCrop((224, 224)),
47
- v2.ToImage(),
48
- v2.ToDtype(torch.float32, scale=True),
49
- v2.Normalize(mean=[0.485, 0.456, 0.406],
50
- std=[0.229, 0.224, 0.225])
51
- ])
52
-
53
- # Load the ViT-Base/16 transformer with input image of 384x384 pixels
54
- vitbase_model_2 = create_vitbase_model(
55
- model_weights_dir=".",
56
- model_weights_name="vitbase16_2_2024-12-31.pth",
57
  img_size=384,
58
  num_classes=num_classes,
59
  compile=True
60
  )
61
 
62
  # Specify manual transforms for model_2
63
- transforms_2 = v2.Compose([
64
  v2.Resize(384), #v2.Resize((384, 384)),
65
  v2.CenterCrop((384, 384)),
66
  v2.ToImage(),
@@ -69,13 +50,10 @@ transforms_2 = v2.Compose([
69
  std=[0.229, 0.224, 0.225])
70
  ])
71
 
 
72
  # Put models into evaluation mode and turn on inference mode
73
  effnetb0_model.eval()
74
- vitbase_model_1.eval()
75
- vitbase_model_2.eval()
76
-
77
- # Specify default ViT model
78
- default_model = "Vision Transformer - 384x384 pixels (higher accuracy, slower predictions)" # "Vision Transformer - 224x224 pixels (lower accuracy, faster predictions)"
79
 
80
  # Predict function
81
  def predict(image) -> Tuple[Dict, str, str]:
@@ -86,14 +64,6 @@ def predict(image) -> Tuple[Dict, str, str]:
86
  # Start the timer
87
  start_time = timer()
88
 
89
- # Select the appropriate model based on the user's choice
90
- if default_model == "Vision Transformer - 384x384 pixels (higher accuracy, slower predictions)":
91
- vitbase_model = vitbase_model_2
92
- transforms = transforms_2
93
- else:
94
- vitbase_model = vitbase_model_1
95
- transforms = transforms_1
96
-
97
  # Transform the target image and add a batch dimension
98
  image = transforms(image).unsqueeze(0)
99
 
@@ -104,14 +74,13 @@ def predict(image) -> Tuple[Dict, str, str]:
104
  if effnetb0_model(image)[:,1].cpu() >= 0.9981166124343872:
105
 
106
  # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
107
- pred_probs = torch.softmax(vitbase_model(image), dim=1) # 101 classes
108
 
109
  # Calculate entropy
110
  entropy = -torch.sum(pred_probs * torch.log(pred_probs), dim=1).item()
111
 
112
  # Create a prediction label and prediction probability dictionary for each prediction class
113
  pred_classes_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(num_classes)}
114
- pred_classes_and_probs["unknown"] = 0.0
115
 
116
  # Get the top predicted class
117
  top_class = max(pred_classes_and_probs, key=pred_classes_and_probs.get)
@@ -164,14 +133,6 @@ A cutting-edge Vision Transformer (ViT) model to classify 101 delicious food typ
164
  # Configure the upload image area
165
  upload_input = gr.Image(type="pil", label="Upload Image", sources=['upload'], show_label=True, mirror_webcam=False)
166
 
167
- # Configure the dropdown option
168
- #model_dropdown = gr.Dropdown(
169
- # choices=["Vision Transformer - 384x384 pixels (higher accuracy, slower predictions)",
170
- # "Vision Transformer - 224x224 pixels (lower accuracy, faster predictions)"],
171
- # value="Vision Transformer - 384x384 pixels (higher accuracy, slower predictions)",
172
- # label="Select Model:"
173
- #)
174
-
175
  # Configure the sample image area
176
  food_vision_examples = [["examples/" + example] for example in os.listdir("examples")]
177
 
 
31
  num_classes=2
32
  )
33
 
34
+ # Load the ViT-Base/16 transformer with input image of 384x384 pixels and 101 + unknown classes
35
+ vitbase_model = create_vitbase_model(
36
  model_weights_dir=".",
37
+ model_weights_name="vitbase16_102_2025-01-03.pth",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  img_size=384,
39
  num_classes=num_classes,
40
  compile=True
41
  )
42
 
43
  # Specify manual transforms for model_2
44
+ transforms = v2.Compose([
45
  v2.Resize(384), #v2.Resize((384, 384)),
46
  v2.CenterCrop((384, 384)),
47
  v2.ToImage(),
 
50
  std=[0.229, 0.224, 0.225])
51
  ])
52
 
53
+
54
  # Put models into evaluation mode and turn on inference mode
55
  effnetb0_model.eval()
56
+ vitbase_model.eval()
 
 
 
 
57
 
58
  # Predict function
59
  def predict(image) -> Tuple[Dict, str, str]:
 
64
  # Start the timer
65
  start_time = timer()
66
 
 
 
 
 
 
 
 
 
67
  # Transform the target image and add a batch dimension
68
  image = transforms(image).unsqueeze(0)
69
 
 
74
  if effnetb0_model(image)[:,1].cpu() >= 0.9981166124343872:
75
 
76
  # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
77
+ pred_probs = torch.softmax(vitbase_model(image), dim=1)
78
 
79
  # Calculate entropy
80
  entropy = -torch.sum(pred_probs * torch.log(pred_probs), dim=1).item()
81
 
82
  # Create a prediction label and prediction probability dictionary for each prediction class
83
  pred_classes_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(num_classes)}
 
84
 
85
  # Get the top predicted class
86
  top_class = max(pred_classes_and_probs, key=pred_classes_and_probs.get)
 
133
  # Configure the upload image area
134
  upload_input = gr.Image(type="pil", label="Upload Image", sources=['upload'], show_label=True, mirror_webcam=False)
135
 
 
 
 
 
 
 
 
 
136
  # Configure the sample image area
137
  food_vision_examples = [["examples/" + example] for example in os.listdir("examples")]
138
 
vision_transformer.py CHANGED
@@ -1,8 +1,361 @@
 
 
1
  import torch
2
  import torchvision
3
  import torch._dynamo
 
 
4
  from torch import nn
 
5
  from torch.nn.init import trunc_normal_, xavier_normal_, zeros_, orthogonal_, kaiming_normal_
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  # Create Pytorch's default ViT models
 
1
+ import os
2
+ import random
3
  import torch
4
  import torchvision
5
  import torch._dynamo
6
+ import matplotlib.pyplot as plt
7
+ from typing import List
8
  from torch import nn
9
+ from torch.utils.data import DataLoader
10
  from torch.nn.init import trunc_normal_, xavier_normal_, zeros_, orthogonal_, kaiming_normal_
11
+ from torchvision import datasets
12
+ from torchvision.transforms import v2
13
+
14
+ def display_random_images(dataset: torch.utils.data.dataset.Dataset, # or torchvision.datasets.ImageFolder?
15
+ classes: List[str] = None,
16
+ n: int = 10,
17
+ display_shape: bool = True,
18
+ rows: int = 5,
19
+ cols: int = 5,
20
+ seed: int = None):
21
+
22
+
23
+ """Displays a number of random images from a given dataset.
24
+
25
+ Args:
26
+ dataset (torch.utils.data.dataset.Dataset): Dataset to select random images from.
27
+ classes (List[str], optional): Names of the classes. Defaults to None.
28
+ n (int, optional): Number of images to display. Defaults to 10.
29
+ display_shape (bool, optional): Whether to display the shape of the image tensors. Defaults to True.
30
+ rows: number of rows of the subplot
31
+ cols: number of columns of the subplot
32
+ seed (int, optional): The seed to set before drawing random images. Defaults to None.
33
+
34
+ Usage:
35
+ display_random_images(train_data,
36
+ n=16,
37
+ classes=class_names,
38
+ rows=4,
39
+ cols=4,
40
+ display_shape=False,
41
+ seed=None)
42
+ """
43
+
44
+ # Setup the range to select images
45
+ n = min(n, len(dataset))
46
+ # Adjust display if n too high
47
+ if n > rows*cols:
48
+ n = rows*cols
49
+ #display_shape = False
50
+ print(f"For display purposes, n shouldn't be larger than {rows*cols}, setting to {n} and removing shape display.")
51
+
52
+ # Set random seed
53
+ if seed:
54
+ random.seed(seed)
55
+
56
+ # Get random sample indexes
57
+ random_samples_idx = random.sample(range(len(dataset)), k=n)
58
+
59
+ # Setup plot
60
+ plt.figure(figsize=(cols*4, rows*4))
61
+
62
+ #Loop through samples and display random samples
63
+ for i, targ_sample in enumerate(random_samples_idx):
64
+ targ_image, targ_label = dataset[targ_sample][0], dataset[targ_sample][1]
65
+
66
+ # 7. Adjust image tensor shape for plotting: [color_channels, height, width] -> [color_channels, height, width]
67
+ targ_image_adjust = targ_image.permute(1, 2, 0)
68
+
69
+ # Plot adjusted samples
70
+ plt.subplot(rows, cols, i+1)
71
+ plt.imshow(targ_image_adjust)
72
+ plt.axis("off")
73
+ if classes:
74
+ title = f"class: {classes[targ_label]}"
75
+ if display_shape:
76
+ title = title + f"\nshape: {targ_image_adjust.shape}"
77
+ plt.title(title)
78
+
79
+ def create_dataloaders(
80
+ train_dir: str,
81
+ test_dir: str,
82
+ train_transform: v2.Compose,
83
+ test_transform: v2.Compose,
84
+ batch_size: int,
85
+ num_workers: int=os.cpu_count()
86
+ ):
87
+ """Creates training and testing DataLoaders.
88
+
89
+ Takes in a training directory and testing directory path and turns
90
+ them into PyTorch Datasets and then into PyTorch DataLoaders.
91
+
92
+ Args:
93
+ train_dir: Path to training directory.
94
+ test_dir: Path to testing directory.
95
+ train_transform: torchvision transforms to perform on training data.
96
+ test_transform: torchvision transforms to perform on test data.
97
+ batch_size: Number of samples per batch in each of the DataLoaders.
98
+ num_workers: An integer for number of workers per DataLoader.
99
+
100
+ Returns:
101
+ A tuple of (train_dataloader, test_dataloader, class_names).
102
+ Where class_names is a list of the target classes.
103
+ Example usage:
104
+ train_dataloader, test_dataloader, class_names = \
105
+ = create_dataloaders(train_dir=path/to/train_dir,
106
+ test_dir=path/to/test_dir,
107
+ transform=some_transform,
108
+ batch_size=32,
109
+ num_workers=4)
110
+ """
111
+ # Use ImageFolder to create dataset(s)
112
+ train_data = datasets.ImageFolder(train_dir, transform=train_transform)
113
+ test_data = datasets.ImageFolder(test_dir, transform=test_transform)
114
+
115
+ # Get class names
116
+ class_names = train_data.classes
117
+
118
+ # Turn images into data loaders
119
+ train_dataloader = DataLoader(
120
+ train_data,
121
+ batch_size=batch_size,
122
+ shuffle=True,
123
+ num_workers=num_workers,
124
+ pin_memory=True, #enables fast data transfers to CUDA-enabled GPU
125
+ )
126
+ test_dataloader = DataLoader(
127
+ test_data,
128
+ batch_size=batch_size,
129
+ shuffle=False,
130
+ num_workers=num_workers,
131
+ pin_memory=True, #enables fast data transfers to CUDA-enabled GPU
132
+ )
133
+
134
+ return train_dataloader, test_dataloader, class_names
135
+
136
+ def create_dataloader_for_vit(
137
+ vit_model: str="bitbase16",
138
+ train_dir: str="./",
139
+ test_dir: str="./",
140
+ batch_size: int=64,
141
+ aug: bool=True,
142
+ display_imgs: bool=True,
143
+ num_workers: int=os.cpu_count()
144
+ ):
145
+
146
+ """
147
+ Creates data loaders for the training and test datasets to be used to traing visiton transformers.
148
+
149
+ Args:
150
+ vit_model (str): The name of the ViT model to use. Default is "bitbase16".
151
+ train_dir (str): The path to the training dataset directory. Default is TRAIN_DIR.
152
+ test_dir (str): The path to the test dataset directory. Default is TEST_DIR.
153
+ batch_size (int): The batch size for the data loaders. Default is BATCH_SIZE.
154
+ aug (bool): Whether to apply data augmentation or not. Default is True.
155
+ display_imgs (bool): Whether to display sample images or not. Default is True.
156
+
157
+ Returns:
158
+ train_dataloader (torch.utils.data.DataLoader): The data loader for the training dataset.
159
+ test_dataloader (torch.utils.data.DataLoader): The data loader for the test dataset.
160
+ class_names (list): A list of class names.
161
+ """
162
+
163
+ IMG_SIZE = 224
164
+ IMG_SIZE_2 = 384
165
+
166
+ # Manual transforms for the training dataset
167
+ manual_transforms = v2.Compose([
168
+ v2.RandomCrop((IMG_SIZE, IMG_SIZE)),
169
+ v2.ToImage(),
170
+ v2.ToDtype(torch.float32, scale=True),
171
+ ])
172
+
173
+ # ViT-Base/16 transforms
174
+ if vit_model == "vitbase16":
175
+
176
+ # Manual transforms for the training dataset
177
+ if aug:
178
+ manual_transforms_train_vitb = v2.Compose([
179
+ v2.TrivialAugmentWide(),
180
+ v2.Resize((256, 256)),
181
+ v2.RandomCrop((IMG_SIZE, IMG_SIZE)),
182
+ v2.ToImage(),
183
+ v2.ToDtype(torch.float32, scale=True),
184
+ v2.Normalize(mean=[0.485, 0.456, 0.406],
185
+ std=[0.229, 0.224, 0.225])
186
+ ])
187
+ else:
188
+ manual_transforms_train_vitb = v2.Compose([
189
+ v2.Resize((256, 256)),
190
+ v2.CenterCrop((IMG_SIZE, IMG_SIZE)),
191
+ v2.ToImage(),
192
+ v2.ToDtype(torch.float32, scale=True),
193
+ v2.Normalize(mean=[0.485, 0.456, 0.406],
194
+ std=[0.229, 0.224, 0.225])
195
+ ])
196
+
197
+ # Manual transforms for the test dataset
198
+ manual_transforms_test_vitb = v2.Compose([
199
+ v2.Resize((256, 256)),
200
+ v2.CenterCrop((IMG_SIZE, IMG_SIZE)),
201
+ v2.ToImage(),
202
+ v2.ToDtype(torch.float32, scale=True),
203
+ v2.Normalize(mean=[0.485, 0.456, 0.406],
204
+ std=[0.229, 0.224, 0.225])
205
+ ])
206
+
207
+ # Create data loaders for ViT-Base
208
+ train_dataloader, test_dataloader, class_names = create_dataloaders(
209
+ train_dir=train_dir,
210
+ test_dir=test_dir,
211
+ train_transform=manual_transforms_train_vitb,
212
+ test_transform=manual_transforms_test_vitb,
213
+ batch_size=batch_size,
214
+ num_workers=num_workers
215
+ )
216
+
217
+ if vit_model == "vitbase16_2":
218
+
219
+ # Manual transforms for the training dataset
220
+ if aug:
221
+ manual_transforms_train_vitb = v2.Compose([
222
+ v2.TrivialAugmentWide(),
223
+ v2.Resize((IMG_SIZE_2, IMG_SIZE_2)),
224
+ v2.CenterCrop((IMG_SIZE_2, IMG_SIZE_2)),
225
+ v2.ToImage(),
226
+ v2.ToDtype(torch.float32, scale=True),
227
+ v2.Normalize(mean=[0.485, 0.456, 0.406],
228
+ std=[0.229, 0.224, 0.225])
229
+ ])
230
+ else:
231
+ manual_transforms_train_vitb = v2.Compose([
232
+ v2.Resize((IMG_SIZE_2, IMG_SIZE_2)),
233
+ v2.CenterCrop((IMG_SIZE_2, IMG_SIZE_2)),
234
+ v2.ToImage(),
235
+ v2.ToDtype(torch.float32, scale=True),
236
+ v2.Normalize(mean=[0.485, 0.456, 0.406],
237
+ std=[0.229, 0.224, 0.225])
238
+ ])
239
+
240
+ # Manual transforms for the test dataset
241
+ manual_transforms_test_vitb = v2.Compose([
242
+ v2.Resize((IMG_SIZE_2, IMG_SIZE_2)),
243
+ v2.CenterCrop((IMG_SIZE_2, IMG_SIZE_2)),
244
+ v2.ToImage(),
245
+ v2.ToDtype(torch.float32, scale=True),
246
+ v2.Normalize(mean=[0.485, 0.456, 0.406],
247
+ std=[0.229, 0.224, 0.225])
248
+ ])
249
+
250
+ # Create data loaders for ViT-Base
251
+ train_dataloader, test_dataloader, class_names = create_dataloaders(
252
+ train_dir=train_dir,
253
+ test_dir=test_dir,
254
+ train_transform=manual_transforms_train_vitb,
255
+ test_transform=manual_transforms_test_vitb,
256
+ batch_size=batch_size,
257
+ num_workers=num_workers
258
+ )
259
+
260
+ # ViT-Large/16 transforms
261
+ elif vit_model == "vitlarge16":
262
+
263
+ # Manual transforms for the training dataset
264
+ if aug:
265
+ manual_transforms_train_vitl = v2.Compose([
266
+ v2.TrivialAugmentWide(),
267
+ v2.Resize((242, 242)),
268
+ v2.RandomCrop((IMG_SIZE, IMG_SIZE)),
269
+ v2.ToImage(),
270
+ v2.ToDtype(torch.float32, scale=True),
271
+ v2.Normalize(mean=[0.485, 0.456, 0.406],
272
+ std=[0.229, 0.224, 0.225])
273
+ ])
274
+ else:
275
+ manual_transforms_train_vitl = v2.Compose([
276
+ v2.Resize((242, 242)),
277
+ v2.CenterCrop((IMG_SIZE, IMG_SIZE)),
278
+ v2.ToImage(),
279
+ v2.ToDtype(torch.float32, scale=True),
280
+ v2.Normalize(mean=[0.485, 0.456, 0.406],
281
+ std=[0.229, 0.224, 0.225])
282
+ ])
283
+
284
+ # Manual transforms for the test dataset
285
+ manual_transforms_test_vitl = v2.Compose([
286
+ v2.Resize((242, 242)),
287
+ v2.CenterCrop((IMG_SIZE, IMG_SIZE)),
288
+ v2.ToImage(),
289
+ v2.ToDtype(torch.float32, scale=True),
290
+ v2.Normalize(mean=[0.485, 0.456, 0.406],
291
+ std=[0.229, 0.224, 0.225])
292
+ ])
293
+
294
+ # Create data loaders for ViT-Large/16
295
+ train_dataloader, test_dataloader, class_names = create_dataloaders(
296
+ train_dir=train_dir,
297
+ test_dir=test_dir,
298
+ train_transform=manual_transforms_train_vitl,
299
+ test_transform=manual_transforms_test_vitl,
300
+ batch_size=batch_size,
301
+ num_workers=num_workers
302
+ )
303
+
304
+ # ViT-Large/32 transforms
305
+ else:
306
+ # Manual transforms for the training dataset
307
+ if aug:
308
+ manual_transforms_train_vitl = v2.Compose([
309
+ v2.TrivialAugmentWide(),
310
+ v2.Resize((256, 256)),
311
+ v2.RandomCrop((IMG_SIZE, IMG_SIZE)),
312
+ v2.ToImage(),
313
+ v2.ToDtype(torch.float32, scale=True),
314
+ v2.Normalize(mean=[0.485, 0.456, 0.406],
315
+ std=[0.229, 0.224, 0.225])
316
+ ])
317
+ else:
318
+ manual_transforms_train_vitl = v2.Compose([
319
+ v2.Resize((256, 256)),
320
+ v2.CenterCrop((IMG_SIZE, IMG_SIZE)),
321
+ v2.ToImage(),
322
+ v2.ToDtype(torch.float32, scale=True),
323
+ v2.Normalize(mean=[0.485, 0.456, 0.406],
324
+ std=[0.229, 0.224, 0.225])
325
+ ])
326
+
327
+ # Manual transforms for the test dataset
328
+ manual_transforms_test_vitl = v2.Compose([
329
+ v2.Resize((256, 256)),
330
+ v2.CenterCrop((IMG_SIZE, IMG_SIZE)),
331
+ v2.ToImage(),
332
+ v2.ToDtype(torch.float32, scale=True),
333
+ v2.Normalize(mean=[0.485, 0.456, 0.406],
334
+ std=[0.229, 0.224, 0.225])
335
+ ])
336
+
337
+ # Create data loaders for ViT-Large/32
338
+ train_dataloader, test_dataloader, class_names = create_dataloaders(
339
+ train_dir=train_dir,
340
+ test_dir=test_dir,
341
+ train_transform=manual_transforms_train_vitl,
342
+ test_transform=manual_transforms_test_vitl,
343
+ batch_size=batch_size,
344
+ num_workers=num_workers
345
+ )
346
+
347
+ # Display images
348
+ if display_imgs:
349
+ train_data = datasets.ImageFolder(train_dir, transform=manual_transforms)
350
+ display_random_images(train_data,
351
+ n=25,
352
+ classes=class_names,
353
+ rows=5,
354
+ cols=5,
355
+ display_shape=False,
356
+ seed=None)
357
+
358
+ return train_dataloader, test_dataloader, class_names
359
 
360
 
361
  # Create Pytorch's default ViT models