ZIT-Controlnet / videox_fun /data /dataset_image_video.py
Alexander Bagus
initial commit
d2c9b66
raw
history blame
28.6 kB
import csv
import gc
import io
import json
import math
import os
import random
from contextlib import contextmanager
from random import shuffle
from threading import Thread
import albumentations
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from decord import VideoReader
from einops import rearrange
from func_timeout import FunctionTimedOut, func_timeout
from packaging import version as pver
from PIL import Image
from safetensors.torch import load_file
from torch.utils.data import BatchSampler, Sampler
from torch.utils.data.dataset import Dataset
from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager,
custom_meshgrid, get_random_mask, get_relative_pose,
get_video_reader_batch, padding_image, process_pose_file,
process_pose_params, ray_condition, resize_frame,
resize_image_with_target_area)
class ImageVideoSampler(BatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
Args:
sampler (Sampler): Base sampler.
dataset (Dataset): Dataset providing data information.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
aspect_ratios (dict): The predefined aspect ratios.
"""
def __init__(self,
sampler: Sampler,
dataset: Dataset,
batch_size: int,
drop_last: bool = False
) -> None:
if not isinstance(sampler, Sampler):
raise TypeError('sampler should be an instance of ``Sampler``, '
f'but got {sampler}')
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError('batch_size should be a positive integer value, '
f'but got batch_size={batch_size}')
self.sampler = sampler
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
# buckets for each aspect ratio
self.bucket = {'image':[], 'video':[]}
def __iter__(self):
for idx in self.sampler:
content_type = self.dataset.dataset[idx].get('type', 'image')
self.bucket[content_type].append(idx)
# yield a batch of indices in the same aspect ratio group
if len(self.bucket['video']) == self.batch_size:
bucket = self.bucket['video']
yield bucket[:]
del bucket[:]
elif len(self.bucket['image']) == self.batch_size:
bucket = self.bucket['image']
yield bucket[:]
del bucket[:]
class ImageVideoDataset(Dataset):
def __init__(
self,
ann_path, data_root=None,
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
image_sample_size=512,
video_repeat=0,
text_drop_ratio=0.1,
enable_bucket=False,
video_length_drop_start=0.0,
video_length_drop_end=1.0,
enable_inpaint=False,
return_file_name=False,
):
# Loading annotations from files
print(f"loading annotations from {ann_path} ...")
if ann_path.endswith('.csv'):
with open(ann_path, 'r') as csvfile:
dataset = list(csv.DictReader(csvfile))
elif ann_path.endswith('.json'):
dataset = json.load(open(ann_path))
self.data_root = data_root
# It's used to balance num of images and videos.
if video_repeat > 0:
self.dataset = []
for data in dataset:
if data.get('type', 'image') != 'video':
self.dataset.append(data)
for _ in range(video_repeat):
for data in dataset:
if data.get('type', 'image') == 'video':
self.dataset.append(data)
else:
self.dataset = dataset
del dataset
self.length = len(self.dataset)
print(f"data scale: {self.length}")
# TODO: enable bucket training
self.enable_bucket = enable_bucket
self.text_drop_ratio = text_drop_ratio
self.enable_inpaint = enable_inpaint
self.return_file_name = return_file_name
self.video_length_drop_start = video_length_drop_start
self.video_length_drop_end = video_length_drop_end
# Video params
self.video_sample_stride = video_sample_stride
self.video_sample_n_frames = video_sample_n_frames
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
self.video_transforms = transforms.Compose(
[
transforms.Resize(min(self.video_sample_size)),
transforms.CenterCrop(self.video_sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
# Image params
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
self.image_transforms = transforms.Compose([
transforms.Resize(min(self.image_sample_size)),
transforms.CenterCrop(self.image_sample_size),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
])
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
def get_batch(self, idx):
data_info = self.dataset[idx % len(self.dataset)]
if data_info.get('type', 'image')=='video':
video_id, text = data_info['file_path'], data_info['text']
if self.data_root is None:
video_dir = video_id
else:
video_dir = os.path.join(self.data_root, video_id)
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
min_sample_n_frames = min(
self.video_sample_n_frames,
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
)
if min_sample_n_frames == 0:
raise ValueError(f"No Frames in video.")
video_length = int(self.video_length_drop_end * len(video_reader))
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
try:
sample_args = (video_reader, batch_index)
pixel_values = func_timeout(
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
)
resized_frames = []
for i in range(len(pixel_values)):
frame = pixel_values[i]
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
resized_frames.append(resized_frame)
pixel_values = np.array(resized_frames)
except FunctionTimedOut:
raise ValueError(f"Read {idx} timeout.")
except Exception as e:
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
if not self.enable_bucket:
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
pixel_values = pixel_values / 255.
del video_reader
else:
pixel_values = pixel_values
if not self.enable_bucket:
pixel_values = self.video_transforms(pixel_values)
# Random use no text generation
if random.random() < self.text_drop_ratio:
text = ''
return pixel_values, text, 'video', video_dir
else:
image_path, text = data_info['file_path'], data_info['text']
if self.data_root is not None:
image_path = os.path.join(self.data_root, image_path)
image = Image.open(image_path).convert('RGB')
if not self.enable_bucket:
image = self.image_transforms(image).unsqueeze(0)
else:
image = np.expand_dims(np.array(image), 0)
if random.random() < self.text_drop_ratio:
text = ''
return image, text, 'image', image_path
def __len__(self):
return self.length
def __getitem__(self, idx):
data_info = self.dataset[idx % len(self.dataset)]
data_type = data_info.get('type', 'image')
while True:
sample = {}
try:
data_info_local = self.dataset[idx % len(self.dataset)]
data_type_local = data_info_local.get('type', 'image')
if data_type_local != data_type:
raise ValueError("data_type_local != data_type")
pixel_values, name, data_type, file_path = self.get_batch(idx)
sample["pixel_values"] = pixel_values
sample["text"] = name
sample["data_type"] = data_type
sample["idx"] = idx
if self.return_file_name:
sample["file_name"] = os.path.basename(file_path)
if len(sample) > 0:
break
except Exception as e:
print(e, self.dataset[idx % len(self.dataset)])
idx = random.randint(0, self.length-1)
if self.enable_inpaint and not self.enable_bucket:
mask = get_random_mask(pixel_values.size())
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
sample["mask_pixel_values"] = mask_pixel_values
sample["mask"] = mask
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
sample["clip_pixel_values"] = clip_pixel_values
return sample
class ImageVideoControlDataset(Dataset):
def __init__(
self,
ann_path, data_root=None,
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
image_sample_size=512,
video_repeat=0,
text_drop_ratio=0.1,
enable_bucket=False,
video_length_drop_start=0.1,
video_length_drop_end=0.9,
enable_inpaint=False,
enable_camera_info=False,
return_file_name=False,
enable_subject_info=False,
padding_subject_info=True,
):
# Loading annotations from files
print(f"loading annotations from {ann_path} ...")
if ann_path.endswith('.csv'):
with open(ann_path, 'r') as csvfile:
dataset = list(csv.DictReader(csvfile))
elif ann_path.endswith('.json'):
dataset = json.load(open(ann_path))
self.data_root = data_root
# It's used to balance num of images and videos.
if video_repeat > 0:
self.dataset = []
for data in dataset:
if data.get('type', 'image') != 'video':
self.dataset.append(data)
for _ in range(video_repeat):
for data in dataset:
if data.get('type', 'image') == 'video':
self.dataset.append(data)
else:
self.dataset = dataset
del dataset
self.length = len(self.dataset)
print(f"data scale: {self.length}")
# TODO: enable bucket training
self.enable_bucket = enable_bucket
self.text_drop_ratio = text_drop_ratio
self.enable_inpaint = enable_inpaint
self.enable_camera_info = enable_camera_info
self.enable_subject_info = enable_subject_info
self.padding_subject_info = padding_subject_info
self.video_length_drop_start = video_length_drop_start
self.video_length_drop_end = video_length_drop_end
# Video params
self.video_sample_stride = video_sample_stride
self.video_sample_n_frames = video_sample_n_frames
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
self.video_transforms = transforms.Compose(
[
transforms.Resize(min(self.video_sample_size)),
transforms.CenterCrop(self.video_sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
if self.enable_camera_info:
self.video_transforms_camera = transforms.Compose(
[
transforms.Resize(min(self.video_sample_size)),
transforms.CenterCrop(self.video_sample_size)
]
)
# Image params
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
self.image_transforms = transforms.Compose([
transforms.Resize(min(self.image_sample_size)),
transforms.CenterCrop(self.image_sample_size),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
])
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
def get_batch(self, idx):
data_info = self.dataset[idx % len(self.dataset)]
video_id, text = data_info['file_path'], data_info['text']
if data_info.get('type', 'image')=='video':
if self.data_root is None:
video_dir = video_id
else:
video_dir = os.path.join(self.data_root, video_id)
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
min_sample_n_frames = min(
self.video_sample_n_frames,
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
)
if min_sample_n_frames == 0:
raise ValueError(f"No Frames in video.")
video_length = int(self.video_length_drop_end * len(video_reader))
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
try:
sample_args = (video_reader, batch_index)
pixel_values = func_timeout(
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
)
resized_frames = []
for i in range(len(pixel_values)):
frame = pixel_values[i]
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
resized_frames.append(resized_frame)
pixel_values = np.array(resized_frames)
except FunctionTimedOut:
raise ValueError(f"Read {idx} timeout.")
except Exception as e:
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
if not self.enable_bucket:
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
pixel_values = pixel_values / 255.
del video_reader
else:
pixel_values = pixel_values
if not self.enable_bucket:
pixel_values = self.video_transforms(pixel_values)
# Random use no text generation
if random.random() < self.text_drop_ratio:
text = ''
control_video_id = data_info['control_file_path']
if control_video_id is not None:
if self.data_root is None:
control_video_id = control_video_id
else:
control_video_id = os.path.join(self.data_root, control_video_id)
if self.enable_camera_info:
if control_video_id.lower().endswith('.txt'):
if not self.enable_bucket:
control_pixel_values = torch.zeros_like(pixel_values)
control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0])
control_camera_values = torch.from_numpy(control_camera_values).permute(0, 3, 1, 2).contiguous()
control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)
control_camera_values = self.video_transforms_camera(control_camera_values)
else:
control_pixel_values = np.zeros_like(pixel_values)
control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0], return_poses=True)
control_camera_values = torch.from_numpy(np.array(control_camera_values)).unsqueeze(0).unsqueeze(0)
control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)[0][0]
control_camera_values = np.array([control_camera_values[index] for index in batch_index])
else:
if not self.enable_bucket:
control_pixel_values = torch.zeros_like(pixel_values)
control_camera_values = None
else:
control_pixel_values = np.zeros_like(pixel_values)
control_camera_values = None
else:
if control_video_id is not None:
with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
try:
sample_args = (control_video_reader, batch_index)
control_pixel_values = func_timeout(
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
)
resized_frames = []
for i in range(len(control_pixel_values)):
frame = control_pixel_values[i]
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
resized_frames.append(resized_frame)
control_pixel_values = np.array(resized_frames)
except FunctionTimedOut:
raise ValueError(f"Read {idx} timeout.")
except Exception as e:
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
if not self.enable_bucket:
control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
control_pixel_values = control_pixel_values / 255.
del control_video_reader
else:
control_pixel_values = control_pixel_values
if not self.enable_bucket:
control_pixel_values = self.video_transforms(control_pixel_values)
else:
if not self.enable_bucket:
control_pixel_values = torch.zeros_like(pixel_values)
else:
control_pixel_values = np.zeros_like(pixel_values)
control_camera_values = None
if self.enable_subject_info:
if not self.enable_bucket:
visual_height, visual_width = pixel_values.shape[-2:]
else:
visual_height, visual_width = pixel_values.shape[1:3]
subject_id = data_info.get('object_file_path', [])
shuffle(subject_id)
subject_images = []
for i in range(min(len(subject_id), 4)):
subject_image = Image.open(subject_id[i])
width, height = subject_image.size
total_pixels = width * height
if self.padding_subject_info:
img = padding_image(subject_image, visual_width, visual_height)
else:
img = resize_image_with_target_area(subject_image, 1024 * 1024)
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
subject_images.append(np.array(img))
if self.padding_subject_info:
subject_image = np.array(subject_images)
else:
subject_image = subject_images
else:
subject_image = None
return pixel_values, control_pixel_values, subject_image, control_camera_values, text, "video"
else:
image_path, text = data_info['file_path'], data_info['text']
if self.data_root is not None:
image_path = os.path.join(self.data_root, image_path)
image = Image.open(image_path).convert('RGB')
if not self.enable_bucket:
image = self.image_transforms(image).unsqueeze(0)
else:
image = np.expand_dims(np.array(image), 0)
if random.random() < self.text_drop_ratio:
text = ''
control_image_id = data_info['control_file_path']
if self.data_root is None:
control_image_id = control_image_id
else:
control_image_id = os.path.join(self.data_root, control_image_id)
control_image = Image.open(control_image_id).convert('RGB')
if not self.enable_bucket:
control_image = self.image_transforms(control_image).unsqueeze(0)
else:
control_image = np.expand_dims(np.array(control_image), 0)
if self.enable_subject_info:
if not self.enable_bucket:
visual_height, visual_width = image.shape[-2:]
else:
visual_height, visual_width = image.shape[1:3]
subject_id = data_info.get('object_file_path', [])
shuffle(subject_id)
subject_images = []
for i in range(min(len(subject_id), 4)):
subject_image = Image.open(subject_id[i]).convert('RGB')
width, height = subject_image.size
total_pixels = width * height
if self.padding_subject_info:
img = padding_image(subject_image, visual_width, visual_height)
else:
img = resize_image_with_target_area(subject_image, 1024 * 1024)
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
subject_images.append(np.array(img))
if self.padding_subject_info:
subject_image = np.array(subject_images)
else:
subject_image = subject_images
else:
subject_image = None
return image, control_image, subject_image, None, text, 'image'
def __len__(self):
return self.length
def __getitem__(self, idx):
data_info = self.dataset[idx % len(self.dataset)]
data_type = data_info.get('type', 'image')
while True:
sample = {}
try:
data_info_local = self.dataset[idx % len(self.dataset)]
data_type_local = data_info_local.get('type', 'image')
if data_type_local != data_type:
raise ValueError("data_type_local != data_type")
pixel_values, control_pixel_values, subject_image, control_camera_values, name, data_type = self.get_batch(idx)
sample["pixel_values"] = pixel_values
sample["control_pixel_values"] = control_pixel_values
sample["subject_image"] = subject_image
sample["text"] = name
sample["data_type"] = data_type
sample["idx"] = idx
if self.enable_camera_info:
sample["control_camera_values"] = control_camera_values
if len(sample) > 0:
break
except Exception as e:
print(e, self.dataset[idx % len(self.dataset)])
idx = random.randint(0, self.length-1)
if self.enable_inpaint and not self.enable_bucket:
mask = get_random_mask(pixel_values.size())
mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
sample["mask_pixel_values"] = mask_pixel_values
sample["mask"] = mask
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
sample["clip_pixel_values"] = clip_pixel_values
return sample
class ImageVideoSafetensorsDataset(Dataset):
def __init__(
self,
ann_path,
data_root=None,
):
# Loading annotations from files
print(f"loading annotations from {ann_path} ...")
if ann_path.endswith('.json'):
dataset = json.load(open(ann_path))
self.data_root = data_root
self.dataset = dataset
self.length = len(self.dataset)
print(f"data scale: {self.length}")
def __len__(self):
return self.length
def __getitem__(self, idx):
if self.data_root is None:
path = self.dataset[idx]["file_path"]
else:
path = os.path.join(self.data_root, self.dataset[idx]["file_path"])
state_dict = load_file(path)
return state_dict
class TextDataset(Dataset):
def __init__(self, ann_path, text_drop_ratio=0.0):
print(f"loading annotations from {ann_path} ...")
with open(ann_path, 'r') as f:
self.dataset = json.load(f)
self.length = len(self.dataset)
print(f"data scale: {self.length}")
self.text_drop_ratio = text_drop_ratio
def __len__(self):
return self.length
def __getitem__(self, idx):
while True:
try:
item = self.dataset[idx]
text = item['text']
# Randomly drop text (for classifier-free guidance)
if random.random() < self.text_drop_ratio:
text = ''
sample = {
"text": text,
"idx": idx
}
return sample
except Exception as e:
print(f"Error at index {idx}: {e}, retrying with random index...")
idx = np.random.randint(0, self.length - 1)