File size: 1,458 Bytes
034f4b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch

from .DeMoE import DeMoE  

def create_model(opt, device):
    '''
    Creates the model.
    opt: a dictionary from the yaml config key network
    '''
    name = opt['name']
        
    if name == 'DeMoE':
        model = DeMoE(img_channel=opt['img_channels'],
                width=opt['width'], 
                middle_blk_num=opt['middle_blk_num'], 
                enc_blk_nums=opt['enc_blk_nums'],
                dec_blk_nums=opt['dec_blk_nums'],
                num_exp=opt['num_experts'],
                k_used=opt['k_used'])

    else:
        raise NotImplementedError('This network is not implemented')

    model.to(device)
    
    return model

def load_weights(model, model_weights):
    '''
    Loads the weights of a pretrained model, picking only the weights that are
    in the new model.
    '''
    new_weights = model.state_dict()
    new_weights.update({k: v for k, v in model_weights.items() if k in new_weights})
    
    model.load_state_dict(new_weights)

    return model

def resume_model(model,
                 path_model, 
                 device):
    
    '''
    Returns the loaded weights of model and optimizer if resume flag is True
    '''

    checkpoints = torch.load(path_model, map_location=device, weights_only=False)
    weights = checkpoints['params']
    model = load_weights(model, model_weights=weights)

    return model


__all__ = ['create_model', 'resume_model', 'load_weights']