from transformers import T5Tokenizer, T5EncoderModel import torch def sequence_mask(lengths, max_len=None): """ Creates a boolean mask from sequence lengths. :param lengths: 1d tensor [batch_size] :param max_len: int """ batch_size = lengths.numel() max_len = max_len or lengths.max() return (torch.arange(0, max_len, device=lengths.device) .type_as(lengths) .repeat(batch_size, 1) .lt(lengths.unsqueeze(1)) .long())