| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| import PIL.Image as Image | |
| from torch.utils.data import Dataset | |
| from vhap.util.log import get_logger | |
| logger = get_logger(__name__) | |
| class ImageFolderDataset(Dataset): | |
| def __init__( | |
| self, | |
| image_folder: Path, | |
| background_folder: Optional[Path]=None, | |
| background_fname2camId=lambda x: x, | |
| image_fname2camId=lambda x: x, | |
| ): | |
| """ | |
| Args: | |
| root_folder: Path to dataset with the following directory layout | |
| <image_folder>/ | |
| |---xx.jpg | |
| |---... | |
| """ | |
| super().__init__() | |
| self.image_fname2camId = image_fname2camId | |
| self.background_foler = background_folder | |
| logger.info(f"Initializing dataset from folder {image_folder}") | |
| self.image_paths = sorted(list(image_folder.glob('*.jpg'))) | |
| if background_folder is not None: | |
| self.backgrounds = {} | |
| background_paths = sorted(list((image_folder / background_folder).glob('*.jpg'))) | |
| for background_path in background_paths: | |
| bg = np.array(Image.open(background_path)) | |
| cam_id = background_fname2camId(background_path.name) | |
| self.backgrounds[cam_id] = bg | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, i): | |
| image_path = self.image_paths[i] | |
| cam_id = self.image_fname2camId(image_path.name) | |
| rgb = np.array(Image.open(image_path)) | |
| item = { | |
| "rgb": rgb, | |
| 'image_path': str(image_path), | |
| } | |
| if self.background_foler is not None: | |
| item['background'] = self.backgrounds[cam_id] | |
| return item | |
| if __name__ == "__main__": | |
| from tqdm import tqdm | |
| from torch.utils.data import DataLoader | |
| dataset = ImageFolderDataset( | |
| image_folder='./xx', | |
| img_to_tensor=True, | |
| ) | |
| print(len(dataset)) | |
| sample = dataset[0] | |
| print(sample.keys()) | |
| print(sample["rgb"].shape) | |
| dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1) | |
| for item in tqdm(dataloader): | |
| pass | |