ctr-ll4 / src /MLM /datasets /preprocess_dataset.py
sanjin7's picture
Upload src/ with huggingface_hub
cea4a4b
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer
def preprocess_dataset(dataset: Dataset | DatasetDict, tokenizer: AutoTokenizer) -> Dataset | DatasetDict:
tokenized_dataset = dataset.map(
lambda examples: tokenize_function(examples, tokenizer), batched=True, remove_columns=["text", 'text_clean', 'language']
)
return tokenized_dataset.map(group_texts, batched=True)
def tokenize_function(examples, tokenizer: AutoTokenizer):
result = tokenizer(examples["text"])
if tokenizer.is_fast:
result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
return result
def group_texts(examples, chunk_size: int = 128):
concatinated_examples = {k : sum(examples[k], []) for k in examples.keys()}
total_length = len(concatinated_examples["input_ids"])
total_length = (total_length // chunk_size) * chunk_size
result = {k : [t[i : i+chunk_size] for i in range(0, total_length, chunk_size)] for k, t in concatinated_examples.items()}
result["labels"] = result["input_ids"].copy()
return result