|
|
import os |
|
|
import random |
|
|
import torch |
|
|
import torchvision |
|
|
import torch._dynamo |
|
|
import matplotlib.pyplot as plt |
|
|
from typing import List |
|
|
from torch import nn |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.nn.init import trunc_normal_, xavier_normal_, zeros_, orthogonal_, kaiming_normal_ |
|
|
from torchvision import datasets |
|
|
from torchvision.transforms import v2 |
|
|
|
|
|
def display_random_images(dataset: torch.utils.data.dataset.Dataset, |
|
|
classes: List[str] = None, |
|
|
n: int = 10, |
|
|
display_shape: bool = True, |
|
|
rows: int = 5, |
|
|
cols: int = 5, |
|
|
seed: int = None): |
|
|
|
|
|
|
|
|
"""Displays a number of random images from a given dataset. |
|
|
|
|
|
Args: |
|
|
dataset (torch.utils.data.dataset.Dataset): Dataset to select random images from. |
|
|
classes (List[str], optional): Names of the classes. Defaults to None. |
|
|
n (int, optional): Number of images to display. Defaults to 10. |
|
|
display_shape (bool, optional): Whether to display the shape of the image tensors. Defaults to True. |
|
|
rows: number of rows of the subplot |
|
|
cols: number of columns of the subplot |
|
|
seed (int, optional): The seed to set before drawing random images. Defaults to None. |
|
|
|
|
|
Usage: |
|
|
display_random_images(train_data, |
|
|
n=16, |
|
|
classes=class_names, |
|
|
rows=4, |
|
|
cols=4, |
|
|
display_shape=False, |
|
|
seed=None) |
|
|
""" |
|
|
|
|
|
|
|
|
n = min(n, len(dataset)) |
|
|
|
|
|
if n > rows*cols: |
|
|
n = rows*cols |
|
|
|
|
|
print(f"For display purposes, n shouldn't be larger than {rows*cols}, setting to {n} and removing shape display.") |
|
|
|
|
|
|
|
|
if seed: |
|
|
random.seed(seed) |
|
|
|
|
|
|
|
|
random_samples_idx = random.sample(range(len(dataset)), k=n) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(cols*4, rows*4)) |
|
|
|
|
|
|
|
|
for i, targ_sample in enumerate(random_samples_idx): |
|
|
targ_image, targ_label = dataset[targ_sample][0], dataset[targ_sample][1] |
|
|
|
|
|
|
|
|
targ_image_adjust = targ_image.permute(1, 2, 0) |
|
|
|
|
|
|
|
|
plt.subplot(rows, cols, i+1) |
|
|
plt.imshow(targ_image_adjust) |
|
|
plt.axis("off") |
|
|
if classes: |
|
|
title = f"class: {classes[targ_label]}" |
|
|
if display_shape: |
|
|
title = title + f"\nshape: {targ_image_adjust.shape}" |
|
|
plt.title(title) |
|
|
|
|
|
def create_dataloaders( |
|
|
train_dir: str, |
|
|
test_dir: str, |
|
|
train_transform: v2.Compose, |
|
|
test_transform: v2.Compose, |
|
|
batch_size: int, |
|
|
num_workers: int=os.cpu_count() |
|
|
): |
|
|
"""Creates training and testing DataLoaders. |
|
|
|
|
|
Takes in a training directory and testing directory path and turns |
|
|
them into PyTorch Datasets and then into PyTorch DataLoaders. |
|
|
|
|
|
Args: |
|
|
train_dir: Path to training directory. |
|
|
test_dir: Path to testing directory. |
|
|
train_transform: torchvision transforms to perform on training data. |
|
|
test_transform: torchvision transforms to perform on test data. |
|
|
batch_size: Number of samples per batch in each of the DataLoaders. |
|
|
num_workers: An integer for number of workers per DataLoader. |
|
|
|
|
|
Returns: |
|
|
A tuple of (train_dataloader, test_dataloader, class_names). |
|
|
Where class_names is a list of the target classes. |
|
|
Example usage: |
|
|
train_dataloader, test_dataloader, class_names = \ |
|
|
= create_dataloaders(train_dir=path/to/train_dir, |
|
|
test_dir=path/to/test_dir, |
|
|
transform=some_transform, |
|
|
batch_size=32, |
|
|
num_workers=4) |
|
|
""" |
|
|
|
|
|
train_data = datasets.ImageFolder(train_dir, transform=train_transform) |
|
|
test_data = datasets.ImageFolder(test_dir, transform=test_transform) |
|
|
|
|
|
|
|
|
class_names = train_data.classes |
|
|
|
|
|
|
|
|
train_dataloader = DataLoader( |
|
|
train_data, |
|
|
batch_size=batch_size, |
|
|
shuffle=True, |
|
|
num_workers=num_workers, |
|
|
pin_memory=True, |
|
|
) |
|
|
test_dataloader = DataLoader( |
|
|
test_data, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=num_workers, |
|
|
pin_memory=True, |
|
|
) |
|
|
|
|
|
return train_dataloader, test_dataloader, class_names |
|
|
|
|
|
def create_dataloader_for_vit( |
|
|
vit_model: str="bitbase16", |
|
|
train_dir: str="./", |
|
|
test_dir: str="./", |
|
|
batch_size: int=64, |
|
|
aug: bool=True, |
|
|
display_imgs: bool=True, |
|
|
num_workers: int=os.cpu_count() |
|
|
): |
|
|
|
|
|
""" |
|
|
Creates data loaders for the training and test datasets to be used to traing visiton transformers. |
|
|
|
|
|
Args: |
|
|
vit_model (str): The name of the ViT model to use. Default is "bitbase16". |
|
|
train_dir (str): The path to the training dataset directory. Default is TRAIN_DIR. |
|
|
test_dir (str): The path to the test dataset directory. Default is TEST_DIR. |
|
|
batch_size (int): The batch size for the data loaders. Default is BATCH_SIZE. |
|
|
aug (bool): Whether to apply data augmentation or not. Default is True. |
|
|
display_imgs (bool): Whether to display sample images or not. Default is True. |
|
|
|
|
|
Returns: |
|
|
train_dataloader (torch.utils.data.DataLoader): The data loader for the training dataset. |
|
|
test_dataloader (torch.utils.data.DataLoader): The data loader for the test dataset. |
|
|
class_names (list): A list of class names. |
|
|
""" |
|
|
|
|
|
IMG_SIZE = 224 |
|
|
IMG_SIZE_2 = 384 |
|
|
|
|
|
|
|
|
manual_transforms = v2.Compose([ |
|
|
v2.RandomCrop((IMG_SIZE, IMG_SIZE)), |
|
|
v2.ToImage(), |
|
|
v2.ToDtype(torch.float32, scale=True), |
|
|
]) |
|
|
|
|
|
|
|
|
if vit_model == "vitbase16": |
|
|
|
|
|
|
|
|
if aug: |
|
|
manual_transforms_train_vitb = v2.Compose([ |
|
|
v2.TrivialAugmentWide(), |
|
|
v2.Resize((256, 256)), |
|
|
v2.RandomCrop((IMG_SIZE, IMG_SIZE)), |
|
|
v2.ToImage(), |
|
|
v2.ToDtype(torch.float32, scale=True), |
|
|
v2.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
else: |
|
|
manual_transforms_train_vitb = v2.Compose([ |
|
|
v2.Resize((256, 256)), |
|
|
v2.CenterCrop((IMG_SIZE, IMG_SIZE)), |
|
|
v2.ToImage(), |
|
|
v2.ToDtype(torch.float32, scale=True), |
|
|
v2.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
manual_transforms_test_vitb = v2.Compose([ |
|
|
v2.Resize((256, 256)), |
|
|
v2.CenterCrop((IMG_SIZE, IMG_SIZE)), |
|
|
v2.ToImage(), |
|
|
v2.ToDtype(torch.float32, scale=True), |
|
|
v2.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
train_dataloader, test_dataloader, class_names = create_dataloaders( |
|
|
train_dir=train_dir, |
|
|
test_dir=test_dir, |
|
|
train_transform=manual_transforms_train_vitb, |
|
|
test_transform=manual_transforms_test_vitb, |
|
|
batch_size=batch_size, |
|
|
num_workers=num_workers |
|
|
) |
|
|
|
|
|
if vit_model == "vitbase16_2": |
|
|
|
|
|
|
|
|
if aug: |
|
|
manual_transforms_train_vitb = v2.Compose([ |
|
|
v2.TrivialAugmentWide(), |
|
|
v2.Resize((IMG_SIZE_2, IMG_SIZE_2)), |
|
|
v2.CenterCrop((IMG_SIZE_2, IMG_SIZE_2)), |
|
|
v2.ToImage(), |
|
|
v2.ToDtype(torch.float32, scale=True), |
|
|
v2.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
else: |
|
|
manual_transforms_train_vitb = v2.Compose([ |
|
|
v2.Resize((IMG_SIZE_2, IMG_SIZE_2)), |
|
|
v2.CenterCrop((IMG_SIZE_2, IMG_SIZE_2)), |
|
|
v2.ToImage(), |
|
|
v2.ToDtype(torch.float32, scale=True), |
|
|
v2.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
manual_transforms_test_vitb = v2.Compose([ |
|
|
v2.Resize((IMG_SIZE_2, IMG_SIZE_2)), |
|
|
v2.CenterCrop((IMG_SIZE_2, IMG_SIZE_2)), |
|
|
v2.ToImage(), |
|
|
v2.ToDtype(torch.float32, scale=True), |
|
|
v2.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
train_dataloader, test_dataloader, class_names = create_dataloaders( |
|
|
train_dir=train_dir, |
|
|
test_dir=test_dir, |
|
|
train_transform=manual_transforms_train_vitb, |
|
|
test_transform=manual_transforms_test_vitb, |
|
|
batch_size=batch_size, |
|
|
num_workers=num_workers |
|
|
) |
|
|
|
|
|
|
|
|
elif vit_model == "vitlarge16": |
|
|
|
|
|
|
|
|
if aug: |
|
|
manual_transforms_train_vitl = v2.Compose([ |
|
|
v2.TrivialAugmentWide(), |
|
|
v2.Resize((242, 242)), |
|
|
v2.RandomCrop((IMG_SIZE, IMG_SIZE)), |
|
|
v2.ToImage(), |
|
|
v2.ToDtype(torch.float32, scale=True), |
|
|
v2.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
else: |
|
|
manual_transforms_train_vitl = v2.Compose([ |
|
|
v2.Resize((242, 242)), |
|
|
v2.CenterCrop((IMG_SIZE, IMG_SIZE)), |
|
|
v2.ToImage(), |
|
|
v2.ToDtype(torch.float32, scale=True), |
|
|
v2.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
manual_transforms_test_vitl = v2.Compose([ |
|
|
v2.Resize((242, 242)), |
|
|
v2.CenterCrop((IMG_SIZE, IMG_SIZE)), |
|
|
v2.ToImage(), |
|
|
v2.ToDtype(torch.float32, scale=True), |
|
|
v2.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
train_dataloader, test_dataloader, class_names = create_dataloaders( |
|
|
train_dir=train_dir, |
|
|
test_dir=test_dir, |
|
|
train_transform=manual_transforms_train_vitl, |
|
|
test_transform=manual_transforms_test_vitl, |
|
|
batch_size=batch_size, |
|
|
num_workers=num_workers |
|
|
) |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
if aug: |
|
|
manual_transforms_train_vitl = v2.Compose([ |
|
|
v2.TrivialAugmentWide(), |
|
|
v2.Resize((256, 256)), |
|
|
v2.RandomCrop((IMG_SIZE, IMG_SIZE)), |
|
|
v2.ToImage(), |
|
|
v2.ToDtype(torch.float32, scale=True), |
|
|
v2.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
else: |
|
|
manual_transforms_train_vitl = v2.Compose([ |
|
|
v2.Resize((256, 256)), |
|
|
v2.CenterCrop((IMG_SIZE, IMG_SIZE)), |
|
|
v2.ToImage(), |
|
|
v2.ToDtype(torch.float32, scale=True), |
|
|
v2.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
manual_transforms_test_vitl = v2.Compose([ |
|
|
v2.Resize((256, 256)), |
|
|
v2.CenterCrop((IMG_SIZE, IMG_SIZE)), |
|
|
v2.ToImage(), |
|
|
v2.ToDtype(torch.float32, scale=True), |
|
|
v2.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
train_dataloader, test_dataloader, class_names = create_dataloaders( |
|
|
train_dir=train_dir, |
|
|
test_dir=test_dir, |
|
|
train_transform=manual_transforms_train_vitl, |
|
|
test_transform=manual_transforms_test_vitl, |
|
|
batch_size=batch_size, |
|
|
num_workers=num_workers |
|
|
) |
|
|
|
|
|
|
|
|
if display_imgs: |
|
|
train_data = datasets.ImageFolder(train_dir, transform=manual_transforms) |
|
|
display_random_images(train_data, |
|
|
n=25, |
|
|
classes=class_names, |
|
|
rows=5, |
|
|
cols=5, |
|
|
display_shape=False, |
|
|
seed=None) |
|
|
|
|
|
return train_dataloader, test_dataloader, class_names |
|
|
|
|
|
|
|
|
|
|
|
def create_vit( |
|
|
vit_model: str="vitbase16", |
|
|
num_classes: int=1000, |
|
|
dropout: float=0.1, |
|
|
seed: float=42, |
|
|
device: torch.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
) -> torchvision.models.VisionTransformer: |
|
|
|
|
|
"""Creates a pretrained PyTorch's default ViT model. |
|
|
|
|
|
Args: |
|
|
vit_model (str, optional): Name of ViT model to create. Default is "vitbase16". |
|
|
num_classes (int, optional): Number of classes in the classifier head. Default is 1000. |
|
|
dropout (float, optional): Dropout rate in the classifier head. Default is 0.1. |
|
|
device (torch.device, optional): Device to run model on. Default is "cuda" if available else "cpu". |
|
|
|
|
|
Returns: |
|
|
torchvision.models.VisionTransformer: A pretrained ViT model. |
|
|
""" |
|
|
|
|
|
|
|
|
if vit_model == "vitbase16": |
|
|
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT |
|
|
|
|
|
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights, dropout=dropout).to(device) |
|
|
|
|
|
|
|
|
elif vit_model == "vitbase16_2": |
|
|
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1 |
|
|
|
|
|
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights, dropout=dropout).to(device) |
|
|
|
|
|
|
|
|
elif vit_model == "vitbase32": |
|
|
pretrained_vit_weights = torchvision.models.ViT_B_32_Weights.DEFAULT |
|
|
|
|
|
pretrained_vit = torchvision.models.vit_b_32(weights=pretrained_vit_weights, dropout=dropout).to(device) |
|
|
|
|
|
|
|
|
elif vit_model == "vitlarge16": |
|
|
pretrained_vit_weights = torchvision.models.ViT_L_16_Weights.DEFAULT |
|
|
pretrained_vit = torchvision.models.vit_l_16(weights=pretrained_vit_weights, dropout=dropout).to(device) |
|
|
|
|
|
|
|
|
elif vit_model == "vitlarge16_2": |
|
|
pretrained_vit_weights = torchvision.models.ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1 |
|
|
pretrained_vit = torchvision.models.vit_l_16(weights=pretrained_vit_weights, dropout=dropout).to(device) |
|
|
|
|
|
|
|
|
elif vit_model == "vitlarge32": |
|
|
pretrained_vit_weights = torchvision.models.ViT_L_32_Weights.DEFAULT |
|
|
pretrained_vit = torchvision.models.vit_l_32(weights=pretrained_vit_weights, dropout=dropout).to(device) |
|
|
|
|
|
|
|
|
elif vit_model == "vithuge14": |
|
|
pretrained_vit_weights = torchvision.models.ViT_H_14_Weights.DEFAULT |
|
|
pretrained_vit = torchvision.models.vit_l_32(weights=pretrained_vit_weights, dropout=dropout).to(device) |
|
|
|
|
|
else: |
|
|
print("Invalid model name, exiting...") |
|
|
exit() |
|
|
|
|
|
|
|
|
for parameter in pretrained_vit.parameters(): |
|
|
parameter.requires_grad = True |
|
|
|
|
|
|
|
|
torch.manual_seed(seed) |
|
|
|
|
|
|
|
|
torch.cuda.manual_seed(seed) |
|
|
|
|
|
|
|
|
if "vitbase" in vit_model: |
|
|
pretrained_vit.heads = nn.Linear(in_features=768, out_features=num_classes).to(device) |
|
|
elif "vitlarge" in vit_model: |
|
|
pretrained_vit.heads = nn.Linear(in_features=1024, out_features=num_classes).to(device) |
|
|
else: |
|
|
pretrained_vit.heads = nn.Linear(in_features=1280, out_features=num_classes).to(device) |
|
|
|
|
|
return pretrained_vit |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PatchEmbedding(nn.Module): |
|
|
|
|
|
""" |
|
|
Turns a 2D input image into a 1D sequence learnable embedding vector. |
|
|
|
|
|
Args: |
|
|
in_channels (int): Number of color channels for the input images. Defaults to 3. |
|
|
patch_size (int): Size of patches to convert input image into. Defaults to 16. |
|
|
emb_dim (int): Size of embedding to turn image into. Defaults to 768. |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
img_size:int=224, |
|
|
in_channels:int=3, |
|
|
patch_size:int=16, |
|
|
emb_dim:int=768, |
|
|
emb_dropout:float=0.1): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
assert img_size % patch_size == 0, f"Image size must be divisible by patch size, image size: {img_size}, patch size: {patch_size}." |
|
|
|
|
|
|
|
|
self.conv_proj = nn.Conv2d(in_channels=in_channels, |
|
|
out_channels=emb_dim, |
|
|
kernel_size=patch_size, |
|
|
stride=patch_size, |
|
|
padding=0) |
|
|
|
|
|
|
|
|
self.flatten = nn.Flatten(start_dim=2, |
|
|
end_dim=3) |
|
|
|
|
|
|
|
|
self.class_token = trunc_normal_(nn.Parameter(torch.zeros(1, 1, emb_dim), requires_grad=True), std=0.02) |
|
|
|
|
|
|
|
|
num_patches = (img_size * img_size) // patch_size**2 |
|
|
self.pos_embedding = trunc_normal_(nn.Parameter(torch.zeros(1, num_patches+1, emb_dim), requires_grad=True), std=0.02) |
|
|
|
|
|
|
|
|
self.emb_dropout = nn.Dropout(p=emb_dropout) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
x = self.conv_proj(x) |
|
|
|
|
|
|
|
|
x = self.flatten(x) |
|
|
|
|
|
|
|
|
x = x.permute(0, 2, 1) |
|
|
|
|
|
|
|
|
class_token = self.class_token.expand(x.shape[0], -1, -1) |
|
|
x = torch.cat((class_token, x), dim=1) |
|
|
|
|
|
|
|
|
x = x + self.pos_embedding |
|
|
|
|
|
|
|
|
x = self.emb_dropout(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class MultiheadSelfAttentionBlock(nn.Module): |
|
|
|
|
|
""" |
|
|
Creates a multi-head self-attention block ("MSA block" for short). |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
emb_dim:int=768, |
|
|
num_heads:int=12, |
|
|
dropout:float=0): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.layer_norm = nn.LayerNorm(normalized_shape=emb_dim) |
|
|
|
|
|
|
|
|
self.self_attention = nn.MultiheadAttention(embed_dim=emb_dim, |
|
|
num_heads=num_heads, |
|
|
dropout=dropout, |
|
|
batch_first=True) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
x = self.layer_norm(x) |
|
|
|
|
|
|
|
|
x, _ = self.self_attention(query=x, |
|
|
key=x, |
|
|
value=x, |
|
|
need_weights=False) |
|
|
return x |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
class MultiheadSelfAttentionBlockV2(nn.Module): |
|
|
|
|
|
""" |
|
|
Creates a custom multi-head self-attention block using scaled dot-product attention. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
emb_dim: int = 768, |
|
|
num_heads: int = 12, |
|
|
dropout: float = 0.0): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.layer_norm = nn.LayerNorm(normalized_shape=emb_dim) |
|
|
|
|
|
|
|
|
self.num_heads = num_heads |
|
|
self.emb_dim = emb_dim |
|
|
self.dropout = dropout |
|
|
self.head_dim = emb_dim // num_heads |
|
|
|
|
|
assert emb_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads." |
|
|
|
|
|
def split_into_heads(self, x): |
|
|
"""Split input tensor into multiple heads.""" |
|
|
batch_size, seq_len, emb_dim = x.shape |
|
|
x = x.view(batch_size, seq_len, self.num_heads, self.head_dim) |
|
|
return x.permute(0, 2, 1, 3) |
|
|
|
|
|
def combine_heads(self, x): |
|
|
"""Combine the heads back into a single tensor.""" |
|
|
batch_size, num_heads, seq_len, head_dim = x.shape |
|
|
x = x.permute(0, 2, 1, 3) |
|
|
return x.contiguous().view(batch_size, seq_len, self.emb_dim) |
|
|
|
|
|
def forward(self, x): |
|
|
"""Forward pass for the MSA block.""" |
|
|
|
|
|
normed_x = self.layer_norm(x) |
|
|
|
|
|
|
|
|
query = self.split_into_heads(normed_x) |
|
|
key = self.split_into_heads(normed_x) |
|
|
value = self.split_into_heads(normed_x) |
|
|
|
|
|
|
|
|
attn_output = F.scaled_dot_product_attention(query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
dropout_p=self.dropout, |
|
|
is_causal=False) |
|
|
|
|
|
|
|
|
output = self.combine_heads(attn_output) |
|
|
|
|
|
|
|
|
output = x + output |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class MLPBlock(nn.Module): |
|
|
|
|
|
""" |
|
|
Creates a layer normalized multilayer perceptron block ("MLP block" for short). |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
emb_dim:int=768, |
|
|
mlp_size:int=3072, |
|
|
dropout:float=0.1): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.layer_norm = nn.LayerNorm(normalized_shape=emb_dim) |
|
|
|
|
|
|
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(in_features=emb_dim, |
|
|
out_features=mlp_size), |
|
|
nn.GELU(), |
|
|
nn.Dropout(p=dropout), |
|
|
nn.Linear(in_features=mlp_size, |
|
|
out_features=emb_dim), |
|
|
nn.Dropout(p=dropout) |
|
|
) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
return self.mlp(self.layer_norm(x)) |
|
|
|
|
|
|
|
|
class TransformerEncoderBlock(nn.Module): |
|
|
|
|
|
""" |
|
|
Creates a Transformer Encoder block. |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
emb_dim:int=768, |
|
|
num_heads:int=12, |
|
|
mlp_size:int=3072, |
|
|
attn_dropout:float=0, |
|
|
mlp_dropout:float=0.1, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.msa_block = MultiheadSelfAttentionBlock(emb_dim=emb_dim, |
|
|
num_heads=num_heads, |
|
|
dropout=attn_dropout) |
|
|
|
|
|
|
|
|
self.mlp_block = MLPBlock(emb_dim=emb_dim, |
|
|
mlp_size=mlp_size, |
|
|
dropout=mlp_dropout) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
x = self.msa_block(x) + x |
|
|
|
|
|
|
|
|
x = self.mlp_block(x) + x |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class ViT(nn.Module): |
|
|
|
|
|
""" |
|
|
Creates a Vision Transformer architecture with ViT-Base hyperparameters by default. |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
img_size:int=224, |
|
|
in_channels:int=3, |
|
|
patch_size:int=16, |
|
|
num_transformer_layers:int=12, |
|
|
emb_dim:int=768, |
|
|
mlp_size:int=3072, |
|
|
num_heads:int=12, |
|
|
emb_dropout:float=0.1, |
|
|
attn_dropout:float=0, |
|
|
mlp_dropout:float=0.1, |
|
|
classif_head_hidden_units:int=0, |
|
|
num_classes:int=1000): |
|
|
|
|
|
""" |
|
|
Initializes a Vision Transformer (ViT) model with specified hyperparameters (ViT-Base parameters by default). |
|
|
|
|
|
The constructor sets up the ViT model by configuring the input image size, number of transformer layers, |
|
|
embedding dimension, number of attention heads, MLP size, and dropout rates, based on the ViT-Base configuration |
|
|
as detailed in the original ViT paper. These parameters are also customizable to suit different downstream tasks. |
|
|
|
|
|
Args: |
|
|
- img_size (int, optional): The resolution of the input images. Default is 224. |
|
|
- in_channels (int, optional): The number of input image channels. Default is 3 (RGB). |
|
|
- patch_size (int, optional): The size of patches to divide the input image into. Default is 16. |
|
|
- num_transformer_layers (int, optional): The number of transformer layers. Default is 12 for ViT-Base. |
|
|
- emb_dim (int, optional): The dimensionality of the embedding space. Default is 768 for ViT-Base. |
|
|
- mlp_size (int, optional): The size of the MLP hidden layers. Default is 3072 for ViT-Base. |
|
|
- num_heads (int, optional): The number of attention heads in each transformer layer. Default is 12. |
|
|
- emb_dropout (float, optional): The dropout rate applied to patch and position embeddings. Default is 0.1. |
|
|
- attn_dropout (float, optional): The dropout rate applied to attention layers. Default is 0. |
|
|
- mlp_dropout (float, optional): The dropout rate applied to the MLP layers. Default is 0.1. |
|
|
- classif_head_hidden_units (int, optional): The number of hidden units in the classification header. Default is 0 (no extra hidden layer). |
|
|
- num_classes (int, optional): The number of output classes. Default is 1000 for ImageNet, but can be customized. |
|
|
|
|
|
Note: |
|
|
This initialization is based on the ViT-Base/16 model as described in the Vision Transformer paper. Custom values can |
|
|
be provided for these parameters based on the specific task or dataset. |
|
|
""" |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.embedder = PatchEmbedding(img_size=img_size, |
|
|
in_channels=in_channels, |
|
|
patch_size=patch_size, |
|
|
emb_dim=emb_dim, |
|
|
emb_dropout=emb_dropout) |
|
|
|
|
|
|
|
|
|
|
|
self.encoder = nn.Sequential(*[TransformerEncoderBlock(emb_dim=emb_dim, |
|
|
num_heads=num_heads, |
|
|
mlp_size=mlp_size, |
|
|
attn_dropout=attn_dropout, |
|
|
mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if classif_head_hidden_units: |
|
|
self.classifier = nn.Sequential( |
|
|
nn.LayerNorm(normalized_shape=emb_dim), |
|
|
nn.Linear(in_features=emb_dim, out_features=classif_head_hidden_units), |
|
|
nn.GELU(), |
|
|
nn.Dropout(p=mlp_dropout), |
|
|
nn.Linear(in_features=classif_head_hidden_units, out_features=num_classes) |
|
|
) |
|
|
else: |
|
|
self.classifier = nn.Sequential( |
|
|
nn.LayerNorm(normalized_shape=emb_dim), |
|
|
nn.Linear(in_features=emb_dim, out_features=num_classes) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def copy_weights(self, |
|
|
model_weights: torchvision.models.ViT_B_16_Weights): |
|
|
|
|
|
""" |
|
|
Copies the pretrained weights from a ViT model (Vision Transformer) to the current model. |
|
|
This method assumes that the current model has a structure compatible with the ViT-base architecture. |
|
|
|
|
|
Args: |
|
|
model_weights (torchvision.models.ViT_B_16_Weights): The pretrained weights of the ViT model. |
|
|
This should be a state dictionary from a ViT-B_16 architecture, such as the one provided |
|
|
by torchvision's ViT_B_16_Weights.DEFAULT. |
|
|
|
|
|
Notes: |
|
|
- This method manually copies weights from the pretrained ViT model to the corresponding layers of the current model. |
|
|
- It supports the ViT-base architecture with 12 transformer encoder layers and expects a similar |
|
|
structure in the target model (e.g., embedder, encoder layers, classifier). |
|
|
- This method does not update the optimizer state or any other model parameters beyond the weights. |
|
|
|
|
|
Example: |
|
|
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT |
|
|
model.copy_weights(pretrained_vit_weights) |
|
|
""" |
|
|
|
|
|
|
|
|
state_dict = self.state_dict() |
|
|
|
|
|
|
|
|
pretrained_state_dict = model_weights.get_state_dict() |
|
|
|
|
|
|
|
|
state_dict['embedder.class_token'].copy_(pretrained_state_dict['class_token']) |
|
|
state_dict['embedder.pos_embedding'].copy_(pretrained_state_dict['encoder.pos_embedding']) |
|
|
state_dict['embedder.conv_proj.weight'].copy_(pretrained_state_dict['conv_proj.weight']) |
|
|
state_dict['embedder.conv_proj.bias'].copy_(pretrained_state_dict['conv_proj.bias']) |
|
|
|
|
|
|
|
|
encoder_layer_keys = [key for key in pretrained_state_dict.keys() if 'encoder.layers' in key] |
|
|
num_encoder_layers = len(set([key.split('.')[2] for key in encoder_layer_keys])) |
|
|
|
|
|
|
|
|
for layer in range(num_encoder_layers): |
|
|
state_dict[f'encoder.{layer}.msa_block.layer_norm.weight'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.ln_1.weight'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.msa_block.layer_norm.bias'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.ln_1.bias'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.msa_block.self_attention.in_proj_weight'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.self_attention.in_proj_weight'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.msa_block.self_attention.in_proj_bias'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.self_attention.in_proj_bias'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.msa_block.self_attention.out_proj.weight'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.self_attention.out_proj.weight'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.msa_block.self_attention.out_proj.bias'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.self_attention.out_proj.bias'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.mlp_block.layer_norm.weight'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.ln_2.weight'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.mlp_block.layer_norm.bias'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.ln_2.bias'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.mlp_block.mlp.0.weight'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.mlp.linear_1.weight'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.mlp_block.mlp.0.bias'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.mlp.linear_1.bias'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.mlp_block.mlp.3.weight'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.mlp.linear_2.weight'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.mlp_block.mlp.3.bias'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.mlp.linear_2.bias'] |
|
|
) |
|
|
|
|
|
|
|
|
state_dict['classifier.0.weight'].copy_(pretrained_state_dict['encoder.ln.weight']) |
|
|
state_dict['classifier.0.bias'].copy_(pretrained_state_dict['encoder.ln.bias']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.load_state_dict(state_dict) |
|
|
|
|
|
print("[INFO] Model weights copied successfully.") |
|
|
print("[INFO] Model weights are trainable by default. Use function set_params_frozen to freeze them.") |
|
|
|
|
|
def set_params_frozen(self, |
|
|
except_head:bool=True): |
|
|
""" |
|
|
Freezes parameters of different components, allowing exceptions. |
|
|
|
|
|
Args: |
|
|
except_head (bool): If True, excludes the classifier head from being frozen. |
|
|
""" |
|
|
|
|
|
for param in self.embedder.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
for param in self.encoder.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
for param in self.classifier.parameters(): |
|
|
param.requires_grad = except_head |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Forward pass of the Vision Transformer model. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): Input tensor of shape [batch_size, in_channels, img_size, img_size]. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Output tensor of shape [batch_size, num_classes]. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.classifier(self.encoder(self.embedder(x))[:,0]) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class ViTv2(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
img_size:int=224, |
|
|
in_channels:int=3, |
|
|
patch_size:int=16, |
|
|
num_transformer_layers:int=12, |
|
|
emb_dim:int=768, |
|
|
mlp_size:int=3072, |
|
|
num_heads:int=12, |
|
|
emb_dropout:float=0.1, |
|
|
attn_dropout:float=0, |
|
|
mlp_dropout:float=0.1, |
|
|
classif_heads:nn.Module=None, |
|
|
num_classes:int=1000): |
|
|
|
|
|
""" |
|
|
Initializes a Vision Transformer (ViT) model with specified hyperparameters (ViT-Base parameters by default). |
|
|
V2 is identical to V1 except that the classification head can be passed as an argument, allowing for customization |
|
|
of the number of hidden layers and units per layer. |
|
|
|
|
|
The constructor sets up the ViT model by configuring the input image size, number of transformer layers, |
|
|
embedding dimension, number of attention heads, MLP size, and dropout rates, based on the ViT-Base configuration |
|
|
as detailed in the original ViT paper. These parameters are also customizable to suit different downstream tasks. |
|
|
|
|
|
Args: |
|
|
- img_size (int, optional): The resolution of the input images. Default is 224. |
|
|
- in_channels (int, optional): The number of input image channels. Default is 3 (RGB). |
|
|
- patch_size (int, optional): The size of patches to divide the input image into. Default is 16. |
|
|
- num_transformer_layers (int, optional): The number of transformer layers. Default is 12 for ViT-Base. |
|
|
- emb_dim (int, optional): The dimensionality of the embedding space. Default is 768 for ViT-Base. |
|
|
- mlp_size (int, optional): The size of the MLP hidden layers. Default is 3072 for ViT-Base. |
|
|
- num_heads (int, optional): The number of attention heads in each transformer layer. Default is 12. |
|
|
- emb_dropout (float, optional): The dropout rate applied to patch and position embeddings. Default is 0.1. |
|
|
- attn_dropout (float, optional): The dropout rate applied to attention layers. Default is 0. |
|
|
- mlp_dropout (float, optional): The dropout rate applied to the MLP layers. Default is 0.1. |
|
|
- classif_head (nn.Module, optional): An optional extra classification header. Default is None, no hidden layer is used. |
|
|
- num_classes (int, optional): The number of output classes. Default is 1000 for ImageNet, but can be customized. |
|
|
|
|
|
Note: |
|
|
This initialization is based on the ViT-Base/16 model as described in the Vision Transformer paper. Custom values can |
|
|
be provided for these parameters based on the specific task or dataset. |
|
|
|
|
|
Usage of classif_heads: |
|
|
- If provided, it will be used as the final classification layer(s) of the model. |
|
|
- If None, a default single-layer classification head will be used with the specified number of classes. |
|
|
- This allows for flexibility in the final layer(s) of the model, enabling customization based on the task requirements. |
|
|
|
|
|
def create_classification_heads(num_heads: int, emb_dim: int, num_classes: int) -> list: |
|
|
heads = [] |
|
|
for i in range(num_heads): |
|
|
head = nn.Sequential( |
|
|
nn.LayerNorm(normalized_shape=emb_dim), |
|
|
nn.Linear(in_features=emb_dim, out_features=emb_dim // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(in_features=emb_dim // 2, out_features=num_classes) |
|
|
) |
|
|
heads.append(head) |
|
|
return heads -> classif_heads |
|
|
""" |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.embedder = PatchEmbedding(img_size=img_size, |
|
|
in_channels=in_channels, |
|
|
patch_size=patch_size, |
|
|
emb_dim=emb_dim, |
|
|
emb_dropout=emb_dropout) |
|
|
|
|
|
|
|
|
|
|
|
self.encoder = nn.Sequential(*[TransformerEncoderBlock(emb_dim=emb_dim, |
|
|
num_heads=num_heads, |
|
|
mlp_size=mlp_size, |
|
|
attn_dropout=attn_dropout, |
|
|
mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if classif_heads: |
|
|
self.classifier = nn.ModuleList(classif_heads) |
|
|
else: |
|
|
classifier = nn.Sequential( |
|
|
nn.LayerNorm(normalized_shape=emb_dim), |
|
|
nn.Linear(in_features=emb_dim, out_features=num_classes) |
|
|
) |
|
|
self.classifier = nn.ModuleList([classifier]) |
|
|
|
|
|
|
|
|
def copy_weights(self, |
|
|
model_weights: torchvision.models.ViT_B_16_Weights): |
|
|
|
|
|
""" |
|
|
Copies the pretrained weights from a ViT model (Vision Transformer) to the current model. |
|
|
This method assumes that the current model has a structure compatible with the ViT-base architecture. |
|
|
|
|
|
Args: |
|
|
model_weights (torchvision.models.ViT_B_16_Weights): The pretrained weights of the ViT model. |
|
|
This should be a state dictionary from a ViT-B_16 architecture, such as the one provided |
|
|
by torchvision's ViT_B_16_Weights.DEFAULT. |
|
|
|
|
|
Notes: |
|
|
- This method manually copies weights from the pretrained ViT model to the corresponding layers of the current model. |
|
|
- It supports the ViT-base architecture with 12 transformer encoder layers and expects a similar |
|
|
structure in the target model (e.g., embedder, encoder layers, classifier). |
|
|
- This method does not update the optimizer state or any other model parameters beyond the weights. |
|
|
|
|
|
Example: |
|
|
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT |
|
|
model.copy_weights(pretrained_vit_weights) |
|
|
""" |
|
|
|
|
|
|
|
|
state_dict = self.state_dict() |
|
|
|
|
|
|
|
|
pretrained_state_dict = model_weights.get_state_dict() |
|
|
|
|
|
|
|
|
state_dict['embedder.class_token'].copy_(pretrained_state_dict['class_token']) |
|
|
state_dict['embedder.pos_embedding'].copy_(pretrained_state_dict['encoder.pos_embedding']) |
|
|
state_dict['embedder.conv_proj.weight'].copy_(pretrained_state_dict['conv_proj.weight']) |
|
|
state_dict['embedder.conv_proj.bias'].copy_(pretrained_state_dict['conv_proj.bias']) |
|
|
|
|
|
|
|
|
encoder_layer_keys = [key for key in pretrained_state_dict.keys() if 'encoder.layers' in key] |
|
|
num_encoder_layers = len(set([key.split('.')[2] for key in encoder_layer_keys])) |
|
|
|
|
|
|
|
|
for layer in range(num_encoder_layers): |
|
|
state_dict[f'encoder.{layer}.msa_block.layer_norm.weight'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.ln_1.weight'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.msa_block.layer_norm.bias'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.ln_1.bias'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.msa_block.self_attention.in_proj_weight'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.self_attention.in_proj_weight'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.msa_block.self_attention.in_proj_bias'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.self_attention.in_proj_bias'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.msa_block.self_attention.out_proj.weight'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.self_attention.out_proj.weight'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.msa_block.self_attention.out_proj.bias'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.self_attention.out_proj.bias'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.mlp_block.layer_norm.weight'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.ln_2.weight'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.mlp_block.layer_norm.bias'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.ln_2.bias'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.mlp_block.mlp.0.weight'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.mlp.linear_1.weight'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.mlp_block.mlp.0.bias'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.mlp.linear_1.bias'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.mlp_block.mlp.3.weight'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.mlp.linear_2.weight'] |
|
|
) |
|
|
state_dict[f'encoder.{layer}.mlp_block.mlp.3.bias'].copy_( |
|
|
pretrained_state_dict[f'encoder.layers.encoder_layer_{layer}.mlp.linear_2.bias'] |
|
|
) |
|
|
|
|
|
|
|
|
state_dict['classifier.0.weight'].copy_(pretrained_state_dict['encoder.ln.weight']) |
|
|
state_dict['classifier.0.bias'].copy_(pretrained_state_dict['encoder.ln.bias']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.load_state_dict(state_dict) |
|
|
|
|
|
def set_params_frozen(self, |
|
|
except_head:bool=True): |
|
|
|
|
|
""" |
|
|
Freezes parameters of different components, allowing exceptions. |
|
|
|
|
|
Args: |
|
|
except_head (bool): If True, excludes the classifier head from being frozen. |
|
|
""" |
|
|
|
|
|
for param in self.embedder.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
for param in self.encoder.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
for param in self.classifier.parameters(): |
|
|
param.requires_grad = except_head |
|
|
|
|
|
def compile(self): |
|
|
"""Compile the model using torch.compile for optimization.""" |
|
|
self.__compiled__ = torch.compile(self) |
|
|
print("Model compiled successfully with torch.compile.") |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
""" |
|
|
Forward pass of the Vision Transformer model. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): Input tensor of shape [batch_size, in_channels, img_size, img_size]. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Output tensor of shape [batch_size, num_classes]. |
|
|
""" |
|
|
|
|
|
|
|
|
x = self.embedder(x) |
|
|
|
|
|
|
|
|
x = self.encoder(x) |
|
|
|
|
|
|
|
|
x_list = [head(x[:, 0]) for head in self.classifier] |
|
|
x = torch.mean(torch.stack(x_list), dim=0) |
|
|
|
|
|
return x |