Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,963 Bytes
142a1ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from typing import List
from torch.utils.data import IterableDataset, Dataset
from omegaconf import DictConfig
import torch
import numpy as np
from datasets.dummy import DummyVideoDataset
from datasets.openx_base import OpenXVideoDataset
from datasets.droid import DroidVideoDataset
from datasets.something_something import SomethingSomethingDataset
from datasets.epic_kitchen import EpicKitchenDataset
from datasets.pandas import PandasVideoDataset
from datasets.deprecated.video_1x_wm import WorldModel1XDataset
from datasets.agibot_world import AgibotWorldDataset
from datasets.ego4d import Ego4DVideoDataset
subset_classes = dict(
dummy=DummyVideoDataset,
something_something=SomethingSomethingDataset,
epic_kitchen=EpicKitchenDataset,
pandas=PandasVideoDataset,
agibot_world=AgibotWorldDataset,
video_1x_wm=WorldModel1XDataset,
ego4d=Ego4DVideoDataset,
droid=DroidVideoDataset,
austin_buds=OpenXVideoDataset,
austin_sailor=OpenXVideoDataset,
austin_sirius=OpenXVideoDataset,
bc_z=OpenXVideoDataset,
berkeley_autolab=OpenXVideoDataset,
berkeley_cable=OpenXVideoDataset,
berkeley_fanuc=OpenXVideoDataset,
bridge=OpenXVideoDataset,
cmu_stretch=OpenXVideoDataset,
dlr_edan=OpenXVideoDataset,
dobbe=OpenXVideoDataset,
fmb=OpenXVideoDataset,
fractal=OpenXVideoDataset,
iamlab_cmu=OpenXVideoDataset,
jaco_play=OpenXVideoDataset,
language_table=OpenXVideoDataset,
nyu_franka=OpenXVideoDataset,
roboturk=OpenXVideoDataset,
stanford_hydra=OpenXVideoDataset,
taco_play=OpenXVideoDataset,
toto=OpenXVideoDataset,
ucsd_kitchen=OpenXVideoDataset,
utaustin_mutex=OpenXVideoDataset,
viola=OpenXVideoDataset,
)
class MixtureDataset(IterableDataset):
"""
A fault tolerant mixture of video datasets
"""
def __init__(self, cfg: DictConfig, split: str = "training"):
super().__init__()
self.cfg = cfg
self.debug = cfg.debug
self.split = split
self.random_seed = np.random.get_state()[1][0] # Get current numpy random seed
self.subset_cfg = {
k.split("/")[1]: v for k, v in self.cfg.items() if k.startswith("subset/")
}
if split == "all":
raise ValueError("split cannot be `all` for MixtureDataset`")
weight = dict(self.cfg[split].weight)
# Check if all keys in weight exist in subset_cfg
for key in weight:
if key not in self.subset_cfg:
raise ValueError(
f"Dataset '{key}' specified in weights but not found in configuration"
)
self.subset_cfg = {k: v for k, v in self.subset_cfg.items() if k in weight}
weight_type = self.cfg[split].weight_type # one of relative or absolute
self.subsets: List[Dataset] = []
for subset_name, subset_cfg in self.subset_cfg.items():
subset_cfg["height"] = self.cfg.height
subset_cfg["width"] = self.cfg.width
subset_cfg["n_frames"] = self.cfg.n_frames
subset_cfg["fps"] = self.cfg.fps
subset_cfg["load_video_latent"] = self.cfg.load_video_latent
subset_cfg["load_prompt_embed"] = self.cfg.load_prompt_embed
subset_cfg["max_text_tokens"] = self.cfg.max_text_tokens
subset_cfg["image_to_video"] = self.cfg.image_to_video
self.subsets.append(subset_classes[subset_name](subset_cfg, split))
if weight_type == "relative":
weight[subset_name] = weight[subset_name] * len(self.subsets[-1])
# Normalize weights to sum to 1
total_weight = sum(weight.values())
self.normalized_weights = {k: v / total_weight for k, v in weight.items()}
# Store dataset sizes for printing
dataset_sizes = {
subset_name: len(subset)
for subset_name, subset in zip(self.subset_cfg.keys(), self.subsets)
}
# Print normalized weights and dataset sizes in a nice format
print("\nDataset information for split '{}':".format(self.split))
print("-" * 60)
print(f"{'Dataset':<25} {'Size':<10} {'Weight':<10} {'Normalized':<10}")
print("-" * 60)
for subset_name, norm_weight in sorted(
self.normalized_weights.items(), key=lambda x: -x[1]
):
size = dataset_sizes[subset_name]
orig_weight = self.cfg[split].weight[subset_name]
print(
f"{subset_name:<25} {size:<10,d} {orig_weight:<10.4f} {norm_weight:<10.4f}"
)
print("-" * 60)
# Calculate cumulative probabilities for sampling
self.cumsum_weights = {}
cumsum = 0
for k, v in self.normalized_weights.items():
cumsum += v
self.cumsum_weights[k] = cumsum
# some scripts want to access the records
self.records = []
for subset in self.subsets:
self.records.extend(subset.records)
def __iter__(self):
while True:
# Sample a random subset based on weights using numpy random
rand = np.random.random()
for subset_name, cumsum in self.cumsum_weights.items():
if rand <= cumsum:
selected_subset = subset_name
break
# Get the corresponding dataset index
subset_idx = list(self.subset_cfg.keys()).index(selected_subset)
try:
# Sample randomly from the selected dataset using numpy random
dataset = self.subsets[subset_idx]
idx = np.random.randint(len(dataset))
sample = dataset[idx]
yield sample
except Exception as e:
if self.debug:
raise e
else:
print(f"Error sampling from {selected_subset}: {str(e)}")
continue
|