Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import sys | |
import torch | |
from fairseq import utils | |
class SequenceScorer(object): | |
"""Scores the target for a given source sentence.""" | |
def __init__( | |
self, | |
tgt_dict, | |
softmax_batch=None, | |
compute_alignment=False, | |
eos=None, | |
symbols_to_strip_from_output=None, | |
): | |
self.pad = tgt_dict.pad() | |
self.eos = tgt_dict.eos() if eos is None else eos | |
self.softmax_batch = softmax_batch or sys.maxsize | |
assert self.softmax_batch > 0 | |
self.compute_alignment = compute_alignment | |
self.symbols_to_strip_from_output = ( | |
symbols_to_strip_from_output.union({self.eos}) | |
if symbols_to_strip_from_output is not None | |
else {self.eos} | |
) | |
def generate(self, models, sample, **kwargs): | |
"""Score a batch of translations.""" | |
net_input = sample["net_input"] | |
def batch_for_softmax(dec_out, target): | |
# assumes decoder_out[0] is the only thing needed (may not be correct for future models!) | |
first, rest = dec_out[0], dec_out[1:] | |
bsz, tsz, dim = first.shape | |
if bsz * tsz < self.softmax_batch: | |
yield dec_out, target, True | |
else: | |
flat = first.contiguous().view(1, -1, dim) | |
flat_tgt = target.contiguous().view(flat.shape[:-1]) | |
s = 0 | |
while s < flat.size(1): | |
e = s + self.softmax_batch | |
yield (flat[:, s:e],) + rest, flat_tgt[:, s:e], False | |
s = e | |
def gather_target_probs(probs, target): | |
probs = probs.gather( | |
dim=2, | |
index=target.unsqueeze(-1), | |
) | |
return probs | |
orig_target = sample["target"] | |
# compute scores for each model in the ensemble | |
avg_probs = None | |
avg_attn = None | |
for model in models: | |
model.eval() | |
decoder_out = model(**net_input) | |
attn = decoder_out[1] if len(decoder_out) > 1 else None | |
if type(attn) is dict: | |
attn = attn.get("attn", None) | |
batched = batch_for_softmax(decoder_out, orig_target) | |
probs, idx = None, 0 | |
for bd, tgt, is_single in batched: | |
sample["target"] = tgt | |
curr_prob = model.get_normalized_probs( | |
bd, log_probs=len(models) == 1, sample=sample | |
).data | |
if is_single: | |
probs = gather_target_probs(curr_prob, orig_target) | |
else: | |
if probs is None: | |
probs = curr_prob.new(orig_target.numel()) | |
step = curr_prob.size(0) * curr_prob.size(1) | |
end = step + idx | |
tgt_probs = gather_target_probs( | |
curr_prob.view(tgt.shape + (curr_prob.size(-1),)), tgt | |
) | |
probs[idx:end] = tgt_probs.view(-1) | |
idx = end | |
sample["target"] = orig_target | |
probs = probs.view(sample["target"].shape) | |
if avg_probs is None: | |
avg_probs = probs | |
else: | |
avg_probs.add_(probs) | |
if attn is not None: | |
if torch.is_tensor(attn): | |
attn = attn.data | |
else: | |
attn = attn[0] | |
if avg_attn is None: | |
avg_attn = attn | |
else: | |
avg_attn.add_(attn) | |
if len(models) > 1: | |
avg_probs.div_(len(models)) | |
avg_probs.log_() | |
if avg_attn is not None: | |
avg_attn.div_(len(models)) | |
bsz = avg_probs.size(0) | |
hypos = [] | |
start_idxs = sample["start_indices"] if "start_indices" in sample else [0] * bsz | |
for i in range(bsz): | |
# remove padding from ref | |
ref = ( | |
utils.strip_pad(sample["target"][i, start_idxs[i] :], self.pad) | |
if sample["target"] is not None | |
else None | |
) | |
tgt_len = ref.numel() | |
avg_probs_i = avg_probs[i][start_idxs[i] : start_idxs[i] + tgt_len] | |
score_i = avg_probs_i.sum() / tgt_len | |
if avg_attn is not None: | |
avg_attn_i = avg_attn[i] | |
if self.compute_alignment: | |
alignment = utils.extract_hard_alignment( | |
avg_attn_i, | |
sample["net_input"]["src_tokens"][i], | |
sample["target"][i], | |
self.pad, | |
self.eos, | |
) | |
else: | |
alignment = None | |
else: | |
avg_attn_i = alignment = None | |
hypos.append( | |
[ | |
{ | |
"tokens": ref, | |
"score": score_i, | |
"attention": avg_attn_i, | |
"alignment": alignment, | |
"positional_scores": avg_probs_i, | |
} | |
] | |
) | |
return hypos | |