Spaces:
Runtime error
Runtime error
from typing import Iterable, List | |
class Vocab: | |
PAD = "[PAD]" | |
UNK = "[UNK]" | |
def __init__(self, vocab: Iterable[str]) -> None: | |
self.token2idx = { | |
Vocab.PAD: 0, | |
Vocab.UNK: 1, | |
**{token: i for i, token in enumerate(vocab, 2)}, | |
} | |
def pad_id(self) -> int: | |
return self.token2idx[Vocab.PAD] | |
def unk_id(self) -> int: | |
return self.token2idx[Vocab.UNK] | |
def tokens(self) -> List[str]: | |
return list(self.token2idx.keys()) | |
def token_to_id(self, token: str) -> int: | |
return self.token2idx.get(token, self.unk_id) | |
def encode(self, tokens: List[str]) -> List[int]: | |
return [self.token_to_id(token) for token in tokens] | |
def encode_batch( | |
self, batch_tokens: List[List[str]], to_len: int = None | |
) -> List[List[int]]: | |
batch_ids = [self.encode(tokens) for tokens in batch_tokens] | |
to_len = max(len(ids) for ids in batch_ids) if to_len is None else to_len | |
padded_ids = pad_to_len(batch_ids, to_len, self.pad_id) | |
return padded_ids | |
def pad_to_len(seqs: List[List[int]], to_len: int, padding: int) -> List[List[int]]: | |
paddeds = [seq[:to_len] + [padding] * max(0, to_len - len(seq)) for seq in seqs] | |
return paddeds |