File size: 1,458 Bytes
034f4b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import torch
from .DeMoE import DeMoE
def create_model(opt, device):
'''
Creates the model.
opt: a dictionary from the yaml config key network
'''
name = opt['name']
if name == 'DeMoE':
model = DeMoE(img_channel=opt['img_channels'],
width=opt['width'],
middle_blk_num=opt['middle_blk_num'],
enc_blk_nums=opt['enc_blk_nums'],
dec_blk_nums=opt['dec_blk_nums'],
num_exp=opt['num_experts'],
k_used=opt['k_used'])
else:
raise NotImplementedError('This network is not implemented')
model.to(device)
return model
def load_weights(model, model_weights):
'''
Loads the weights of a pretrained model, picking only the weights that are
in the new model.
'''
new_weights = model.state_dict()
new_weights.update({k: v for k, v in model_weights.items() if k in new_weights})
model.load_state_dict(new_weights)
return model
def resume_model(model,
path_model,
device):
'''
Returns the loaded weights of model and optimizer if resume flag is True
'''
checkpoints = torch.load(path_model, map_location=device, weights_only=False)
weights = checkpoints['params']
model = load_weights(model, model_weights=weights)
return model
__all__ = ['create_model', 'resume_model', 'load_weights']
|