Spaces:
Build error
Build error
| # -------------------------------------------------------- | |
| # X-Decoder -- Generalized Decoding for Pixel, Image, and Language | |
| # Copyright (c) 2022 Microsoft | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # Written by Xueyan Zou (xueyan@cs.wisc.edu) | |
| # -------------------------------------------------------- | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| from typing import Dict | |
| from torch import nn | |
| from detectron2.layers import ShapeSpec | |
| from .build import register_body | |
| from ..vision.encoder import build_encoder | |
| from ..interface import build_decoder | |
| from ..utils import configurable | |
| class XdecoderHead(nn.Module): | |
| def __init__( | |
| self, | |
| input_shape: Dict[str, ShapeSpec], | |
| *, | |
| num_classes: int, | |
| pixel_decoder: nn.Module, | |
| loss_weight: float = 1.0, | |
| ignore_value: int = -1, | |
| # extra parameters | |
| transformer_predictor: nn.Module, | |
| transformer_in_feature: str, | |
| binary_classes: bool, | |
| ): | |
| """ | |
| NOTE: this interface is experimental. | |
| Args: | |
| input_shape: shapes (channels and stride) of the input features | |
| num_classes: number of classes to predict | |
| pixel_decoder: the pixel decoder module | |
| loss_weight: loss weight | |
| ignore_value: category id to be ignored during training. | |
| transformer_predictor: the transformer decoder that makes prediction | |
| transformer_in_feature: input feature name to the transformer_predictor | |
| """ | |
| super().__init__() | |
| input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) | |
| self.in_features = [k for k, v in input_shape] | |
| feature_strides = [v.stride for k, v in input_shape] | |
| feature_channels = [v.channels for k, v in input_shape] | |
| self.ignore_value = ignore_value | |
| self.common_stride = 4 | |
| self.loss_weight = loss_weight | |
| self.pixel_decoder = pixel_decoder | |
| self.predictor = transformer_predictor | |
| self.transformer_in_feature = transformer_in_feature | |
| self.num_classes = num_classes | |
| if binary_classes: | |
| self.num_classes = 1 | |
| def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], lang_encoder: nn.Module, extra: dict): | |
| in_features_type = cfg['MODEL']['DECODER']['TRANSFORMER_IN_FEATURE'] | |
| enc_cfg = cfg['MODEL']['ENCODER'] | |
| dec_cfg = cfg['MODEL']['DECODER'] | |
| # figure out in_channels to transformer predictor | |
| if in_features_type == "transformer_encoder": | |
| transformer_predictor_in_channels = enc_cfg['CONVS_DIM'] | |
| elif in_features_type == "pixel_embedding": | |
| transformer_predictor_in_channels = enc_cfg['MASK_DIM'] | |
| elif in_features_type == "multi_scale_pixel_decoder": | |
| transformer_predictor_in_channels = enc_cfg['CONVS_DIM'] | |
| else: | |
| transformer_predictor_in_channels = input_shape[dec_cfg['TRANSFORMER_IN_FEATURE']].channels | |
| return { | |
| "input_shape": { | |
| k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES'] | |
| }, | |
| "ignore_value": enc_cfg['IGNORE_VALUE'], | |
| "num_classes": enc_cfg.get('NUM_CLASSES', None), | |
| "pixel_decoder": build_encoder(cfg, input_shape), | |
| "loss_weight": enc_cfg['LOSS_WEIGHT'], | |
| "transformer_in_feature": dec_cfg['TRANSFORMER_IN_FEATURE'], | |
| "transformer_predictor": build_decoder( | |
| cfg, | |
| transformer_predictor_in_channels, | |
| lang_encoder, | |
| mask_classification=True, | |
| extra=extra, | |
| ), | |
| "binary_classes": enc_cfg['BINARY_CLASSES'] | |
| } | |
| def forward(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}): | |
| return self.layers(features, mask, target_queries, target_vlp, task, extra) | |
| def layers(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}): | |
| mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features) | |
| if self.transformer_in_feature == "multi_scale_pixel_decoder": | |
| predictions = self.predictor(multi_scale_features, mask_features, mask, target_queries, target_vlp, task, extra) | |
| else: | |
| if self.transformer_in_feature == "transformer_encoder": | |
| assert ( | |
| transformer_encoder_features is not None | |
| ), "Please use the TransformerEncoderPixelDecoder." | |
| predictions = self.predictor(transformer_encoder_features, mask_features, mask) | |
| elif self.transformer_in_feature == "pixel_embedding": | |
| predictions = self.predictor(mask_features, mask_features, mask) | |
| else: | |
| predictions = self.predictor(features[self.transformer_in_feature], mask_features, mask) | |
| return predictions | |
| def get_xdecoder_head(cfg, input_shape, lang_encoder, extra): | |
| return XdecoderHead(cfg, input_shape, lang_encoder, extra) |