Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from ..version_utils import digit_version | |
| from .parrots_wrapper import TORCH_VERSION | |
| _torch_version_meshgrid_indexing = ( | |
| 'parrots' not in TORCH_VERSION | |
| and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0')) | |
| def torch_meshgrid(*tensors): | |
| """A wrapper of torch.meshgrid to compat different PyTorch versions. | |
| Since PyTorch 1.10.0a0, torch.meshgrid supports the arguments ``indexing``. | |
| So we implement a wrapper here to avoid warning when using high-version | |
| PyTorch and avoid compatibility issues when using previous versions of | |
| PyTorch. | |
| Args: | |
| tensors (List[Tensor]): List of scalars or 1 dimensional tensors. | |
| Returns: | |
| Sequence[Tensor]: Sequence of meshgrid tensors. | |
| """ | |
| if _torch_version_meshgrid_indexing: | |
| return torch.meshgrid(*tensors, indexing='ij') | |
| else: | |
| return torch.meshgrid(*tensors) # Uses indexing='ij' by default | |