ecg-ai-classifier / model.py
Sayed223's picture
Update model.py
3963964 verified
import torch.nn as nn
from torchvision.models import resnet18
def get_model(num_classes, pretrained=True):
"""
Returns a CNN model adapted for grayscale ECG images
"""
model = resnet18(pretrained=pretrained)
# Change first layer to accept 1-channel input (grayscale)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# Change the output layer for our number of classes
model.fc = nn.Linear(model.fc.in_features, num_classes)
return model
import torch.nn as nn
from torchvision.models import resnet18
def get_model(num_classes, pretrained=True):
"""
Returns a CNN model adapted for grayscale ECG images
"""
model = resnet18(pretrained=pretrained)
# Change first layer to accept 1-channel input (grayscale)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# Change the output layer for our number of classes
model.fc = nn.Linear(model.fc.in_features, num_classes)
return model