Spaces:
Sleeping
Sleeping
File size: 498 Bytes
9f22f23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
from transformers import T5Tokenizer, T5EncoderModel
import torch
def sequence_mask(lengths, max_len=None):
"""
Creates a boolean mask from sequence lengths.
:param lengths: 1d tensor [batch_size]
:param max_len: int
"""
batch_size = lengths.numel()
max_len = max_len or lengths.max()
return (torch.arange(0, max_len, device=lengths.device)
.type_as(lengths)
.repeat(batch_size, 1)
.lt(lengths.unsqueeze(1))
.long())
|