LVP / datasets /mixture.py
kiwhansong's picture
add demo
142a1ac
from typing import List
from torch.utils.data import IterableDataset, Dataset
from omegaconf import DictConfig
import torch
import numpy as np
from datasets.dummy import DummyVideoDataset
from datasets.openx_base import OpenXVideoDataset
from datasets.droid import DroidVideoDataset
from datasets.something_something import SomethingSomethingDataset
from datasets.epic_kitchen import EpicKitchenDataset
from datasets.pandas import PandasVideoDataset
from datasets.deprecated.video_1x_wm import WorldModel1XDataset
from datasets.agibot_world import AgibotWorldDataset
from datasets.ego4d import Ego4DVideoDataset
subset_classes = dict(
dummy=DummyVideoDataset,
something_something=SomethingSomethingDataset,
epic_kitchen=EpicKitchenDataset,
pandas=PandasVideoDataset,
agibot_world=AgibotWorldDataset,
video_1x_wm=WorldModel1XDataset,
ego4d=Ego4DVideoDataset,
droid=DroidVideoDataset,
austin_buds=OpenXVideoDataset,
austin_sailor=OpenXVideoDataset,
austin_sirius=OpenXVideoDataset,
bc_z=OpenXVideoDataset,
berkeley_autolab=OpenXVideoDataset,
berkeley_cable=OpenXVideoDataset,
berkeley_fanuc=OpenXVideoDataset,
bridge=OpenXVideoDataset,
cmu_stretch=OpenXVideoDataset,
dlr_edan=OpenXVideoDataset,
dobbe=OpenXVideoDataset,
fmb=OpenXVideoDataset,
fractal=OpenXVideoDataset,
iamlab_cmu=OpenXVideoDataset,
jaco_play=OpenXVideoDataset,
language_table=OpenXVideoDataset,
nyu_franka=OpenXVideoDataset,
roboturk=OpenXVideoDataset,
stanford_hydra=OpenXVideoDataset,
taco_play=OpenXVideoDataset,
toto=OpenXVideoDataset,
ucsd_kitchen=OpenXVideoDataset,
utaustin_mutex=OpenXVideoDataset,
viola=OpenXVideoDataset,
)
class MixtureDataset(IterableDataset):
"""
A fault tolerant mixture of video datasets
"""
def __init__(self, cfg: DictConfig, split: str = "training"):
super().__init__()
self.cfg = cfg
self.debug = cfg.debug
self.split = split
self.random_seed = np.random.get_state()[1][0] # Get current numpy random seed
self.subset_cfg = {
k.split("/")[1]: v for k, v in self.cfg.items() if k.startswith("subset/")
}
if split == "all":
raise ValueError("split cannot be `all` for MixtureDataset`")
weight = dict(self.cfg[split].weight)
# Check if all keys in weight exist in subset_cfg
for key in weight:
if key not in self.subset_cfg:
raise ValueError(
f"Dataset '{key}' specified in weights but not found in configuration"
)
self.subset_cfg = {k: v for k, v in self.subset_cfg.items() if k in weight}
weight_type = self.cfg[split].weight_type # one of relative or absolute
self.subsets: List[Dataset] = []
for subset_name, subset_cfg in self.subset_cfg.items():
subset_cfg["height"] = self.cfg.height
subset_cfg["width"] = self.cfg.width
subset_cfg["n_frames"] = self.cfg.n_frames
subset_cfg["fps"] = self.cfg.fps
subset_cfg["load_video_latent"] = self.cfg.load_video_latent
subset_cfg["load_prompt_embed"] = self.cfg.load_prompt_embed
subset_cfg["max_text_tokens"] = self.cfg.max_text_tokens
subset_cfg["image_to_video"] = self.cfg.image_to_video
self.subsets.append(subset_classes[subset_name](subset_cfg, split))
if weight_type == "relative":
weight[subset_name] = weight[subset_name] * len(self.subsets[-1])
# Normalize weights to sum to 1
total_weight = sum(weight.values())
self.normalized_weights = {k: v / total_weight for k, v in weight.items()}
# Store dataset sizes for printing
dataset_sizes = {
subset_name: len(subset)
for subset_name, subset in zip(self.subset_cfg.keys(), self.subsets)
}
# Print normalized weights and dataset sizes in a nice format
print("\nDataset information for split '{}':".format(self.split))
print("-" * 60)
print(f"{'Dataset':<25} {'Size':<10} {'Weight':<10} {'Normalized':<10}")
print("-" * 60)
for subset_name, norm_weight in sorted(
self.normalized_weights.items(), key=lambda x: -x[1]
):
size = dataset_sizes[subset_name]
orig_weight = self.cfg[split].weight[subset_name]
print(
f"{subset_name:<25} {size:<10,d} {orig_weight:<10.4f} {norm_weight:<10.4f}"
)
print("-" * 60)
# Calculate cumulative probabilities for sampling
self.cumsum_weights = {}
cumsum = 0
for k, v in self.normalized_weights.items():
cumsum += v
self.cumsum_weights[k] = cumsum
# some scripts want to access the records
self.records = []
for subset in self.subsets:
self.records.extend(subset.records)
def __iter__(self):
while True:
# Sample a random subset based on weights using numpy random
rand = np.random.random()
for subset_name, cumsum in self.cumsum_weights.items():
if rand <= cumsum:
selected_subset = subset_name
break
# Get the corresponding dataset index
subset_idx = list(self.subset_cfg.keys()).index(selected_subset)
try:
# Sample randomly from the selected dataset using numpy random
dataset = self.subsets[subset_idx]
idx = np.random.randint(len(dataset))
sample = dataset[idx]
yield sample
except Exception as e:
if self.debug:
raise e
else:
print(f"Error sampling from {selected_subset}: {str(e)}")
continue