| from datasets import Dataset, DatasetDict | |
| from transformers import AutoTokenizer | |
| def preprocess_dataset(dataset: Dataset | DatasetDict, tokenizer: AutoTokenizer) -> Dataset | DatasetDict: | |
| tokenized_dataset = dataset.map( | |
| lambda examples: tokenize_function(examples, tokenizer), batched=True, remove_columns=["text", 'text_clean', 'language'] | |
| ) | |
| return tokenized_dataset.map(group_texts, batched=True) | |
| def tokenize_function(examples, tokenizer: AutoTokenizer): | |
| result = tokenizer(examples["text"]) | |
| if tokenizer.is_fast: | |
| result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))] | |
| return result | |
| def group_texts(examples, chunk_size: int = 128): | |
| concatinated_examples = {k : sum(examples[k], []) for k in examples.keys()} | |
| total_length = len(concatinated_examples["input_ids"]) | |
| total_length = (total_length // chunk_size) * chunk_size | |
| result = {k : [t[i : i+chunk_size] for i in range(0, total_length, chunk_size)] for k, t in concatinated_examples.items()} | |
| result["labels"] = result["input_ids"].copy() | |
| return result | |