Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| from typing import Mapping, Optional, Sequence, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmengine.registry import MODELS | |
| from mmengine.structures import BaseDataElement | |
| from mmengine.utils import is_seq_of | |
| from ..utils import stack_batch | |
| CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str, | |
| None] | |
| class BaseDataPreprocessor(nn.Module): | |
| """Base data pre-processor used for copying data to the target device. | |
| Subclasses inherit from ``BaseDataPreprocessor`` could override the | |
| forward method to implement custom data pre-processing, such as | |
| batch-resize, MixUp, or CutMix. | |
| Args: | |
| non_blocking (bool): Whether block current process | |
| when transferring data to device. | |
| New in version 0.3.0. | |
| Note: | |
| Data dictionary returned by dataloader must be a dict and at least | |
| contain the ``inputs`` key. | |
| """ | |
| def __init__(self, non_blocking: Optional[bool] = False): | |
| super().__init__() | |
| self._non_blocking = non_blocking | |
| self._device = torch.device('cpu') | |
| def cast_data(self, data: CastData) -> CastData: | |
| """Copying data to the target device. | |
| Args: | |
| data (dict): Data returned by ``DataLoader``. | |
| Returns: | |
| CollatedResult: Inputs and data sample at target device. | |
| """ | |
| if isinstance(data, Mapping): | |
| return {key: self.cast_data(data[key]) for key in data} | |
| elif isinstance(data, (str, bytes)) or data is None: | |
| return data | |
| elif isinstance(data, tuple) and hasattr(data, '_fields'): | |
| # namedtuple | |
| return type(data)(*(self.cast_data(sample) for sample in data)) # type: ignore # noqa: E501 # yapf:disable | |
| elif isinstance(data, Sequence): | |
| return type(data)(self.cast_data(sample) for sample in data) # type: ignore # noqa: E501 # yapf:disable | |
| elif isinstance(data, (torch.Tensor, BaseDataElement)): | |
| return data.to(self.device, non_blocking=self._non_blocking) | |
| else: | |
| return data | |
| def forward(self, data: dict, training: bool = False) -> Union[dict, list]: | |
| """Preprocesses the data into the model input format. | |
| After the data pre-processing of :meth:`cast_data`, ``forward`` | |
| will stack the input tensor list to a batch tensor at the first | |
| dimension. | |
| Args: | |
| data (dict): Data returned by dataloader | |
| training (bool): Whether to enable training time augmentation. | |
| Returns: | |
| dict or list: Data in the same format as the model input. | |
| """ | |
| return self.cast_data(data) # type: ignore | |
| def device(self): | |
| return self._device | |
| def to(self, *args, **kwargs) -> nn.Module: | |
| """Overrides this method to set the :attr:`device` | |
| Returns: | |
| nn.Module: The model itself. | |
| """ | |
| # Since Torch has not officially merged | |
| # the npu-related fields, using the _parse_to function | |
| # directly will cause the NPU to not be found. | |
| # Here, the input parameters are processed to avoid errors. | |
| if args and isinstance(args[0], str) and 'npu' in args[0]: | |
| args = tuple( | |
| [list(args)[0].replace('npu', torch.npu.native_device)]) | |
| if kwargs and 'npu' in str(kwargs.get('device', '')): | |
| kwargs['device'] = kwargs['device'].replace( | |
| 'npu', torch.npu.native_device) | |
| device = torch._C._nn._parse_to(*args, **kwargs)[0] | |
| if device is not None: | |
| self._device = torch.device(device) | |
| return super().to(*args, **kwargs) | |
| def cuda(self, *args, **kwargs) -> nn.Module: | |
| """Overrides this method to set the :attr:`device` | |
| Returns: | |
| nn.Module: The model itself. | |
| """ | |
| self._device = torch.device(torch.cuda.current_device()) | |
| return super().cuda() | |
| def npu(self, *args, **kwargs) -> nn.Module: | |
| """Overrides this method to set the :attr:`device` | |
| Returns: | |
| nn.Module: The model itself. | |
| """ | |
| self._device = torch.device(torch.npu.current_device()) | |
| return super().npu() | |
| def mlu(self, *args, **kwargs) -> nn.Module: | |
| """Overrides this method to set the :attr:`device` | |
| Returns: | |
| nn.Module: The model itself. | |
| """ | |
| self._device = torch.device(torch.mlu.current_device()) | |
| return super().mlu() | |
| def cpu(self, *args, **kwargs) -> nn.Module: | |
| """Overrides this method to set the :attr:`device` | |
| Returns: | |
| nn.Module: The model itself. | |
| """ | |
| self._device = torch.device('cpu') | |
| return super().cpu() | |
| class ImgDataPreprocessor(BaseDataPreprocessor): | |
| """Image pre-processor for normalization and bgr to rgb conversion. | |
| Accepts the data sampled by the dataloader, and preprocesses it into the | |
| format of the model input. ``ImgDataPreprocessor`` provides the | |
| basic data pre-processing as follows | |
| - Collates and moves data to the target device. | |
| - Converts inputs from bgr to rgb if the shape of input is (3, H, W). | |
| - Normalizes image with defined std and mean. | |
| - Pads inputs to the maximum size of current batch with defined | |
| ``pad_value``. The padding size can be divisible by a defined | |
| ``pad_size_divisor`` | |
| - Stack inputs to batch_inputs. | |
| For ``ImgDataPreprocessor``, the dimension of the single inputs must be | |
| (3, H, W). | |
| Note: | |
| ``ImgDataPreprocessor`` and its subclass is built in the | |
| constructor of :class:`BaseDataset`. | |
| Args: | |
| mean (Sequence[float or int], optional): The pixel mean of image | |
| channels. If ``bgr_to_rgb=True`` it means the mean value of R, | |
| G, B channels. If the length of `mean` is 1, it means all | |
| channels have the same mean value, or the input is a gray image. | |
| If it is not specified, images will not be normalized. Defaults | |
| None. | |
| std (Sequence[float or int], optional): The pixel standard deviation of | |
| image channels. If ``bgr_to_rgb=True`` it means the standard | |
| deviation of R, G, B channels. If the length of `std` is 1, | |
| it means all channels have the same standard deviation, or the | |
| input is a gray image. If it is not specified, images will | |
| not be normalized. Defaults None. | |
| pad_size_divisor (int): The size of padded image should be | |
| divisible by ``pad_size_divisor``. Defaults to 1. | |
| pad_value (float or int): The padded pixel value. Defaults to 0. | |
| bgr_to_rgb (bool): whether to convert image from BGR to RGB. | |
| Defaults to False. | |
| rgb_to_bgr (bool): whether to convert image from RGB to RGB. | |
| Defaults to False. | |
| non_blocking (bool): Whether block current process | |
| when transferring data to device. | |
| New in version v0.3.0. | |
| Note: | |
| if images do not need to be normalized, `std` and `mean` should be | |
| both set to None, otherwise both of them should be set to a tuple of | |
| corresponding values. | |
| """ | |
| def __init__(self, | |
| mean: Optional[Sequence[Union[float, int]]] = None, | |
| std: Optional[Sequence[Union[float, int]]] = None, | |
| pad_size_divisor: int = 1, | |
| pad_value: Union[float, int] = 0, | |
| bgr_to_rgb: bool = False, | |
| rgb_to_bgr: bool = False, | |
| non_blocking: Optional[bool] = False): | |
| super().__init__(non_blocking) | |
| assert not (bgr_to_rgb and rgb_to_bgr), ( | |
| '`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time') | |
| assert (mean is None) == (std is None), ( | |
| 'mean and std should be both None or tuple') | |
| if mean is not None: | |
| assert len(mean) == 3 or len(mean) == 1, ( | |
| '`mean` should have 1 or 3 values, to be compatible with ' | |
| f'RGB or gray image, but got {len(mean)} values') | |
| assert len(std) == 3 or len(std) == 1, ( # type: ignore | |
| '`std` should have 1 or 3 values, to be compatible with RGB ' # type: ignore # noqa: E501 | |
| f'or gray image, but got {len(std)} values') # type: ignore | |
| self._enable_normalize = True | |
| self.register_buffer('mean', | |
| torch.tensor(mean).view(-1, 1, 1), False) | |
| self.register_buffer('std', | |
| torch.tensor(std).view(-1, 1, 1), False) | |
| else: | |
| self._enable_normalize = False | |
| self._channel_conversion = rgb_to_bgr or bgr_to_rgb | |
| self.pad_size_divisor = pad_size_divisor | |
| self.pad_value = pad_value | |
| def forward(self, data: dict, training: bool = False) -> Union[dict, list]: | |
| """Performs normalization、padding and bgr2rgb conversion based on | |
| ``BaseDataPreprocessor``. | |
| Args: | |
| data (dict): Data sampled from dataset. If the collate | |
| function of DataLoader is :obj:`pseudo_collate`, data will be a | |
| list of dict. If collate function is :obj:`default_collate`, | |
| data will be a tuple with batch input tensor and list of data | |
| samples. | |
| training (bool): Whether to enable training time augmentation. If | |
| subclasses override this method, they can perform different | |
| preprocessing strategies for training and testing based on the | |
| value of ``training``. | |
| Returns: | |
| dict or list: Data in the same format as the model input. | |
| """ | |
| data = self.cast_data(data) # type: ignore | |
| _batch_inputs = data['inputs'] | |
| # Process data with `pseudo_collate`. | |
| if is_seq_of(_batch_inputs, torch.Tensor): | |
| batch_inputs = [] | |
| for _batch_input in _batch_inputs: | |
| # channel transform | |
| if self._channel_conversion: | |
| _batch_input = _batch_input[[2, 1, 0], ...] | |
| # Convert to float after channel conversion to ensure | |
| # efficiency | |
| _batch_input = _batch_input.float() | |
| # Normalization. | |
| if self._enable_normalize: | |
| if self.mean.shape[0] == 3: | |
| assert _batch_input.dim( | |
| ) == 3 and _batch_input.shape[0] == 3, ( | |
| 'If the mean has 3 values, the input tensor ' | |
| 'should in shape of (3, H, W), but got the tensor ' | |
| f'with shape {_batch_input.shape}') | |
| _batch_input = (_batch_input - self.mean) / self.std | |
| batch_inputs.append(_batch_input) | |
| # Pad and stack Tensor. | |
| batch_inputs = stack_batch(batch_inputs, self.pad_size_divisor, | |
| self.pad_value) | |
| # Process data with `default_collate`. | |
| elif isinstance(_batch_inputs, torch.Tensor): | |
| assert _batch_inputs.dim() == 4, ( | |
| 'The input of `ImgDataPreprocessor` should be a NCHW tensor ' | |
| 'or a list of tensor, but got a tensor with shape: ' | |
| f'{_batch_inputs.shape}') | |
| if self._channel_conversion: | |
| _batch_inputs = _batch_inputs[:, [2, 1, 0], ...] | |
| # Convert to float after channel conversion to ensure | |
| # efficiency | |
| _batch_inputs = _batch_inputs.float() | |
| if self._enable_normalize: | |
| _batch_inputs = (_batch_inputs - self.mean) / self.std | |
| h, w = _batch_inputs.shape[2:] | |
| target_h = math.ceil( | |
| h / self.pad_size_divisor) * self.pad_size_divisor | |
| target_w = math.ceil( | |
| w / self.pad_size_divisor) * self.pad_size_divisor | |
| pad_h = target_h - h | |
| pad_w = target_w - w | |
| batch_inputs = F.pad(_batch_inputs, (0, pad_w, 0, pad_h), | |
| 'constant', self.pad_value) | |
| else: | |
| raise TypeError('Output of `cast_data` should be a dict of ' | |
| 'list/tuple with inputs and data_samples, ' | |
| f'but got {type(data)}: {data}') | |
| data['inputs'] = batch_inputs | |
| data.setdefault('data_samples', None) | |
| return data | |