from typing import List from torch.utils.data import IterableDataset, Dataset from omegaconf import DictConfig import torch import numpy as np from datasets.dummy import DummyVideoDataset from datasets.openx_base import OpenXVideoDataset from datasets.droid import DroidVideoDataset from datasets.something_something import SomethingSomethingDataset from datasets.epic_kitchen import EpicKitchenDataset from datasets.pandas import PandasVideoDataset from datasets.deprecated.video_1x_wm import WorldModel1XDataset from datasets.agibot_world import AgibotWorldDataset from datasets.ego4d import Ego4DVideoDataset subset_classes = dict( dummy=DummyVideoDataset, something_something=SomethingSomethingDataset, epic_kitchen=EpicKitchenDataset, pandas=PandasVideoDataset, agibot_world=AgibotWorldDataset, video_1x_wm=WorldModel1XDataset, ego4d=Ego4DVideoDataset, droid=DroidVideoDataset, austin_buds=OpenXVideoDataset, austin_sailor=OpenXVideoDataset, austin_sirius=OpenXVideoDataset, bc_z=OpenXVideoDataset, berkeley_autolab=OpenXVideoDataset, berkeley_cable=OpenXVideoDataset, berkeley_fanuc=OpenXVideoDataset, bridge=OpenXVideoDataset, cmu_stretch=OpenXVideoDataset, dlr_edan=OpenXVideoDataset, dobbe=OpenXVideoDataset, fmb=OpenXVideoDataset, fractal=OpenXVideoDataset, iamlab_cmu=OpenXVideoDataset, jaco_play=OpenXVideoDataset, language_table=OpenXVideoDataset, nyu_franka=OpenXVideoDataset, roboturk=OpenXVideoDataset, stanford_hydra=OpenXVideoDataset, taco_play=OpenXVideoDataset, toto=OpenXVideoDataset, ucsd_kitchen=OpenXVideoDataset, utaustin_mutex=OpenXVideoDataset, viola=OpenXVideoDataset, ) class MixtureDataset(IterableDataset): """ A fault tolerant mixture of video datasets """ def __init__(self, cfg: DictConfig, split: str = "training"): super().__init__() self.cfg = cfg self.debug = cfg.debug self.split = split self.random_seed = np.random.get_state()[1][0] # Get current numpy random seed self.subset_cfg = { k.split("/")[1]: v for k, v in self.cfg.items() if k.startswith("subset/") } if split == "all": raise ValueError("split cannot be `all` for MixtureDataset`") weight = dict(self.cfg[split].weight) # Check if all keys in weight exist in subset_cfg for key in weight: if key not in self.subset_cfg: raise ValueError( f"Dataset '{key}' specified in weights but not found in configuration" ) self.subset_cfg = {k: v for k, v in self.subset_cfg.items() if k in weight} weight_type = self.cfg[split].weight_type # one of relative or absolute self.subsets: List[Dataset] = [] for subset_name, subset_cfg in self.subset_cfg.items(): subset_cfg["height"] = self.cfg.height subset_cfg["width"] = self.cfg.width subset_cfg["n_frames"] = self.cfg.n_frames subset_cfg["fps"] = self.cfg.fps subset_cfg["load_video_latent"] = self.cfg.load_video_latent subset_cfg["load_prompt_embed"] = self.cfg.load_prompt_embed subset_cfg["max_text_tokens"] = self.cfg.max_text_tokens subset_cfg["image_to_video"] = self.cfg.image_to_video self.subsets.append(subset_classes[subset_name](subset_cfg, split)) if weight_type == "relative": weight[subset_name] = weight[subset_name] * len(self.subsets[-1]) # Normalize weights to sum to 1 total_weight = sum(weight.values()) self.normalized_weights = {k: v / total_weight for k, v in weight.items()} # Store dataset sizes for printing dataset_sizes = { subset_name: len(subset) for subset_name, subset in zip(self.subset_cfg.keys(), self.subsets) } # Print normalized weights and dataset sizes in a nice format print("\nDataset information for split '{}':".format(self.split)) print("-" * 60) print(f"{'Dataset':<25} {'Size':<10} {'Weight':<10} {'Normalized':<10}") print("-" * 60) for subset_name, norm_weight in sorted( self.normalized_weights.items(), key=lambda x: -x[1] ): size = dataset_sizes[subset_name] orig_weight = self.cfg[split].weight[subset_name] print( f"{subset_name:<25} {size:<10,d} {orig_weight:<10.4f} {norm_weight:<10.4f}" ) print("-" * 60) # Calculate cumulative probabilities for sampling self.cumsum_weights = {} cumsum = 0 for k, v in self.normalized_weights.items(): cumsum += v self.cumsum_weights[k] = cumsum # some scripts want to access the records self.records = [] for subset in self.subsets: self.records.extend(subset.records) def __iter__(self): while True: # Sample a random subset based on weights using numpy random rand = np.random.random() for subset_name, cumsum in self.cumsum_weights.items(): if rand <= cumsum: selected_subset = subset_name break # Get the corresponding dataset index subset_idx = list(self.subset_cfg.keys()).index(selected_subset) try: # Sample randomly from the selected dataset using numpy random dataset = self.subsets[subset_idx] idx = np.random.randint(len(dataset)) sample = dataset[idx] yield sample except Exception as e: if self.debug: raise e else: print(f"Error sampling from {selected_subset}: {str(e)}") continue