HuBERT / fairseq /ngram_repeat_block.py
aliabd
full working demo
d5175d3
# Originally from Microsoft Corporation.
# Licensed under the MIT License.
""" Wrapper for ngram_repeat_block cuda extension """
import torch
from torch import nn
import math
from typing import Dict, List, Optional
import warnings
try:
from fairseq import ngram_repeat_block_cuda
EXTENSION_BUILT = True
except ImportError:
EXTENSION_BUILT = False
def is_cuda_extension_usable() -> bool:
"""Check whether ngram_repeat_block_cuda is built properly"""
if not EXTENSION_BUILT or not torch.cuda.is_available():
return False
bsz = 2
tokens = torch.tensor([[4, 4, 3, 2], [1, 2, 3, 4]], dtype=torch.long, device="cuda")
lprobs = torch.rand((8, 12), device="cuda")
try:
outputs = ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, 3, 4, 3)
outputs = outputs + 4 # This line breaks if the extension is built incorrectly.
return True
except RuntimeError:
warnings.warn(
"NGramRepeatBlock extension must be rebuilt."
'Run TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0" python setup.py build_ext --inplace'
)
return False
class NGramRepeatBlock(nn.Module):
""" Wrapper class for calling ngram_repeat_block cuda extension """
def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True):
super().__init__()
self.use_extension = is_cuda_extension_usable() if use_extension else False
self.no_repeat_ngram_size = no_repeat_ngram_size
def reset_parameters(self):
pass
@torch.jit.unused
def call_cuda_extension(
self,
tokens,
lprobs,
bsz: int,
beam_size: int,
step: int,
):
return ngram_repeat_block_cuda.forward(
tokens, lprobs, bsz, step, beam_size, self.no_repeat_ngram_size
)
def forward(
self,
tokens,
lprobs,
bsz: int,
beam_size: int,
step: int,
):
"""
Args:
tokens(Tensor): Input tokens(Bsz*beam, seq_len)
lprobs(Tensor): likelihood probability,
Expected to be updated in place.(Bsz*beam, vocab_size)
bsz(int): batch size
step(int): current step
beam_size(int): beam size
no_repeat_ngram_size(int): Ngram size
"""
msg = f"expected {bsz *beam_size} got"
assert tokens.size(0) == bsz * beam_size, f"{msg} {tokens.size(0)}"
assert lprobs.size(0) == bsz * beam_size, f"{msg} {lprobs.size(0)}"
if self.use_extension:
return self.call_cuda_extension(tokens, lprobs, bsz, beam_size, step)
else:
return self._no_repeat_ngram(
tokens,
lprobs,
bsz,
beam_size,
step,
)
def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, step: int):
"""For each hypothesis generate a list of previous ngrams and set associated lprobs to -inf"""
gen_ngrams: List[Dict[str, List[int]]] = [
torch.jit.annotate(Dict[str, List[int]], {})
for bbsz_idx in range(bsz * beam_size)
]
cpu_tokens = tokens.cpu()
for bbsz_idx in range(bsz * beam_size):
gen_tokens: List[int] = cpu_tokens[bbsz_idx].tolist()
for ngram in self.transpose_list(
[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]
):
key = ",".join([str(x) for x in ngram[:-1]])
gen_ngrams[bbsz_idx][key] = gen_ngrams[bbsz_idx].get(
key, torch.jit.annotate(List[int], [])
) + [ngram[-1]]
if step + 2 - self.no_repeat_ngram_size >= 0:
# no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
banned_tokens = [
self.calculate_banned_tokens(
tokens, step, gen_ngrams, self.no_repeat_ngram_size, bbsz_idx
)
for bbsz_idx in range(bsz * beam_size)
]
else:
banned_tokens = [
torch.jit.annotate(List[int], []) for bbsz_idx in range(bsz * beam_size)
]
for bbsz_idx in range(bsz * beam_size):
lprobs[bbsz_idx][
torch.tensor(banned_tokens[bbsz_idx], dtype=torch.int64)
] = torch.tensor(-math.inf).to(lprobs)
return lprobs
@staticmethod
def calculate_banned_tokens(
tokens,
step: int,
gen_ngrams: List[Dict[str, List[int]]],
no_repeat_ngram_size: int,
bbsz_idx: int,
):
tokens_list: List[int] = tokens[
bbsz_idx, step + 2 - no_repeat_ngram_size : step + 1
].tolist()
# before decoding the next token, prevent decoding of ngrams that have already appeared
ngram_index = ",".join([str(x) for x in tokens_list])
return gen_ngrams[bbsz_idx].get(ngram_index, torch.jit.annotate(List[int], []))
@staticmethod
def transpose_list(l: List[List[int]]):
# GeneratorExp aren't supported in TS so ignoring the lint
min_len = min([len(x) for x in l]) # noqa
l2 = [[row[i] for row in l] for i in range(min_len)]
return l2