Spaces:
Running
on
Zero
Running
on
Zero
| import csv | |
| import gc | |
| import io | |
| import json | |
| import math | |
| import os | |
| import random | |
| from contextlib import contextmanager | |
| from threading import Thread | |
| import albumentations | |
| import cv2 | |
| import librosa | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as transforms | |
| from decord import VideoReader | |
| from einops import rearrange | |
| from func_timeout import FunctionTimedOut, func_timeout | |
| from PIL import Image | |
| from torch.utils.data import BatchSampler, Sampler | |
| from torch.utils.data.dataset import Dataset | |
| from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager, | |
| custom_meshgrid, get_random_mask, get_relative_pose, | |
| get_video_reader_batch, padding_image, process_pose_file, | |
| process_pose_params, ray_condition, resize_frame, | |
| resize_image_with_target_area) | |
| class WebVid10M(Dataset): | |
| def __init__( | |
| self, | |
| csv_path, video_folder, | |
| sample_size=256, sample_stride=4, sample_n_frames=16, | |
| enable_bucket=False, enable_inpaint=False, is_image=False, | |
| ): | |
| print(f"loading annotations from {csv_path} ...") | |
| with open(csv_path, 'r') as csvfile: | |
| self.dataset = list(csv.DictReader(csvfile)) | |
| self.length = len(self.dataset) | |
| print(f"data scale: {self.length}") | |
| self.video_folder = video_folder | |
| self.sample_stride = sample_stride | |
| self.sample_n_frames = sample_n_frames | |
| self.enable_bucket = enable_bucket | |
| self.enable_inpaint = enable_inpaint | |
| self.is_image = is_image | |
| sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) | |
| self.pixel_transforms = transforms.Compose([ | |
| transforms.Resize(sample_size[0]), | |
| transforms.CenterCrop(sample_size), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
| ]) | |
| def get_batch(self, idx): | |
| video_dict = self.dataset[idx] | |
| videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] | |
| video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") | |
| video_reader = VideoReader(video_dir) | |
| video_length = len(video_reader) | |
| if not self.is_image: | |
| clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) | |
| start_idx = random.randint(0, video_length - clip_length) | |
| batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) | |
| else: | |
| batch_index = [random.randint(0, video_length - 1)] | |
| if not self.enable_bucket: | |
| pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() | |
| pixel_values = pixel_values / 255. | |
| del video_reader | |
| else: | |
| pixel_values = video_reader.get_batch(batch_index).asnumpy() | |
| if self.is_image: | |
| pixel_values = pixel_values[0] | |
| return pixel_values, name | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| while True: | |
| try: | |
| pixel_values, name = self.get_batch(idx) | |
| break | |
| except Exception as e: | |
| print("Error info:", e) | |
| idx = random.randint(0, self.length-1) | |
| if not self.enable_bucket: | |
| pixel_values = self.pixel_transforms(pixel_values) | |
| if self.enable_inpaint: | |
| mask = get_random_mask(pixel_values.size()) | |
| mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask | |
| sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name) | |
| else: | |
| sample = dict(pixel_values=pixel_values, text=name) | |
| return sample | |
| class VideoDataset(Dataset): | |
| def __init__( | |
| self, | |
| ann_path, data_root=None, | |
| sample_size=256, sample_stride=4, sample_n_frames=16, | |
| enable_bucket=False, enable_inpaint=False | |
| ): | |
| print(f"loading annotations from {ann_path} ...") | |
| self.dataset = json.load(open(ann_path, 'r')) | |
| self.length = len(self.dataset) | |
| print(f"data scale: {self.length}") | |
| self.data_root = data_root | |
| self.sample_stride = sample_stride | |
| self.sample_n_frames = sample_n_frames | |
| self.enable_bucket = enable_bucket | |
| self.enable_inpaint = enable_inpaint | |
| sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) | |
| self.pixel_transforms = transforms.Compose( | |
| [ | |
| transforms.Resize(sample_size[0]), | |
| transforms.CenterCrop(sample_size), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
| ] | |
| ) | |
| def get_batch(self, idx): | |
| video_dict = self.dataset[idx] | |
| video_id, text = video_dict['file_path'], video_dict['text'] | |
| if self.data_root is None: | |
| video_dir = video_id | |
| else: | |
| video_dir = os.path.join(self.data_root, video_id) | |
| with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: | |
| min_sample_n_frames = min( | |
| self.video_sample_n_frames, | |
| int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) | |
| ) | |
| if min_sample_n_frames == 0: | |
| raise ValueError(f"No Frames in video.") | |
| video_length = int(self.video_length_drop_end * len(video_reader)) | |
| clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) | |
| start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 | |
| batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) | |
| try: | |
| sample_args = (video_reader, batch_index) | |
| pixel_values = func_timeout( | |
| VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args | |
| ) | |
| except FunctionTimedOut: | |
| raise ValueError(f"Read {idx} timeout.") | |
| except Exception as e: | |
| raise ValueError(f"Failed to extract frames from video. Error is {e}.") | |
| if not self.enable_bucket: | |
| pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() | |
| pixel_values = pixel_values / 255. | |
| del video_reader | |
| else: | |
| pixel_values = pixel_values | |
| if not self.enable_bucket: | |
| pixel_values = self.video_transforms(pixel_values) | |
| # Random use no text generation | |
| if random.random() < self.text_drop_ratio: | |
| text = '' | |
| return pixel_values, text | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| while True: | |
| sample = {} | |
| try: | |
| pixel_values, name = self.get_batch(idx) | |
| sample["pixel_values"] = pixel_values | |
| sample["text"] = name | |
| sample["idx"] = idx | |
| if len(sample) > 0: | |
| break | |
| except Exception as e: | |
| print(e, self.dataset[idx % len(self.dataset)]) | |
| idx = random.randint(0, self.length-1) | |
| if self.enable_inpaint and not self.enable_bucket: | |
| mask = get_random_mask(pixel_values.size()) | |
| mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask | |
| sample["mask_pixel_values"] = mask_pixel_values | |
| sample["mask"] = mask | |
| clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() | |
| clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 | |
| sample["clip_pixel_values"] = clip_pixel_values | |
| return sample | |
| class VideoSpeechDataset(Dataset): | |
| def __init__( | |
| self, | |
| ann_path, data_root=None, | |
| video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, | |
| enable_bucket=False, enable_inpaint=False, | |
| audio_sr=16000, # 新增:目标音频采样率 | |
| text_drop_ratio=0.1 # 新增:文本丢弃概率 | |
| ): | |
| print(f"loading annotations from {ann_path} ...") | |
| self.dataset = json.load(open(ann_path, 'r')) | |
| self.length = len(self.dataset) | |
| print(f"data scale: {self.length}") | |
| self.data_root = data_root | |
| self.video_sample_stride = video_sample_stride | |
| self.video_sample_n_frames = video_sample_n_frames | |
| self.enable_bucket = enable_bucket | |
| self.enable_inpaint = enable_inpaint | |
| self.audio_sr = audio_sr | |
| self.text_drop_ratio = text_drop_ratio | |
| video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) | |
| self.pixel_transforms = transforms.Compose( | |
| [ | |
| transforms.Resize(video_sample_size[0]), | |
| transforms.CenterCrop(video_sample_size), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
| ] | |
| ) | |
| def get_batch(self, idx): | |
| video_dict = self.dataset[idx] | |
| video_id, text = video_dict['file_path'], video_dict['text'] | |
| audio_id = video_dict['audio_path'] | |
| if self.data_root is None: | |
| video_path = video_id | |
| else: | |
| video_path = os.path.join(self.data_root, video_id) | |
| if self.data_root is None: | |
| audio_path = audio_id | |
| else: | |
| audio_path = os.path.join(self.data_root, audio_id) | |
| if not os.path.exists(audio_path): | |
| raise FileNotFoundError(f"Audio file not found for {video_path}") | |
| with VideoReader_contextmanager(video_path, num_threads=2) as video_reader: | |
| total_frames = len(video_reader) | |
| fps = video_reader.get_avg_fps() # 获取原始视频帧率 | |
| # 计算实际采样的视频帧数(考虑边界) | |
| max_possible_frames = (total_frames - 1) // self.video_sample_stride + 1 | |
| actual_n_frames = min(self.video_sample_n_frames, max_possible_frames) | |
| if actual_n_frames <= 0: | |
| raise ValueError(f"Video too short: {video_path}") | |
| # 随机选择起始帧 | |
| max_start = total_frames - (actual_n_frames - 1) * self.video_sample_stride - 1 | |
| start_frame = random.randint(0, max_start) if max_start > 0 else 0 | |
| frame_indices = [start_frame + i * self.video_sample_stride for i in range(actual_n_frames)] | |
| # 读取视频帧 | |
| try: | |
| sample_args = (video_reader, frame_indices) | |
| pixel_values = func_timeout( | |
| VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args | |
| ) | |
| except FunctionTimedOut: | |
| raise ValueError(f"Read {idx} timeout.") | |
| except Exception as e: | |
| raise ValueError(f"Failed to extract frames from video. Error is {e}.") | |
| # 视频后处理 | |
| if not self.enable_bucket: | |
| pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() | |
| pixel_values = pixel_values / 255. | |
| pixel_values = self.pixel_transforms(pixel_values) | |
| # === 新增:加载并截取对应音频 === | |
| # 视频片段的起止时间(秒) | |
| start_time = start_frame / fps | |
| end_time = (start_frame + (actual_n_frames - 1) * self.video_sample_stride) / fps | |
| duration = end_time - start_time | |
| # 使用 librosa 加载整个音频(或仅加载所需部分,但 librosa.load 不支持精确 seek,所以先加载再切) | |
| audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr) # 重采样到目标 sr | |
| # 转换为样本索引 | |
| start_sample = int(start_time * self.audio_sr) | |
| end_sample = int(end_time * self.audio_sr) | |
| # 安全截取 | |
| if start_sample >= len(audio_input): | |
| # 音频太短,用零填充或截断 | |
| audio_segment = np.zeros(int(duration * self.audio_sr), dtype=np.float32) | |
| else: | |
| audio_segment = audio_input[start_sample:end_sample] | |
| # 如果太短,补零 | |
| target_len = int(duration * self.audio_sr) | |
| if len(audio_segment) < target_len: | |
| audio_segment = np.pad(audio_segment, (0, target_len - len(audio_segment)), mode='constant') | |
| # === 文本随机丢弃 === | |
| if random.random() < self.text_drop_ratio: | |
| text = '' | |
| return pixel_values, text, audio_segment, sample_rate | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| while True: | |
| sample = {} | |
| try: | |
| pixel_values, text, audio, sample_rate = self.get_batch(idx) | |
| sample["pixel_values"] = pixel_values | |
| sample["text"] = text | |
| sample["audio"] = torch.from_numpy(audio).float() # 转为 tensor | |
| sample["sample_rate"] = sample_rate | |
| sample["idx"] = idx | |
| break | |
| except Exception as e: | |
| print(f"Error processing {idx}: {e}, retrying with random idx...") | |
| idx = random.randint(0, self.length - 1) | |
| if self.enable_inpaint and not self.enable_bucket: | |
| mask = get_random_mask(pixel_values.size(), image_start_only=True) | |
| mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask | |
| sample["mask_pixel_values"] = mask_pixel_values | |
| sample["mask"] = mask | |
| clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() | |
| clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 | |
| sample["clip_pixel_values"] = clip_pixel_values | |
| return sample | |
| class VideoSpeechControlDataset(Dataset): | |
| def __init__( | |
| self, | |
| ann_path, data_root=None, | |
| video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, | |
| enable_bucket=False, enable_inpaint=False, | |
| audio_sr=16000, | |
| text_drop_ratio=0.1, | |
| enable_motion_info=False, | |
| motion_frames=73, | |
| ): | |
| print(f"loading annotations from {ann_path} ...") | |
| self.dataset = json.load(open(ann_path, 'r')) | |
| self.length = len(self.dataset) | |
| print(f"data scale: {self.length}") | |
| self.data_root = data_root | |
| self.video_sample_stride = video_sample_stride | |
| self.video_sample_n_frames = video_sample_n_frames | |
| self.enable_bucket = enable_bucket | |
| self.enable_inpaint = enable_inpaint | |
| self.audio_sr = audio_sr | |
| self.text_drop_ratio = text_drop_ratio | |
| self.enable_motion_info = enable_motion_info | |
| self.motion_frames = motion_frames | |
| video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) | |
| self.pixel_transforms = transforms.Compose( | |
| [ | |
| transforms.Resize(video_sample_size[0]), | |
| transforms.CenterCrop(video_sample_size), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
| ] | |
| ) | |
| self.video_sample_size = video_sample_size | |
| def get_batch(self, idx): | |
| video_dict = self.dataset[idx] | |
| video_id, text = video_dict['file_path'], video_dict['text'] | |
| audio_id = video_dict['audio_path'] | |
| control_video_id = video_dict['control_file_path'] | |
| if self.data_root is None: | |
| video_path = video_id | |
| else: | |
| video_path = os.path.join(self.data_root, video_id) | |
| if self.data_root is None: | |
| audio_path = audio_id | |
| else: | |
| audio_path = os.path.join(self.data_root, audio_id) | |
| if self.data_root is None: | |
| control_video_id = control_video_id | |
| else: | |
| control_video_id = os.path.join(self.data_root, control_video_id) | |
| if not os.path.exists(audio_path): | |
| raise FileNotFoundError(f"Audio file not found for {video_path}") | |
| # Video information | |
| with VideoReader_contextmanager(video_path, num_threads=2) as video_reader: | |
| total_frames = len(video_reader) | |
| fps = video_reader.get_avg_fps() | |
| if fps <= 0: | |
| raise ValueError(f"Video has negative fps: {video_path}") | |
| local_video_sample_stride = self.video_sample_stride | |
| new_fps = int(fps // local_video_sample_stride) | |
| while new_fps > 30: | |
| local_video_sample_stride = local_video_sample_stride + 1 | |
| new_fps = int(fps // local_video_sample_stride) | |
| max_possible_frames = (total_frames - 1) // local_video_sample_stride + 1 | |
| actual_n_frames = min(self.video_sample_n_frames, max_possible_frames) | |
| if actual_n_frames <= 0: | |
| raise ValueError(f"Video too short: {video_path}") | |
| max_start = total_frames - (actual_n_frames - 1) * local_video_sample_stride - 1 | |
| start_frame = random.randint(0, max_start) if max_start > 0 else 0 | |
| frame_indices = [start_frame + i * local_video_sample_stride for i in range(actual_n_frames)] | |
| try: | |
| sample_args = (video_reader, frame_indices) | |
| pixel_values = func_timeout( | |
| VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args | |
| ) | |
| except FunctionTimedOut: | |
| raise ValueError(f"Read {idx} timeout.") | |
| except Exception as e: | |
| raise ValueError(f"Failed to extract frames from video. Error is {e}.") | |
| _, height, width, channel = np.shape(pixel_values) | |
| if self.enable_motion_info: | |
| motion_pixel_values = np.ones([self.motion_frames, height, width, channel]) * 127.5 | |
| if start_frame > 0: | |
| motion_max_possible_frames = (start_frame - 1) // local_video_sample_stride + 1 | |
| motion_frame_indices = [0 + i * local_video_sample_stride for i in range(motion_max_possible_frames)] | |
| motion_frame_indices = motion_frame_indices[-self.motion_frames:] | |
| _motion_sample_args = (video_reader, motion_frame_indices) | |
| _motion_pixel_values = func_timeout( | |
| VIDEO_READER_TIMEOUT, get_video_reader_batch, args=_motion_sample_args | |
| ) | |
| motion_pixel_values[-len(motion_frame_indices):] = _motion_pixel_values | |
| if not self.enable_bucket: | |
| motion_pixel_values = torch.from_numpy(motion_pixel_values).permute(0, 3, 1, 2).contiguous() | |
| motion_pixel_values = motion_pixel_values / 255. | |
| motion_pixel_values = self.pixel_transforms(motion_pixel_values) | |
| else: | |
| motion_pixel_values = None | |
| if not self.enable_bucket: | |
| pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() | |
| pixel_values = pixel_values / 255. | |
| pixel_values = self.pixel_transforms(pixel_values) | |
| # Audio information | |
| start_time = start_frame / fps | |
| end_time = (start_frame + (actual_n_frames - 1) * local_video_sample_stride) / fps | |
| duration = end_time - start_time | |
| audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr) | |
| start_sample = int(start_time * self.audio_sr) | |
| end_sample = int(end_time * self.audio_sr) | |
| if start_sample >= len(audio_input): | |
| raise ValueError(f"Audio file too short: {audio_path}") | |
| else: | |
| audio_segment = audio_input[start_sample:end_sample] | |
| target_len = int(duration * self.audio_sr) | |
| if len(audio_segment) < target_len: | |
| raise ValueError(f"Audio file too short: {audio_path}") | |
| # Control information | |
| with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader: | |
| try: | |
| sample_args = (control_video_reader, frame_indices) | |
| control_pixel_values = func_timeout( | |
| VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args | |
| ) | |
| resized_frames = [] | |
| for i in range(len(control_pixel_values)): | |
| frame = control_pixel_values[i] | |
| resized_frame = resize_frame(frame, max(self.video_sample_size)) | |
| resized_frames.append(resized_frame) | |
| control_pixel_values = np.array(control_pixel_values) | |
| except FunctionTimedOut: | |
| raise ValueError(f"Read {idx} timeout.") | |
| except Exception as e: | |
| raise ValueError(f"Failed to extract frames from video. Error is {e}.") | |
| if not self.enable_bucket: | |
| control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous() | |
| control_pixel_values = control_pixel_values / 255. | |
| del control_video_reader | |
| else: | |
| control_pixel_values = control_pixel_values | |
| if not self.enable_bucket: | |
| control_pixel_values = self.video_transforms(control_pixel_values) | |
| if random.random() < self.text_drop_ratio: | |
| text = '' | |
| return pixel_values, motion_pixel_values, control_pixel_values, text, audio_segment, sample_rate, new_fps | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| while True: | |
| sample = {} | |
| try: | |
| pixel_values, motion_pixel_values, control_pixel_values, text, audio, sample_rate, new_fps = self.get_batch(idx) | |
| sample["pixel_values"] = pixel_values | |
| sample["motion_pixel_values"] = motion_pixel_values | |
| sample["control_pixel_values"] = control_pixel_values | |
| sample["text"] = text | |
| sample["audio"] = torch.from_numpy(audio).float() # 转为 tensor | |
| sample["sample_rate"] = sample_rate | |
| sample["fps"] = new_fps | |
| sample["idx"] = idx | |
| break | |
| except Exception as e: | |
| print(f"Error processing {idx}: {e}, retrying with random idx...") | |
| idx = random.randint(0, self.length - 1) | |
| if self.enable_inpaint and not self.enable_bucket: | |
| mask = get_random_mask(pixel_values.size(), image_start_only=True) | |
| mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask | |
| sample["mask_pixel_values"] = mask_pixel_values | |
| sample["mask"] = mask | |
| clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() | |
| clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 | |
| sample["clip_pixel_values"] = clip_pixel_values | |
| return sample | |
| class VideoAnimateDataset(Dataset): | |
| def __init__( | |
| self, | |
| ann_path, data_root=None, | |
| video_sample_size=512, | |
| video_sample_stride=4, | |
| video_sample_n_frames=16, | |
| video_repeat=0, | |
| text_drop_ratio=0.1, | |
| enable_bucket=False, | |
| video_length_drop_start=0.1, | |
| video_length_drop_end=0.9, | |
| return_file_name=False, | |
| ): | |
| # Loading annotations from files | |
| print(f"loading annotations from {ann_path} ...") | |
| if ann_path.endswith('.csv'): | |
| with open(ann_path, 'r') as csvfile: | |
| dataset = list(csv.DictReader(csvfile)) | |
| elif ann_path.endswith('.json'): | |
| dataset = json.load(open(ann_path)) | |
| self.data_root = data_root | |
| # It's used to balance num of images and videos. | |
| if video_repeat > 0: | |
| self.dataset = [] | |
| for data in dataset: | |
| if data.get('type', 'image') != 'video': | |
| self.dataset.append(data) | |
| for _ in range(video_repeat): | |
| for data in dataset: | |
| if data.get('type', 'image') == 'video': | |
| self.dataset.append(data) | |
| else: | |
| self.dataset = dataset | |
| del dataset | |
| self.length = len(self.dataset) | |
| print(f"data scale: {self.length}") | |
| # TODO: enable bucket training | |
| self.enable_bucket = enable_bucket | |
| self.text_drop_ratio = text_drop_ratio | |
| self.video_length_drop_start = video_length_drop_start | |
| self.video_length_drop_end = video_length_drop_end | |
| # Video params | |
| self.video_sample_stride = video_sample_stride | |
| self.video_sample_n_frames = video_sample_n_frames | |
| self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) | |
| self.video_transforms = transforms.Compose( | |
| [ | |
| transforms.Resize(min(self.video_sample_size)), | |
| transforms.CenterCrop(self.video_sample_size), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
| ] | |
| ) | |
| self.larger_side_of_image_and_video = min(self.video_sample_size) | |
| def get_batch(self, idx): | |
| data_info = self.dataset[idx % len(self.dataset)] | |
| video_id, text = data_info['file_path'], data_info['text'] | |
| if self.data_root is None: | |
| video_dir = video_id | |
| else: | |
| video_dir = os.path.join(self.data_root, video_id) | |
| with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: | |
| min_sample_n_frames = min( | |
| self.video_sample_n_frames, | |
| int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) | |
| ) | |
| if min_sample_n_frames == 0: | |
| raise ValueError(f"No Frames in video.") | |
| video_length = int(self.video_length_drop_end * len(video_reader)) | |
| clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) | |
| start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 | |
| batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) | |
| try: | |
| sample_args = (video_reader, batch_index) | |
| pixel_values = func_timeout( | |
| VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args | |
| ) | |
| resized_frames = [] | |
| for i in range(len(pixel_values)): | |
| frame = pixel_values[i] | |
| resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) | |
| resized_frames.append(resized_frame) | |
| pixel_values = np.array(resized_frames) | |
| except FunctionTimedOut: | |
| raise ValueError(f"Read {idx} timeout.") | |
| except Exception as e: | |
| raise ValueError(f"Failed to extract frames from video. Error is {e}.") | |
| if not self.enable_bucket: | |
| pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() | |
| pixel_values = pixel_values / 255. | |
| del video_reader | |
| else: | |
| pixel_values = pixel_values | |
| if not self.enable_bucket: | |
| pixel_values = self.video_transforms(pixel_values) | |
| # Random use no text generation | |
| if random.random() < self.text_drop_ratio: | |
| text = '' | |
| control_video_id = data_info['control_file_path'] | |
| if control_video_id is not None: | |
| if self.data_root is None: | |
| control_video_id = control_video_id | |
| else: | |
| control_video_id = os.path.join(self.data_root, control_video_id) | |
| if control_video_id is not None: | |
| with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader: | |
| try: | |
| sample_args = (control_video_reader, batch_index) | |
| control_pixel_values = func_timeout( | |
| VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args | |
| ) | |
| resized_frames = [] | |
| for i in range(len(control_pixel_values)): | |
| frame = control_pixel_values[i] | |
| resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) | |
| resized_frames.append(resized_frame) | |
| control_pixel_values = np.array(resized_frames) | |
| except FunctionTimedOut: | |
| raise ValueError(f"Read {idx} timeout.") | |
| except Exception as e: | |
| raise ValueError(f"Failed to extract frames from video. Error is {e}.") | |
| if not self.enable_bucket: | |
| control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous() | |
| control_pixel_values = control_pixel_values / 255. | |
| del control_video_reader | |
| else: | |
| control_pixel_values = control_pixel_values | |
| if not self.enable_bucket: | |
| control_pixel_values = self.video_transforms(control_pixel_values) | |
| else: | |
| if not self.enable_bucket: | |
| control_pixel_values = torch.zeros_like(pixel_values) | |
| else: | |
| control_pixel_values = np.zeros_like(pixel_values) | |
| face_video_id = data_info['face_file_path'] | |
| if face_video_id is not None: | |
| if self.data_root is None: | |
| face_video_id = face_video_id | |
| else: | |
| face_video_id = os.path.join(self.data_root, face_video_id) | |
| if face_video_id is not None: | |
| with VideoReader_contextmanager(face_video_id, num_threads=2) as face_video_reader: | |
| try: | |
| sample_args = (face_video_reader, batch_index) | |
| face_pixel_values = func_timeout( | |
| VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args | |
| ) | |
| resized_frames = [] | |
| for i in range(len(face_pixel_values)): | |
| frame = face_pixel_values[i] | |
| resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) | |
| resized_frames.append(resized_frame) | |
| face_pixel_values = np.array(resized_frames) | |
| except FunctionTimedOut: | |
| raise ValueError(f"Read {idx} timeout.") | |
| except Exception as e: | |
| raise ValueError(f"Failed to extract frames from video. Error is {e}.") | |
| if not self.enable_bucket: | |
| face_pixel_values = torch.from_numpy(face_pixel_values).permute(0, 3, 1, 2).contiguous() | |
| face_pixel_values = face_pixel_values / 255. | |
| del face_video_reader | |
| else: | |
| face_pixel_values = face_pixel_values | |
| if not self.enable_bucket: | |
| face_pixel_values = self.video_transforms(face_pixel_values) | |
| else: | |
| if not self.enable_bucket: | |
| face_pixel_values = torch.zeros_like(pixel_values) | |
| else: | |
| face_pixel_values = np.zeros_like(pixel_values) | |
| background_video_id = data_info.get('background_file_path', None) | |
| if background_video_id is not None: | |
| if self.data_root is None: | |
| background_video_id = background_video_id | |
| else: | |
| background_video_id = os.path.join(self.data_root, background_video_id) | |
| if background_video_id is not None: | |
| with VideoReader_contextmanager(background_video_id, num_threads=2) as background_video_reader: | |
| try: | |
| sample_args = (background_video_reader, batch_index) | |
| background_pixel_values = func_timeout( | |
| VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args | |
| ) | |
| resized_frames = [] | |
| for i in range(len(background_pixel_values)): | |
| frame = background_pixel_values[i] | |
| resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) | |
| resized_frames.append(resized_frame) | |
| background_pixel_values = np.array(resized_frames) | |
| except FunctionTimedOut: | |
| raise ValueError(f"Read {idx} timeout.") | |
| except Exception as e: | |
| raise ValueError(f"Failed to extract frames from video. Error is {e}.") | |
| if not self.enable_bucket: | |
| background_pixel_values = torch.from_numpy(background_pixel_values).permute(0, 3, 1, 2).contiguous() | |
| background_pixel_values = background_pixel_values / 255. | |
| del background_video_reader | |
| else: | |
| background_pixel_values = background_pixel_values | |
| if not self.enable_bucket: | |
| background_pixel_values = self.video_transforms(background_pixel_values) | |
| else: | |
| if not self.enable_bucket: | |
| background_pixel_values = torch.ones_like(pixel_values) * 127.5 | |
| else: | |
| background_pixel_values = np.ones_like(pixel_values) * 127.5 | |
| mask_video_id = data_info.get('mask_file_path', None) | |
| if mask_video_id is not None: | |
| if self.data_root is None: | |
| mask_video_id = mask_video_id | |
| else: | |
| mask_video_id = os.path.join(self.data_root, mask_video_id) | |
| if mask_video_id is not None: | |
| with VideoReader_contextmanager(mask_video_id, num_threads=2) as mask_video_reader: | |
| try: | |
| sample_args = (mask_video_reader, batch_index) | |
| mask = func_timeout( | |
| VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args | |
| ) | |
| resized_frames = [] | |
| for i in range(len(mask)): | |
| frame = mask[i] | |
| resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) | |
| resized_frames.append(resized_frame) | |
| mask = np.array(resized_frames) | |
| except FunctionTimedOut: | |
| raise ValueError(f"Read {idx} timeout.") | |
| except Exception as e: | |
| raise ValueError(f"Failed to extract frames from video. Error is {e}.") | |
| if not self.enable_bucket: | |
| mask = torch.from_numpy(mask).permute(0, 3, 1, 2).contiguous() | |
| mask = mask / 255. | |
| del mask_video_reader | |
| else: | |
| mask = mask | |
| else: | |
| if not self.enable_bucket: | |
| mask = torch.ones_like(pixel_values) | |
| else: | |
| mask = np.ones_like(pixel_values) * 255 | |
| mask = mask[:, :, :, :1] | |
| ref_pixel_values_path = data_info.get('ref_file_path', []) | |
| if self.data_root is not None: | |
| ref_pixel_values_path = os.path.join(self.data_root, ref_pixel_values_path) | |
| ref_pixel_values = Image.open(ref_pixel_values_path).convert('RGB') | |
| if not self.enable_bucket: | |
| raise ValueError("Not enable_bucket is not supported now. ") | |
| else: | |
| ref_pixel_values = np.array(ref_pixel_values) | |
| return pixel_values, control_pixel_values, face_pixel_values, background_pixel_values, mask, ref_pixel_values, text, "video" | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| data_info = self.dataset[idx % len(self.dataset)] | |
| data_type = data_info.get('type', 'image') | |
| while True: | |
| sample = {} | |
| try: | |
| data_info_local = self.dataset[idx % len(self.dataset)] | |
| data_type_local = data_info_local.get('type', 'image') | |
| if data_type_local != data_type: | |
| raise ValueError("data_type_local != data_type") | |
| pixel_values, control_pixel_values, face_pixel_values, background_pixel_values, mask, ref_pixel_values, name, data_type = \ | |
| self.get_batch(idx) | |
| sample["pixel_values"] = pixel_values | |
| sample["control_pixel_values"] = control_pixel_values | |
| sample["face_pixel_values"] = face_pixel_values | |
| sample["background_pixel_values"] = background_pixel_values | |
| sample["mask"] = mask | |
| sample["ref_pixel_values"] = ref_pixel_values | |
| sample["clip_pixel_values"] = ref_pixel_values | |
| sample["text"] = name | |
| sample["data_type"] = data_type | |
| sample["idx"] = idx | |
| if len(sample) > 0: | |
| break | |
| except Exception as e: | |
| print(e, self.dataset[idx % len(self.dataset)]) | |
| idx = random.randint(0, self.length-1) | |
| return sample | |
| if __name__ == "__main__": | |
| if 1: | |
| dataset = VideoDataset( | |
| json_path="./webvidval/results_2M_val.json", | |
| sample_size=256, | |
| sample_stride=4, sample_n_frames=16, | |
| ) | |
| if 0: | |
| dataset = WebVid10M( | |
| csv_path="./webvid/results_2M_val.csv", | |
| video_folder="./webvid/2M_val", | |
| sample_size=256, | |
| sample_stride=4, sample_n_frames=16, | |
| is_image=False, | |
| ) | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,) | |
| for idx, batch in enumerate(dataloader): | |
| print(batch["pixel_values"].shape, len(batch["text"])) |