| import torch | |
| from .DeMoE import DeMoE | |
| def create_model(opt, device): | |
| ''' | |
| Creates the model. | |
| opt: a dictionary from the yaml config key network | |
| ''' | |
| name = opt['name'] | |
| if name == 'DeMoE': | |
| model = DeMoE(img_channel=opt['img_channels'], | |
| width=opt['width'], | |
| middle_blk_num=opt['middle_blk_num'], | |
| enc_blk_nums=opt['enc_blk_nums'], | |
| dec_blk_nums=opt['dec_blk_nums'], | |
| num_exp=opt['num_experts'], | |
| k_used=opt['k_used']) | |
| else: | |
| raise NotImplementedError('This network is not implemented') | |
| model.to(device) | |
| return model | |
| def load_weights(model, model_weights): | |
| ''' | |
| Loads the weights of a pretrained model, picking only the weights that are | |
| in the new model. | |
| ''' | |
| new_weights = model.state_dict() | |
| new_weights.update({k: v for k, v in model_weights.items() if k in new_weights}) | |
| model.load_state_dict(new_weights) | |
| return model | |
| def resume_model(model, | |
| path_model, | |
| device): | |
| ''' | |
| Returns the loaded weights of model and optimizer if resume flag is True | |
| ''' | |
| checkpoints = torch.load(path_model, map_location=device, weights_only=False) | |
| weights = checkpoints['params'] | |
| model = load_weights(model, model_weights=weights) | |
| return model | |
| __all__ = ['create_model', 'resume_model', 'load_weights'] | |