| | import torch |
| | import os |
| | import torchvision.transforms as transforms |
| |
|
| |
|
| | class Augment_RGB_torch: |
| | |
| | def __init__(self, rotate=0): |
| | self.rotate = rotate |
| | pass |
| | def transform0(self, torch_tensor): |
| | return torch_tensor |
| |
|
| | def transform1(self, torch_tensor): |
| | H, W = torch_tensor.shape[1], torch_tensor.shape[2] |
| | train_transform = transforms.Compose([ |
| | transforms.RandomRotation((self.rotate,self.rotate), interpolation=transforms.InterpolationMode.BILINEAR, expand=False), |
| | transforms.Resize((int(H * 1.3), int(W * 1.3)), antialias=True), |
| | |
| | transforms.CenterCrop([H, W]) |
| | ]) |
| | return train_transform(torch_tensor) |
| |
|
| | def transform2(self, torch_tensor): |
| | torch_tensor = torch.rot90(torch_tensor, k=1, dims=[-1,-2]) |
| | return torch_tensor |
| | def transform3(self, torch_tensor): |
| | torch_tensor = torch.rot90(torch_tensor, k=2, dims=[-1,-2]) |
| | return torch_tensor |
| | def transform4(self, torch_tensor): |
| | torch_tensor = torch.rot90(torch_tensor, k=3, dims=[-1,-2]) |
| | return torch_tensor |
| | def transform5(self, torch_tensor): |
| | torch_tensor = torch_tensor.flip(-2) |
| | return torch_tensor |
| | def transform6(self, torch_tensor): |
| | torch_tensor = (torch.rot90(torch_tensor, k=1, dims=[-1,-2])).flip(-2) |
| | return torch_tensor |
| | def transform7(self, torch_tensor): |
| | torch_tensor = (torch.rot90(torch_tensor, k=2, dims=[-1,-2])).flip(-2) |
| | return torch_tensor |
| | def transform8(self, torch_tensor): |
| | torch_tensor = (torch.rot90(torch_tensor, k=3, dims=[-1,-2])).flip(-2) |
| | return torch_tensor |
| |
|
| |
|
| |
|