Mini-gpt-0.000001 / data_loader.py
AIencoder's picture
Create data_loader.py
93bedb7 verified
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchaudio
class OmniDataset(Dataset):
def __init__(self, data_list, vision_tokenizer, audio_tokenizer):
"""
data_list: List of dicts with {'text': str, 'img_path': str, 'audio_path': str}
"""
self.data = data_list
self.v_tok = vision_tokenizer
self.a_tok = audio_tokenizer
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
# 1. Process Image to Visual Tokens
image = Image.open(item['img_path']).convert("RGB").resize((224, 224))
# Note: You'd use your vision_tokenizer.py logic here
# 2. Process Audio to Acoustic Tokens
waveform, sr = torchaudio.load(item['audio_path'])
# Note: You'd use your audio_tokenizer.py logic here
# 3. Text is handled by the main model's embedding layer
text = item['text']
return {
"text": text,
"image": image,
"audio": waveform
}
# Example: Initializing the loader for your Omni-training
# dataset = OmniDataset(my_data, v_tokenizer, a_tokenizer)
# loader = DataLoader(dataset, batch_size=4, shuffle=True)