File size: 519 Bytes
2a16572
 
 
 
 
 
 
 
 
 
 
 
 
 
302f7b0
2a16572
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# model.py 
import torch.nn as nn
from torchvision.models import resnet18

def get_model(num_classes, pretrained=True):
    """
    Returns a CNN model adapted for grayscale ECG images
    """
    model = resnet18(pretrained=pretrained)

    # Change first layer to accept 1-channel input (grayscale)
    model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

    # Change the output layer for our number of classes
    model.fc = nn
Linear(model.fc.in_features, num_classes)

    return model