Spaces:
Runtime error
Runtime error
| import torch, os | |
| from PIL import Image | |
| import numpy as np | |
| import torchvision.transforms as transforms | |
| import torch.utils.data as data | |
| from einops import rearrange | |
| class ImageLoader: | |
| def __init__(self, root): | |
| self.img_dir = root | |
| def __call__(self, img): | |
| file = f'{self.img_dir}/{img}' | |
| img = Image.open(file).convert('RGB') | |
| return img | |
| def imagenet_transform(phase): | |
| if phase == 'train': | |
| transform = transforms.Compose([ | |
| transforms.RandomResizedCrop(224), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor() | |
| ]) | |
| elif phase == 'test': | |
| transform = transforms.Compose([ | |
| transforms.Resize([224,224]), | |
| transforms.ToTensor() | |
| ]) | |
| return transform | |
| class Dataset_embedding(data.Dataset): | |
| def __init__(self, cfg_data, phase='train'): | |
| self.transform = imagenet_transform(phase) | |
| self.type_name = cfg_data.type_name | |
| self.type2idx = {self.type_name[i]: i for i in range(len(self.type_name))} | |
| if phase == 'train': | |
| self.loader = ImageLoader(cfg_data.train_dir) | |
| name = os.listdir(f'{cfg_data.train_dir}/{self.type_name[0]}') | |
| self.data = [] | |
| for i in range(len(self.type_name)): | |
| for j in range(len(name)): | |
| self.data.append([self.type_name[i], name[j]]) | |
| elif phase == 'test': | |
| self.loader = ImageLoader(cfg_data.test_dir) | |
| name = os.listdir(f'{cfg_data.test_dir}/{self.type_name[0]}') | |
| self.data = [] | |
| for i in range(1, len(self.type_name)): | |
| for j in range(len(name)): | |
| self.data.append([self.type_name[i], name[j]]) | |
| print(f'The amount of {phase} data is {len(self.data)}') | |
| def __getitem__(self, index): | |
| type_name, image_name = self.data[index] | |
| scene = self.type2idx[type_name] | |
| image = self.transform(self.loader(f'{type_name}/{image_name}')) | |
| return (scene, image) | |
| def __len__(self): | |
| return len(self.data) | |
| def init_embedding_data(cfg_em, phase): | |
| if phase == 'train': | |
| train_dataset = Dataset_embedding(cfg_em, 'train') | |
| test_dataset = Dataset_embedding(cfg_em, 'test') | |
| train_loader = data.DataLoader(train_dataset, | |
| batch_size=cfg_em.batch, | |
| shuffle=True, | |
| num_workers=cfg_em.num_workers, | |
| pin_memory=True) | |
| test_loader = data.DataLoader(test_dataset, | |
| batch_size=cfg_em.batch, | |
| shuffle=False, | |
| num_workers=cfg_em.num_workers, | |
| pin_memory=True) | |
| print(len(train_dataset),len(test_dataset)) | |
| elif phase == 'inference': | |
| test_dataset = Dataset_embedding(cfg_em, 'test') | |
| test_loader = data.DataLoader(test_dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=cfg_em.num_workers, | |
| pin_memory=True) | |
| return train_loader, test_loader |