LingConv / utils.py
mohdelgaar's picture
add missing file
9f22f23
raw
history blame
498 Bytes
from transformers import T5Tokenizer, T5EncoderModel
import torch
def sequence_mask(lengths, max_len=None):
"""
Creates a boolean mask from sequence lengths.
:param lengths: 1d tensor [batch_size]
:param max_len: int
"""
batch_size = lengths.numel()
max_len = max_len or lengths.max()
return (torch.arange(0, max_len, device=lengths.device)
.type_as(lengths)
.repeat(batch_size, 1)
.lt(lengths.unsqueeze(1))
.long())