Spaces:
Running
Running
#!/usr/bin/env python3 -u | |
import argparse | |
import fileinput | |
import logging | |
import os | |
import sys | |
from fairseq.models.transformer import TransformerModel | |
logging.getLogger().setLevel(logging.INFO) | |
def main(): | |
parser = argparse.ArgumentParser(description="") | |
parser.add_argument("--en2fr", required=True, help="path to en2fr model") | |
parser.add_argument( | |
"--fr2en", required=True, help="path to fr2en mixture of experts model" | |
) | |
parser.add_argument( | |
"--user-dir", help="path to fairseq examples/translation_moe/src directory" | |
) | |
parser.add_argument( | |
"--num-experts", | |
type=int, | |
default=10, | |
help="(keep at 10 unless using a different model)", | |
) | |
parser.add_argument( | |
"files", | |
nargs="*", | |
default=["-"], | |
help='input files to paraphrase; "-" for stdin', | |
) | |
args = parser.parse_args() | |
if args.user_dir is None: | |
args.user_dir = os.path.join( | |
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # examples/ | |
"translation_moe", | |
"src", | |
) | |
if os.path.exists(args.user_dir): | |
logging.info("found user_dir:" + args.user_dir) | |
else: | |
raise RuntimeError( | |
"cannot find fairseq examples/translation_moe/src " | |
"(tried looking here: {})".format(args.user_dir) | |
) | |
logging.info("loading en2fr model from:" + args.en2fr) | |
en2fr = TransformerModel.from_pretrained( | |
model_name_or_path=args.en2fr, | |
tokenizer="moses", | |
bpe="sentencepiece", | |
).eval() | |
logging.info("loading fr2en model from:" + args.fr2en) | |
fr2en = TransformerModel.from_pretrained( | |
model_name_or_path=args.fr2en, | |
tokenizer="moses", | |
bpe="sentencepiece", | |
user_dir=args.user_dir, | |
task="translation_moe", | |
).eval() | |
def gen_paraphrases(en): | |
fr = en2fr.translate(en) | |
return [ | |
fr2en.translate(fr, inference_step_args={"expert": i}) | |
for i in range(args.num_experts) | |
] | |
logging.info("Type the input sentence and press return:") | |
for line in fileinput.input(args.files): | |
line = line.strip() | |
if len(line) == 0: | |
continue | |
for paraphrase in gen_paraphrases(line): | |
print(paraphrase) | |
if __name__ == "__main__": | |
main() | |