Spaces:
Running
on
Zero
Running
on
Zero
# This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
# ## Citations | |
# ```bibtex | |
# @inproceedings{yao2021wenet, | |
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
# booktitle={Proc. Interspeech}, | |
# year={2021}, | |
# address={Brno, Czech Republic }, | |
# organization={IEEE} | |
# } | |
# @article{zhang2022wenet, | |
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
# journal={arXiv preprint arXiv:2203.15455}, | |
# year={2022} | |
# } | |
# | |
"""Unility functions for Transformer.""" | |
import math | |
from typing import List, Tuple | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
IGNORE_ID = -1 | |
def pad_list(xs: List[torch.Tensor], pad_value: int): | |
"""Perform padding for the list of tensors. | |
Args: | |
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. | |
pad_value (float): Value for padding. | |
Returns: | |
Tensor: Padded tensor (B, Tmax, `*`). | |
Examples: | |
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] | |
>>> x | |
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] | |
>>> pad_list(x, 0) | |
tensor([[1., 1., 1., 1.], | |
[1., 1., 0., 0.], | |
[1., 0., 0., 0.]]) | |
""" | |
n_batch = len(xs) | |
max_len = max([x.size(0) for x in xs]) | |
pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device) | |
pad = pad.fill_(pad_value) | |
for i in range(n_batch): | |
pad[i, : xs[i].size(0)] = xs[i] | |
return pad | |
def add_blank(ys_pad: torch.Tensor, blank: int, ignore_id: int) -> torch.Tensor: | |
"""Prepad blank for transducer predictor | |
Args: | |
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) | |
blank (int): index of <blank> | |
Returns: | |
ys_in (torch.Tensor) : (B, Lmax + 1) | |
Examples: | |
>>> blank = 0 | |
>>> ignore_id = -1 | |
>>> ys_pad | |
tensor([[ 1, 2, 3, 4, 5], | |
[ 4, 5, 6, -1, -1], | |
[ 7, 8, 9, -1, -1]], dtype=torch.int32) | |
>>> ys_in = add_blank(ys_pad, 0, -1) | |
>>> ys_in | |
tensor([[0, 1, 2, 3, 4, 5], | |
[0, 4, 5, 6, 0, 0], | |
[0, 7, 8, 9, 0, 0]]) | |
""" | |
bs = ys_pad.size(0) | |
_blank = torch.tensor( | |
[blank], dtype=torch.long, requires_grad=False, device=ys_pad.device | |
) | |
_blank = _blank.repeat(bs).unsqueeze(1) # [bs,1] | |
out = torch.cat([_blank, ys_pad], dim=1) # [bs, Lmax+1] | |
return torch.where(out == ignore_id, blank, out) | |
def add_sos_eos( | |
ys_pad: torch.Tensor, sos: int, eos: int, ignore_id: int | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Add <sos> and <eos> labels. | |
Args: | |
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) | |
sos (int): index of <sos> | |
eos (int): index of <eeos> | |
ignore_id (int): index of padding | |
Returns: | |
ys_in (torch.Tensor) : (B, Lmax + 1) | |
ys_out (torch.Tensor) : (B, Lmax + 1) | |
Examples: | |
>>> sos_id = 10 | |
>>> eos_id = 11 | |
>>> ignore_id = -1 | |
>>> ys_pad | |
tensor([[ 1, 2, 3, 4, 5], | |
[ 4, 5, 6, -1, -1], | |
[ 7, 8, 9, -1, -1]], dtype=torch.int32) | |
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id) | |
>>> ys_in | |
tensor([[10, 1, 2, 3, 4, 5], | |
[10, 4, 5, 6, 11, 11], | |
[10, 7, 8, 9, 11, 11]]) | |
>>> ys_out | |
tensor([[ 1, 2, 3, 4, 5, 11], | |
[ 4, 5, 6, 11, -1, -1], | |
[ 7, 8, 9, 11, -1, -1]]) | |
""" | |
_sos = torch.tensor( | |
[sos], dtype=torch.long, requires_grad=False, device=ys_pad.device | |
) | |
_eos = torch.tensor( | |
[eos], dtype=torch.long, requires_grad=False, device=ys_pad.device | |
) | |
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys | |
ys_in = [torch.cat([_sos, y], dim=0) for y in ys] | |
ys_out = [torch.cat([y, _eos], dim=0) for y in ys] | |
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) | |
def reverse_pad_list( | |
ys_pad: torch.Tensor, ys_lens: torch.Tensor, pad_value: float = -1.0 | |
) -> torch.Tensor: | |
"""Reverse padding for the list of tensors. | |
Args: | |
ys_pad (tensor): The padded tensor (B, Tokenmax). | |
ys_lens (tensor): The lens of token seqs (B) | |
pad_value (int): Value for padding. | |
Returns: | |
Tensor: Padded tensor (B, Tokenmax). | |
Examples: | |
>>> x | |
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]]) | |
>>> pad_list(x, 0) | |
tensor([[4, 3, 2, 1], | |
[7, 6, 5, 0], | |
[9, 8, 0, 0]]) | |
""" | |
r_ys_pad = pad_sequence( | |
[(torch.flip(y.int()[:i], [0])) for y, i in zip(ys_pad, ys_lens)], | |
True, | |
pad_value, | |
) | |
return r_ys_pad | |
def th_accuracy( | |
pad_outputs: torch.Tensor, pad_targets: torch.Tensor, ignore_label: int | |
) -> float: | |
"""Calculate accuracy. | |
Args: | |
pad_outputs (Tensor): Prediction tensors (B * Lmax, D). | |
pad_targets (LongTensor): Target label tensors (B, Lmax). | |
ignore_label (int): Ignore label id. | |
Returns: | |
float: Accuracy value (0.0 - 1.0). | |
""" | |
pad_pred = pad_outputs.view( | |
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1) | |
).argmax(2) | |
mask = pad_targets != ignore_label | |
numerator = torch.sum( | |
pad_pred.masked_select(mask) == pad_targets.masked_select(mask) | |
) | |
denominator = torch.sum(mask) | |
return float(numerator) / float(denominator) | |
def get_rnn(rnn_type: str) -> torch.nn.Module: | |
assert rnn_type in ["rnn", "lstm", "gru"] | |
if rnn_type == "rnn": | |
return torch.nn.RNN | |
elif rnn_type == "lstm": | |
return torch.nn.LSTM | |
else: | |
return torch.nn.GRU | |
def get_activation(act): | |
"""Return activation function.""" | |
# Lazy load to avoid unused import | |
from modules.wenet_extractor.transformer.swish import Swish | |
activation_funcs = { | |
"hardtanh": torch.nn.Hardtanh, | |
"tanh": torch.nn.Tanh, | |
"relu": torch.nn.ReLU, | |
"selu": torch.nn.SELU, | |
"swish": getattr(torch.nn, "SiLU", Swish), | |
"gelu": torch.nn.GELU, | |
} | |
return activation_funcs[act]() | |
def get_subsample(config): | |
input_layer = config["encoder_conf"]["input_layer"] | |
assert input_layer in ["conv2d", "conv2d6", "conv2d8"] | |
if input_layer == "conv2d": | |
return 4 | |
elif input_layer == "conv2d6": | |
return 6 | |
elif input_layer == "conv2d8": | |
return 8 | |
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: | |
new_hyp: List[int] = [] | |
cur = 0 | |
while cur < len(hyp): | |
if hyp[cur] != 0: | |
new_hyp.append(hyp[cur]) | |
prev = cur | |
while cur < len(hyp) and hyp[cur] == hyp[prev]: | |
cur += 1 | |
return new_hyp | |
def replace_duplicates_with_blank(hyp: List[int]) -> List[int]: | |
new_hyp: List[int] = [] | |
cur = 0 | |
while cur < len(hyp): | |
new_hyp.append(hyp[cur]) | |
prev = cur | |
cur += 1 | |
while cur < len(hyp) and hyp[cur] == hyp[prev] and hyp[cur] != 0: | |
new_hyp.append(0) | |
cur += 1 | |
return new_hyp | |
def log_add(args: List[int]) -> float: | |
""" | |
Stable log add | |
""" | |
if all(a == -float("inf") for a in args): | |
return -float("inf") | |
a_max = max(args) | |
lsp = math.log(sum(math.exp(a - a_max) for a in args)) | |
return a_max + lsp | |