File size: 6,556 Bytes
c5f4ee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from .processor import Processor, DCPProcessor, JTFNProcessor, JTFNDCPProcessor
from .UNet_p import U_Net_P, R2AttUNetDecoder, UNetDecoder, Prompt_U_Net_P_DCP
from .jtfn import JTFN, JTFNDecoder, JTFN_DCP
from .backbones import build_backbone


def build_model(model_name, model_params, training, dataset_idx, pretrained):
    model = getattr(Models, model_name)(model_params=model_params, training=training, dataset_idx=dataset_idx, pretrained=pretrained)
    return model


class Models(object):
    @staticmethod
    def effi_b3_p_unet(model_params, training, dataset_idx, pretrained=True):
        n_class = model_params['n_class']
        channels = (24, 12, 40, 120, 384)

        encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
        decoder = UNetDecoder(channels=channels)

        seg_net = U_Net_P(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=n_class)
        model = Processor(model=seg_net, training_params=model_params, training=training)
        return model
    
    
    @staticmethod
    def effi_b3_p_r2attunet(model_params, training, dataset_idx, pretrained=True):
        n_class = model_params['n_class']
        channels = (24, 12, 40, 120, 384)

        encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
        decoder = R2AttUNetDecoder(channels=channels)

        seg_net = U_Net_P(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=n_class)
        model = Processor(model=seg_net, training_params=model_params, training=training)
        return model
    
    @staticmethod
    def effi_b3_p_jtfn(model_params, training, dataset_idx, pretrained=True):
        n_class = model_params['n_class']
        channels = (24, 12, 40, 120, 384)
        steps = model_params['steps']
        
        encoder = build_backbone('efficientnet_b3_p')
        decoder = JTFNDecoder(channels=channels, use_topo=True)
        
        seg_net = JTFN(encoder=encoder, decoder=decoder, channels=channels, num_classes=n_class, steps=steps)
        model = JTFNProcessor(model=seg_net, training_params=model_params, training=training)
        return model

    
    @staticmethod
    def prompt_effi_b3_p_unet_dcp(model_params, training, dataset_idx, pretrained=True):
        n_class = model_params['n_class']
        channels = [24, 12, 40, 120, 384]

        cha_promot_channels = model_params['cha_promot_channels']
        pos_promot_channels = model_params['pos_promot_channels']
        local_window_sizes = model_params['local_window_sizes']
        att_fusion = model_params['att_fusion']
        prompt_init = model_params.get('prompt_init', 'rand')  # rand, zero, one
        embed_ratio = model_params['embed_ratio']
        strides = model_params['strides']
        use_conv = model_params['use_conv']

        encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
        decoder = UNetDecoder(channels=channels)

        seg_net = Prompt_U_Net_P_DCP(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=n_class, 
                                         dataset_idx=dataset_idx, encoder_channels=channels, prompt_init=prompt_init, 
                                         cha_promot_channels=cha_promot_channels, pos_promot_channels=pos_promot_channels,
                                         embed_ratio=embed_ratio, strides=strides, local_window_sizes=local_window_sizes,
                                         att_fusion=att_fusion, use_conv=use_conv)

        model = DCPProcessor(model=seg_net, training_params=model_params, training=training)
        return model
    
    @staticmethod
    def prompt_effi_b3_p_r2attunet_dcp(model_params, training, dataset_idx, pretrained=True):
        n_class = model_params['n_class']
        channels = [24, 12, 40, 120, 384]

        cha_promot_channels = model_params['cha_promot_channels']
        pos_promot_channels = model_params['pos_promot_channels']
        local_window_sizes = model_params['local_window_sizes']
        att_fusion = model_params['att_fusion']
        prompt_init = model_params.get('prompt_init', 'rand')  # rand, zero, one
        embed_ratio = model_params['embed_ratio']
        strides = model_params['strides']
        use_conv = model_params['use_conv']

        encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
        decoder = R2AttUNetDecoder(channels=channels)

        seg_net = Prompt_U_Net_P_DCP(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=n_class, 
                                         dataset_idx=dataset_idx, encoder_channels=channels, prompt_init=prompt_init, 
                                         cha_promot_channels=cha_promot_channels, pos_promot_channels=pos_promot_channels,
                                         embed_ratio=embed_ratio, strides=strides, local_window_sizes=local_window_sizes,
                                         att_fusion=att_fusion, use_conv=use_conv)

        model = DCPProcessor(model=seg_net, training_params=model_params, training=training)
        return model

        
    @staticmethod
    def prompt_effi_b3_p_jtfn_dcp(model_params, training, dataset_idx, pretrained=True):
        n_class = model_params['n_class']
        steps = model_params['steps']
        channels = [24, 12, 40, 120, 384]

        cha_promot_channels = model_params['cha_promot_channels']
        pos_promot_channels = model_params['pos_promot_channels']
        local_window_sizes = model_params['local_window_sizes']
        att_fusion = model_params['att_fusion']
        embed_ratio = model_params['embed_ratio']
        strides = model_params['strides']
        use_conv = model_params['use_conv']

        encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
        decoder = JTFNDecoder(channels=channels, use_topo=True)
        seg_net = JTFN_DCP(encoder=encoder, decoder=decoder, channels=channels, num_classes=n_class, steps=steps,
                                         dataset_idx=dataset_idx, local_window_sizes=local_window_sizes, 
                                         encoder_channels=channels, 
                                         cha_promot_channels=cha_promot_channels, pos_promot_channels=pos_promot_channels,
                                         embed_ratio=embed_ratio, strides=strides, 
                                         att_fusion=att_fusion, use_conv=use_conv)

        model = JTFNDCPProcessor(model=seg_net, training_params=model_params, training=training)
        return model