Spaces:
Runtime error
Runtime error
| import os | |
| from PIL import Image | |
| import cv2 | |
| import torch | |
| from torch.utils import data | |
| from torchvision import transforms | |
| from torchvision.transforms import functional as F | |
| import numbers | |
| import numpy as np | |
| import random | |
| #re_size = (256, 256) | |
| #cr_size = (224, 224) | |
| class ImageDataTrain(data.Dataset): | |
| def __init__(self): | |
| self.sal_root = '/home/liuj/dataset/DUTS/DUTS-TR' | |
| self.sal_source = '/home/liuj/dataset/DUTS/DUTS-TR/train_pair_edge.lst' | |
| with open(self.sal_source, 'r') as f: | |
| self.sal_list = [x.strip() for x in f.readlines()] | |
| self.sal_num = len(self.sal_list) | |
| def __getitem__(self, item): | |
| sal_image = load_image(os.path.join(self.sal_root, self.sal_list[item%self.sal_num].split()[0])) | |
| sal_label = load_sal_label(os.path.join(self.sal_root, self.sal_list[item%self.sal_num].split()[1])) | |
| sal_edge = load_edge_label(os.path.join(self.sal_root, self.sal_list[item%self.sal_num].split()[2])) | |
| sal_image, sal_label, sal_edge = cv_random_flip(sal_image, sal_label, sal_edge) | |
| sal_image = torch.Tensor(sal_image) | |
| sal_label = torch.Tensor(sal_label) | |
| sal_edge = torch.Tensor(sal_edge) | |
| sample = {'sal_image': sal_image, 'sal_label': sal_label, 'sal_edge': sal_edge} | |
| return sample | |
| def __len__(self): | |
| # return max(max(self.edge_num, self.sal_num), self.skel_num) | |
| return self.sal_num | |
| class ImageDataTest(data.Dataset): | |
| def __init__(self, test_mode=1, sal_mode='e'): | |
| if test_mode == 0: | |
| # self.image_root = '/home/liuj/dataset/saliency_test/ECSSD/Imgs/' | |
| # self.image_source = '/home/liuj/dataset/saliency_test/ECSSD/test.lst' | |
| self.image_root = '/home/liuj/dataset/HED-BSDS_PASCAL/HED-BSDS/test/' | |
| self.image_source = '/home/liuj/dataset/HED-BSDS_PASCAL/HED-BSDS/test.lst' | |
| elif test_mode == 1: | |
| if sal_mode == 'e': | |
| self.image_root = '/home/liuj/dataset/saliency_test/ECSSD/Imgs/' | |
| self.image_source = '/home/liuj/dataset/saliency_test/ECSSD/test.lst' | |
| self.test_fold = '/media/ubuntu/disk/Result/saliency/ECSSD/' | |
| elif sal_mode == 'p': | |
| self.image_root = '/home/liuj/dataset/saliency_test/PASCALS/Imgs/' | |
| self.image_source = '/home/liuj/dataset/saliency_test/PASCALS/test.lst' | |
| self.test_fold = '/media/ubuntu/disk/Result/saliency/PASCALS/' | |
| elif sal_mode == 'd': | |
| self.image_root = '/home/liuj/dataset/saliency_test/DUTOMRON/Imgs/' | |
| self.image_source = '/home/liuj/dataset/saliency_test/DUTOMRON/test.lst' | |
| self.test_fold = '/media/ubuntu/disk/Result/saliency/DUTOMRON/' | |
| elif sal_mode == 'h': | |
| self.image_root = '/home/liuj/dataset/saliency_test/HKU-IS/Imgs/' | |
| self.image_source = '/home/liuj/dataset/saliency_test/HKU-IS/test.lst' | |
| self.test_fold = '/media/ubuntu/disk/Result/saliency/HKU-IS/' | |
| elif sal_mode == 's': | |
| self.image_root = '/home/liuj/dataset/saliency_test/SOD/Imgs/' | |
| self.image_source = '/home/liuj/dataset/saliency_test/SOD/test.lst' | |
| self.test_fold = '/media/ubuntu/disk/Result/saliency/SOD/' | |
| elif sal_mode == 'm': | |
| self.image_root = '/home/liuj/dataset/saliency_test/MSRA/Imgs/' | |
| self.image_source = '/home/liuj/dataset/saliency_test/MSRA/test.lst' | |
| elif sal_mode == 'o': | |
| self.image_root = '/home/liuj/dataset/saliency_test/SOC/TestSet/Imgs/' | |
| self.image_source = '/home/liuj/dataset/saliency_test/SOC/TestSet/test.lst' | |
| self.test_fold = '/media/ubuntu/disk/Result/saliency/SOC/' | |
| elif sal_mode == 't': | |
| self.image_root = '/home/liuj/dataset/DUTS/DUTS-TE/DUTS-TE-Image/' | |
| self.image_source = '/home/liuj/dataset/DUTS/DUTS-TE/test.lst' | |
| self.test_fold = '/media/ubuntu/disk/Result/saliency/DUTS/' | |
| elif test_mode == 2: | |
| self.image_root = '/home/liuj/dataset/SK-LARGE/images/test/' | |
| self.image_source = '/home/liuj/dataset/SK-LARGE/test.lst' | |
| with open(self.image_source, 'r') as f: | |
| self.image_list = [x.strip() for x in f.readlines()] | |
| self.image_num = len(self.image_list) | |
| def __getitem__(self, item): | |
| image, im_size = load_image_test(os.path.join(self.image_root, self.image_list[item])) | |
| image = torch.Tensor(image) | |
| return {'image': image, 'name': self.image_list[item%self.image_num], 'size': im_size} | |
| def save_folder(self): | |
| return self.test_fold | |
| def __len__(self): | |
| # return max(max(self.edge_num, self.skel_num), self.sal_num) | |
| return self.image_num | |
| # get the dataloader (Note: without data augmentation, except saliency with random flip) | |
| def get_loader(batch_size, mode='train', num_thread=1, test_mode=0, sal_mode='e'): | |
| shuffle = False | |
| if mode == 'train': | |
| shuffle = True | |
| dataset = ImageDataTrain() | |
| else: | |
| dataset = ImageDataTest(test_mode=test_mode, sal_mode=sal_mode) | |
| data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_thread) | |
| return data_loader, dataset | |
| def load_image(pah): | |
| if not os.path.exists(pah): | |
| print('File Not Exists') | |
| im = cv2.imread(pah) | |
| in_ = np.array(im, dtype=np.float32) | |
| # in_ = cv2.resize(in_, im_sz, interpolation=cv2.INTER_CUBIC) | |
| # in_ = in_[:,:,::-1] # only if use PIL to load image | |
| in_ -= np.array((104.00699, 116.66877, 122.67892)) | |
| in_ = in_.transpose((2,0,1)) | |
| return in_ | |
| def load_image_test(pah): | |
| if not os.path.exists(pah): | |
| print('File Not Exists') | |
| im = cv2.imread(pah) | |
| in_ = np.array(im, dtype=np.float32) | |
| im_size = tuple(in_.shape[:2]) | |
| # in_ = cv2.resize(in_, (cr_size[1], cr_size[0]), interpolation=cv2.INTER_LINEAR) | |
| # in_ = in_[:,:,::-1] # only if use PIL to load image | |
| in_ -= np.array((104.00699, 116.66877, 122.67892)) | |
| in_ = in_.transpose((2,0,1)) | |
| return in_, im_size | |
| def load_edge_label(pah): | |
| """ | |
| pixels > 0.5 -> 1 | |
| Load label image as 1 x height x width integer array of label indices. | |
| The leading singleton dimension is required by the loss. | |
| """ | |
| if not os.path.exists(pah): | |
| print('File Not Exists') | |
| im = Image.open(pah) | |
| label = np.array(im, dtype=np.float32) | |
| if len(label.shape) == 3: | |
| label = label[:,:,0] | |
| # label = cv2.resize(label, im_sz, interpolation=cv2.INTER_NEAREST) | |
| label = label / 255. | |
| label[np.where(label > 0.5)] = 1. | |
| label = label[np.newaxis, ...] | |
| return label | |
| def load_skel_label(pah): | |
| """ | |
| pixels > 0 -> 1 | |
| Load label image as 1 x height x width integer array of label indices. | |
| The leading singleton dimension is required by the loss. | |
| """ | |
| if not os.path.exists(pah): | |
| print('File Not Exists') | |
| im = Image.open(pah) | |
| label = np.array(im, dtype=np.float32) | |
| if len(label.shape) == 3: | |
| label = label[:,:,0] | |
| # label = cv2.resize(label, im_sz, interpolation=cv2.INTER_NEAREST) | |
| label = label / 255. | |
| label[np.where(label > 0.)] = 1. | |
| label = label[np.newaxis, ...] | |
| return label | |
| def load_sal_label(pah): | |
| """ | |
| Load label image as 1 x height x width integer array of label indices. | |
| The leading singleton dimension is required by the loss. | |
| """ | |
| if not os.path.exists(pah): | |
| print('File Not Exists') | |
| im = Image.open(pah) | |
| label = np.array(im, dtype=np.float32) | |
| if len(label.shape) == 3: | |
| label = label[:,:,0] | |
| # label = cv2.resize(label, im_sz, interpolation=cv2.INTER_NEAREST) | |
| label = label / 255. | |
| label = label[np.newaxis, ...] | |
| return label | |
| def load_sem_label(pah): | |
| """ | |
| Load label image as 1 x height x width integer array of label indices. | |
| The leading singleton dimension is required by the loss. | |
| """ | |
| if not os.path.exists(pah): | |
| print('File Not Exists') | |
| im = Image.open(pah) | |
| label = np.array(im, dtype=np.float32) | |
| if len(label.shape) == 3: | |
| label = label[:,:,0] | |
| # label = cv2.resize(label, im_sz, interpolation=cv2.INTER_NEAREST) | |
| # label = label / 255. | |
| label = label[np.newaxis, ...] | |
| return label | |
| def edge_thres_transform(x, thres): | |
| # y0 = torch.zeros(x.size()) | |
| y1 = torch.ones(x.size()) | |
| x = torch.where(x >= thres, y1, x) | |
| return x | |
| def skel_thres_transform(x, thres): | |
| y0 = torch.zeros(x.size()) | |
| y1 = torch.ones(x.size()) | |
| x = torch.where(x > thres, y1, y0) | |
| return x | |
| def cv_random_flip(img, label, edge): | |
| flip_flag = random.randint(0, 1) | |
| if flip_flag == 1: | |
| img = img[:,:,::-1].copy() | |
| label = label[:,:,::-1].copy() | |
| edge = edge[:,:,::-1].copy() | |
| return img, label, edge | |
| def cv_random_crop_flip(img, label, resize_size, crop_size, random_flip=True): | |
| def get_params(img_size, output_size): | |
| h, w = img_size | |
| th, tw = output_size | |
| if w == tw and h == th: | |
| return 0, 0, h, w | |
| i = random.randint(0, h - th) | |
| j = random.randint(0, w - tw) | |
| return i, j, th, tw | |
| if random_flip: | |
| flip_flag = random.randint(0, 1) | |
| img = img.transpose((1,2,0)) # H, W, C | |
| label = label[0,:,:] # H, W | |
| img = cv2.resize(img, (resize_size[1], resize_size[0]), interpolation=cv2.INTER_LINEAR) | |
| label = cv2.resize(label, (resize_size[1], resize_size[0]), interpolation=cv2.INTER_NEAREST) | |
| i, j, h, w = get_params(resize_size, crop_size) | |
| img = img[i:i+h, j:j+w, :].transpose((2,0,1)) # C, H, W | |
| label = label[i:i+h, j:j+w][np.newaxis, ...] # 1, H, W | |
| if flip_flag == 1: | |
| img = img[:,:,::-1].copy() | |
| label = label[:,:,::-1].copy() | |
| return img, label | |
| def random_crop(img, label, size, padding=None, pad_if_needed=True, fill_img=(123, 116, 103), fill_label=0, padding_mode='constant'): | |
| def get_params(img, output_size): | |
| w, h = img.size | |
| th, tw = output_size | |
| if w == tw and h == th: | |
| return 0, 0, h, w | |
| i = random.randint(0, h - th) | |
| j = random.randint(0, w - tw) | |
| return i, j, th, tw | |
| if isinstance(size, numbers.Number): | |
| size = (int(size), int(size)) | |
| if padding is not None: | |
| img = F.pad(img, padding, fill_img, padding_mode) | |
| label = F.pad(label, padding, fill_label, padding_mode) | |
| # pad the width if needed | |
| if pad_if_needed and img.size[0] < size[1]: | |
| img = F.pad(img, (int((1 + size[1] - img.size[0]) / 2), 0), fill_img, padding_mode) | |
| label = F.pad(label, (int((1 + size[1] - label.size[0]) / 2), 0), fill_label, padding_mode) | |
| # pad the height if needed | |
| if pad_if_needed and img.size[1] < size[0]: | |
| img = F.pad(img, (0, int((1 + size[0] - img.size[1]) / 2)), fill_img, padding_mode) | |
| label = F.pad(label, (0, int((1 + size[0] - label.size[1]) / 2)), fill_label, padding_mode) | |
| i, j, h, w = get_params(img, size) | |
| return [F.crop(img, i, j, h, w), F.crop(label, i, j, h, w)] | |