slot_tagging / utils.py
xjlulu's picture
"good run"
fba58f1
raw
history blame
1.31 kB
from typing import Iterable, List
class Vocab:
PAD = "[PAD]"
UNK = "[UNK]"
def __init__(self, vocab: Iterable[str]) -> None:
self.token2idx = {
Vocab.PAD: 0,
Vocab.UNK: 1,
**{token: i for i, token in enumerate(vocab, 2)},
}
@property
def pad_id(self) -> int:
return self.token2idx[Vocab.PAD]
@property
def unk_id(self) -> int:
return self.token2idx[Vocab.UNK]
@property
def tokens(self) -> List[str]:
return list(self.token2idx.keys())
def token_to_id(self, token: str) -> int:
return self.token2idx.get(token, self.unk_id)
def encode(self, tokens: List[str]) -> List[int]:
return [self.token_to_id(token) for token in tokens]
def encode_batch(
self, batch_tokens: List[List[str]], to_len: int = None
) -> List[List[int]]:
batch_ids = [self.encode(tokens) for tokens in batch_tokens]
to_len = max(len(ids) for ids in batch_ids) if to_len is None else to_len
padded_ids = pad_to_len(batch_ids, to_len, self.pad_id)
return padded_ids
def pad_to_len(seqs: List[List[int]], to_len: int, padding: int) -> List[List[int]]:
paddeds = [seq[:to_len] + [padding] * max(0, to_len - len(seq)) for seq in seqs]
return paddeds