Spaces:
Build error
Build error
""" | |
Taken from ESPNet, modified by Florian Lux | |
""" | |
import os | |
from abc import ABC | |
import torch | |
def cumsum_durations(durations): | |
out = [0] | |
for duration in durations: | |
out.append(duration + out[-1]) | |
centers = list() | |
for index, _ in enumerate(out): | |
if index + 1 < len(out): | |
centers.append((out[index] + out[index + 1]) / 2) | |
return out, centers | |
def delete_old_checkpoints(checkpoint_dir, keep=5): | |
checkpoint_list = list() | |
for el in os.listdir(checkpoint_dir): | |
if el.endswith(".pt") and el != "best.pt": | |
checkpoint_list.append(int(el.split(".")[0].split("_")[1])) | |
if len(checkpoint_list) <= keep: | |
return | |
else: | |
checkpoint_list.sort(reverse=False) | |
checkpoints_to_delete = [os.path.join(checkpoint_dir, "checkpoint_{}.pt".format(step)) for step in checkpoint_list[:-keep]] | |
for old_checkpoint in checkpoints_to_delete: | |
os.remove(os.path.join(old_checkpoint)) | |
def get_most_recent_checkpoint(checkpoint_dir, verbose=True): | |
checkpoint_list = list() | |
for el in os.listdir(checkpoint_dir): | |
if el.endswith(".pt") and el != "best.pt": | |
checkpoint_list.append(int(el.split(".")[0].split("_")[1])) | |
if len(checkpoint_list) == 0: | |
print("No previous checkpoints found, cannot reload.") | |
return None | |
checkpoint_list.sort(reverse=True) | |
if verbose: | |
print("Reloading checkpoint_{}.pt".format(checkpoint_list[0])) | |
return os.path.join(checkpoint_dir, "checkpoint_{}.pt".format(checkpoint_list[0])) | |
def make_pad_mask(lengths, xs=None, length_dim=-1, device=None): | |
""" | |
Make mask tensor containing indices of padded part. | |
Args: | |
lengths (LongTensor or List): Batch of lengths (B,). | |
xs (Tensor, optional): The reference tensor. | |
If set, masks will be the same shape as this tensor. | |
length_dim (int, optional): Dimension indicator of the above tensor. | |
See the example. | |
Returns: | |
Tensor: Mask tensor containing indices of padded part. | |
dtype=torch.uint8 in PyTorch 1.2- | |
dtype=torch.bool in PyTorch 1.2+ (including 1.2) | |
""" | |
if length_dim == 0: | |
raise ValueError("length_dim cannot be 0: {}".format(length_dim)) | |
if not isinstance(lengths, list): | |
lengths = lengths.tolist() | |
bs = int(len(lengths)) | |
if xs is None: | |
maxlen = int(max(lengths)) | |
else: | |
maxlen = xs.size(length_dim) | |
if device is not None: | |
seq_range = torch.arange(0, maxlen, dtype=torch.int64, device=device) | |
else: | |
seq_range = torch.arange(0, maxlen, dtype=torch.int64) | |
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) | |
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) | |
mask = seq_range_expand >= seq_length_expand | |
if xs is not None: | |
assert xs.size(0) == bs, (xs.size(0), bs) | |
if length_dim < 0: | |
length_dim = xs.dim() + length_dim | |
# ind = (:, None, ..., None, :, , None, ..., None) | |
ind = tuple(slice(None) if i in (0, length_dim) else None for i in range(xs.dim())) | |
mask = mask[ind].expand_as(xs).to(xs.device) | |
return mask | |
def make_non_pad_mask(lengths, xs=None, length_dim=-1, device=None): | |
""" | |
Make mask tensor containing indices of non-padded part. | |
Args: | |
lengths (LongTensor or List): Batch of lengths (B,). | |
xs (Tensor, optional): The reference tensor. | |
If set, masks will be the same shape as this tensor. | |
length_dim (int, optional): Dimension indicator of the above tensor. | |
See the example. | |
Returns: | |
ByteTensor: mask tensor containing indices of padded part. | |
dtype=torch.uint8 in PyTorch 1.2- | |
dtype=torch.bool in PyTorch 1.2+ (including 1.2) | |
""" | |
return ~make_pad_mask(lengths, xs, length_dim, device=device) | |
def initialize(model, init): | |
""" | |
Initialize weights of a neural network module. | |
Parameters are initialized using the given method or distribution. | |
Args: | |
model: Target. | |
init: Method of initialization. | |
""" | |
# weight init | |
for p in model.parameters(): | |
if p.dim() > 1: | |
if init == "xavier_uniform": | |
torch.nn.init.xavier_uniform_(p.data) | |
elif init == "xavier_normal": | |
torch.nn.init.xavier_normal_(p.data) | |
elif init == "kaiming_uniform": | |
torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") | |
elif init == "kaiming_normal": | |
torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") | |
else: | |
raise ValueError("Unknown initialization: " + init) | |
# bias init | |
for p in model.parameters(): | |
if p.dim() == 1: | |
p.data.zero_() | |
# reset some modules with default init | |
for m in model.modules(): | |
if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm)): | |
m.reset_parameters() | |
def pad_list(xs, pad_value): | |
""" | |
Perform padding for the list of tensors. | |
Args: | |
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. | |
pad_value (float): Value for padding. | |
Returns: | |
Tensor: Padded tensor (B, Tmax, `*`). | |
""" | |
n_batch = len(xs) | |
max_len = max(x.size(0) for x in xs) | |
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) | |
for i in range(n_batch): | |
pad[i, : xs[i].size(0)] = xs[i] | |
return pad | |
def subsequent_mask(size, device="cpu", dtype=torch.bool): | |
""" | |
Create mask for subsequent steps (size, size). | |
:param int size: size of mask | |
:param str device: "cpu" or "cuda" or torch.Tensor.device | |
:param torch.dtype dtype: result dtype | |
:rtype | |
""" | |
ret = torch.ones(size, size, device=device, dtype=dtype) | |
return torch.tril(ret, out=ret) | |
class ScorerInterface: | |
""" | |
Scorer interface for beam search. | |
The scorer performs scoring of the all tokens in vocabulary. | |
Examples: | |
* Search heuristics | |
* :class:`espnet.nets.scorers.length_bonus.LengthBonus` | |
* Decoder networks of the sequence-to-sequence models | |
* :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder` | |
* :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder` | |
* Neural language models | |
* :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM` | |
* :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM` | |
* :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM` | |
""" | |
def init_state(self, x): | |
""" | |
Get an initial state for decoding (optional). | |
Args: | |
x (torch.Tensor): The encoded feature tensor | |
Returns: initial state | |
""" | |
return None | |
def select_state(self, state, i, new_id=None): | |
""" | |
Select state with relative ids in the main beam search. | |
Args: | |
state: Decoder state for prefix tokens | |
i (int): Index to select a state in the main beam search | |
new_id (int): New label index to select a state if necessary | |
Returns: | |
state: pruned state | |
""" | |
return None if state is None else state[i] | |
def score(self, y, state, x): | |
""" | |
Score new token (required). | |
Args: | |
y (torch.Tensor): 1D torch.int64 prefix tokens. | |
state: Scorer state for prefix tokens | |
x (torch.Tensor): The encoder feature that generates ys. | |
Returns: | |
tuple[torch.Tensor, Any]: Tuple of | |
scores for next token that has a shape of `(n_vocab)` | |
and next state for ys | |
""" | |
raise NotImplementedError | |
def final_score(self, state): | |
""" | |
Score eos (optional). | |
Args: | |
state: Scorer state for prefix tokens | |
Returns: | |
float: final score | |
""" | |
return 0.0 | |
class BatchScorerInterface(ScorerInterface, ABC): | |
def batch_init_state(self, x): | |
""" | |
Get an initial state for decoding (optional). | |
Args: | |
x (torch.Tensor): The encoded feature tensor | |
Returns: initial state | |
""" | |
return self.init_state(x) | |
def batch_score(self, ys, states, xs): | |
""" | |
Score new token batch (required). | |
Args: | |
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). | |
states (List[Any]): Scorer states for prefix tokens. | |
xs (torch.Tensor): | |
The encoder feature that generates ys (n_batch, xlen, n_feat). | |
Returns: | |
tuple[torch.Tensor, List[Any]]: Tuple of | |
batchfied scores for next token with shape of `(n_batch, n_vocab)` | |
and next state list for ys. | |
""" | |
scores = list() | |
outstates = list() | |
for i, (y, state, x) in enumerate(zip(ys, states, xs)): | |
score, outstate = self.score(y, state, x) | |
outstates.append(outstate) | |
scores.append(score) | |
scores = torch.cat(scores, 0).view(ys.shape[0], -1) | |
return scores, outstates | |
def to_device(m, x): | |
"""Send tensor into the device of the module. | |
Args: | |
m (torch.nn.Module): Torch module. | |
x (Tensor): Torch tensor. | |
Returns: | |
Tensor: Torch tensor located in the same place as torch module. | |
""" | |
if isinstance(m, torch.nn.Module): | |
device = next(m.parameters()).device | |
elif isinstance(m, torch.Tensor): | |
device = m.device | |
else: | |
raise TypeError( | |
"Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}" | |
) | |
return x.to(device) | |