Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import io | |
| import glob | |
| import torch | |
| import pickle | |
| import numpy as np | |
| import mediapy as media | |
| from PIL import Image | |
| from typing import Mapping, Tuple, Union | |
| from cotracker.datasets.utils import CoTrackerData | |
| DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]] | |
| def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: | |
| """Resize a video to output_size.""" | |
| # If you have a GPU, consider replacing this with a GPU-enabled resize op, | |
| # such as a jitted jax.image.resize. It will make things faster. | |
| return media.resize_video(video, output_size) | |
| def sample_queries_first( | |
| target_occluded: np.ndarray, | |
| target_points: np.ndarray, | |
| frames: np.ndarray, | |
| ) -> Mapping[str, np.ndarray]: | |
| """Package a set of frames and tracks for use in TAPNet evaluations. | |
| Given a set of frames and tracks with no query points, use the first | |
| visible point in each track as the query. | |
| Args: | |
| target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], | |
| where True indicates occluded. | |
| target_points: Position, of shape [n_tracks, n_frames, 2], where each point | |
| is [x,y] scaled between 0 and 1. | |
| frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between | |
| -1 and 1. | |
| Returns: | |
| A dict with the keys: | |
| video: Video tensor of shape [1, n_frames, height, width, 3] | |
| query_points: Query points of shape [1, n_queries, 3] where | |
| each point is [t, y, x] scaled to the range [-1, 1] | |
| target_points: Target points of shape [1, n_queries, n_frames, 2] where | |
| each point is [x, y] scaled to the range [-1, 1] | |
| """ | |
| valid = np.sum(~target_occluded, axis=1) > 0 | |
| target_points = target_points[valid, :] | |
| target_occluded = target_occluded[valid, :] | |
| query_points = [] | |
| for i in range(target_points.shape[0]): | |
| index = np.where(target_occluded[i] == 0)[0][0] | |
| x, y = target_points[i, index, 0], target_points[i, index, 1] | |
| query_points.append(np.array([index, y, x])) # [t, y, x] | |
| query_points = np.stack(query_points, axis=0) | |
| return { | |
| "video": frames[np.newaxis, ...], | |
| "query_points": query_points[np.newaxis, ...], | |
| "target_points": target_points[np.newaxis, ...], | |
| "occluded": target_occluded[np.newaxis, ...], | |
| } | |
| def sample_queries_strided( | |
| target_occluded: np.ndarray, | |
| target_points: np.ndarray, | |
| frames: np.ndarray, | |
| query_stride: int = 5, | |
| ) -> Mapping[str, np.ndarray]: | |
| """Package a set of frames and tracks for use in TAPNet evaluations. | |
| Given a set of frames and tracks with no query points, sample queries | |
| strided every query_stride frames, ignoring points that are not visible | |
| at the selected frames. | |
| Args: | |
| target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], | |
| where True indicates occluded. | |
| target_points: Position, of shape [n_tracks, n_frames, 2], where each point | |
| is [x,y] scaled between 0 and 1. | |
| frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between | |
| -1 and 1. | |
| query_stride: When sampling query points, search for un-occluded points | |
| every query_stride frames and convert each one into a query. | |
| Returns: | |
| A dict with the keys: | |
| video: Video tensor of shape [1, n_frames, height, width, 3]. The video | |
| has floats scaled to the range [-1, 1]. | |
| query_points: Query points of shape [1, n_queries, 3] where | |
| each point is [t, y, x] scaled to the range [-1, 1]. | |
| target_points: Target points of shape [1, n_queries, n_frames, 2] where | |
| each point is [x, y] scaled to the range [-1, 1]. | |
| trackgroup: Index of the original track that each query point was | |
| sampled from. This is useful for visualization. | |
| """ | |
| tracks = [] | |
| occs = [] | |
| queries = [] | |
| trackgroups = [] | |
| total = 0 | |
| trackgroup = np.arange(target_occluded.shape[0]) | |
| for i in range(0, target_occluded.shape[1], query_stride): | |
| mask = target_occluded[:, i] == 0 | |
| query = np.stack( | |
| [ | |
| i * np.ones(target_occluded.shape[0:1]), | |
| target_points[:, i, 1], | |
| target_points[:, i, 0], | |
| ], | |
| axis=-1, | |
| ) | |
| queries.append(query[mask]) | |
| tracks.append(target_points[mask]) | |
| occs.append(target_occluded[mask]) | |
| trackgroups.append(trackgroup[mask]) | |
| total += np.array(np.sum(target_occluded[:, i] == 0)) | |
| return { | |
| "video": frames[np.newaxis, ...], | |
| "query_points": np.concatenate(queries, axis=0)[np.newaxis, ...], | |
| "target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...], | |
| "occluded": np.concatenate(occs, axis=0)[np.newaxis, ...], | |
| "trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...], | |
| } | |
| class TapVidDataset(torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| data_root, | |
| dataset_type="davis", | |
| resize_to_256=True, | |
| queried_first=True, | |
| ): | |
| self.dataset_type = dataset_type | |
| self.resize_to_256 = resize_to_256 | |
| self.queried_first = queried_first | |
| if self.dataset_type == "kinetics": | |
| all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl")) | |
| points_dataset = [] | |
| for pickle_path in all_paths: | |
| with open(pickle_path, "rb") as f: | |
| data = pickle.load(f) | |
| points_dataset = points_dataset + data | |
| self.points_dataset = points_dataset | |
| else: | |
| with open(data_root, "rb") as f: | |
| self.points_dataset = pickle.load(f) | |
| if self.dataset_type == "davis": | |
| self.video_names = list(self.points_dataset.keys()) | |
| print("found %d unique videos in %s" % (len(self.points_dataset), data_root)) | |
| def __getitem__(self, index): | |
| if self.dataset_type == "davis": | |
| video_name = self.video_names[index] | |
| else: | |
| video_name = index | |
| video = self.points_dataset[video_name] | |
| frames = video["video"] | |
| if isinstance(frames[0], bytes): | |
| # TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s. | |
| def decode(frame): | |
| byteio = io.BytesIO(frame) | |
| img = Image.open(byteio) | |
| return np.array(img) | |
| frames = np.array([decode(frame) for frame in frames]) | |
| target_points = self.points_dataset[video_name]["points"] | |
| if self.resize_to_256: | |
| frames = resize_video(frames, [256, 256]) | |
| target_points *= np.array([255, 255]) # 1 should be mapped to 256-1 | |
| else: | |
| target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1]) | |
| target_occ = self.points_dataset[video_name]["occluded"] | |
| if self.queried_first: | |
| converted = sample_queries_first(target_occ, target_points, frames) | |
| else: | |
| converted = sample_queries_strided(target_occ, target_points, frames) | |
| assert converted["target_points"].shape[1] == converted["query_points"].shape[1] | |
| trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D | |
| rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float() | |
| visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute( | |
| 1, 0 | |
| ) # T, N | |
| query_points = torch.from_numpy(converted["query_points"])[0] # T, N | |
| return CoTrackerData( | |
| rgbs, | |
| trajs, | |
| visibles, | |
| seq_name=str(video_name), | |
| query_points=query_points, | |
| ) | |
| def __len__(self): | |
| return len(self.points_dataset) | |