Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| from itertools import zip_longest | |
| def replace_oovs(source_in, target_in, vocabulary, source_out, target_out): | |
| """Replaces out-of-vocabulary words in source and target text with <unk-N>, | |
| where N in is the position of the word in the source sequence. | |
| """ | |
| def format_unk(pos): | |
| return "<unk-{}>".format(pos) | |
| if target_in is None: | |
| target_in = [] | |
| for seq_num, (source_seq, target_seq) in enumerate( | |
| zip_longest(source_in, target_in) | |
| ): | |
| source_seq_out = [] | |
| target_seq_out = [] | |
| word_to_pos = dict() | |
| for position, token in enumerate(source_seq.strip().split()): | |
| if token in vocabulary: | |
| token_out = token | |
| else: | |
| if token in word_to_pos: | |
| oov_pos = word_to_pos[token] | |
| else: | |
| word_to_pos[token] = position | |
| oov_pos = position | |
| token_out = format_unk(oov_pos) | |
| source_seq_out.append(token_out) | |
| source_out.write(" ".join(source_seq_out) + "\n") | |
| if target_seq is not None: | |
| for token in target_seq.strip().split(): | |
| if token in word_to_pos: | |
| token_out = format_unk(word_to_pos[token]) | |
| else: | |
| token_out = token | |
| target_seq_out.append(token_out) | |
| if target_out is not None: | |
| target_out.write(" ".join(target_seq_out) + "\n") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Replaces out-of-vocabulary words in both source and target " | |
| "sequences with tokens that indicate the position of the word " | |
| "in the source sequence." | |
| ) | |
| parser.add_argument( | |
| "--source", type=str, help="text file with source sequences", required=True | |
| ) | |
| parser.add_argument( | |
| "--target", type=str, help="text file with target sequences", default=None | |
| ) | |
| parser.add_argument("--vocab", type=str, help="vocabulary file", required=True) | |
| parser.add_argument( | |
| "--source-out", | |
| type=str, | |
| help="where to write source sequences with <unk-N> entries", | |
| required=True, | |
| ) | |
| parser.add_argument( | |
| "--target-out", | |
| type=str, | |
| help="where to write target sequences with <unk-N> entries", | |
| default=None, | |
| ) | |
| args = parser.parse_args() | |
| with open(args.vocab, encoding="utf-8") as vocab: | |
| vocabulary = vocab.read().splitlines() | |
| target_in = ( | |
| open(args.target, "r", encoding="utf-8") if args.target is not None else None | |
| ) | |
| target_out = ( | |
| open(args.target_out, "w", encoding="utf-8") | |
| if args.target_out is not None | |
| else None | |
| ) | |
| with open(args.source, "r", encoding="utf-8") as source_in, open( | |
| args.source_out, "w", encoding="utf-8" | |
| ) as source_out: | |
| replace_oovs(source_in, target_in, vocabulary, source_out, target_out) | |
| if target_in is not None: | |
| target_in.close() | |
| if target_out is not None: | |
| target_out.close() | |
| if __name__ == "__main__": | |
| main() | |