Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import dataclasses | |
| import torch.nn.functional as F | |
| from dataclasses import dataclass | |
| from typing import Any, Optional | |
| class CoTrackerData: | |
| """ | |
| Dataclass for storing video tracks data. | |
| """ | |
| video: torch.Tensor # B, S, C, H, W | |
| trajectory: torch.Tensor # B, S, N, 2 | |
| visibility: torch.Tensor # B, S, N | |
| # optional data | |
| valid: Optional[torch.Tensor] = None # B, S, N | |
| segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W | |
| seq_name: Optional[str] = None | |
| query_points: Optional[torch.Tensor] = None # TapVID evaluation format | |
| def collate_fn(batch): | |
| """ | |
| Collate function for video tracks data. | |
| """ | |
| video = torch.stack([b.video for b in batch], dim=0) | |
| trajectory = torch.stack([b.trajectory for b in batch], dim=0) | |
| visibility = torch.stack([b.visibility for b in batch], dim=0) | |
| query_points = segmentation = None | |
| if batch[0].query_points is not None: | |
| query_points = torch.stack([b.query_points for b in batch], dim=0) | |
| if batch[0].segmentation is not None: | |
| segmentation = torch.stack([b.segmentation for b in batch], dim=0) | |
| seq_name = [b.seq_name for b in batch] | |
| return CoTrackerData( | |
| video=video, | |
| trajectory=trajectory, | |
| visibility=visibility, | |
| segmentation=segmentation, | |
| seq_name=seq_name, | |
| query_points=query_points, | |
| ) | |
| def collate_fn_train(batch): | |
| """ | |
| Collate function for video tracks data during training. | |
| """ | |
| gotit = [gotit for _, gotit in batch] | |
| video = torch.stack([b.video for b, _ in batch], dim=0) | |
| trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0) | |
| visibility = torch.stack([b.visibility for b, _ in batch], dim=0) | |
| valid = torch.stack([b.valid for b, _ in batch], dim=0) | |
| seq_name = [b.seq_name for b, _ in batch] | |
| return ( | |
| CoTrackerData( | |
| video=video, | |
| trajectory=trajectory, | |
| visibility=visibility, | |
| valid=valid, | |
| seq_name=seq_name, | |
| ), | |
| gotit, | |
| ) | |
| def try_to_cuda(t: Any) -> Any: | |
| """ | |
| Try to move the input variable `t` to a cuda device. | |
| Args: | |
| t: Input. | |
| Returns: | |
| t_cuda: `t` moved to a cuda device, if supported. | |
| """ | |
| try: | |
| t = t.float().cuda() | |
| except AttributeError: | |
| pass | |
| return t | |
| def dataclass_to_cuda_(obj): | |
| """ | |
| Move all contents of a dataclass to cuda inplace if supported. | |
| Args: | |
| batch: Input dataclass. | |
| Returns: | |
| batch_cuda: `batch` moved to a cuda device, if supported. | |
| """ | |
| for f in dataclasses.fields(obj): | |
| setattr(obj, f.name, try_to_cuda(getattr(obj, f.name))) | |
| return obj | |