| import os | |
| import cv2 | |
| import sys | |
| import numpy as np | |
| import torchvision.transforms as transforms | |
| from torch.utils.data.dataset import Dataset | |
| from PIL import Image | |
| from threading import Thread | |
| filepath = os.path.split(__file__)[0] | |
| repopath = os.path.split(filepath)[0] | |
| sys.path.append(repopath) | |
| from data.custom_transforms import * | |
| from utils.misc import * | |
| Image.MAX_IMAGE_PIXELS = None | |
| def get_transform(tfs): | |
| comp = [] | |
| for key, value in zip(tfs.keys(), tfs.values()): | |
| if value is not None: | |
| tf = eval(key)(**value) | |
| else: | |
| tf = eval(key)() | |
| comp.append(tf) | |
| return transforms.Compose(comp) | |
| class RGB_Dataset(Dataset): | |
| def __init__(self, root, sets, tfs): | |
| self.images, self.gts = [], [] | |
| for set in sets: | |
| image_root, gt_root = os.path.join(root, set, 'images'), os.path.join(root, set, 'masks') | |
| images = [os.path.join(image_root, f) for f in os.listdir(image_root) if f.lower().endswith(('.jpg', '.png'))] | |
| images = sort(images) | |
| gts = [os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.lower().endswith(('.jpg', '.png'))] | |
| gts = sort(gts) | |
| self.images.extend(images) | |
| self.gts.extend(gts) | |
| self.filter_files() | |
| self.size = len(self.images) | |
| self.transform = get_transform(tfs) | |
| def __getitem__(self, index): | |
| image = Image.open(self.images[index]).convert('RGB') | |
| gt = Image.open(self.gts[index]).convert('L') | |
| shape = gt.size[::-1] | |
| name = self.images[index].split(os.sep)[-1] | |
| name = os.path.splitext(name)[0] | |
| sample = {'image': image, 'gt': gt, 'name': name, 'shape': shape} | |
| sample = self.transform(sample) | |
| return sample | |
| def filter_files(self): | |
| assert len(self.images) == len(self.gts) | |
| images, gts = [], [] | |
| for img_path, gt_path in zip(self.images, self.gts): | |
| img, gt = Image.open(img_path), Image.open(gt_path) | |
| if img.size == gt.size: | |
| images.append(img_path) | |
| gts.append(gt_path) | |
| self.images, self.gts = images, gts | |
| def __len__(self): | |
| return self.size | |
| class ImageLoader: | |
| def __init__(self, root, tfs): | |
| if os.path.isdir(root): | |
| self.images = [os.path.join(root, f) for f in os.listdir(root) if f.lower().endswith(('.jpg', '.png', '.jpeg'))] | |
| self.images = sort(self.images) | |
| elif os.path.isfile(root): | |
| self.images = [root] | |
| self.size = len(self.images) | |
| self.transform = get_transform(tfs) | |
| def __iter__(self): | |
| self.index = 0 | |
| return self | |
| def __next__(self): | |
| if self.index == self.size: | |
| raise StopIteration | |
| image = Image.open(self.images[self.index]).convert('RGB') | |
| shape = image.size[::-1] | |
| name = self.images[self.index].split(os.sep)[-1] | |
| name = os.path.splitext(name)[0] | |
| sample = {'image': image, 'name': name, 'shape': shape, 'original': image} | |
| sample = self.transform(sample) | |
| sample['image'] = sample['image'].unsqueeze(0) | |
| if 'image_resized' in sample.keys(): | |
| sample['image_resized'] = sample['image_resized'].unsqueeze(0) | |
| self.index += 1 | |
| return sample | |
| def __len__(self): | |
| return self.size | |
| class VideoLoader: | |
| def __init__(self, root, tfs): | |
| if os.path.isdir(root): | |
| self.videos = [os.path.join(root, f) for f in os.listdir(root) if f.lower().endswith(('.mp4', '.avi', 'mov'))] | |
| elif os.path.isfile(root): | |
| self.videos = [root] | |
| self.size = len(self.videos) | |
| self.transform = get_transform(tfs) | |
| def __iter__(self): | |
| self.index = 0 | |
| self.cap = None | |
| self.fps = None | |
| return self | |
| def __next__(self): | |
| if self.index == self.size: | |
| raise StopIteration | |
| if self.cap is None: | |
| self.cap = cv2.VideoCapture(self.videos[self.index]) | |
| self.fps = self.cap.get(cv2.CAP_PROP_FPS) | |
| ret, frame = self.cap.read() | |
| name = self.videos[self.index].split(os.sep)[-1] | |
| name = os.path.splitext(name)[0] | |
| if ret is False: | |
| self.cap.release() | |
| self.cap = None | |
| sample = {'image': None, 'shape': None, 'name': name, 'original': None} | |
| self.index += 1 | |
| else: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| image = Image.fromarray(frame).convert('RGB') | |
| shape = image.size[::-1] | |
| sample = {'image': image, 'shape': shape, 'name': name, 'original': image} | |
| sample = self.transform(sample) | |
| sample['image'] = sample['image'].unsqueeze(0) | |
| if 'image_resized' in sample.keys(): | |
| sample['image_resized'] = sample['image_resized'].unsqueeze(0) | |
| return sample | |
| def __len__(self): | |
| return self.size | |
| class WebcamLoader: | |
| def __init__(self, ID, tfs): | |
| self.ID = int(ID) | |
| self.cap = cv2.VideoCapture(self.ID) | |
| self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) | |
| self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) | |
| self.transform = get_transform(tfs) | |
| self.imgs = [] | |
| self.imgs.append(self.cap.read()[1]) | |
| self.thread = Thread(target=self.update, daemon=True) | |
| self.thread.start() | |
| def update(self): | |
| while self.cap.isOpened(): | |
| ret, frame = self.cap.read() | |
| if ret is True: | |
| self.imgs.append(frame) | |
| else: | |
| break | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| if len(self.imgs) > 0: | |
| frame = self.imgs[-1] | |
| else: | |
| frame = np.zeros((480, 640, 3)).astype(np.uint8) | |
| if self.thread.is_alive() is False or cv2.waitKey(1) == ord('q'): | |
| cv2.destroyAllWindows() | |
| raise StopIteration | |
| else: | |
| image = Image.fromarray(frame).convert('RGB') | |
| shape = image.size[::-1] | |
| sample = {'image': image, 'shape': shape, 'name': 'webcam', 'original': image} | |
| sample = self.transform(sample) | |
| sample['image'] = sample['image'].unsqueeze(0) | |
| if 'image_resized' in sample.keys(): | |
| sample['image_resized'] = sample['image_resized'].unsqueeze(0) | |
| del self.imgs[:-1] | |
| return sample | |
| def __len__(self): | |
| return 0 | |
| class RefinementLoader: | |
| def __init__(self, image_dir, seg_dir, tfs): | |
| self.images = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))] | |
| self.images = sort(self.images) | |
| self.segs = [os.path.join(seg_dir, f) for f in os.listdir(seg_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))] | |
| self.segs = sort(self.segs) | |
| self.size = len(self.images) | |
| self.transform = get_transform(tfs) | |
| def __iter__(self): | |
| self.index = 0 | |
| return self | |
| def __next__(self): | |
| if self.index == self.size: | |
| raise StopIteration | |
| image = Image.open(self.images[self.index]).convert('RGB') | |
| seg = Image.open(self.segs[self.index]).convert('L') | |
| shape = image.size[::-1] | |
| name = self.images[self.index].split(os.sep)[-1] | |
| name = os.path.splitext(name)[0] | |
| sample = {'image': image, 'gt': seg, 'name': name, 'shape': shape, 'original': image} | |
| sample = self.transform(sample) | |
| sample['image'] = sample['image'].unsqueeze(0) | |
| sample['mask'] = sample['gt'].unsqueeze(0) | |
| if 'image_resized' in sample.keys(): | |
| sample['image_resized'] = sample['image_resized'].unsqueeze(0) | |
| del sample['gt'] | |
| self.index += 1 | |
| return sample | |
| def __len__(self): | |
| return self.size | |