Spaces:
Sleeping
Sleeping
File size: 6,556 Bytes
c5f4ee2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from .processor import Processor, DCPProcessor, JTFNProcessor, JTFNDCPProcessor
from .UNet_p import U_Net_P, R2AttUNetDecoder, UNetDecoder, Prompt_U_Net_P_DCP
from .jtfn import JTFN, JTFNDecoder, JTFN_DCP
from .backbones import build_backbone
def build_model(model_name, model_params, training, dataset_idx, pretrained):
model = getattr(Models, model_name)(model_params=model_params, training=training, dataset_idx=dataset_idx, pretrained=pretrained)
return model
class Models(object):
@staticmethod
def effi_b3_p_unet(model_params, training, dataset_idx, pretrained=True):
n_class = model_params['n_class']
channels = (24, 12, 40, 120, 384)
encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
decoder = UNetDecoder(channels=channels)
seg_net = U_Net_P(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=n_class)
model = Processor(model=seg_net, training_params=model_params, training=training)
return model
@staticmethod
def effi_b3_p_r2attunet(model_params, training, dataset_idx, pretrained=True):
n_class = model_params['n_class']
channels = (24, 12, 40, 120, 384)
encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
decoder = R2AttUNetDecoder(channels=channels)
seg_net = U_Net_P(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=n_class)
model = Processor(model=seg_net, training_params=model_params, training=training)
return model
@staticmethod
def effi_b3_p_jtfn(model_params, training, dataset_idx, pretrained=True):
n_class = model_params['n_class']
channels = (24, 12, 40, 120, 384)
steps = model_params['steps']
encoder = build_backbone('efficientnet_b3_p')
decoder = JTFNDecoder(channels=channels, use_topo=True)
seg_net = JTFN(encoder=encoder, decoder=decoder, channels=channels, num_classes=n_class, steps=steps)
model = JTFNProcessor(model=seg_net, training_params=model_params, training=training)
return model
@staticmethod
def prompt_effi_b3_p_unet_dcp(model_params, training, dataset_idx, pretrained=True):
n_class = model_params['n_class']
channels = [24, 12, 40, 120, 384]
cha_promot_channels = model_params['cha_promot_channels']
pos_promot_channels = model_params['pos_promot_channels']
local_window_sizes = model_params['local_window_sizes']
att_fusion = model_params['att_fusion']
prompt_init = model_params.get('prompt_init', 'rand') # rand, zero, one
embed_ratio = model_params['embed_ratio']
strides = model_params['strides']
use_conv = model_params['use_conv']
encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
decoder = UNetDecoder(channels=channels)
seg_net = Prompt_U_Net_P_DCP(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=n_class,
dataset_idx=dataset_idx, encoder_channels=channels, prompt_init=prompt_init,
cha_promot_channels=cha_promot_channels, pos_promot_channels=pos_promot_channels,
embed_ratio=embed_ratio, strides=strides, local_window_sizes=local_window_sizes,
att_fusion=att_fusion, use_conv=use_conv)
model = DCPProcessor(model=seg_net, training_params=model_params, training=training)
return model
@staticmethod
def prompt_effi_b3_p_r2attunet_dcp(model_params, training, dataset_idx, pretrained=True):
n_class = model_params['n_class']
channels = [24, 12, 40, 120, 384]
cha_promot_channels = model_params['cha_promot_channels']
pos_promot_channels = model_params['pos_promot_channels']
local_window_sizes = model_params['local_window_sizes']
att_fusion = model_params['att_fusion']
prompt_init = model_params.get('prompt_init', 'rand') # rand, zero, one
embed_ratio = model_params['embed_ratio']
strides = model_params['strides']
use_conv = model_params['use_conv']
encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
decoder = R2AttUNetDecoder(channels=channels)
seg_net = Prompt_U_Net_P_DCP(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=n_class,
dataset_idx=dataset_idx, encoder_channels=channels, prompt_init=prompt_init,
cha_promot_channels=cha_promot_channels, pos_promot_channels=pos_promot_channels,
embed_ratio=embed_ratio, strides=strides, local_window_sizes=local_window_sizes,
att_fusion=att_fusion, use_conv=use_conv)
model = DCPProcessor(model=seg_net, training_params=model_params, training=training)
return model
@staticmethod
def prompt_effi_b3_p_jtfn_dcp(model_params, training, dataset_idx, pretrained=True):
n_class = model_params['n_class']
steps = model_params['steps']
channels = [24, 12, 40, 120, 384]
cha_promot_channels = model_params['cha_promot_channels']
pos_promot_channels = model_params['pos_promot_channels']
local_window_sizes = model_params['local_window_sizes']
att_fusion = model_params['att_fusion']
embed_ratio = model_params['embed_ratio']
strides = model_params['strides']
use_conv = model_params['use_conv']
encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
decoder = JTFNDecoder(channels=channels, use_topo=True)
seg_net = JTFN_DCP(encoder=encoder, decoder=decoder, channels=channels, num_classes=n_class, steps=steps,
dataset_idx=dataset_idx, local_window_sizes=local_window_sizes,
encoder_channels=channels,
cha_promot_channels=cha_promot_channels, pos_promot_channels=pos_promot_channels,
embed_ratio=embed_ratio, strides=strides,
att_fusion=att_fusion, use_conv=use_conv)
model = JTFNDCPProcessor(model=seg_net, training_params=model_params, training=training)
return model
|