import torch.nn as nn from torchvision.models import resnet18 def get_model(num_classes, pretrained=True): """ Returns a CNN model adapted for grayscale ECG images """ model = resnet18(pretrained=pretrained) # Change first layer to accept 1-channel input (grayscale) model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) # Change the output layer for our number of classes model.fc = nn.Linear(model.fc.in_features, num_classes) return model import torch.nn as nn from torchvision.models import resnet18 def get_model(num_classes, pretrained=True): """ Returns a CNN model adapted for grayscale ECG images """ model = resnet18(pretrained=pretrained) # Change first layer to accept 1-channel input (grayscale) model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) # Change the output layer for our number of classes model.fc = nn.Linear(model.fc.in_features, num_classes) return model