Spaces:
Build error
Build error
| import numpy as np | |
| from PIL import Image, UnidentifiedImageError | |
| from pathlib import Path | |
| from random import randint, choice | |
| import torch | |
| from torch.utils.data import Dataset | |
| class TextImageDataset(Dataset): | |
| def __init__(self, | |
| folder, | |
| text_len=40, | |
| truncate_captions=False, | |
| text_tokenizer=None, | |
| image_tokenizer=None, | |
| shuffle=False | |
| ): | |
| """ | |
| @param folder: Folder containing images and text files matched by their paths' respective "stem" | |
| @param truncate_captions: Rather than throw an exception, captions which are too long will be truncated. | |
| """ | |
| super().__init__() | |
| self.shuffle = shuffle | |
| path = Path(folder) | |
| text_files = [*path.glob('**/*.txt')] | |
| image_files = [ | |
| *path.glob('**/*.png'), *path.glob('**/*.jpg'), | |
| *path.glob('**/*.jpeg'), *path.glob('**/*.bmp') | |
| ] | |
| text_files = {text_file.stem: text_file for text_file in text_files} | |
| image_files = {image_file.stem: image_file for image_file in image_files} | |
| keys = (image_files.keys() & text_files.keys()) | |
| self.keys = list(keys) | |
| self.text_files = {k: v for k, v in text_files.items() if k in keys} | |
| self.image_files = {k: v for k, v in image_files.items() if k in keys} | |
| self.text_len = text_len | |
| self.truncate_captions = truncate_captions | |
| self.text_tokenizer = text_tokenizer | |
| self.image_tokenizer = image_tokenizer | |
| def __len__(self): | |
| return len(self.keys) | |
| def random_sample(self): | |
| return self.__getitem__(randint(0, self.__len__() - 1)) | |
| def sequential_sample(self, ind): | |
| if ind >= self.__len__() - 1: | |
| return self.__getitem__(0) | |
| return self.__getitem__(ind + 1) | |
| def skip_sample(self, ind): | |
| if self.shuffle: | |
| return self.random_sample() | |
| return self.sequential_sample(ind=ind) | |
| def __getitem__(self, ind): | |
| key = self.keys[ind] | |
| text_file = self.text_files[key] | |
| image_file = self.image_files[key] | |
| descriptions = text_file.read_text().split('\n') | |
| descriptions = list(filter(lambda t: len(t) > 0, descriptions)) | |
| try: | |
| description = choice(descriptions) | |
| except IndexError as zero_captions_in_file_ex: | |
| print(f"An exception occurred trying to load file {text_file}.") | |
| print(f"Skipping index {ind}") | |
| return self.skip_sample(ind) | |
| tokenized_text = self.text_tokenizer.tokenize( | |
| description, | |
| self.text_len, | |
| truncate_text=self.truncate_captions | |
| ).squeeze(0) | |
| try: | |
| image = Image.open(image_file).convert('RGB') | |
| pixels = np.array(image).reshape(-1, 3) | |
| tokenized_image = [self.image_tokenizer[str(idx)] for idx in pixels] | |
| tokenized_image = torch.tensor(tokenized_image) | |
| except (UnidentifiedImageError, OSError) as corrupt_image_exceptions: | |
| print(f"An exception occurred trying to load file {image_file}.") | |
| print(f"Skipping index {ind}") | |
| return self.skip_sample(ind) | |
| # Success | |
| return tokenized_text, tokenized_image | |