from datasets import Dataset, DatasetDict from transformers import AutoTokenizer def preprocess_dataset(dataset: Dataset | DatasetDict, tokenizer: AutoTokenizer) -> Dataset | DatasetDict: tokenized_dataset = dataset.map( lambda examples: tokenize_function(examples, tokenizer), batched=True, remove_columns=["text", 'text_clean', 'language'] ) return tokenized_dataset.map(group_texts, batched=True) def tokenize_function(examples, tokenizer: AutoTokenizer): result = tokenizer(examples["text"]) if tokenizer.is_fast: result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))] return result def group_texts(examples, chunk_size: int = 128): concatinated_examples = {k : sum(examples[k], []) for k in examples.keys()} total_length = len(concatinated_examples["input_ids"]) total_length = (total_length // chunk_size) * chunk_size result = {k : [t[i : i+chunk_size] for i in range(0, total_length, chunk_size)] for k, t in concatinated_examples.items()} result["labels"] = result["input_ids"].copy() return result