Spaces:
Runtime error
Runtime error
| import itertools | |
| import json | |
| import logging | |
| import math | |
| import os | |
| import re | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Optional, Sequence | |
| import torch | |
| import transformers | |
| from torch.utils.data import default_collate | |
| from transformers.trainer_pt_utils import LabelSmoother | |
| IGNORE_TOKEN_ID = LabelSmoother.ignore_index | |
| class DataCollatorForSupervisedDataset(object): | |
| tokenizer: transformers.PreTrainedTokenizer | |
| def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: | |
| input_ids = [instance["input_ids"] for instance in instances] | |
| labels = [instance["labels"] for instance in instances] | |
| input_ids = torch.nn.utils.rnn.pad_sequence( | |
| input_ids, | |
| batch_first=True, | |
| padding_value=self.tokenizer.pad_token_id, | |
| ) | |
| labels = torch.nn.utils.rnn.pad_sequence( | |
| labels, | |
| batch_first=True, | |
| padding_value=IGNORE_TOKEN_ID, | |
| ) | |
| input_ids = input_ids[:, : self.tokenizer.model_max_length] | |
| labels = labels[:, : self.tokenizer.model_max_length] | |
| batch = dict( | |
| input_ids=input_ids, | |
| labels=labels, | |
| attention_mask=input_ids.ne(self.tokenizer.pad_token_id), | |
| ) | |
| if "images" in instances[0]: | |
| images = [instance["images"] for instance in instances] | |
| batch["images"] = torch.cat(images, dim=0) | |
| if "doclm_images" in instances[0]: | |
| images = [instance["doclm_images"] for instance in instances] | |
| batch["doclm_images"] = torch.cat(images, dim=0) | |
| if "image_paths" in instances[0]: | |
| image_paths = [instance["image_paths"] for instance in instances] | |
| batch["image_paths"] = image_paths | |
| if "pixel_values" in instances[0]: | |
| pixel_values = torch.cat([instance["pixel_values"] for instance in instances]) | |
| batch["pixel_values"] = pixel_values | |
| if "image_flags" in instances[0]: | |
| image_flags = torch.cat([instance["image_flags"] for instance in instances]) | |
| batch["image_flags"] = image_flags | |
| return batch | |
| def collate_fn_deepspeed_old(batch): | |
| keys = list(set().union(*[set(x.keys()) for x in batch])) | |
| tmp_batch = [{} for _ in range(len(batch))] | |
| if "actual_seq_len" in batch[0]: | |
| actual_seq_len = [x["actual_seq_len"] for x in batch] | |
| else: | |
| actual_seq_len = None | |
| for k in keys: | |
| if "images" in k or k == "image_indices": | |
| for x, y in zip(tmp_batch, batch): | |
| if k in y: | |
| x[k] = y.pop(k) | |
| # print("x[image_indices]", x["image_indices"].size()) | |
| new_batch = default_collate(batch) | |
| for k in keys: | |
| if "images" in k or k == "image_indices": | |
| cat_dim = 0 if k != "image_indices" else 1 | |
| if k == "image_indices": | |
| cnt = 0 | |
| for sample in tmp_batch: | |
| if k in sample: | |
| sample[k][0] = cnt | |
| cnt += 1 | |
| new_batch[k] = torch.cat([x[k] for x in tmp_batch if k in x], dim=cat_dim) | |
| # print("new_batch[image_indices]", new_batch["image_indices"].size()) | |
| if actual_seq_len is not None: | |
| seq_len = actual_seq_len[0][-1] | |
| actual_seq_len = [elem + i * seq_len for i, elem in enumerate(actual_seq_len)] | |
| new_batch["actual_seq_len"] = torch.cat(actual_seq_len) | |
| return new_batch | |
| def collate_fn_deepspeed(batch): | |
| keys = list(set().union(*[set(x.keys()) for x in batch])) | |
| # print(f"{keys=}") | |
| tmp_batch = [{} for _ in range(len(batch))] | |
| if "actual_seq_len" in batch[0]: | |
| actual_seq_len = [x["actual_seq_len"] for x in batch] | |
| else: | |
| actual_seq_len = None | |
| if "images" in batch[0].keys(): | |
| for new_x, x in zip(tmp_batch, batch): | |
| new_x["images"] = x.pop("images") | |
| new_x["image_indices"] = x.pop("image_indices") | |
| if "audios" in batch[0].keys(): | |
| for new_x, x in zip(tmp_batch, batch): | |
| new_x["audios"] = x.pop("audios") | |
| new_x["audio_indices"] = x.pop("audio_indices") | |
| new_batch = default_collate(batch) | |
| if "images" in tmp_batch[0].keys(): | |
| new_batch["images"] = torch.cat([x["images"] for x in tmp_batch], dim=0) | |
| for sample_idx, sample in enumerate(tmp_batch): | |
| for j in range(len(sample["image_indices"])): | |
| sample["image_indices"][j][0, :, :] = sample_idx | |
| new_batch["image_indices"] = torch.cat([x["image_indices"] for x in tmp_batch], dim=1) | |
| if "audios" in tmp_batch[0].keys(): | |
| new_batch["audios"] = list(itertools.chain.from_iterable([x["audios"] for x in tmp_batch])) | |
| # print(f"{[x.size() for x in sample['audios']]}") | |
| for sample_idx, sample in enumerate(tmp_batch): | |
| for j in range(len(sample["audio_indices"])): | |
| sample["audio_indices"][j][0, :, :] = sample_idx | |
| new_batch["audio_indices"] = list( | |
| itertools.chain.from_iterable([x["audio_indices"] for x in tmp_batch]) | |
| ) | |
| # print(f"{[x.size() for x in sample['audio_indices']]}") | |
| if actual_seq_len is not None: | |
| seq_len = actual_seq_len[0][-1] | |
| actual_seq_len = [elem + i * seq_len for i, elem in enumerate(actual_seq_len)] | |
| new_batch["actual_seq_len"] = torch.cat(actual_seq_len) | |
| return new_batch | |