|
|
|
|
|
|
|
""" Wrapper for ngram_repeat_block cuda extension """ |
|
import torch |
|
from torch import nn |
|
|
|
import math |
|
from typing import Dict, List, Optional |
|
import warnings |
|
|
|
try: |
|
from fairseq import ngram_repeat_block_cuda |
|
|
|
EXTENSION_BUILT = True |
|
except ImportError: |
|
EXTENSION_BUILT = False |
|
|
|
|
|
def is_cuda_extension_usable() -> bool: |
|
"""Check whether ngram_repeat_block_cuda is built properly""" |
|
if not EXTENSION_BUILT or not torch.cuda.is_available(): |
|
return False |
|
bsz = 2 |
|
tokens = torch.tensor([[4, 4, 3, 2], [1, 2, 3, 4]], dtype=torch.long, device="cuda") |
|
lprobs = torch.rand((8, 12), device="cuda") |
|
try: |
|
outputs = ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, 3, 4, 3) |
|
outputs = outputs + 4 |
|
return True |
|
except RuntimeError: |
|
warnings.warn( |
|
"NGramRepeatBlock extension must be rebuilt." |
|
'Run TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0" python setup.py build_ext --inplace' |
|
) |
|
return False |
|
|
|
|
|
class NGramRepeatBlock(nn.Module): |
|
""" Wrapper class for calling ngram_repeat_block cuda extension """ |
|
|
|
def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True): |
|
super().__init__() |
|
self.use_extension = is_cuda_extension_usable() if use_extension else False |
|
self.no_repeat_ngram_size = no_repeat_ngram_size |
|
|
|
def reset_parameters(self): |
|
pass |
|
|
|
@torch.jit.unused |
|
def call_cuda_extension( |
|
self, |
|
tokens, |
|
lprobs, |
|
bsz: int, |
|
beam_size: int, |
|
step: int, |
|
): |
|
return ngram_repeat_block_cuda.forward( |
|
tokens, lprobs, bsz, step, beam_size, self.no_repeat_ngram_size |
|
) |
|
|
|
def forward( |
|
self, |
|
tokens, |
|
lprobs, |
|
bsz: int, |
|
beam_size: int, |
|
step: int, |
|
): |
|
""" |
|
Args: |
|
tokens(Tensor): Input tokens(Bsz*beam, seq_len) |
|
lprobs(Tensor): likelihood probability, |
|
Expected to be updated in place.(Bsz*beam, vocab_size) |
|
bsz(int): batch size |
|
step(int): current step |
|
beam_size(int): beam size |
|
no_repeat_ngram_size(int): Ngram size |
|
""" |
|
msg = f"expected {bsz *beam_size} got" |
|
assert tokens.size(0) == bsz * beam_size, f"{msg} {tokens.size(0)}" |
|
assert lprobs.size(0) == bsz * beam_size, f"{msg} {lprobs.size(0)}" |
|
if self.use_extension: |
|
return self.call_cuda_extension(tokens, lprobs, bsz, beam_size, step) |
|
|
|
else: |
|
return self._no_repeat_ngram( |
|
tokens, |
|
lprobs, |
|
bsz, |
|
beam_size, |
|
step, |
|
) |
|
|
|
def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, step: int): |
|
"""For each hypothesis generate a list of previous ngrams and set associated lprobs to -inf""" |
|
gen_ngrams: List[Dict[str, List[int]]] = [ |
|
torch.jit.annotate(Dict[str, List[int]], {}) |
|
for bbsz_idx in range(bsz * beam_size) |
|
] |
|
cpu_tokens = tokens.cpu() |
|
for bbsz_idx in range(bsz * beam_size): |
|
gen_tokens: List[int] = cpu_tokens[bbsz_idx].tolist() |
|
for ngram in self.transpose_list( |
|
[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)] |
|
): |
|
key = ",".join([str(x) for x in ngram[:-1]]) |
|
gen_ngrams[bbsz_idx][key] = gen_ngrams[bbsz_idx].get( |
|
key, torch.jit.annotate(List[int], []) |
|
) + [ngram[-1]] |
|
if step + 2 - self.no_repeat_ngram_size >= 0: |
|
|
|
banned_tokens = [ |
|
self.calculate_banned_tokens( |
|
tokens, step, gen_ngrams, self.no_repeat_ngram_size, bbsz_idx |
|
) |
|
for bbsz_idx in range(bsz * beam_size) |
|
] |
|
else: |
|
banned_tokens = [ |
|
torch.jit.annotate(List[int], []) for bbsz_idx in range(bsz * beam_size) |
|
] |
|
for bbsz_idx in range(bsz * beam_size): |
|
lprobs[bbsz_idx][ |
|
torch.tensor(banned_tokens[bbsz_idx], dtype=torch.int64) |
|
] = torch.tensor(-math.inf).to(lprobs) |
|
return lprobs |
|
|
|
@staticmethod |
|
def calculate_banned_tokens( |
|
tokens, |
|
step: int, |
|
gen_ngrams: List[Dict[str, List[int]]], |
|
no_repeat_ngram_size: int, |
|
bbsz_idx: int, |
|
): |
|
tokens_list: List[int] = tokens[ |
|
bbsz_idx, step + 2 - no_repeat_ngram_size : step + 1 |
|
].tolist() |
|
|
|
ngram_index = ",".join([str(x) for x in tokens_list]) |
|
return gen_ngrams[bbsz_idx].get(ngram_index, torch.jit.annotate(List[int], [])) |
|
|
|
@staticmethod |
|
def transpose_list(l: List[List[int]]): |
|
|
|
min_len = min([len(x) for x in l]) |
|
l2 = [[row[i] for row in l] for i in range(min_len)] |
|
return l2 |
|
|