Spaces:
Runtime error
Runtime error
| def get_length_param(text: str, tokenizer) -> str: | |
| """Maps text to 1 of 4 buckets based on length after encoding. | |
| Parameters | |
| ---------- | |
| text: str | |
| The text to be given 1 of 4 length parameters. | |
| tokenizer: HuggingFace tokenizer | |
| Tokenizer that used to compute the length of the text after encoding. | |
| For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html | |
| Returns | |
| ------- | |
| len_param: str | |
| One of four buckets: | |
| '1' for short, '2' for medium, '3' for long texts and '-' for all others. | |
| """ | |
| tokens_count = len(tokenizer.encode(text)) | |
| if tokens_count <= 15: | |
| len_param = '1' | |
| elif tokens_count <= 50: | |
| len_param = '2' | |
| elif tokens_count <= 256: | |
| len_param = '3' | |
| else: | |
| len_param = '-' | |
| return len_param | |
| def get_user_param(text: dict, machine_name_in_chat: str) -> str: | |
| """Maps text by 1/0 for it to be the person or the machine in the dialogue | |
| Parameters | |
| ---------- | |
| text: Dict[..., 'from', ...] | |
| Dict containing field 'from' with the name of the user who sent the message | |
| machine_name_in_chat: str | |
| Str with the name of the machine - it will be predicted | |
| """ | |
| if text['from'] == machine_name_in_chat: | |
| return '1' # machine | |
| else: | |
| return '0' # human | |
| def build_text_file(data_json: dict, dest_path: str, | |
| tokenizer, machine_name_in_chat='Кирилл Гельван'): | |
| """Create a text file for training in special format for ruDialoGPT-3. | |
| Parameters | |
| ---------- | |
| data_json: dict | |
| Dict containing 'text' (message) and 'from' (user who sent the message) | |
| dest_path: str | |
| String containing path to write data there | |
| tokenizer: HuggingFace tokenizer | |
| Tokenizer that used to compute the length of the text after encoding. | |
| For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html | |
| """ | |
| f = open(dest_path, 'w') | |
| new_data = '' | |
| for i in range(len(data_json) - 1): | |
| message, next_message = data_json[i], data_json[i+1] | |
| if message['text'] == '' or type(message['text']) != str: | |
| continue | |
| if next_message['text'] == '' or type(next_message['text']) != str: | |
| continue | |
| user = get_user_param(message, machine_name_in_chat=machine_name_in_chat) | |
| length = get_length_param(data_json[i+1]['text'], tokenizer) | |
| message_text = re.sub(r"\n", ". ", message['text']) | |
| new_data += f"|{user}|{length}|{message_text}{tokenizer.eos_token}" + "\n" | |
| f.write(new_data) | |
| def load_dataset(train_path, test_path, tokenizer): | |
| """Creates train and test PyTorch datasets and collate_fn using HuggingFace. | |
| Parameters | |
| ---------- | |
| train_path: str | |
| String containing path to train data | |
| test_path: str | |
| String containing path to test data | |
| tokenizer: HuggingFace tokenizer | |
| Tokenizer that used to compute the length of the text after encoding. | |
| For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html | |
| """ | |
| train_dataset = TextDataset( | |
| tokenizer = tokenizer, | |
| file_path = train_path, | |
| block_size = 256) | |
| test_dataset = TextDataset( | |
| tokenizer = tokenizer, | |
| file_path = test_path, | |
| block_size = 256) | |
| data_collator = DataCollatorForLanguageModeling( | |
| tokenizer=tokenizer, mlm=False | |
| ) | |
| return train_dataset, test_dataset, data_collator | |