File size: 535 Bytes
72647b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# model.py
import torch.nn as nn
from torchvision.models import resnet18

def get_model(num_classes, pretrained=True):
    """

    Returns a CNN model adapted for grayscale ECG images

    """
    model = resnet18(pretrained=pretrained)

    # Change first layer to accept 1-channel input (grayscale)
    model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

    # Change the output layer for our number of classes
    model.fc = nn.Linear(model.fc.in_features, num_classes)

    return model