import json import os import random import numpy as np import torch import torchvision.transforms as transforms from PIL import Image from torch.utils.data.dataset import Dataset class CC15M(Dataset): def __init__( self, json_path, video_folder=None, resolution=512, enable_bucket=False, ): print(f"loading annotations from {json_path} ...") self.dataset = json.load(open(json_path, 'r')) self.length = len(self.dataset) print(f"data scale: {self.length}") self.enable_bucket = enable_bucket self.video_folder = video_folder resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution) self.pixel_transforms = transforms.Compose([ transforms.Resize(resolution[0]), transforms.CenterCrop(resolution), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) def get_batch(self, idx): video_dict = self.dataset[idx] video_id, name = video_dict['file_path'], video_dict['text'] if self.video_folder is None: video_dir = video_id else: video_dir = os.path.join(self.video_folder, video_id) pixel_values = Image.open(video_dir).convert("RGB") return pixel_values, name def __len__(self): return self.length def __getitem__(self, idx): while True: try: pixel_values, name = self.get_batch(idx) break except Exception as e: print(e) idx = random.randint(0, self.length-1) if not self.enable_bucket: pixel_values = self.pixel_transforms(pixel_values) else: pixel_values = np.array(pixel_values) sample = dict(pixel_values=pixel_values, text=name) return sample class ImageEditDataset(Dataset): def __init__( self, ann_path, data_root=None, image_sample_size=512, text_drop_ratio=0.1, enable_bucket=False, enable_inpaint=False, return_file_name=False, ): # Loading annotations from files print(f"loading annotations from {ann_path} ...") if ann_path.endswith('.csv'): with open(ann_path, 'r') as csvfile: dataset = list(csv.DictReader(csvfile)) elif ann_path.endswith('.json'): dataset = json.load(open(ann_path)) self.data_root = data_root self.dataset = dataset self.length = len(self.dataset) print(f"data scale: {self.length}") # TODO: enable bucket training self.enable_bucket = enable_bucket self.text_drop_ratio = text_drop_ratio self.enable_inpaint = enable_inpaint self.return_file_name = return_file_name # Image params self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) self.image_transforms = transforms.Compose([ transforms.Resize(min(self.image_sample_size)), transforms.CenterCrop(self.image_sample_size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) ]) def get_batch(self, idx): data_info = self.dataset[idx % len(self.dataset)] image_path, text = data_info['file_path'], data_info['text'] if self.data_root is not None: image_path = os.path.join(self.data_root, image_path) image = Image.open(image_path).convert('RGB') if not self.enable_bucket: raise ValueError("Not enable_bucket is not supported now. ") else: image = np.expand_dims(np.array(image), 0) source_image_path = data_info.get('source_file_path', []) source_image = [] if isinstance(source_image_path, list): for _source_image_path in source_image_path: if self.data_root is not None: _source_image_path = os.path.join(self.data_root, _source_image_path) _source_image = Image.open(_source_image_path).convert('RGB') source_image.append(_source_image) else: if self.data_root is not None: _source_image_path = os.path.join(self.data_root, source_image_path) _source_image = Image.open(_source_image_path).convert('RGB') source_image.append(_source_image) if not self.enable_bucket: raise ValueError("Not enable_bucket is not supported now. ") else: source_image = [np.array(_source_image) for _source_image in source_image] if random.random() < self.text_drop_ratio: text = '' return image, source_image, text, 'image', image_path def __len__(self): return self.length def __getitem__(self, idx): data_info = self.dataset[idx % len(self.dataset)] data_type = data_info.get('type', 'image') while True: sample = {} try: data_info_local = self.dataset[idx % len(self.dataset)] data_type_local = data_info_local.get('type', 'image') if data_type_local != data_type: raise ValueError("data_type_local != data_type") pixel_values, source_pixel_values, name, data_type, file_path = self.get_batch(idx) sample["pixel_values"] = pixel_values sample["source_pixel_values"] = source_pixel_values sample["text"] = name sample["data_type"] = data_type sample["idx"] = idx if self.return_file_name: sample["file_name"] = os.path.basename(file_path) if len(sample) > 0: break except Exception as e: print(e, self.dataset[idx % len(self.dataset)]) idx = random.randint(0, self.length-1) if self.enable_inpaint and not self.enable_bucket: mask = get_random_mask(pixel_values.size()) mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask sample["mask_pixel_values"] = mask_pixel_values sample["mask"] = mask clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 sample["clip_pixel_values"] = clip_pixel_values return sample if __name__ == "__main__": dataset = CC15M( csv_path="./cc15m_add_index.json", resolution=512, ) dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,) for idx, batch in enumerate(dataloader): print(batch["pixel_values"].shape, len(batch["text"]))