|
|
import os
|
|
|
|
|
|
from .utils import MODELS, imagenet_weights
|
|
|
from .utils import tome_presets
|
|
|
from .model.base_module import BaseModule
|
|
|
from .configs.config.config import Config
|
|
|
from .utils.build_functions import build_model_from_cfg
|
|
|
|
|
|
|
|
|
class SegFormer(BaseModule):
|
|
|
"""
|
|
|
This class represents a SegFormer model that allows for the application of token merging.
|
|
|
|
|
|
Attributes:
|
|
|
backbone (BaseModule): MixVisionTransformer backbone
|
|
|
decode_head (BaseModule): SegFormer head
|
|
|
|
|
|
"""
|
|
|
def __init__(self, cfg):
|
|
|
"""
|
|
|
Initialize the SegFormer model.
|
|
|
|
|
|
Args:
|
|
|
cfg (Config): an mmengine Config object, which defines the backbone, head and token merging strategy used.
|
|
|
|
|
|
"""
|
|
|
super().__init__()
|
|
|
self.backbone = build_model_from_cfg(cfg.backbone, registry=MODELS)
|
|
|
self.decode_head = build_model_from_cfg(cfg.decode_head, registry=MODELS)
|
|
|
|
|
|
def forward(self, x):
|
|
|
"""
|
|
|
Forward pass of the model.
|
|
|
|
|
|
Args:
|
|
|
x (torch.Tensor): input tensor of shape [B, C, H, W]
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: output tensor
|
|
|
|
|
|
"""
|
|
|
x = self.backbone(x)
|
|
|
x = self.decode_head(x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
def create_model(
|
|
|
backbone: str = 'b0',
|
|
|
tome_strategy: str = None,
|
|
|
out_channels: int = 19,
|
|
|
pretrained: bool = False,
|
|
|
):
|
|
|
"""
|
|
|
Create a SegFormer model using the predefined SegFormer backbones from the MiT series (b0-b5).
|
|
|
|
|
|
Args:
|
|
|
backbone (str): backbone name (e.g. 'b0')
|
|
|
tome_strategy (str | list(dict)): select strategy from presets ('bsm_hq', 'bsm_fast', 'n2d_2x2') or define a
|
|
|
custom strategy using a list, that contains of dictionaries, in which the strategies for the stage are
|
|
|
defined
|
|
|
out_channels (int): number of output channels (e.g. 19 for the cityscapes semantic segmentation task)
|
|
|
pretrained: use pretrained (imagenet) weights
|
|
|
|
|
|
Returns:
|
|
|
BaseModule: SegFormer model
|
|
|
|
|
|
"""
|
|
|
backbone = backbone.lower()
|
|
|
assert backbone in [f'b{i}' for i in range(6)]
|
|
|
|
|
|
wd = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
|
cfg = Config.fromfile(os.path.join(wd, 'configs', f'segformer_mit_{backbone}.py'))
|
|
|
|
|
|
cfg.decode_head.out_channels = out_channels
|
|
|
|
|
|
if tome_strategy is not None:
|
|
|
if tome_strategy not in list(tome_presets.keys()):
|
|
|
print("Using custom merging strategy.")
|
|
|
cfg.backbone.tome_cfg = tome_presets[tome_strategy]
|
|
|
|
|
|
|
|
|
if pretrained:
|
|
|
cfg.backbone.init_cfg = dict(type='Pretrained', checkpoint=imagenet_weights[backbone])
|
|
|
|
|
|
return SegFormer(cfg)
|
|
|
|
|
|
|
|
|
def create_custom_model(
|
|
|
model_cfg: Config,
|
|
|
tome_strategy: list[dict] = None,
|
|
|
):
|
|
|
"""
|
|
|
Create a SegFormer model with customizable backbone and head.
|
|
|
|
|
|
Args:
|
|
|
model_cfg (Config): backbone name (e.g. 'b0')
|
|
|
tome_strategy (list(dict)): custom token merging strategy
|
|
|
|
|
|
Returns:
|
|
|
BaseModule: SegFormer model
|
|
|
|
|
|
"""
|
|
|
if tome_strategy is not None:
|
|
|
model_cfg.backbone.tome_cfg = tome_strategy
|
|
|
|
|
|
return SegFormer(model_cfg)
|
|
|
|