import torch from torch.utils.data import Dataset, DataLoader from PIL import Image import torchaudio class OmniDataset(Dataset): def __init__(self, data_list, vision_tokenizer, audio_tokenizer): """ data_list: List of dicts with {'text': str, 'img_path': str, 'audio_path': str} """ self.data = data_list self.v_tok = vision_tokenizer self.a_tok = audio_tokenizer def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] # 1. Process Image to Visual Tokens image = Image.open(item['img_path']).convert("RGB").resize((224, 224)) # Note: You'd use your vision_tokenizer.py logic here # 2. Process Audio to Acoustic Tokens waveform, sr = torchaudio.load(item['audio_path']) # Note: You'd use your audio_tokenizer.py logic here # 3. Text is handled by the main model's embedding layer text = item['text'] return { "text": text, "image": image, "audio": waveform } # Example: Initializing the loader for your Omni-training # dataset = OmniDataset(my_data, v_tokenizer, a_tokenizer) # loader = DataLoader(dataset, batch_size=4, shuffle=True)