| | import os |
| | import sys |
| | import string |
| | from tqdm import tqdm |
| | from collections import defaultdict |
| | from typing import List, Tuple, Dict |
| |
|
| |
|
| | def read_lines(fname: str) -> List[str]: |
| | """ |
| | Reads all lines from an input file and returns them as a list of strings. |
| | |
| | Args: |
| | fname (str): path to the input file to read |
| | |
| | Returns: |
| | List[str]: a list of strings, where each string is a line from the file |
| | and returns an empty list if the file does not exist. |
| | """ |
| | |
| | if not os.path.exists(fname): |
| | return [] |
| |
|
| | with open(fname, "r") as f: |
| | lines = f.readlines() |
| | return lines |
| |
|
| |
|
| | def create_txt(out_file: str, lines: List[str]): |
| | """ |
| | Creates a text file and writes the given list of lines to file. |
| | |
| | Args: |
| | out_file (str): path to the output file to be created. |
| | lines (List[str]): a list of strings to be written to the output file. |
| | """ |
| | add_newline = not "\n" in lines[0] |
| | outfile = open("{}".format(out_file), "w", encoding="utf-8") |
| | for line in lines: |
| | if add_newline: |
| | outfile.write(line + "\n") |
| | else: |
| | outfile.write(line) |
| | outfile.close() |
| |
|
| |
|
| | def pair_dedup_lists(src_list: List[str], tgt_list: List[str]) -> Tuple[List[str], List[str]]: |
| | """ |
| | Removes duplicates from two lists by pairing their elements and removing duplicates from the pairs. |
| | |
| | Args: |
| | src_list (List[str]): a list of strings from source language data. |
| | tgt_list (List[str]): a list of strings from target language data. |
| | |
| | Returns: |
| | Tuple[List[str], List[str]]: a tuple of deduplicated version of "`(src_list, tgt_list)`". |
| | """ |
| | src_tgt = list(set(zip(src_list, tgt_list))) |
| | src_deduped, tgt_deduped = zip(*src_tgt) |
| | return src_deduped, tgt_deduped |
| |
|
| |
|
| | def pair_dedup_files(src_file: str, tgt_file: str): |
| | """ |
| | Removes duplicates from two files by pairing their lines and removing duplicates from the pairs. |
| | |
| | Args: |
| | src_file (str): path to the source language file to deduplicate. |
| | tgt_file (str): path to the target language file to deduplicate. |
| | """ |
| | src_lines = read_lines(src_file) |
| | tgt_lines = read_lines(tgt_file) |
| | len_before = len(src_lines) |
| |
|
| | src_dedupped, tgt_dedupped = pair_dedup_lists(src_lines, tgt_lines) |
| |
|
| | len_after = len(src_dedupped) |
| | num_duplicates = len_before - len_after |
| |
|
| | print(f"Dropped duplicate pairs in {src_file} Num duplicates -> {num_duplicates}") |
| | create_txt(src_file, src_dedupped) |
| | create_txt(tgt_file, tgt_dedupped) |
| |
|
| |
|
| | def strip_and_normalize(line: str) -> str: |
| | """ |
| | Strips and normalizes a string by lowercasing it, removing spaces and punctuation. |
| | |
| | Args: |
| | line (str): string to strip and normalize. |
| | |
| | Returns: |
| | str: stripped and normalized version of the input string. |
| | """ |
| | |
| |
|
| | |
| | |
| | |
| | exclist = string.punctuation + "\u0964" |
| | table_ = str.maketrans("", "", exclist) |
| |
|
| | line = line.replace(" ", "").lower() |
| | |
| | |
| | line = line.translate(table_) |
| | return line |
| |
|
| |
|
| | def expand_tupled_list(list_of_tuples: List[Tuple[str, str]]) -> Tuple[List[str], List[str]]: |
| | """ |
| | Expands a list of tuples into two lists by extracting the first and second elements of the tuples. |
| | |
| | Args: |
| | list_of_tuples (List[Tuple[str, str]]): a list of tuples, where each tuple contains two strings. |
| | |
| | Returns: |
| | Tuple[List[str], List[str]]: a tuple containing two lists, the first being the first elements of the |
| | tuples in `list_of_tuples` and the second being the second elements. |
| | """ |
| | |
| | |
| | |
| | list_a, list_b = map(list, zip(*list_of_tuples)) |
| | return list_a, list_b |
| |
|
| |
|
| | def normalize_and_gather_all_benchmarks(devtest_dir: str) -> Dict[str, Dict[str, List[str]]]: |
| | """ |
| | Normalizes and gathers all benchmark datasets from a directory into a dictionary. |
| | |
| | Args: |
| | devtest_dir (str): path to the directory containing the subdirectories named after the benchmark datasets, \ |
| | where each subdirectory is named in the format "`src_lang-tgt_lang`" and contain four files: `dev.src_lang`, \ |
| | `dev.tgt_lang`, `test.src_lang`, and `test.tgt_lang` representing the development and test sets for the language pair. |
| | |
| | Returns: |
| | Dict[str, Dict[str, List[str]]]: a dictionary mapping language pairs (in the format "`src_lang-tgt_lang`") \ |
| | to dictionaries containing two lists, the first being the normalized source language lines and the \ |
| | second being the normalized target language lines for all benchmark datasets. |
| | """ |
| | devtest_pairs_normalized = defaultdict(lambda: defaultdict(list)) |
| |
|
| | for benchmark in os.listdir(devtest_dir): |
| | print(f"{devtest_dir}/{benchmark}") |
| | for pair in tqdm(os.listdir(f"{devtest_dir}/{benchmark}")): |
| | src_lang, tgt_lang = pair.split("-") |
| |
|
| | src_dev = read_lines(f"{devtest_dir}/{benchmark}/{pair}/dev.{src_lang}") |
| | tgt_dev = read_lines(f"{devtest_dir}/{benchmark}/{pair}/dev.{tgt_lang}") |
| | src_test = read_lines(f"{devtest_dir}/{benchmark}/{pair}/test.{src_lang}") |
| | tgt_test = read_lines(f"{devtest_dir}/{benchmark}/{pair}/test.{tgt_lang}") |
| |
|
| | |
| | |
| | if tgt_test == [] or tgt_dev == []: |
| | print(f"{benchmark} does not have {src_lang}-{tgt_lang} data") |
| | continue |
| |
|
| | |
| | src_devtest = src_dev + src_test |
| | tgt_devtest = tgt_dev + tgt_test |
| |
|
| | src_devtest = [strip_and_normalize(line) for line in src_devtest] |
| | tgt_devtest = [strip_and_normalize(line) for line in tgt_devtest] |
| |
|
| | devtest_pairs_normalized[pair]["src"].extend(src_devtest) |
| | devtest_pairs_normalized[pair]["tgt"].extend(tgt_devtest) |
| |
|
| | |
| | for pair in devtest_pairs_normalized: |
| | src_devtest = devtest_pairs_normalized[pair]["src"] |
| | tgt_devtest = devtest_pairs_normalized[pair]["tgt"] |
| |
|
| | src_devtest, tgt_devtest = pair_dedup_lists(src_devtest, tgt_devtest) |
| | devtest_pairs_normalized[pair]["src"] = src_devtest |
| | devtest_pairs_normalized[pair]["tgt"] = tgt_devtest |
| |
|
| | return devtest_pairs_normalized |
| |
|
| |
|
| | def remove_train_devtest_overlaps(train_dir: str, devtest_dir: str): |
| | """ |
| | Removes overlapping data between the training and dev/test (benchmark) |
| | datasets for all language pairs. |
| | |
| | Args: |
| | train_dir (str): path of the directory containing the training data. |
| | devtest_dir (str): path of the directory containing the dev/test data. |
| | """ |
| | devtest_pairs_normalized = normalize_and_gather_all_benchmarks(devtest_dir) |
| |
|
| | all_src_sentences_normalized = [] |
| | for key in devtest_pairs_normalized: |
| | all_src_sentences_normalized.extend(devtest_pairs_normalized[key]["src"]) |
| | |
| | |
| | all_src_sentences_normalized = list(set(all_src_sentences_normalized)) |
| |
|
| | src_overlaps = [] |
| | tgt_overlaps = [] |
| |
|
| | pairs = os.listdir(train_dir) |
| | for pair in pairs: |
| | src_lang, tgt_lang = pair.split("-") |
| |
|
| | new_src_train, new_tgt_train = [], [] |
| |
|
| | src_train = read_lines(f"{train_dir}/{pair}/train.{src_lang}") |
| | tgt_train = read_lines(f"{train_dir}/{pair}/train.{tgt_lang}") |
| |
|
| | len_before = len(src_train) |
| | if len_before == 0: |
| | continue |
| |
|
| | src_train_normalized = [strip_and_normalize(line) for line in src_train] |
| | tgt_train_normalized = [strip_and_normalize(line) for line in tgt_train] |
| |
|
| | src_devtest_normalized = all_src_sentences_normalized |
| | tgt_devtest_normalized = devtest_pairs_normalized[pair]["tgt"] |
| |
|
| | |
| | overlaps = set(src_train_normalized) & set(src_devtest_normalized) |
| | src_overlaps.extend(list(overlaps)) |
| |
|
| | overlaps = set(tgt_train_normalized) & set(tgt_devtest_normalized) |
| | tgt_overlaps.extend(list(overlaps)) |
| |
|
| | |
| | src_overlaps_dict, tgt_overlaps_dict = {}, {} |
| | for line in src_overlaps: |
| | src_overlaps_dict[line] = 1 |
| | for line in tgt_overlaps: |
| | tgt_overlaps_dict[line] = 1 |
| |
|
| | |
| | idx = 0 |
| | for src_line_norm, tgt_line_norm in tqdm( |
| | zip(src_train_normalized, tgt_train_normalized), total=len_before |
| | ): |
| | if src_overlaps_dict.get(src_line_norm, None): |
| | continue |
| | if tgt_overlaps_dict.get(tgt_line_norm, None): |
| | continue |
| |
|
| | new_src_train.append(src_train[idx]) |
| | new_tgt_train.append(tgt_train[idx]) |
| | idx += 1 |
| |
|
| | len_after = len(new_src_train) |
| | print( |
| | f"Detected overlaps between train and devetest for {pair} is {len_before - len_after}" |
| | ) |
| | print(f"saving new files at {train_dir}/{pair}/") |
| | create_txt(f"{train_dir}/{pair}/train.{src_lang}", new_src_train) |
| | create_txt(f"{train_dir}/{pair}/train.{tgt_lang}", new_tgt_train) |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | train_data_dir = sys.argv[1] |
| | |
| | devtest_data_dir = sys.argv[2] |
| |
|
| | remove_train_devtest_overlaps(train_data_dir, devtest_data_dir) |
| |
|