Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 -u | |
| import argparse | |
| import fileinput | |
| import logging | |
| import os | |
| import sys | |
| from fairseq.models.transformer import TransformerModel | |
| logging.getLogger().setLevel(logging.INFO) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="") | |
| parser.add_argument("--en2fr", required=True, help="path to en2fr model") | |
| parser.add_argument( | |
| "--fr2en", required=True, help="path to fr2en mixture of experts model" | |
| ) | |
| parser.add_argument( | |
| "--user-dir", help="path to fairseq examples/translation_moe/src directory" | |
| ) | |
| parser.add_argument( | |
| "--num-experts", | |
| type=int, | |
| default=10, | |
| help="(keep at 10 unless using a different model)", | |
| ) | |
| parser.add_argument( | |
| "files", | |
| nargs="*", | |
| default=["-"], | |
| help='input files to paraphrase; "-" for stdin', | |
| ) | |
| args = parser.parse_args() | |
| if args.user_dir is None: | |
| args.user_dir = os.path.join( | |
| os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # examples/ | |
| "translation_moe", | |
| "src", | |
| ) | |
| if os.path.exists(args.user_dir): | |
| logging.info("found user_dir:" + args.user_dir) | |
| else: | |
| raise RuntimeError( | |
| "cannot find fairseq examples/translation_moe/src " | |
| "(tried looking here: {})".format(args.user_dir) | |
| ) | |
| logging.info("loading en2fr model from:" + args.en2fr) | |
| en2fr = TransformerModel.from_pretrained( | |
| model_name_or_path=args.en2fr, | |
| tokenizer="moses", | |
| bpe="sentencepiece", | |
| ).eval() | |
| logging.info("loading fr2en model from:" + args.fr2en) | |
| fr2en = TransformerModel.from_pretrained( | |
| model_name_or_path=args.fr2en, | |
| tokenizer="moses", | |
| bpe="sentencepiece", | |
| user_dir=args.user_dir, | |
| task="translation_moe", | |
| ).eval() | |
| def gen_paraphrases(en): | |
| fr = en2fr.translate(en) | |
| return [ | |
| fr2en.translate(fr, inference_step_args={"expert": i}) | |
| for i in range(args.num_experts) | |
| ] | |
| logging.info("Type the input sentence and press return:") | |
| for line in fileinput.input(args.files): | |
| line = line.strip() | |
| if len(line) == 0: | |
| continue | |
| for paraphrase in gen_paraphrases(line): | |
| print(paraphrase) | |
| if __name__ == "__main__": | |
| main() | |