Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import json | |
| import torch | |
| import decord | |
| import torchvision | |
| import numpy as np | |
| import random | |
| from PIL import Image | |
| from einops import rearrange | |
| from typing import Dict, List, Tuple | |
| from torchvision import transforms | |
| import traceback | |
| class_labels_map = None | |
| cls_sample_cnt = None | |
| def temporal_sampling(frames, start_idx, end_idx, num_samples): | |
| """ | |
| Given the start and end frame index, sample num_samples frames between | |
| the start and end with equal interval. | |
| Args: | |
| frames (tensor): a tensor of video frames, dimension is | |
| `num video frames` x `channel` x `height` x `width`. | |
| start_idx (int): the index of the start frame. | |
| end_idx (int): the index of the end frame. | |
| num_samples (int): number of frames to sample. | |
| Returns: | |
| frames (tersor): a tensor of temporal sampled video frames, dimension is | |
| `num clip frames` x `channel` x `height` x `width`. | |
| """ | |
| index = torch.linspace(start_idx, end_idx, num_samples) | |
| index = torch.clamp(index, 0, frames.shape[0] - 1).long() | |
| frames = torch.index_select(frames, 0, index) | |
| return frames | |
| def numpy2tensor(x): | |
| return torch.from_numpy(x) | |
| def get_filelist(file_path): | |
| Filelist = [] | |
| for home, dirs, files in os.walk(file_path): | |
| for filename in files: | |
| # 文件名列表,包含完整路径 | |
| Filelist.append(os.path.join(home, filename)) | |
| # # 文件名列表,只包含文件名 | |
| # Filelist.append( filename) | |
| return Filelist | |
| def load_annotation_data(data_file_path): | |
| with open(data_file_path, 'r') as data_file: | |
| return json.load(data_file) | |
| def get_class_labels(num_class, anno_pth='./k400_classmap.json'): | |
| global class_labels_map, cls_sample_cnt | |
| if class_labels_map is not None: | |
| return class_labels_map, cls_sample_cnt | |
| else: | |
| cls_sample_cnt = {} | |
| class_labels_map = load_annotation_data(anno_pth) | |
| for cls in class_labels_map: | |
| cls_sample_cnt[cls] = 0 | |
| return class_labels_map, cls_sample_cnt | |
| def load_annotations(ann_file, num_class, num_samples_per_cls): | |
| dataset = [] | |
| class_to_idx, cls_sample_cnt = get_class_labels(num_class) | |
| with open(ann_file, 'r') as fin: | |
| for line in fin: | |
| line_split = line.strip().split('\t') | |
| sample = {} | |
| idx = 0 | |
| # idx for frame_dir | |
| frame_dir = line_split[idx] | |
| sample['video'] = frame_dir | |
| idx += 1 | |
| # idx for label[s] | |
| label = [x for x in line_split[idx:]] | |
| assert label, f'missing label in line: {line}' | |
| assert len(label) == 1 | |
| class_name = label[0] | |
| class_index = int(class_to_idx[class_name]) | |
| # choose a class subset of whole dataset | |
| if class_index < num_class: | |
| sample['label'] = class_index | |
| if cls_sample_cnt[class_name] < num_samples_per_cls: | |
| dataset.append(sample) | |
| cls_sample_cnt[class_name]+=1 | |
| return dataset | |
| class DecordInit(object): | |
| """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader.""" | |
| def __init__(self, num_threads=1, **kwargs): | |
| self.num_threads = num_threads | |
| self.ctx = decord.cpu(0) | |
| self.kwargs = kwargs | |
| def __call__(self, filename): | |
| """Perform the Decord initialization. | |
| Args: | |
| results (dict): The resulting dict to be modified and passed | |
| to the next transform in pipeline. | |
| """ | |
| reader = decord.VideoReader(filename, | |
| ctx=self.ctx, | |
| num_threads=self.num_threads) | |
| return reader | |
| def __repr__(self): | |
| repr_str = (f'{self.__class__.__name__}(' | |
| f'sr={self.sr},' | |
| f'num_threads={self.num_threads})') | |
| return repr_str | |
| class FaceForensicsImages(torch.utils.data.Dataset): | |
| """Load the FaceForensics video files | |
| Args: | |
| target_video_len (int): the number of video frames will be load. | |
| align_transform (callable): Align different videos in a specified size. | |
| temporal_sample (callable): Sample the target length of a video. | |
| """ | |
| def __init__(self, | |
| configs, | |
| transform=None, | |
| temporal_sample=None): | |
| self.configs = configs | |
| self.data_path = configs.data_path | |
| self.video_lists = get_filelist(configs.data_path) | |
| self.transform = transform | |
| self.temporal_sample = temporal_sample | |
| self.target_video_len = self.configs.num_frames | |
| self.v_decoder = DecordInit() | |
| self.video_length = len(self.video_lists) | |
| # ffs video frames | |
| self.video_frame_path = configs.frame_data_path | |
| self.video_frame_txt = configs.frame_data_txt | |
| self.video_frame_files = [frame_file.strip() for frame_file in open(self.video_frame_txt)] | |
| random.shuffle(self.video_frame_files) | |
| self.use_image_num = configs.use_image_num | |
| self.image_tranform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) | |
| ]) | |
| def __getitem__(self, index): | |
| video_index = index % self.video_length | |
| path = self.video_lists[video_index] | |
| vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') | |
| total_frames = len(vframes) | |
| # Sampling video frames | |
| start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) | |
| assert end_frame_ind - start_frame_ind >= self.target_video_len | |
| frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int) | |
| video = vframes[frame_indice] | |
| # videotransformer data proprecess | |
| video = self.transform(video) # T C H W | |
| # get video frames | |
| images = [] | |
| for i in range(self.use_image_num): | |
| while True: | |
| try: | |
| image = Image.open(os.path.join(self.video_frame_path, self.video_frame_files[index+i])).convert("RGB") | |
| image = self.image_tranform(image).unsqueeze(0) | |
| images.append(image) | |
| break | |
| except Exception as e: | |
| traceback.print_exc() | |
| index = random.randint(0, len(self.video_frame_files) - self.use_image_num) | |
| images = torch.cat(images, dim=0) | |
| assert len(images) == self.use_image_num | |
| video_cat = torch.cat([video, images], dim=0) | |
| return {'video': video_cat, 'video_name': 1} | |
| def __len__(self): | |
| return len(self.video_frame_files) | |
| if __name__ == '__main__': | |
| import argparse | |
| import torchvision | |
| import video_transforms | |
| import torch.utils.data as Data | |
| import torchvision.transforms as transform | |
| from PIL import Image | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--num_frames", type=int, default=16) | |
| parser.add_argument("--use-image-num", type=int, default=5) | |
| parser.add_argument("--frame_interval", type=int, default=3) | |
| parser.add_argument("--dataset", type=str, default='webvideo10m') | |
| parser.add_argument("--test-run", type=bool, default='') | |
| parser.add_argument("--data-path", type=str, default="/path/to/datasets/preprocessed_ffs/train/videos/") | |
| parser.add_argument("--frame-data-path", type=str, default="/path/to/datasets/preprocessed_ffs/train/images/") | |
| parser.add_argument("--frame-data-txt", type=str, default="/path/to/datasets/faceForensics_v1/train_list.txt") | |
| config = parser.parse_args() | |
| temporal_sample = video_transforms.TemporalRandomCrop(config.num_frames * config.frame_interval) | |
| transform_webvideo = transform.Compose([ | |
| video_transforms.ToTensorVideo(), | |
| transform.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
| ]) | |
| dataset = FaceForensicsImages(config, transform=transform_webvideo, temporal_sample=temporal_sample) | |
| dataloader = Data.DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=4) | |
| for i, video_data in enumerate(dataloader): | |
| video, video_label = video_data['video'], video_data['video_name'] | |
| # print(video_label) | |
| # print(image_label) | |
| print(video.shape) | |
| print(video_label) | |
| # video_ = ((video[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1) | |
| # print(video_.shape) | |
| # try: | |
| # torchvision.io.write_video(f'./test/{i:03d}_{video_label}.mp4', video_[:16], fps=8) | |
| # except: | |
| # pass | |
| # if i % 100 == 0 and i != 0: | |
| # break | |
| print('Done!') |