Spaces:
Runtime error
Runtime error
| '''Basic cortex_DIM encoder. | |
| ''' | |
| import torch | |
| from cortex_DIM.nn_modules.convnet import Convnet, FoldedConvnet | |
| from cortex_DIM.nn_modules.resnet import ResNet, FoldedResNet | |
| def create_encoder(Module): | |
| class Encoder(Module): | |
| '''Encoder used for cortex_DIM. | |
| ''' | |
| def __init__(self, *args, local_idx=None, multi_idx=None, conv_idx=None, fc_idx=None, **kwargs): | |
| ''' | |
| Args: | |
| args: Arguments for parent class. | |
| local_idx: Index in list of convolutional layers for local features. | |
| multi_idx: Index in list of convolutional layers for multiple globals. | |
| conv_idx: Index in list of convolutional layers for intermediate features. | |
| fc_idx: Index in list of fully-connected layers for intermediate features. | |
| kwargs: Keyword arguments for the parent class. | |
| ''' | |
| super().__init__(*args, **kwargs) | |
| if local_idx is None: | |
| raise ValueError('`local_idx` must be set') | |
| conv_idx = conv_idx or local_idx | |
| self.local_idx = local_idx | |
| self.multi_idx = multi_idx | |
| self.conv_idx = conv_idx | |
| self.fc_idx = fc_idx | |
| def forward(self, x: torch.Tensor): | |
| ''' | |
| Args: | |
| x: Input tensor. | |
| Returns: | |
| local_out, multi_out, hidden_out, global_out | |
| ''' | |
| outs = super().forward(x, return_full_list=True) | |
| if len(outs) == 2: | |
| conv_out, fc_out = outs | |
| else: | |
| conv_before_out, res_out, conv_after_out, fc_out = outs | |
| conv_out = conv_before_out + res_out + conv_after_out | |
| local_out = conv_out[self.local_idx] | |
| if self.multi_idx is not None: | |
| multi_out = conv_out[self.multi_idx] | |
| else: | |
| multi_out = None | |
| if len(fc_out) > 0: | |
| if self.fc_idx is not None: | |
| hidden_out = fc_out[self.fc_idx] | |
| else: | |
| hidden_out = None | |
| global_out = fc_out[-1] | |
| else: | |
| hidden_out = None | |
| global_out = None | |
| conv_out = conv_out[self.conv_idx] | |
| return local_out, conv_out, multi_out, hidden_out, global_out | |
| return Encoder | |
| class ConvnetEncoder(create_encoder(Convnet)): | |
| pass | |
| class FoldedConvnetEncoder(create_encoder(FoldedConvnet)): | |
| pass | |
| class ResnetEncoder(create_encoder(ResNet)): | |
| pass | |
| class FoldedResnetEncoder(create_encoder(FoldedResNet)): | |
| pass | |