Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import itertools | |
| from collections.abc import Sized | |
| from typing import Any, List, Union | |
| import numpy as np | |
| import torch | |
| from mmengine.device import get_device | |
| from .base_data_element import BaseDataElement | |
| BoolTypeTensor: Union[Any] | |
| LongTypeTensor: Union[Any] | |
| if get_device() == 'npu': | |
| BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor] | |
| LongTypeTensor = Union[torch.LongTensor, torch.npu.LongTensor] | |
| elif get_device() == 'mlu': | |
| BoolTypeTensor = Union[torch.BoolTensor, torch.mlu.BoolTensor] | |
| LongTypeTensor = Union[torch.LongTensor, torch.mlu.LongTensor] | |
| else: | |
| BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] | |
| LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] | |
| IndexType: Union[Any] = Union[str, slice, int, list, LongTypeTensor, | |
| BoolTypeTensor, np.ndarray] | |
| # Modified from | |
| # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa | |
| class InstanceData(BaseDataElement): | |
| """Data structure for instance-level annotations or predictions. | |
| Subclass of :class:`BaseDataElement`. All value in `data_fields` | |
| should have the same length. This design refer to | |
| https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 | |
| InstanceData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value | |
| in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`, | |
| and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes. | |
| Examples: | |
| >>> # custom data structure | |
| >>> class TmpObject: | |
| ... def __init__(self, tmp) -> None: | |
| ... assert isinstance(tmp, list) | |
| ... self.tmp = tmp | |
| ... def __len__(self): | |
| ... return len(self.tmp) | |
| ... def __getitem__(self, item): | |
| ... if isinstance(item, int): | |
| ... if item >= len(self) or item < -len(self): # type:ignore | |
| ... raise IndexError(f'Index {item} out of range!') | |
| ... else: | |
| ... # keep the dimension | |
| ... item = slice(item, None, len(self)) | |
| ... return TmpObject(self.tmp[item]) | |
| ... @staticmethod | |
| ... def cat(tmp_objs): | |
| ... assert all(isinstance(results, TmpObject) for results in tmp_objs) | |
| ... if len(tmp_objs) == 1: | |
| ... return tmp_objs[0] | |
| ... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] | |
| ... tmp_list = list(itertools.chain(*tmp_list)) | |
| ... new_data = TmpObject(tmp_list) | |
| ... return new_data | |
| ... def __repr__(self): | |
| ... return str(self.tmp) | |
| >>> from mmengine.structures import InstanceData | |
| >>> import numpy as np | |
| >>> import torch | |
| >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) | |
| >>> instance_data = InstanceData(metainfo=img_meta) | |
| >>> 'img_shape' in instance_data | |
| True | |
| >>> instance_data.det_labels = torch.LongTensor([2, 3]) | |
| >>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) | |
| >>> instance_data.bboxes = torch.rand((2, 4)) | |
| >>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]]) | |
| >>> len(instance_data) | |
| 2 | |
| >>> print(instance_data) | |
| <InstanceData( | |
| META INFORMATION | |
| img_shape: (800, 1196, 3) | |
| pad_shape: (800, 1216, 3) | |
| DATA FIELDS | |
| det_labels: tensor([2, 3]) | |
| det_scores: tensor([0.8000, 0.7000]) | |
| bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], | |
| [0.8101, 0.3105, 0.5123, 0.6263]]) | |
| polygons: [[1, 2, 3, 4], [5, 6, 7, 8]] | |
| ) at 0x7fb492de6280> | |
| >>> sorted_results = instance_data[instance_data.det_scores.sort().indices] | |
| >>> sorted_results.det_scores | |
| tensor([0.7000, 0.8000]) | |
| >>> print(instance_data[instance_data.det_scores > 0.75]) | |
| <InstanceData( | |
| META INFORMATION | |
| img_shape: (800, 1196, 3) | |
| pad_shape: (800, 1216, 3) | |
| DATA FIELDS | |
| det_labels: tensor([2]) | |
| det_scores: tensor([0.8000]) | |
| bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]]) | |
| polygons: [[1, 2, 3, 4]] | |
| ) at 0x7f64ecf0ec40> | |
| >>> print(instance_data[instance_data.det_scores > 1]) | |
| <InstanceData( | |
| META INFORMATION | |
| img_shape: (800, 1196, 3) | |
| pad_shape: (800, 1216, 3) | |
| DATA FIELDS | |
| det_labels: tensor([], dtype=torch.int64) | |
| det_scores: tensor([]) | |
| bboxes: tensor([], size=(0, 4)) | |
| polygons: [] | |
| ) at 0x7f660a6a7f70> | |
| >>> print(instance_data.cat([instance_data, instance_data])) | |
| <InstanceData( | |
| META INFORMATION | |
| img_shape: (800, 1196, 3) | |
| pad_shape: (800, 1216, 3) | |
| DATA FIELDS | |
| det_labels: tensor([2, 3, 2, 3]) | |
| det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000]) | |
| bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], | |
| [0.8101, 0.3105, 0.5123, 0.6263], | |
| [0.4997, 0.7707, 0.0595, 0.4188], | |
| [0.8101, 0.3105, 0.5123, 0.6263]]) | |
| polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]] | |
| ) at 0x7f203542feb0> | |
| """ | |
| def __setattr__(self, name: str, value: Sized): | |
| """setattr is only used to set data. | |
| The value must have the attribute of `__len__` and have the same length | |
| of `InstanceData`. | |
| """ | |
| if name in ('_metainfo_fields', '_data_fields'): | |
| if not hasattr(self, name): | |
| super().__setattr__(name, value) | |
| else: | |
| raise AttributeError(f'{name} has been used as a ' | |
| 'private attribute, which is immutable.') | |
| else: | |
| assert isinstance(value, | |
| Sized), 'value must contain `__len__` attribute' | |
| if len(self) > 0: | |
| assert len(value) == len(self), 'The length of ' \ | |
| f'values {len(value)} is ' \ | |
| 'not consistent with ' \ | |
| 'the length of this ' \ | |
| ':obj:`InstanceData` ' \ | |
| f'{len(self)}' | |
| super().__setattr__(name, value) | |
| __setitem__ = __setattr__ | |
| def __getitem__(self, item: IndexType) -> 'InstanceData': | |
| """ | |
| Args: | |
| item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`, | |
| :obj:`torch.LongTensor`, :obj:`torch.BoolTensor`): | |
| Get the corresponding values according to item. | |
| Returns: | |
| :obj:`InstanceData`: Corresponding values. | |
| """ | |
| assert isinstance(item, IndexType.__args__) | |
| if isinstance(item, list): | |
| item = np.array(item) | |
| if isinstance(item, np.ndarray): | |
| # The default int type of numpy is platform dependent, int32 for | |
| # windows and int64 for linux. `torch.Tensor` requires the index | |
| # should be int64, therefore we simply convert it to int64 here. | |
| # More details in https://github.com/numpy/numpy/issues/9464 | |
| item = item.astype(np.int64) if item.dtype == np.int32 else item | |
| item = torch.from_numpy(item) | |
| if isinstance(item, str): | |
| return getattr(self, item) | |
| if isinstance(item, int): | |
| if item >= len(self) or item < -len(self): # type:ignore | |
| raise IndexError(f'Index {item} out of range!') | |
| else: | |
| # keep the dimension | |
| item = slice(item, None, len(self)) | |
| new_data = self.__class__(metainfo=self.metainfo) | |
| if isinstance(item, torch.Tensor): | |
| assert item.dim() == 1, 'Only support to get the' \ | |
| ' values along the first dimension.' | |
| if isinstance(item, BoolTypeTensor.__args__): | |
| assert len(item) == len(self), 'The shape of the ' \ | |
| 'input(BoolTensor) ' \ | |
| f'{len(item)} ' \ | |
| 'does not match the shape ' \ | |
| 'of the indexed tensor ' \ | |
| 'in results_field ' \ | |
| f'{len(self)} at ' \ | |
| 'first dimension.' | |
| for k, v in self.items(): | |
| if isinstance(v, torch.Tensor): | |
| new_data[k] = v[item] | |
| elif isinstance(v, np.ndarray): | |
| new_data[k] = v[item.cpu().numpy()] | |
| elif isinstance( | |
| v, (str, list, tuple)) or (hasattr(v, '__getitem__') | |
| and hasattr(v, 'cat')): | |
| # convert to indexes from BoolTensor | |
| if isinstance(item, BoolTypeTensor.__args__): | |
| indexes = torch.nonzero(item).view( | |
| -1).cpu().numpy().tolist() | |
| else: | |
| indexes = item.cpu().numpy().tolist() | |
| slice_list = [] | |
| if indexes: | |
| for index in indexes: | |
| slice_list.append(slice(index, None, len(v))) | |
| else: | |
| slice_list.append(slice(None, 0, None)) | |
| r_list = [v[s] for s in slice_list] | |
| if isinstance(v, (str, list, tuple)): | |
| new_value = r_list[0] | |
| for r in r_list[1:]: | |
| new_value = new_value + r | |
| else: | |
| new_value = v.cat(r_list) | |
| new_data[k] = new_value | |
| else: | |
| raise ValueError( | |
| f'The type of `{k}` is `{type(v)}`, which has no ' | |
| 'attribute of `cat`, so it does not ' | |
| 'support slice with `bool`') | |
| else: | |
| # item is a slice | |
| for k, v in self.items(): | |
| new_data[k] = v[item] | |
| return new_data # type:ignore | |
| def cat(instances_list: List['InstanceData']) -> 'InstanceData': | |
| """Concat the instances of all :obj:`InstanceData` in the list. | |
| Note: To ensure that cat returns as expected, make sure that | |
| all elements in the list must have exactly the same keys. | |
| Args: | |
| instances_list (list[:obj:`InstanceData`]): A list | |
| of :obj:`InstanceData`. | |
| Returns: | |
| :obj:`InstanceData` | |
| """ | |
| assert all( | |
| isinstance(results, InstanceData) for results in instances_list) | |
| assert len(instances_list) > 0 | |
| if len(instances_list) == 1: | |
| return instances_list[0] | |
| # metainfo and data_fields must be exactly the | |
| # same for each element to avoid exceptions. | |
| field_keys_list = [ | |
| instances.all_keys() for instances in instances_list | |
| ] | |
| assert len({len(field_keys) for field_keys in field_keys_list}) \ | |
| == 1 and len(set(itertools.chain(*field_keys_list))) \ | |
| == len(field_keys_list[0]), 'There are different keys in ' \ | |
| '`instances_list`, which may ' \ | |
| 'cause the cat operation ' \ | |
| 'to fail. Please make sure all ' \ | |
| 'elements in `instances_list` ' \ | |
| 'have the exact same key.' | |
| new_data = instances_list[0].__class__( | |
| metainfo=instances_list[0].metainfo) | |
| for k in instances_list[0].keys(): | |
| values = [results[k] for results in instances_list] | |
| v0 = values[0] | |
| if isinstance(v0, torch.Tensor): | |
| new_values = torch.cat(values, dim=0) | |
| elif isinstance(v0, np.ndarray): | |
| new_values = np.concatenate(values, axis=0) | |
| elif isinstance(v0, (str, list, tuple)): | |
| new_values = v0[:] | |
| for v in values[1:]: | |
| new_values += v | |
| elif hasattr(v0, 'cat'): | |
| new_values = v0.cat(values) | |
| else: | |
| raise ValueError( | |
| f'The type of `{k}` is `{type(v0)}` which has no ' | |
| 'attribute of `cat`') | |
| new_data[k] = new_values | |
| return new_data # type:ignore | |
| def __len__(self) -> int: | |
| """int: The length of InstanceData.""" | |
| if len(self._data_fields) > 0: | |
| return len(self.values()[0]) | |
| else: | |
| return 0 | |