Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import numpy as np | |
| import torch | |
| class XYMasking: | |
| def __init__( | |
| self, | |
| num_masks_x: int | tuple[int, int], | |
| num_masks_y: int | tuple[int, int], | |
| mask_x_length: int | tuple[int, int], | |
| mask_y_length: int | tuple[int, int], | |
| fill_value: int, | |
| p: float = 1.0, | |
| ): | |
| self.num_masks_x = num_masks_x | |
| self.num_masks_y = num_masks_y | |
| self.mask_x_length = mask_x_length | |
| self.mask_y_length = mask_y_length | |
| self.fill_value = fill_value | |
| self.p = p | |
| def __call__(self, img: torch.tensor) -> torch.tensor: | |
| if np.random.rand() < self.p: | |
| return img | |
| _, width, height = img.shape | |
| num_masks_x = ( | |
| np.random.randint(*self.num_masks_x) | |
| if isinstance(self.num_masks_x, tuple) | |
| else self.num_masks_x | |
| ) | |
| for _ in range(num_masks_x): | |
| mask_x_length = ( | |
| np.random.randint(*self.mask_x_length) | |
| if isinstance(self.mask_x_length, tuple) | |
| else self.mask_x_length | |
| ) | |
| x = np.random.randint(0, width - mask_x_length) | |
| img[:, :, x : x + mask_x_length] = self.fill_value | |
| num_masks_y = ( | |
| np.random.randint(*self.num_masks_y) | |
| if isinstance(self.num_masks_y, tuple) | |
| else self.num_masks_y | |
| ) | |
| for _ in range(num_masks_y): | |
| mask_y_length = ( | |
| np.random.randint(*self.mask_y_length) | |
| if isinstance(self.mask_y_length, tuple) | |
| else self.mask_y_length | |
| ) | |
| y = np.random.randint(0, height - mask_y_length) | |
| img[:, y : y + mask_y_length, :] = self.fill_value | |
| return img | |