File size: 4,214 Bytes
b64811b
18add7e
4df6cc7
c781e06
 
59126da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cfaed9
59126da
 
 
c781e06
 
 
 
7c933ef
 
c781e06
59126da
 
 
 
8df1ddb
 
 
 
59126da
 
 
 
c781e06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2a4d06
7c933ef
 
c2a4d06
c781e06
 
 
 
 
 
 
 
8df1ddb
 
 
 
 
bbc7a9d
8df1ddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63cd96e
8df1ddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd79e1e
8df1ddb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import torchvision
from pathlib import Path
from vision_transformer import ViT

def load_model(model: torch.nn.Module,
               model_weights_dir: str,
               model_weights_name: str):

    """Loads a PyTorch model from a target directory.

    Args:
    model: A target PyTorch model to load.
    model_weights_dir: A directory where the model is located.
    model_weights_name: The name of the model to load.
      Should include either ".pth" or ".pt" as the file extension.

    Example usage:
    model = load_model(model=model,
                       model_weights_dir="models",
                       model_weights_name="05_going_modular_tingvgg_model.pth")

    Returns:
    The loaded PyTorch model.
    """
    # Create the model directory path
    model_dir_path = Path(model_weights_dir)

    # Create the model path
    assert model_weights_name.endswith(".pth") or model_weights_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'"
    model_path = model_dir_path / model_weights_name

    # Load the model
    print(f"[INFO] Loading model from: {model_path}")
    
    model.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device('cpu')))
    
    return model

def create_vitbase_model(
    model_weights_dir:Path,
    model_weights_name:str,
    img_size:int=224,
    num_classes:int=101,
    compile:bool=False
    ):
    """
    Creates a ViT-B/16 model with the specified number of classes.

    Args:
        model_weights_dir: A directory where the model is located.
        model_weights_name: The name of the model to load.
        img_size: The size of the input image.
        num_classes: The number of classes for the classification task.

    Returns:
    The created ViT-B/16 model.
    """    
    # Instantiate the model
    vitbase16_model = ViT(
        img_size=img_size,
        in_channels=3,
        patch_size=16,
        num_transformer_layers=12,
        emb_dim=768,
        mlp_size=3072,
        num_heads=12,
        attn_dropout=0,
        mlp_dropout=0.1,
        emb_dropout=0.1,
        num_classes=num_classes
    )
    
    # Compile the model
    if compile:
        vitbase16_model = torch.compile(vitbase16_model, backend="aot_eager")

    # Load the trained weights
    vitbase16_model = load_model(
        model=vitbase16_model,
        model_weights_dir=model_weights_dir,
        model_weights_name=model_weights_name
        )
    
    return vitbase16_model

# Create an EfficientNet-B0 Model
def create_effnetb0(
        model_weights_dir: Path,
        model_weights_name: str,
        num_classes: int=2,
        dropout: float=0.2
        ):
    """Creates an EfficientNetB0 feature extractor model and transforms.

    Args:
        model_weights_dir: A directory where the model is located.
        model_weights_name: The name of the model to load.
        num_classes (int, optional): number of classes in the classifier head.
        dropout (float, optional): Dropout rate. Defaults to 0.2.

    Returns:
        effnetb0_model (torch.nn.Module): EffNetB0 feature extractor model.
        transforms (torchvision.transforms): Image transforms.
    """
    
    # Load pretrained weights
    weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT # .DEFAULT = best available weights 
    effnetb0_model = torchvision.models.efficientnet_b0(weights=weights).to('cpu')

    # Recreate the classifier layer and seed it to the target device
    effnetb0_model.classifier = torch.nn.Sequential(
        torch.nn.Dropout(p=dropout, inplace=True), 
        torch.nn.Linear(in_features=1280, 
                        out_features=num_classes,
                        bias=True))
    
    # Create the model directory path
    model_dir_path = Path(model_weights_dir)

    # Create the model path
    assert model_weights_name.endswith(".pth") or model_weights_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'"
    model_path = model_dir_path / model_weights_name

    # Load the state dictionary into the model
    effnetb0_model.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device('cpu')))
        
    return effnetb0_model