Spaces:
Sleeping
Sleeping
| from typing import Tuple, List | |
| from my_model.utilities.gen_utilities import is_pycharm | |
| import seaborn as sns | |
| from transformers import AutoTokenizer | |
| from datasets import Dataset, load_dataset | |
| import my_model.config.fine_tuning_config as config | |
| from my_model.LLAMA2.LLAMA2_model import Llama2ModelManager | |
| class FinetuningDataHandler: | |
| """ | |
| A class dedicated to handling data for fine-tuning LLaMA-2 Chat models. It manages loading, | |
| inspecting, preparing, and splitting the dataset, specifically designed to filter out | |
| data samples exceeding a specified token count limit. This is crucial for models with | |
| token count constraints and it helps control the level of GPU RAM tolerance based on the number of tokens, | |
| ensuring efficient and effective model fine-tuning. | |
| Attributes: | |
| tokenizer (AutoTokenizer): Tokenizer used for tokenizing the dataset. | |
| dataset_file (str): File path to the dataset. | |
| max_token_count (int): Maximum allowable token count per data sample. | |
| Methods: | |
| load_llm_tokenizer: Loads the LLM tokenizer and adds special tokens, if not already loaded. | |
| load_dataset: Loads the dataset from a specified file path. | |
| plot_tokens_count_distribution: Plots the distribution of token counts in the dataset. | |
| filter_dataset_by_indices: Filters the dataset based on valid indices, removing samples exceeding token limits. | |
| get_token_counts: Calculates token counts for each sample in the dataset. | |
| prepare_dataset: Tokenizes and filters the dataset, preparing it for training. Also visualizes token count | |
| distribution before and after filtering. | |
| split_dataset_for_train_eval: Divides the dataset into training and evaluation sets. | |
| inspect_prepare_split_data: Coordinates the data preparation and splitting process for fine-tuning. | |
| """ | |
| def __init__(self, tokenizer: AutoTokenizer = None, dataset_file: str = config.DATASET_FILE) -> None: | |
| """ | |
| Initializes the FinetuningDataHandler class. | |
| Args: | |
| tokenizer (AutoTokenizer, optional): Tokenizer to use for tokenizing the dataset. Defaults to None. | |
| dataset_file (str): Path to the dataset file. Defaults to config.DATASET_FILE. | |
| """ | |
| self.tokenizer = tokenizer # The tokenizer used for processing the dataset. | |
| self.dataset_file = dataset_file # Path to the fine-tuning dataset file. | |
| self.max_token_count = config.MAX_TOKEN_COUNT # Max token count for filtering set to 1,024. | |
| def load_llm_tokenizer(self) -> None: | |
| """ | |
| Loads the LLM tokenizer and adds special tokens, if not already loaded. | |
| If the tokenizer is already loaded, this method does nothing. | |
| Returns: | |
| None | |
| """ | |
| if self.tokenizer is None: | |
| llm_manager = Llama2ModelManager() # Initialize Llama2 model manager. | |
| # we only need the tokenizer for the data inspection not the model itself. | |
| self.tokenizer = llm_manager.load_tokenizer() | |
| llm_manager.add_special_tokens() # Add special tokens specific to LLAMA2 vocab for efficient tokenization. | |
| def load_dataset(self) -> Dataset: | |
| """ | |
| Loads the dataset from the specified file path. The dataset is expected to be in CSV format. | |
| Returns: | |
| Dataset: The loaded dataset, ready for processing. | |
| """ | |
| return load_dataset('csv', data_files=self.dataset_file) | |
| def plot_tokens_count_distribution(self, token_counts: List[int], title: str = "Token Count Distribution") -> None: | |
| """ | |
| Plots the distribution of token counts in the dataset for visualization purposes. | |
| Args: | |
| token_counts (List[int]): List of token counts, each count representing the number of tokens in a dataset | |
| sample. | |
| title (str): Title for the plot, highlighting the nature of the distribution. | |
| Returns: | |
| None | |
| """ | |
| if is_pycharm(): # Ensuring compatibility with PyCharm's environment for interactive plot. | |
| import matplotlib # The import is kept here intentionaly. | |
| matplotlib.use('TkAgg') # Set the backend to 'TkAgg' | |
| import matplotlib.pyplot as plt # The import is kept here intentionaly. | |
| sns.set_style("whitegrid") | |
| plt.figure(figsize=(15, 6)) | |
| plt.hist(token_counts, bins=50, color='#3498db', edgecolor='black') | |
| plt.title(title, fontsize=16) | |
| plt.xlabel("Number of Tokens", fontsize=14) | |
| plt.ylabel("Number of Samples", fontsize=14) | |
| plt.xticks(fontsize=12) | |
| plt.yticks(fontsize=12) | |
| plt.tight_layout() | |
| plt.show() | |
| def filter_dataset_by_indices(self, dataset: Dataset, valid_indices: List[int]) -> Dataset: | |
| """ | |
| Filters the dataset based on a list of valid indices. This method is used to exclude | |
| data samples that have a token count exceeding the specified maximum token count. | |
| Args: | |
| dataset (Dataset): The dataset to be filtered. | |
| valid_indices (List[int]): Indices of samples with token counts within the limit. | |
| Returns: | |
| Dataset: Filtered dataset containing only samples with valid indices. | |
| """ | |
| return dataset['train'].select(valid_indices) # Select only samples with valid indices based on token count. | |
| def get_token_counts(self, dataset: Dataset) -> List[int]: | |
| """ | |
| Calculates and returns the token counts for each sample in the dataset. | |
| This function assumes the dataset has a 'train' split and a 'text' field. | |
| Args: | |
| dataset (Dataset): The dataset for which to count tokens. | |
| Returns: | |
| List[int]: List of token counts per sample in the dataset. | |
| """ | |
| if 'train' in dataset: | |
| return [len(self.tokenizer.tokenize(s)) for s in dataset["train"]["text"]] | |
| else: | |
| # After filtering the samples with unacceptable token count, the dataset is | |
| # already `dataset = dataset['train']`. | |
| return [len(self.tokenizer.tokenize(s)) for s in dataset["text"]] | |
| def prepare_dataset(self) -> Tuple[Dataset, Dataset]: | |
| """ | |
| Prepares the dataset for fine-tuning by tokenizing the data and filtering out samples | |
| that exceed the maximum used context window (configurable through max_token_count). | |
| It also visualizes the token count distribution before and after filtering. | |
| Returns: | |
| Tuple[Dataset, Dataset]: The train and evaluate datasets, post-filtering. | |
| """ | |
| dataset = self.load_dataset() | |
| self.load_llm_tokenizer() | |
| # Count tokens in each dataset sample before filtering | |
| token_counts_before_filtering = self.get_token_counts(dataset) | |
| # Plot token count distribution before filtering for visualization. | |
| self.plot_tokens_count_distribution(token_counts_before_filtering, "Token Count Distribution Before Filtration") | |
| # Identify valid indices based on max token count. | |
| valid_indices = [i for i, count in enumerate(token_counts_before_filtering) if count <= self.max_token_count] | |
| # Filter the dataset to exclude samples with excessive token counts. | |
| filtered_dataset = self.filter_dataset_by_indices(dataset, valid_indices) | |
| token_counts_after_filtering = self.get_token_counts(filtered_dataset) | |
| self.plot_tokens_count_distribution(token_counts_after_filtering, "Token Count Distribution After Filtration") | |
| return self.split_dataset_for_train_eval(filtered_dataset) # split the dataset into training and evaluation. | |
| def split_dataset_for_train_eval(self, dataset: Dataset) -> Tuple[Dataset, Dataset]: | |
| """ | |
| Splits the dataset into training and evaluation datasets. | |
| Args: | |
| dataset (Dataset): The dataset to split. | |
| Returns: | |
| Tuple[Dataset, Dataset]: The split training and evaluation datasets. | |
| """ | |
| split_data = dataset.train_test_split(test_size=config.TEST_SIZE, shuffle=True, seed=config.SEED) | |
| train_data, eval_data = split_data['train'], split_data['test'] | |
| return train_data, eval_data | |
| def inspect_prepare_split_data(self) -> Tuple[Dataset, Dataset]: | |
| """ | |
| Orchestrates the process of inspecting, preparing, and splitting the dataset for fine-tuning. | |
| Returns: | |
| Tuple[Dataset, Dataset]: The prepared training and evaluation datasets. | |
| """ | |
| return self.prepare_dataset() | |
| # Example usage | |
| if __name__ == "__main__": | |
| # Please uncomment the below lines to test the data prep. | |
| # data_handler = FinetuningDataHandler() | |
| # fine_tuning_data_train, fine_tuning_data_eval = data_handler.inspect_prepare_split_data() | |
| # print(fine_tuning_data_train, fine_tuning_data_eval) | |
| pass | |