Spaces:
Sleeping
Sleeping
| import random | |
| import numpy as np | |
| from torchvision import transforms | |
| from PIL import Image | |
| class CutPaste(object): | |
| def __init__(self, transform = True, type = 'binary'): | |
| ''' | |
| This class creates to different augmentation CutPaste and CutPaste-Scar. Moreover, it returns augmented images | |
| for binary and 3 way classification | |
| :arg | |
| :transform[binary]: - if True use Color Jitter augmentations for patches | |
| :type[str]: options ['binary' or '3way'] - classification type | |
| ''' | |
| self.type = type | |
| if transform: | |
| self.transform = transforms.ColorJitter(brightness = 0.1, | |
| contrast = 0.1, | |
| saturation = 0.1, | |
| hue = 0.1) | |
| else: | |
| self.transform = None | |
| def crop_and_paste_patch(image, patch_w, patch_h, transform, rotation=False): | |
| """ | |
| Crop patch from original image and paste it randomly on the same image. | |
| :image: [PIL] _ original image | |
| :patch_w: [int] _ width of the patch | |
| :patch_h: [int] _ height of the patch | |
| :transform: [binary] _ if True use Color Jitter augmentation | |
| :rotation: [binary[ _ if True randomly rotates image from (-45, 45) range | |
| :return: augmented image | |
| """ | |
| org_w, org_h = image.size | |
| mask = None | |
| patch_left, patch_top = random.randint(0, org_w - patch_w), random.randint(0, org_h - patch_h) | |
| patch_right, patch_bottom = patch_left + patch_w, patch_top + patch_h | |
| patch = image.crop((patch_left, patch_top, patch_right, patch_bottom)) | |
| if transform: | |
| patch= transform(patch) | |
| if rotation: | |
| random_rotate = random.uniform(*rotation) | |
| patch = patch.convert("RGBA").rotate(random_rotate, expand=True) | |
| mask = patch.split()[-1] | |
| # new location | |
| paste_left, paste_top = random.randint(0, org_w - patch_w), random.randint(0, org_h - patch_h) | |
| aug_image = image.copy() | |
| aug_image.paste(patch, (paste_left, paste_top), mask=mask) | |
| # Create a mask of the pasted area | |
| paste_right, paste_bottom = paste_left + patch_w, paste_top + patch_h | |
| paste_mask = Image.new('L', image.size, 0) | |
| paste_mask.paste(255, (paste_left, paste_top, paste_right, paste_bottom)) | |
| return aug_image,paste_mask | |
| def cutpaste(self, image, area_ratio = (0.02, 0.15), aspect_ratio = ((0.3, 1) , (1, 3.3))): | |
| ''' | |
| CutPaste augmentation | |
| :image: [PIL] - original image | |
| :area_ratio: [tuple] - range for area ratio for patch | |
| :aspect_ratio: [tuple] - range for aspect ratio | |
| :return: PIL image after CutPaste transformation | |
| ''' | |
| img_area = image.size[0] * image.size[1] | |
| patch_area = random.uniform(*area_ratio) * img_area | |
| patch_aspect = random.choice([random.uniform(*aspect_ratio[0]), random.uniform(*aspect_ratio[1])]) | |
| patch_w = int(np.sqrt(patch_area*patch_aspect)) | |
| patch_h = int(np.sqrt(patch_area/patch_aspect)) | |
| cutpaste,paste_mask = self.crop_and_paste_patch(image, patch_w, patch_h, self.transform, rotation = False) | |
| return cutpaste,paste_mask | |
| def cutpaste_scar(self, image, width = [2,16], length = [10,25], rotation = (-45, 45)): | |
| ''' | |
| :image: [PIL] - original image | |
| :width: [list] - range for width of patch | |
| :length: [list] - range for length of patch | |
| :rotation: [tuple] - range for rotation | |
| :return: PIL image after CutPaste-Scare transformation | |
| ''' | |
| patch_w, patch_h = random.randint(*width), random.randint(*length) | |
| cutpaste_scar,paste_mask = self.crop_and_paste_patch(image, patch_w, patch_h, self.transform, rotation = rotation) | |
| return cutpaste_scar,paste_mask | |
| def __call__(self, image): | |
| ''' | |
| :image: [PIL] - original image | |
| :return: if type == 'binary' returns original image and randomly chosen transformation, else it returns | |
| original image, an image after CutPaste transformation and an image after CutPaste-Scar transformation | |
| ''' | |
| if self.type == 'binary': | |
| aug = random.choice([self.cutpaste, self.cutpaste_scar]) | |
| return image, aug(image) | |
| elif self.type == '3way': | |
| cutpaste = self.cutpaste(image) | |
| scar = self.cutpaste_scar(image) | |
| return image, cutpaste, scar |