File size: 4,214 Bytes
b64811b 18add7e 4df6cc7 c781e06 59126da 6cfaed9 59126da c781e06 7c933ef c781e06 59126da 8df1ddb 59126da c781e06 c2a4d06 7c933ef c2a4d06 c781e06 8df1ddb bbc7a9d 8df1ddb 63cd96e 8df1ddb cd79e1e 8df1ddb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import torch
import torchvision
from pathlib import Path
from vision_transformer import ViT
def load_model(model: torch.nn.Module,
model_weights_dir: str,
model_weights_name: str):
"""Loads a PyTorch model from a target directory.
Args:
model: A target PyTorch model to load.
model_weights_dir: A directory where the model is located.
model_weights_name: The name of the model to load.
Should include either ".pth" or ".pt" as the file extension.
Example usage:
model = load_model(model=model,
model_weights_dir="models",
model_weights_name="05_going_modular_tingvgg_model.pth")
Returns:
The loaded PyTorch model.
"""
# Create the model directory path
model_dir_path = Path(model_weights_dir)
# Create the model path
assert model_weights_name.endswith(".pth") or model_weights_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'"
model_path = model_dir_path / model_weights_name
# Load the model
print(f"[INFO] Loading model from: {model_path}")
model.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device('cpu')))
return model
def create_vitbase_model(
model_weights_dir:Path,
model_weights_name:str,
img_size:int=224,
num_classes:int=101,
compile:bool=False
):
"""
Creates a ViT-B/16 model with the specified number of classes.
Args:
model_weights_dir: A directory where the model is located.
model_weights_name: The name of the model to load.
img_size: The size of the input image.
num_classes: The number of classes for the classification task.
Returns:
The created ViT-B/16 model.
"""
# Instantiate the model
vitbase16_model = ViT(
img_size=img_size,
in_channels=3,
patch_size=16,
num_transformer_layers=12,
emb_dim=768,
mlp_size=3072,
num_heads=12,
attn_dropout=0,
mlp_dropout=0.1,
emb_dropout=0.1,
num_classes=num_classes
)
# Compile the model
if compile:
vitbase16_model = torch.compile(vitbase16_model, backend="aot_eager")
# Load the trained weights
vitbase16_model = load_model(
model=vitbase16_model,
model_weights_dir=model_weights_dir,
model_weights_name=model_weights_name
)
return vitbase16_model
# Create an EfficientNet-B0 Model
def create_effnetb0(
model_weights_dir: Path,
model_weights_name: str,
num_classes: int=2,
dropout: float=0.2
):
"""Creates an EfficientNetB0 feature extractor model and transforms.
Args:
model_weights_dir: A directory where the model is located.
model_weights_name: The name of the model to load.
num_classes (int, optional): number of classes in the classifier head.
dropout (float, optional): Dropout rate. Defaults to 0.2.
Returns:
effnetb0_model (torch.nn.Module): EffNetB0 feature extractor model.
transforms (torchvision.transforms): Image transforms.
"""
# Load pretrained weights
weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT # .DEFAULT = best available weights
effnetb0_model = torchvision.models.efficientnet_b0(weights=weights).to('cpu')
# Recreate the classifier layer and seed it to the target device
effnetb0_model.classifier = torch.nn.Sequential(
torch.nn.Dropout(p=dropout, inplace=True),
torch.nn.Linear(in_features=1280,
out_features=num_classes,
bias=True))
# Create the model directory path
model_dir_path = Path(model_weights_dir)
# Create the model path
assert model_weights_name.endswith(".pth") or model_weights_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'"
model_path = model_dir_path / model_weights_name
# Load the state dictionary into the model
effnetb0_model.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device('cpu')))
return effnetb0_model
|