File size: 5,963 Bytes
142a1ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from typing import List
from torch.utils.data import IterableDataset, Dataset
from omegaconf import DictConfig
import torch
import numpy as np
from datasets.dummy import DummyVideoDataset
from datasets.openx_base import OpenXVideoDataset
from datasets.droid import DroidVideoDataset
from datasets.something_something import SomethingSomethingDataset
from datasets.epic_kitchen import EpicKitchenDataset
from datasets.pandas import PandasVideoDataset
from datasets.deprecated.video_1x_wm import WorldModel1XDataset
from datasets.agibot_world import AgibotWorldDataset
from datasets.ego4d import Ego4DVideoDataset

subset_classes = dict(
    dummy=DummyVideoDataset,
    something_something=SomethingSomethingDataset,
    epic_kitchen=EpicKitchenDataset,
    pandas=PandasVideoDataset,
    agibot_world=AgibotWorldDataset,
    video_1x_wm=WorldModel1XDataset,
    ego4d=Ego4DVideoDataset,
    droid=DroidVideoDataset,
    austin_buds=OpenXVideoDataset,
    austin_sailor=OpenXVideoDataset,
    austin_sirius=OpenXVideoDataset,
    bc_z=OpenXVideoDataset,
    berkeley_autolab=OpenXVideoDataset,
    berkeley_cable=OpenXVideoDataset,
    berkeley_fanuc=OpenXVideoDataset,
    bridge=OpenXVideoDataset,
    cmu_stretch=OpenXVideoDataset,
    dlr_edan=OpenXVideoDataset,
    dobbe=OpenXVideoDataset,
    fmb=OpenXVideoDataset,
    fractal=OpenXVideoDataset,
    iamlab_cmu=OpenXVideoDataset,
    jaco_play=OpenXVideoDataset,
    language_table=OpenXVideoDataset,
    nyu_franka=OpenXVideoDataset,
    roboturk=OpenXVideoDataset,
    stanford_hydra=OpenXVideoDataset,
    taco_play=OpenXVideoDataset,
    toto=OpenXVideoDataset,
    ucsd_kitchen=OpenXVideoDataset,
    utaustin_mutex=OpenXVideoDataset,
    viola=OpenXVideoDataset,
)


class MixtureDataset(IterableDataset):
    """
    A fault tolerant mixture of video datasets
    """

    def __init__(self, cfg: DictConfig, split: str = "training"):
        super().__init__()
        self.cfg = cfg
        self.debug = cfg.debug
        self.split = split
        self.random_seed = np.random.get_state()[1][0]  # Get current numpy random seed
        self.subset_cfg = {
            k.split("/")[1]: v for k, v in self.cfg.items() if k.startswith("subset/")
        }
        if split == "all":
            raise ValueError("split cannot be `all` for MixtureDataset`")
        weight = dict(self.cfg[split].weight)
        # Check if all keys in weight exist in subset_cfg
        for key in weight:
            if key not in self.subset_cfg:
                raise ValueError(
                    f"Dataset '{key}' specified in weights but not found in configuration"
                )
        self.subset_cfg = {k: v for k, v in self.subset_cfg.items() if k in weight}
        weight_type = self.cfg[split].weight_type  # one of relative or absolute
        self.subsets: List[Dataset] = []
        for subset_name, subset_cfg in self.subset_cfg.items():
            subset_cfg["height"] = self.cfg.height
            subset_cfg["width"] = self.cfg.width
            subset_cfg["n_frames"] = self.cfg.n_frames
            subset_cfg["fps"] = self.cfg.fps
            subset_cfg["load_video_latent"] = self.cfg.load_video_latent
            subset_cfg["load_prompt_embed"] = self.cfg.load_prompt_embed
            subset_cfg["max_text_tokens"] = self.cfg.max_text_tokens
            subset_cfg["image_to_video"] = self.cfg.image_to_video
            self.subsets.append(subset_classes[subset_name](subset_cfg, split))
            if weight_type == "relative":
                weight[subset_name] = weight[subset_name] * len(self.subsets[-1])

        # Normalize weights to sum to 1
        total_weight = sum(weight.values())
        self.normalized_weights = {k: v / total_weight for k, v in weight.items()}

        # Store dataset sizes for printing
        dataset_sizes = {
            subset_name: len(subset)
            for subset_name, subset in zip(self.subset_cfg.keys(), self.subsets)
        }

        # Print normalized weights and dataset sizes in a nice format
        print("\nDataset information for split '{}':".format(self.split))
        print("-" * 60)
        print(f"{'Dataset':<25} {'Size':<10} {'Weight':<10} {'Normalized':<10}")
        print("-" * 60)
        for subset_name, norm_weight in sorted(
            self.normalized_weights.items(), key=lambda x: -x[1]
        ):
            size = dataset_sizes[subset_name]
            orig_weight = self.cfg[split].weight[subset_name]
            print(
                f"{subset_name:<25} {size:<10,d} {orig_weight:<10.4f} {norm_weight:<10.4f}"
            )
        print("-" * 60)

        # Calculate cumulative probabilities for sampling
        self.cumsum_weights = {}
        cumsum = 0
        for k, v in self.normalized_weights.items():
            cumsum += v
            self.cumsum_weights[k] = cumsum

        # some scripts want to access the records
        self.records = []
        for subset in self.subsets:
            self.records.extend(subset.records)

    def __iter__(self):
        while True:
            # Sample a random subset based on weights using numpy random
            rand = np.random.random()
            for subset_name, cumsum in self.cumsum_weights.items():
                if rand <= cumsum:
                    selected_subset = subset_name
                    break

            # Get the corresponding dataset index
            subset_idx = list(self.subset_cfg.keys()).index(selected_subset)

            try:
                # Sample randomly from the selected dataset using numpy random
                dataset = self.subsets[subset_idx]
                idx = np.random.randint(len(dataset))
                sample = dataset[idx]
                yield sample
            except Exception as e:
                if self.debug:
                    raise e
                else:
                    print(f"Error sampling from {selected_subset}: {str(e)}")
                    continue