| | import torch |
| |
|
| |
|
| | class TestTimeAugmentation: |
| | """Test-Time Augmentation for image restoration models""" |
| | |
| | def __init__(self, model, dino_net, device, use_flip=True, use_rot=True, use_multi_scale=False, scales=None): |
| | """ |
| | Args: |
| | model: The model to apply TTA to |
| | dino_net: DINO feature extractor |
| | device: Device to run inference on |
| | use_flip: Whether to use horizontal and vertical flips |
| | use_rot: Whether to use 90-degree rotations |
| | use_multi_scale: Whether to use multi-scale testing |
| | scales: List of scales to use for multi-scale testing, e.g. [0.8, 1.0, 1.2] |
| | """ |
| | self.model = model |
| | self.dino_net = dino_net |
| | self.device = device |
| | self.use_flip = use_flip |
| | self.use_rot = use_rot |
| | self.use_multi_scale = use_multi_scale |
| | self.scales = scales or [1.0] |
| | |
| | def _apply_augmentation(self, image, point, normal, aug_type): |
| | """Apply single augmentation to input images |
| | |
| | Args: |
| | image: Input RGB image |
| | point: Point map |
| | normal: Normal map |
| | aug_type: Augmentation type string (e.g., 'original', 'h_flip', etc.) |
| | |
| | Returns: |
| | Augmented versions of image, point map and normal map |
| | """ |
| | if aug_type == 'original': |
| | return image, point, normal |
| | |
| | elif aug_type == 'h_flip': |
| | |
| | img_aug = torch.flip(image, dims=[3]) |
| | point_aug = torch.flip(point, dims=[3]) |
| | normal_aug = torch.flip(normal, dims=[3]) |
| | |
| | normal_aug[:, 0, :, :] = -normal_aug[:, 0, :, :] |
| | return img_aug, point_aug, normal_aug |
| | |
| | elif aug_type == 'v_flip': |
| | |
| | img_aug = torch.flip(image, dims=[2]) |
| | point_aug = torch.flip(point, dims=[2]) |
| | normal_aug = torch.flip(normal, dims=[2]) |
| | |
| | normal_aug[:, 1, :, :] = -normal_aug[:, 1, :, :] |
| | return img_aug, point_aug, normal_aug |
| | |
| | elif aug_type == 'rot90': |
| | |
| | img_aug = torch.rot90(image, k=1, dims=[2, 3]) |
| | point_aug = torch.rot90(point, k=1, dims=[2, 3]) |
| | normal_aug = torch.rot90(normal, k=1, dims=[2, 3]) |
| | |
| | normal_x = -normal_aug[:, 1, :, :].clone() |
| | normal_y = normal_aug[:, 0, :, :].clone() |
| | normal_aug[:, 0, :, :] = normal_x |
| | normal_aug[:, 1, :, :] = normal_y |
| | return img_aug, point_aug, normal_aug |
| | |
| | elif aug_type == 'rot180': |
| | |
| | img_aug = torch.rot90(image, k=2, dims=[2, 3]) |
| | point_aug = torch.rot90(point, k=2, dims=[2, 3]) |
| | normal_aug = torch.rot90(normal, k=2, dims=[2, 3]) |
| | |
| | normal_aug[:, 0, :, :] = -normal_aug[:, 0, :, :] |
| | normal_aug[:, 1, :, :] = -normal_aug[:, 1, :, :] |
| | return img_aug, point_aug, normal_aug |
| | |
| | elif aug_type == 'rot270': |
| | |
| | img_aug = torch.rot90(image, k=3, dims=[2, 3]) |
| | point_aug = torch.rot90(point, k=3, dims=[2, 3]) |
| | normal_aug = torch.rot90(normal, k=3, dims=[2, 3]) |
| | |
| | normal_x = normal_aug[:, 1, :, :].clone() |
| | normal_y = -normal_aug[:, 0, :, :].clone() |
| | normal_aug[:, 0, :, :] = normal_x |
| | normal_aug[:, 1, :, :] = normal_y |
| | return img_aug, point_aug, normal_aug |
| | |
| | else: |
| | raise ValueError(f"Unknown augmentation type: {aug_type}") |
| | |
| | def _reverse_augmentation(self, result, aug_type): |
| | """Reverse the augmentation on the result |
| | |
| | Args: |
| | result: Model output to reverse augmentation on |
| | aug_type: Augmentation type string |
| | |
| | Returns: |
| | De-augmented result |
| | """ |
| | if aug_type == 'original': |
| | return result |
| | |
| | elif aug_type == 'h_flip': |
| | return torch.flip(result, dims=[3]) |
| | |
| | elif aug_type == 'v_flip': |
| | return torch.flip(result, dims=[2]) |
| | |
| | elif aug_type == 'rot90': |
| | return torch.rot90(result, k=3, dims=[2, 3]) |
| | |
| | elif aug_type == 'rot180': |
| | return torch.rot90(result, k=2, dims=[2, 3]) |
| | |
| | elif aug_type == 'rot270': |
| | return torch.rot90(result, k=1, dims=[2, 3]) |
| | |
| | else: |
| | raise ValueError(f"Unknown augmentation type: {aug_type}") |
| | |
| | def __call__(self, sliding_window, input_img, point, normal): |
| | """ |
| | Apply TTA to the model and return ensemble result |
| | |
| | Args: |
| | sliding_window: SlidingWindowInference class instance |
| | input_img: Input RGB image [B, C, H, W] |
| | point: Point map [B, C, H, W] |
| | normal: Normal map [B, C, H, W] |
| | |
| | Returns: |
| | Ensemble result with TTA [B, C, H, W] |
| | """ |
| | |
| | augmentations = ['original'] |
| | if self.use_flip: |
| | augmentations.extend(['h_flip', 'v_flip']) |
| | if self.use_rot: |
| | augmentations.extend(['rot90', 'rot180', 'rot270']) |
| | |
| | |
| | ensemble_result = torch.zeros_like(input_img) |
| | ensemble_weight = 0.0 |
| | |
| | |
| | for scale in self.scales: |
| | scale_weight = 1.0 |
| | if scale != 1.0: |
| | |
| | h, w = input_img.shape[2], input_img.shape[3] |
| | new_h, new_w = int(h * scale), int(w * scale) |
| | |
| | |
| | resize_fn = torch.nn.functional.interpolate |
| | input_img_scaled = resize_fn(input_img, size=(new_h, new_w), mode='bilinear', align_corners=False) |
| | point_scaled = resize_fn(point, size=(new_h, new_w), mode='bilinear', align_corners=False) |
| | normal_scaled = resize_fn(normal, size=(new_h, new_w), mode='bilinear', align_corners=False) |
| | |
| | |
| | normal_norm = torch.sqrt(torch.sum(normal_scaled**2, dim=1, keepdim=True) + 1e-6) |
| | normal_scaled = normal_scaled / normal_norm |
| | else: |
| | input_img_scaled = input_img |
| | point_scaled = point |
| | normal_scaled = normal |
| | |
| | |
| | for aug_type in augmentations: |
| | |
| | img_aug, point_aug, normal_aug = self._apply_augmentation( |
| | input_img_scaled, point_scaled, normal_scaled, aug_type |
| | ) |
| | |
| | |
| | with torch.cuda.amp.autocast(): |
| | result_aug = sliding_window( |
| | model=self.model, |
| | input_=img_aug, |
| | point=point_aug, |
| | normal=normal_aug, |
| | dino_net=self.dino_net, |
| | device=self.device |
| | ) |
| | |
| | |
| | result_aug = self._reverse_augmentation(result_aug, aug_type) |
| | |
| | |
| | if scale != 1.0: |
| | result_aug = resize_fn(result_aug, size=(h, w), mode='bilinear', align_corners=False) |
| | |
| | |
| | ensemble_result += result_aug * scale_weight |
| | ensemble_weight += scale_weight |
| | |
| | |
| | ensemble_result = ensemble_result / ensemble_weight |
| | |
| | return ensemble_result |