Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| import math | |
| import os | |
| import subprocess | |
| import sys | |
| import tempfile | |
| from collections import defaultdict | |
| from itertools import combinations | |
| def read_translations(path, n_repeats): | |
| segment_counter = 0 | |
| segment_translations = [] | |
| translations = defaultdict(list) | |
| for line in open(path): | |
| segment_translations.append(" ".join(line.split())) | |
| if len(segment_translations) == n_repeats: | |
| translations[segment_counter] = segment_translations | |
| segment_translations = [] | |
| segment_counter += 1 | |
| return translations | |
| def generate_input(translations, n_repeats): | |
| _, ref_path = tempfile.mkstemp() | |
| _, mt_path = tempfile.mkstemp() | |
| ref_fh = open(ref_path, "w") | |
| mt_fh = open(mt_path, "w") | |
| for segid in sorted(translations.keys()): | |
| assert len(translations[segid]) == n_repeats | |
| indexes = combinations(range(n_repeats), 2) | |
| for idx1, idx2 in indexes: | |
| mt_fh.write(translations[segid][idx1].strip() + "\n") | |
| ref_fh.write(translations[segid][idx2].strip() + "\n") | |
| sys.stderr.write("\nSaved translations to %s and %s" % (ref_path, mt_path)) | |
| return ref_path, mt_path | |
| def run_meteor(ref_path, mt_path, metric_path, lang="en"): | |
| _, out_path = tempfile.mkstemp() | |
| subprocess.call( | |
| [ | |
| "java", | |
| "-Xmx2G", | |
| "-jar", | |
| metric_path, | |
| mt_path, | |
| ref_path, | |
| "-p", | |
| "0.5 0.2 0.6 0.75", # default parameters, only changed alpha to give equal weight to P and R | |
| "-norm", | |
| "-l", | |
| lang, | |
| ], | |
| stdout=open(out_path, "w"), | |
| ) | |
| os.remove(ref_path) | |
| os.remove(mt_path) | |
| sys.stderr.write("\nSaved Meteor output to %s" % out_path) | |
| return out_path | |
| def read_output(meteor_output_path, n_repeats): | |
| n_combinations = math.factorial(n_repeats) / ( | |
| math.factorial(2) * math.factorial(n_repeats - 2) | |
| ) | |
| raw_scores = [] | |
| average_scores = [] | |
| for line in open(meteor_output_path): | |
| if not line.startswith("Segment "): | |
| continue | |
| score = float(line.strip().split("\t")[1]) | |
| raw_scores.append(score) | |
| if len(raw_scores) == n_combinations: | |
| average_scores.append(sum(raw_scores) / n_combinations) | |
| raw_scores = [] | |
| os.remove(meteor_output_path) | |
| return average_scores | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("-i", "--infile") | |
| parser.add_argument("-n", "--repeat_times", type=int) | |
| parser.add_argument("-m", "--meteor") | |
| parser.add_argument("-o", "--output") | |
| args = parser.parse_args() | |
| translations = read_translations(args.infile, args.repeat_times) | |
| sys.stderr.write("\nGenerating input for Meteor...") | |
| ref_path, mt_path = generate_input(translations, args.repeat_times) | |
| sys.stderr.write("\nRunning Meteor...") | |
| out_path = run_meteor(ref_path, mt_path, args.meteor) | |
| sys.stderr.write("\nReading output...") | |
| scores = read_output(out_path, args.repeat_times) | |
| sys.stderr.write("\nWriting results...") | |
| with open(args.output, "w") as o: | |
| for scr in scores: | |
| o.write("{}\n".format(scr)) | |
| o.close() | |
| if __name__ == "__main__": | |
| main() | |