| | import numpy as np |
| | import cv2 |
| | import torch |
| |
|
| | class Compose(object): |
| | """Composes several transforms together. |
| | |
| | Args: |
| | transforms (list of ``Transform`` objects): list of transforms to compose. |
| | |
| | Example: |
| | >>> transforms.Compose([ |
| | >>> transforms.CenterCrop(10), |
| | >>> transforms.ToTensor(), |
| | >>> ]) |
| | """ |
| |
|
| | def __init__(self, transforms): |
| | self.transforms = transforms |
| |
|
| | def __call__(self, data): |
| | for t in self.transforms: |
| | data = t(data) |
| | return data |
| |
|
| | def __repr__(self): |
| | format_string = self.__class__.__name__ + '(' |
| | for t in self.transforms: |
| | format_string += '\n' |
| | format_string += ' {0}'.format(t) |
| | format_string += '\n)' |
| | return format_string |
| |
|
| |
|
| | class ConvertUcharToFloat(object): |
| | """ |
| | Convert img form uchar to float32 |
| | """ |
| |
|
| | def __call__(self, data): |
| | data = [x.astype(np.float32) for x in data] |
| | return data |
| |
|
| |
|
| | class RandomContrast(object): |
| | """ |
| | Get random contrast img |
| | """ |
| | def __init__(self, phase, lower=0.8, upper=1.2, prob=0.5): |
| | self.phase = phase |
| | self.lower = lower |
| | self.upper = upper |
| | self.prob = prob |
| | assert self.upper >= self.lower, "contrast upper must be >= lower!" |
| | assert self.lower > 0, "contrast lower must be non-negative!" |
| |
|
| | def __call__(self, data): |
| | if self.phase in ['od', 'seg']: |
| | img, _ = data |
| | if torch.rand(1) < self.prob: |
| | alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) |
| | img *= alpha.numpy() |
| | return_data = img, _ |
| | elif self.phase == 'cd': |
| | img1, label1, img2, label2 = data |
| | if torch.rand(1) < self.prob: |
| | alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) |
| | img1 *= alpha.numpy() |
| | if torch.rand(1) < self.prob: |
| | alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) |
| | img2 *= alpha.numpy() |
| | return_data = img1, label1, img2, label2 |
| | return return_data |
| |
|
| |
|
| | class RandomBrightness(object): |
| | """ |
| | Get random brightness img |
| | """ |
| | def __init__(self, phase, delta=10, prob=0.5): |
| | self.phase = phase |
| | self.delta = delta |
| | self.prob = prob |
| | assert 0. <= self.delta < 255., "brightness delta must between 0 to 255" |
| |
|
| | def __call__(self, data): |
| | if self.phase in ['od', 'seg']: |
| | img, _ = data |
| | if torch.rand(1) < self.prob: |
| | delta = torch.FloatTensor(1).uniform_(- self.delta, self.delta) |
| | img += delta.numpy() |
| | return_data = img, _ |
| |
|
| | elif self.phase == 'cd': |
| | img1, label1, img2, label2 = data |
| | if torch.rand(1) < self.prob: |
| | delta = torch.FloatTensor(1).uniform_(- self.delta, self.delta) |
| | img1 += delta.numpy() |
| | if torch.rand(1) < self.prob: |
| | delta = torch.FloatTensor(1).uniform_(- self.delta, self.delta) |
| | img2 += delta.numpy() |
| | return_data = img1, label1, img2, label2 |
| |
|
| | return return_data |
| |
|
| |
|
| | class ConvertColor(object): |
| | """ |
| | Convert img color BGR to HSV or HSV to BGR for later img distortion. |
| | """ |
| | def __init__(self, phase, current='RGB', target='HSV'): |
| | self.phase = phase |
| | self.current = current |
| | self.target = target |
| |
|
| | def __call__(self, data): |
| |
|
| | if self.phase in ['od', 'seg']: |
| | img, _ = data |
| | if self.current == 'RGB' and self.target == 'HSV': |
| | img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) |
| | elif self.current == 'HSV' and self.target == 'RGB': |
| | img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) |
| | else: |
| | raise NotImplementedError("Convert color fail!") |
| | return_data = img, _ |
| |
|
| | elif self.phase == 'cd': |
| | img1, label1, img2, label2 = data |
| | if self.current == 'RGB' and self.target == 'HSV': |
| | img1 = cv2.cvtColor(img1, cv2.COLOR_RGB2HSV) |
| | img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2HSV) |
| | elif self.current == 'HSV' and self.target == 'RGB': |
| | img1 = cv2.cvtColor(img1, cv2.COLOR_HSV2RGB) |
| | img2 = cv2.cvtColor(img2, cv2.COLOR_HSV2RGB) |
| | else: |
| | raise NotImplementedError("Convert color fail!") |
| | return_data = img1, label1, img2, label2 |
| |
|
| | return return_data |
| |
|
| |
|
| | class RandomSaturation(object): |
| | """ |
| | get random saturation img |
| | apply the restriction on saturation S |
| | """ |
| | def __init__(self, phase, lower=0.8, upper=1.2, prob=0.5): |
| | self.phase = phase |
| | self.lower = lower |
| | self.upper = upper |
| | self.prob = prob |
| | assert self.upper >= self.lower, "saturation upper must be >= lower!" |
| | assert self.lower > 0, "saturation lower must be non-negative!" |
| |
|
| | def __call__(self, data): |
| | if self.phase in ['od', 'seg']: |
| | img, _ = data |
| | if torch.rand(1) < self.prob: |
| | alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) |
| | img[:, :, 1] *= alpha.numpy() |
| | return_data = img, _ |
| | elif self.phase == 'cd': |
| | img1, label1, img2, label2 = data |
| | if torch.rand(1) < self.prob: |
| | alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) |
| | img1[:, :, 1] *= alpha.numpy() |
| | if torch.rand(1) < self.prob: |
| | alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) |
| | img2[:, :, 1] *= alpha.numpy() |
| | return_data = img1, label1, img2, label2 |
| | return return_data |
| |
|
| |
|
| | class RandomHue(object): |
| | """ |
| | get random Hue img |
| | apply the restriction on Hue H |
| | """ |
| | def __init__(self, phase, delta=10., prob=0.5): |
| | self.phase = phase |
| | self.delta = delta |
| | self.prob = prob |
| | assert 0 <= self.delta < 360, "Hue delta must between 0 to 360!" |
| |
|
| | def __call__(self, data): |
| | if self.phase in ['od', 'seg']: |
| | img, _ = data |
| | if torch.rand(1) < self.prob: |
| | alpha = torch.FloatTensor(1).uniform_(-self.delta, self.delta) |
| | img[:, :, 0] += alpha.numpy() |
| | img[:, :, 0][img[:, :, 0] > 360.0] -= 360.0 |
| | img[:, :, 0][img[:, :, 0] < 0.0] += 360.0 |
| | return_data = img, _ |
| |
|
| | elif self.phase == 'cd': |
| | img1, label1, img2, label2 = data |
| | if torch.rand(1) < self.prob: |
| | alpha = torch.FloatTensor(1).uniform_(-self.delta, self.delta) |
| | img1[:, :, 0] += alpha.numpy() |
| | img1[:, :, 0][img1[:, :, 0] > 360.0] -= 360.0 |
| | img1[:, :, 0][img1[:, :, 0] < 0.0] += 360.0 |
| | if torch.rand(1) < self.prob: |
| | alpha = torch.FloatTensor(1).uniform_(-self.delta, self.delta) |
| | img2[:, :, 0] += alpha.numpy() |
| | img2[:, :, 0][img2[:, :, 0] > 360.0] -= 360.0 |
| | img2[:, :, 0][img2[:, :, 0] < 0.0] += 360.0 |
| |
|
| | return_data = img1, label1, img2, label2 |
| |
|
| | return return_data |
| |
|
| |
|
| | class RandomChannelNoise(object): |
| | """ |
| | Get random shuffle channels |
| | """ |
| | def __init__(self, phase, prob=0.4): |
| | self.phase = phase |
| | self.prob = prob |
| | self.perms = ((0, 1, 2), (0, 2, 1), |
| | (1, 0, 2), (1, 2, 0), |
| | (2, 0, 1), (2, 1, 0)) |
| |
|
| | def __call__(self, data): |
| | if self.phase in ['od', 'seg']: |
| | img, _ = data |
| | if torch.rand(1) < self.prob: |
| | shuffle_factor = self.perms[torch.randint(0, len(self.perms), size=[])] |
| | img = img[:, :, shuffle_factor] |
| | return_data = img, _ |
| |
|
| | elif self.phase == 'cd': |
| | img1, label1, img2, label2 = data |
| | if torch.rand(1) < self.prob: |
| | shuffle_factor = self.perms[torch.randint(0, len(self.perms), size=[])] |
| | img1 = img1[:, :, shuffle_factor] |
| | if torch.rand(1) < self.prob: |
| | shuffle_factor = self.perms[torch.randint(0, len(self.perms), size=[])] |
| | img2 = img2[:, :, shuffle_factor] |
| | return_data = img1, label1, img2, label2 |
| |
|
| | return return_data |
| |
|
| |
|
| | class ImgDistortion(object): |
| | """ |
| | Change img by distortion |
| | """ |
| | def __init__(self, phase, prob=0.5): |
| | self.phase = phase |
| | self.prob = prob |
| | self.operation = [ |
| | RandomContrast(phase), |
| | ConvertColor(phase, current='RGB', target='HSV'), |
| | RandomSaturation(phase), |
| | RandomHue(phase), |
| | ConvertColor(phase, current='HSV', target='RGB'), |
| | RandomContrast(phase) |
| | ] |
| | self.random_brightness = RandomBrightness(phase) |
| | self.random_light_noise = RandomChannelNoise(phase) |
| |
|
| | def __call__(self, data): |
| | if torch.rand(1) < self.prob: |
| | data = self.random_brightness(data) |
| | if torch.rand(1) < self.prob: |
| | distort = Compose(self.operation[:-1]) |
| | else: |
| | distort = Compose(self.operation[1:]) |
| | data = distort(data) |
| | data = self.random_light_noise(data) |
| | return data |
| |
|
| |
|
| | class ExpandImg(object): |
| | """ |
| | Get expand img |
| | """ |
| | def __init__(self, phase, prior_mean, prob=0.5, expand_ratio=0.2): |
| | self.phase = phase |
| | self.prior_mean = np.array(prior_mean) * 255 |
| | self.prob = prob |
| | self.expand_ratio = expand_ratio |
| |
|
| | def __call__(self, data): |
| | if self.phase == 'seg': |
| | img, label = data |
| | if torch.rand(1) < self.prob: |
| | return data |
| | height, width, channels = img.shape |
| | ratio_width = self.expand_ratio * torch.rand([]) |
| | ratio_height = self.expand_ratio * torch.rand([]) |
| | left, right = torch.randint(high=int(max(1, width * ratio_width)), size=[2]) |
| | top, bottom = torch.randint(high=int(max(1, width * ratio_height)), size=[2]) |
| | img = cv2.copyMakeBorder( |
| | img, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=self.prior_mean) |
| | label = cv2.copyMakeBorder( |
| | label, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0) |
| | return img, label |
| | elif self.phase == 'cd': |
| | img1, label1, img2, label2 = data |
| | if torch.rand(1) < self.prob: |
| | return data |
| | height, width, channels = img1.shape |
| | ratio_width = self.expand_ratio * torch.rand([]) |
| | ratio_height = self.expand_ratio * torch.rand([]) |
| | left, right = torch.randint(high=int(max(1, width * ratio_width)), size=[2]) |
| | top, bottom = torch.randint(high=int(max(1, width * ratio_height)), size=[2]) |
| | img1 = cv2.copyMakeBorder( |
| | img1, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=self.prior_mean) |
| | label1 = cv2.copyMakeBorder( |
| | label1, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0) |
| | img2 = cv2.copyMakeBorder( |
| | img2, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=self.prior_mean) |
| | label2 = cv2.copyMakeBorder( |
| | label2, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0) |
| | return img1, label1, img2, label2 |
| |
|
| | elif self.phase == 'od': |
| | if torch.rand(1) < self.prob: |
| | return data |
| | img, label = data |
| | height, width, channels = img.shape |
| | ratio_width = self.expand_ratio * torch.rand([]) |
| | ratio_height = self.expand_ratio * torch.rand([]) |
| | left, right = torch.randint(high=int(max(1, width * ratio_width)), size=[2]) |
| | top, bottom = torch.randint(high=int(max(1, width * ratio_height)), size=[2]) |
| | left = int(left) |
| | right = int(right) |
| | top = int(top) |
| | bottom = int(bottom) |
| | img = cv2.copyMakeBorder( |
| | img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=self.prior_mean) |
| |
|
| | label[:, 1::2] += left |
| | label[:, 2::2] += top |
| | return img, label |
| |
|
| |
|
| | class RandomSampleCrop(object): |
| | """ |
| | Crop |
| | Arguments: |
| | img (Image): the image being input during training |
| | boxes (Tensor): the original bounding boxes in pt form |
| | label (Tensor): the class label for each bbox |
| | mode (float tuple): the min and max jaccard overlaps |
| | Return: |
| | (img, boxes, classes) |
| | img (Image): the cropped image |
| | boxes (Tensor): the adjusted bounding boxes in pt form |
| | label (Tensor): the class label for each bbox |
| | """ |
| | def __init__(self, |
| | phase, |
| | original_size=[512, 512], |
| | prob=0.5, |
| | crop_scale_ratios_range=[0.8, 1.2], |
| | aspect_ratio_range=[4./5, 5./4]): |
| | self.phase = phase |
| | self.prob = prob |
| | self.scale_range = crop_scale_ratios_range |
| | self.original_size = original_size |
| | self.aspect_ratio_range = aspect_ratio_range |
| | self.max_try_times = 500 |
| |
|
| | def __call__(self, data): |
| | if self.phase == 'seg': |
| | img, label = data |
| | w, h, c = img.shape |
| | if torch.rand(1) < self.prob: |
| | return data |
| | else: |
| | try_times = 0 |
| | while try_times < self.max_try_times: |
| | crop_w = torch.randint( |
| | min(w, int(self.scale_range[0] * self.original_size[0])), |
| | min(w + 1, int(self.scale_range[1] * self.original_size[0])), |
| | size=[] |
| | ) |
| | crop_h = torch.randint( |
| | min(h, int(self.scale_range[0] * self.original_size[1])), |
| | min(h + 1, int(self.scale_range[1] * self.original_size[1])), |
| | size=[] |
| | ) |
| | |
| | if self.aspect_ratio_range[0] < crop_h / crop_w < self.aspect_ratio_range[1]: |
| | break |
| | else: |
| | try_times += 1 |
| | if try_times >= self.max_try_times: |
| | print("try times over max threshold!", flush=True) |
| | return img, label |
| |
|
| | left = torch.randint(0, w - crop_w + 1, size=[]) |
| | top = torch.randint(0, h - crop_h + 1, size=[]) |
| | img = img[top:(top + crop_h), left:(left + crop_w), :] |
| | label = label[top:(top + crop_h), left:(left + crop_w)] |
| | return img, label |
| |
|
| | elif self.phase == 'od': |
| | if torch.rand(1) < self.prob: |
| | return data |
| | img, label = data |
| | w, h, c = img.shape |
| |
|
| | while True: |
| | crop_w = torch.randint( |
| | min(w, int(self.scale_range[0] * self.original_size[0])), |
| | min(w + 1, int(self.scale_range[1] * self.original_size[0])), |
| | size=[] |
| | ) |
| | crop_h = torch.randint( |
| | min(h, int(self.scale_range[0] * self.original_size[1])), |
| | min(h + 1, int(self.scale_range[1] * self.original_size[1])), |
| | size=[] |
| | ) |
| |
|
| | |
| | if self.aspect_ratio_range[0] < crop_h / crop_w < self.aspect_ratio_range[1]: |
| | break |
| |
|
| | left = torch.randint(0, w - crop_w + 1, size=[]) |
| | top = torch.randint(0, h - crop_h + 1, size=[]) |
| | left = left.numpy() |
| | top = top.numpy() |
| | crop_h = crop_h.numpy() |
| | crop_w = crop_w.numpy() |
| | img = img[top:(top + crop_h), left:(left + crop_w), :] |
| | if len(label): |
| | |
| | centers = (label[:, 1:3] + label[:, 3:]) / 2.0 |
| | |
| | m1 = (left <= centers[:, 0]) * (top <= centers[:, 1]) |
| | |
| | m2 = ((left + crop_w) >= centers[:, 0]) * ((top + crop_h) > centers[:, 1]) |
| | |
| | mask = m1 * m2 |
| |
|
| | |
| | current_label = label[mask, :] |
| |
|
| | |
| | current_label[:, 1::2] -= left |
| | current_label[:, 2::2] -= top |
| | label = current_label |
| | return img, label |
| |
|
| |
|
| | class RandomMirror(object): |
| | def __init__(self, phase, prob=0.5): |
| | self.phase = phase |
| | self.prob = prob |
| |
|
| | def __call__(self, data): |
| | if self.phase == 'seg': |
| | img, label = data |
| | if torch.rand(1) < self.prob: |
| | img = img[:, ::-1] |
| | label = label[:, ::-1] |
| | return img, label |
| | elif self.phase == 'cd': |
| | img1, label1, img2, label2 = data |
| | if torch.rand(1) < self.prob: |
| | img1 = img1[:, ::-1] |
| | label1 = label1[:, ::-1] |
| | img2 = img2[:, ::-1] |
| | label2 = label2[:, ::-1] |
| | return img1, label1, img2, label2 |
| | elif self.phase == 'od': |
| | img, label = data |
| | if torch.rand(1) < self.prob: |
| | _, width, _ = img.shape |
| | img = img[:, ::-1] |
| | label[:, 1::2] = width - label[:, 3::-2] |
| | return img, label |
| |
|
| |
|
| | class RandomFlipV(object): |
| | def __init__(self, phase, prob=0.5): |
| | self.phase = phase |
| | self.prob = prob |
| |
|
| | def __call__(self, data): |
| | if self.phase == 'seg': |
| | img, label = data |
| | if torch.rand(1) < self.prob: |
| | img = img[::-1, :] |
| | label = label[::-1, :] |
| | return img, label |
| | elif self.phase == 'cd': |
| | img1, label1, img2, label2 = data |
| | if torch.rand(1) < self.prob: |
| | img1 = img1[::-1, :] |
| | label1 = label1[::-1, :] |
| | img2 = img2[::-1, :] |
| | label2 = label2[::-1, :] |
| | return img1, label1, img2, label2 |
| | elif self.phase == 'od': |
| | img, label = data |
| | if torch.rand(1) < self.prob: |
| | height, _, _ = img.shape |
| | img = img[::-1, :] |
| | label[:, 2::2] = height - label[:, 4:1:-2] |
| | return img, label |
| |
|
| |
|
| | class Resize(object): |
| | def __init__(self, phase, size): |
| | self.phase = phase |
| | self.size = size |
| |
|
| | def __call__(self, data): |
| | if self.phase == 'seg': |
| | img, label = data |
| | img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) |
| | |
| | label = cv2.resize(label, self.size, interpolation=cv2.INTER_NEAREST) |
| | return img, label |
| | elif self.phase == 'cd': |
| | img1, label1, img2, label2 = data |
| | img1 = cv2.resize(img1, self.size, interpolation=cv2.INTER_LINEAR) |
| | img2 = cv2.resize(img2, self.size, interpolation=cv2.INTER_LINEAR) |
| | |
| | label1 = cv2.resize(label1, self.size, interpolation=cv2.INTER_NEAREST) |
| | label2 = cv2.resize(label2, self.size, interpolation=cv2.INTER_NEAREST) |
| | return img1, label1, img2, label2 |
| | elif self.phase == 'od': |
| | img, label = data |
| | height, width, _ = img.shape |
| | img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) |
| | label[:, 1::2] = label[:, 1::2] / width * self.size[0] |
| | label[:, 2::2] = label[:, 2::2] / height * self.size[1] |
| | return img, label |
| |
|
| |
|
| | class Normalize(object): |
| | def __init__(self, phase, prior_mean, prior_std): |
| | self.phase = phase |
| | self.prior_mean = np.array([[prior_mean]], dtype=np.float32) |
| | self.prior_std = np.array([[prior_std]], dtype=np.float32) |
| |
|
| | def __call__(self, data): |
| | if self.phase in ['od', 'seg']: |
| | img, _ = data |
| | img = img / 255. |
| | img = (img - self.prior_mean) / (self.prior_std + 1e-10) |
| |
|
| | return img, _ |
| | elif self.phase == 'cd': |
| | img1, label1, img2, label2 = data |
| | img1 = img1 / 255. |
| | img1 = (img1 - self.prior_mean) / (self.prior_std + 1e-10) |
| | img2 = img2 / 255. |
| | img2 = (img2 - self.prior_mean) / (self.prior_std + 1e-10) |
| |
|
| | return img1, label1, img2, label2 |
| |
|
| |
|
| | class InvNormalize(object): |
| | def __init__(self, prior_mean, prior_std): |
| | self.prior_mean = np.array([[prior_mean]], dtype=np.float32) |
| | self.prior_std = np.array([[prior_std]], dtype=np.float32) |
| |
|
| | def __call__(self, img): |
| | img = img * self.prior_std + self.prior_mean |
| | img = img * 255. |
| | img = np.clip(img, a_min=0, a_max=255) |
| | return img |
| |
|
| |
|
| | class Augmentations(object): |
| | def __init__(self, size, prior_mean=0, prior_std=1, pattern='train', phase='seg', *args, **kwargs): |
| | self.size = size |
| | self.prior_mean = prior_mean |
| | self.prior_std = prior_std |
| | self.phase = phase |
| |
|
| | augments = { |
| | 'train': Compose([ |
| | ConvertUcharToFloat(), |
| | ImgDistortion(self.phase), |
| | ExpandImg(self.phase, self.prior_mean), |
| | RandomSampleCrop(self.phase, original_size=self.size), |
| | RandomMirror(self.phase), |
| | RandomFlipV(self.phase), |
| | Resize(self.phase, self.size), |
| | Normalize(self.phase, self.prior_mean, self.prior_std), |
| | ]), |
| | 'val': Compose([ |
| | ConvertUcharToFloat(), |
| | Resize(self.phase, self.size), |
| | Normalize(self.phase, self.prior_mean, self.prior_std), |
| | ]), |
| | 'test': Compose([ |
| | ConvertUcharToFloat(), |
| | Resize(self.phase, self.size), |
| | Normalize(self.phase, self.prior_mean, self.prior_std), |
| | ]) |
| | } |
| | self.augment = augments[pattern] |
| |
|
| | def __call__(self, data): |
| | return self.augment(data) |
| |
|
| |
|