ML6-UniKP / build_vocab.py
Topallaj Denis
copied the unikp model into this endpoint
c7272f2
raw
history blame contribute delete
No virus
5.41 kB
import pickle
from collections import Counter
class TorchVocab(object):
"""
:property freqs: collections.Counter, コーパス中の単語の出現頻度を保持するオブジェクト
:property stoi: collections.defaultdict, string → id の対応を示す辞書
:property itos: collections.defaultdict, id → string の対応を示す辞書
"""
def __init__(self, counter, max_size=None, min_freq=1, specials=['<pad>', '<oov>'],
vectors=None, unk_init=None, vectors_cache=None):
"""
:param counter: collections.Counter, データ中に含まれる単語の頻度を計測するためのcounter
:param max_size: int, vocabularyの最大のサイズ. Noneの場合は最大値なし. defaultはNone
:param min_freq: int, vocabulary中の単語の最低出現頻度. この数以下の出現回数の単語はvocabularyに加えられない.
:param specials: list of str, vocabularyにあらかじめ登録するtoken
:param vectors: list of vectors, 事前学習済みのベクトル. ex)Vocab.load_vectors
"""
self.freqs = counter
counter = counter.copy()
min_freq = max(min_freq, 1)
self.itos = list(specials)
# special tokensの出現頻度はvocabulary作成の際にカウントされない
for tok in specials:
del counter[tok]
max_size = None if max_size is None else max_size + len(self.itos)
# まず頻度でソートし、次に文字順で並び替える
words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
# 出現頻度がmin_freq未満のものはvocabに加えない
for word, freq in words_and_frequencies:
if freq < min_freq or len(self.itos) == max_size:
break
self.itos.append(word)
# dictのk,vをいれかえてstoiを作成する
self.stoi = {tok: i for i, tok in enumerate(self.itos)}
self.vectors = None
if vectors is not None:
self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
else:
assert unk_init is None and vectors_cache is None
def __eq__(self, other):
if self.freqs != other.freqs:
return False
if self.stoi != other.stoi:
return False
if self.itos != other.itos:
return False
if self.vectors != other.vectors:
return False
return True
def __len__(self):
return len(self.itos)
def vocab_rerank(self):
self.stoi = {word: i for i, word in enumerate(self.itos)}
def extend(self, v, sort=False):
words = sorted(v.itos) if sort else v.itos
for w in words:
if w not in self.stoi:
self.itos.append(w)
self.stoi[w] = len(self.itos) - 1
class Vocab(TorchVocab):
def __init__(self, counter, max_size=None, min_freq=1):
self.pad_index = 0
self.unk_index = 1
self.eos_index = 2
self.sos_index = 3
self.mask_index = 4
super().__init__(counter, specials=["<pad>", "<unk>", "<eos>", "<sos>", "<mask>"], max_size=max_size, min_freq=min_freq)
# override用
def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list:
pass
# override用
def from_seq(self, seq, join=False, with_pad=False):
pass
def load_vocab(vocab_path: str) -> 'Vocab':
with open(vocab_path, "rb") as f:
return pickle.load(f)
def save_vocab(self, vocab_path):
with open(vocab_path, "wb") as f:
pickle.dump(self, f)
# テキストファイルからvocabを作成する
class WordVocab(Vocab):
def __init__(self, texts, max_size=None, min_freq=1):
print("Building Vocab")
counter = Counter()
for line in texts:
if isinstance(line, list):
words = line
else:
words = line.replace("\n", "").replace("\t", "").split()
for word in words:
counter[word] += 1
super().__init__(counter, max_size=max_size, min_freq=min_freq)
def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False):
if isinstance(sentence, str):
sentence = sentence.split()
seq = [self.stoi.get(word, self.unk_index) for word in sentence]
if with_eos:
seq += [self.eos_index] # this would be index 1
if with_sos:
seq = [self.sos_index] + seq
origin_seq_len = len(seq)
if seq_len is None:
pass
elif len(seq) <= seq_len:
seq += [self.pad_index for _ in range(seq_len - len(seq))]
else:
seq = seq[:seq_len]
return (seq, origin_seq_len) if with_len else seq
def from_seq(self, seq, join=False, with_pad=False):
words = [self.itos[idx]
if idx < len(self.itos)
else "<%d>" % idx
for idx in seq
if not with_pad or idx != self.pad_index]
return " ".join(words) if join else words
def load_vocab(vocab_path: str) -> 'WordVocab':
with open(vocab_path, "rb") as f:
return pickle.load(f)