Spaces:
Sleeping
Sleeping
File size: 778 Bytes
c5f4ee2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import os
import torch
from timm.models import efficientnet, convnext
def build_backbone(model_name, pretrained):
model = getattr(Backbones, model_name)(pretrained=pretrained)
return model
class Backbones(object):
@staticmethod
def efficientnet_b3_p(pretrained):
# channels: 24, 12, 40, 120, 384
# for test, pretrained can be set to False
model = efficientnet.efficientnet_b3_pruned(pretrained=pretrained, features_only=True)
'''
# pre-downloaded weights
cp_path = os.path.join('checkpoints', 'effnetb3_pruned-59ecf72d.pth')
state_dict = torch.load(cp_path, map_location=torch.device('cpu'))
model.load_state_dict(state_dict=state_dict, strict=False)'''
return model
|