DeMoE / archs /__init__.py
danifei's picture
basic functionality
034f4b8
raw
history blame
1.46 kB
import torch
from .DeMoE import DeMoE
def create_model(opt, device):
'''
Creates the model.
opt: a dictionary from the yaml config key network
'''
name = opt['name']
if name == 'DeMoE':
model = DeMoE(img_channel=opt['img_channels'],
width=opt['width'],
middle_blk_num=opt['middle_blk_num'],
enc_blk_nums=opt['enc_blk_nums'],
dec_blk_nums=opt['dec_blk_nums'],
num_exp=opt['num_experts'],
k_used=opt['k_used'])
else:
raise NotImplementedError('This network is not implemented')
model.to(device)
return model
def load_weights(model, model_weights):
'''
Loads the weights of a pretrained model, picking only the weights that are
in the new model.
'''
new_weights = model.state_dict()
new_weights.update({k: v for k, v in model_weights.items() if k in new_weights})
model.load_state_dict(new_weights)
return model
def resume_model(model,
path_model,
device):
'''
Returns the loaded weights of model and optimizer if resume flag is True
'''
checkpoints = torch.load(path_model, map_location=device, weights_only=False)
weights = checkpoints['params']
model = load_weights(model, model_weights=weights)
return model
__all__ = ['create_model', 'resume_model', 'load_weights']