Spaces:
Running
Running
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
BLEU scoring of generated translations against reference translations. | |
""" | |
import argparse | |
import os | |
import sys | |
from fairseq.data import dictionary | |
from fairseq.scoring import bleu | |
def get_parser(): | |
parser = argparse.ArgumentParser( | |
description="Command-line script for BLEU scoring." | |
) | |
# fmt: off | |
parser.add_argument('-s', '--sys', default='-', help='system output') | |
parser.add_argument('-r', '--ref', required=True, help='references') | |
parser.add_argument('-o', '--order', default=4, metavar='N', | |
type=int, help='consider ngrams up to this order') | |
parser.add_argument('--ignore-case', action='store_true', | |
help='case-insensitive scoring') | |
parser.add_argument('--sacrebleu', action='store_true', | |
help='score with sacrebleu') | |
parser.add_argument('--sentence-bleu', action='store_true', | |
help='report sentence-level BLEUs (i.e., with +1 smoothing)') | |
# fmt: on | |
return parser | |
def cli_main(): | |
parser = get_parser() | |
args = parser.parse_args() | |
print(args) | |
assert args.sys == "-" or os.path.exists( | |
args.sys | |
), "System output file {} does not exist".format(args.sys) | |
assert os.path.exists(args.ref), "Reference file {} does not exist".format(args.ref) | |
dict = dictionary.Dictionary() | |
def readlines(fd): | |
for line in fd.readlines(): | |
if args.ignore_case: | |
yield line.lower() | |
else: | |
yield line | |
if args.sacrebleu: | |
import sacrebleu | |
def score(fdsys): | |
with open(args.ref) as fdref: | |
print(sacrebleu.corpus_bleu(fdsys, [fdref]).format()) | |
elif args.sentence_bleu: | |
def score(fdsys): | |
with open(args.ref) as fdref: | |
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) | |
for i, (sys_tok, ref_tok) in enumerate( | |
zip(readlines(fdsys), readlines(fdref)) | |
): | |
scorer.reset(one_init=True) | |
sys_tok = dict.encode_line(sys_tok) | |
ref_tok = dict.encode_line(ref_tok) | |
scorer.add(ref_tok, sys_tok) | |
print(i, scorer.result_string(args.order)) | |
else: | |
def score(fdsys): | |
with open(args.ref) as fdref: | |
scorer = bleu.Scorer( | |
bleu.BleuConfig( | |
pad=dict.pad(), | |
eos=dict.eos(), | |
unk=dict.unk(), | |
) | |
) | |
for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)): | |
sys_tok = dict.encode_line(sys_tok) | |
ref_tok = dict.encode_line(ref_tok) | |
scorer.add(ref_tok, sys_tok) | |
print(scorer.result_string(args.order)) | |
if args.sys == "-": | |
score(sys.stdin) | |
else: | |
with open(args.sys, "r") as f: | |
score(f) | |
if __name__ == "__main__": | |
cli_main() | |