Spaces:
Sleeping
Sleeping
File size: 535 Bytes
72647b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
# model.py
import torch.nn as nn
from torchvision.models import resnet18
def get_model(num_classes, pretrained=True):
"""
Returns a CNN model adapted for grayscale ECG images
"""
model = resnet18(pretrained=pretrained)
# Change first layer to accept 1-channel input (grayscale)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# Change the output layer for our number of classes
model.fc = nn.Linear(model.fc.in_features, num_classes)
return model
|