Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| import re | |
| import sys | |
| class OOVIndexError(IndexError): | |
| def __init__(self, pos, source_seq, target_seq): | |
| super(OOVIndexError, self).__init__( | |
| "A <unk-N> tag in the target sequence refers to a position that is " | |
| "outside the source sequence. Most likely there was a mismatch in " | |
| "provided source and target sequences. Otherwise this would mean that " | |
| "the pointing mechanism somehow attended to a position that is past " | |
| "the actual sequence end." | |
| ) | |
| self.source_pos = pos | |
| self.source_seq = source_seq | |
| self.target_seq = target_seq | |
| def replace_oovs(source_in, target_in, target_out): | |
| """Replaces <unk-N> tokens in the target text with the corresponding word in | |
| the source text. | |
| """ | |
| oov_re = re.compile("^<unk-([0-9]+)>$") | |
| for source_seq, target_seq in zip(source_in, target_in): | |
| target_seq_out = [] | |
| pos_to_word = source_seq.strip().split() | |
| for token in target_seq.strip().split(): | |
| m = oov_re.match(token) | |
| if m: | |
| pos = int(m.group(1)) | |
| if pos >= len(pos_to_word): | |
| raise OOVIndexError(pos, source_seq, target_seq) | |
| token_out = pos_to_word[pos] | |
| else: | |
| token_out = token | |
| target_seq_out.append(token_out) | |
| target_out.write(" ".join(target_seq_out) + "\n") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Replaces <unk-N> tokens in target sequences with words from " | |
| "the corresponding position in the source sequence." | |
| ) | |
| parser.add_argument( | |
| "--source", type=str, help="text file with source sequences", required=True | |
| ) | |
| parser.add_argument( | |
| "--target", type=str, help="text file with target sequences", required=True | |
| ) | |
| parser.add_argument( | |
| "--target-out", | |
| type=str, | |
| help="where to write target sequences without <unk-N> " "entries", | |
| required=True, | |
| ) | |
| args = parser.parse_args() | |
| target_in = ( | |
| open(args.target, "r", encoding="utf-8") if args.target is not None else None | |
| ) | |
| target_out = ( | |
| open(args.target_out, "w", encoding="utf-8") | |
| if args.target_out is not None | |
| else None | |
| ) | |
| with open(args.source, "r", encoding="utf-8") as source_in, open( | |
| args.target, "r", encoding="utf-8" | |
| ) as target_in, open(args.target_out, "w", encoding="utf-8") as target_out: | |
| replace_oovs(source_in, target_in, target_out) | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except OOVIndexError as e: | |
| print(e, file=sys.stderr) | |
| print("Source sequence:", e.source_seq.strip(), file=sys.stderr) | |
| print("Target sequence:", e.target_seq.strip(), file=sys.stderr) | |
| print( | |
| "Source sequence length:", | |
| len(e.source_seq.strip().split()), | |
| file=sys.stderr, | |
| ) | |
| print("The offending tag points to:", e.source_pos) | |
| sys.exit(2) | |