Spaces:
Runtime error
Runtime error
import pickle | |
from collections import Counter | |
class TorchVocab(object): | |
""" | |
:property freqs: collections.Counter, コーパス中の単語の出現頻度を保持するオブジェクト | |
:property stoi: collections.defaultdict, string → id の対応を示す辞書 | |
:property itos: collections.defaultdict, id → string の対応を示す辞書 | |
""" | |
def __init__(self, counter, max_size=None, min_freq=1, specials=['<pad>', '<oov>'], | |
vectors=None, unk_init=None, vectors_cache=None): | |
""" | |
:param counter: collections.Counter, データ中に含まれる単語の頻度を計測するためのcounter | |
:param max_size: int, vocabularyの最大のサイズ. Noneの場合は最大値なし. defaultはNone | |
:param min_freq: int, vocabulary中の単語の最低出現頻度. この数以下の出現回数の単語はvocabularyに加えられない. | |
:param specials: list of str, vocabularyにあらかじめ登録するtoken | |
:param vectors: list of vectors, 事前学習済みのベクトル. ex)Vocab.load_vectors | |
""" | |
self.freqs = counter | |
counter = counter.copy() | |
min_freq = max(min_freq, 1) | |
self.itos = list(specials) | |
# special tokensの出現頻度はvocabulary作成の際にカウントされない | |
for tok in specials: | |
del counter[tok] | |
max_size = None if max_size is None else max_size + len(self.itos) | |
# まず頻度でソートし、次に文字順で並び替える | |
words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) | |
words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) | |
# 出現頻度がmin_freq未満のものはvocabに加えない | |
for word, freq in words_and_frequencies: | |
if freq < min_freq or len(self.itos) == max_size: | |
break | |
self.itos.append(word) | |
# dictのk,vをいれかえてstoiを作成する | |
self.stoi = {tok: i for i, tok in enumerate(self.itos)} | |
self.vectors = None | |
if vectors is not None: | |
self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache) | |
else: | |
assert unk_init is None and vectors_cache is None | |
def __eq__(self, other): | |
if self.freqs != other.freqs: | |
return False | |
if self.stoi != other.stoi: | |
return False | |
if self.itos != other.itos: | |
return False | |
if self.vectors != other.vectors: | |
return False | |
return True | |
def __len__(self): | |
return len(self.itos) | |
def vocab_rerank(self): | |
self.stoi = {word: i for i, word in enumerate(self.itos)} | |
def extend(self, v, sort=False): | |
words = sorted(v.itos) if sort else v.itos | |
for w in words: | |
if w not in self.stoi: | |
self.itos.append(w) | |
self.stoi[w] = len(self.itos) - 1 | |
class Vocab(TorchVocab): | |
def __init__(self, counter, max_size=None, min_freq=1): | |
self.pad_index = 0 | |
self.unk_index = 1 | |
self.eos_index = 2 | |
self.sos_index = 3 | |
self.mask_index = 4 | |
super().__init__(counter, specials=["<pad>", "<unk>", "<eos>", "<sos>", "<mask>"], max_size=max_size, min_freq=min_freq) | |
# override用 | |
def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list: | |
pass | |
# override用 | |
def from_seq(self, seq, join=False, with_pad=False): | |
pass | |
def load_vocab(vocab_path: str) -> 'Vocab': | |
with open(vocab_path, "rb") as f: | |
return pickle.load(f) | |
def save_vocab(self, vocab_path): | |
with open(vocab_path, "wb") as f: | |
pickle.dump(self, f) | |
# テキストファイルからvocabを作成する | |
class WordVocab(Vocab): | |
def __init__(self, texts, max_size=None, min_freq=1): | |
print("Building Vocab") | |
counter = Counter() | |
for line in texts: | |
if isinstance(line, list): | |
words = line | |
else: | |
words = line.replace("\n", "").replace("\t", "").split() | |
for word in words: | |
counter[word] += 1 | |
super().__init__(counter, max_size=max_size, min_freq=min_freq) | |
def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False): | |
if isinstance(sentence, str): | |
sentence = sentence.split() | |
seq = [self.stoi.get(word, self.unk_index) for word in sentence] | |
if with_eos: | |
seq += [self.eos_index] # this would be index 1 | |
if with_sos: | |
seq = [self.sos_index] + seq | |
origin_seq_len = len(seq) | |
if seq_len is None: | |
pass | |
elif len(seq) <= seq_len: | |
seq += [self.pad_index for _ in range(seq_len - len(seq))] | |
else: | |
seq = seq[:seq_len] | |
return (seq, origin_seq_len) if with_len else seq | |
def from_seq(self, seq, join=False, with_pad=False): | |
words = [self.itos[idx] | |
if idx < len(self.itos) | |
else "<%d>" % idx | |
for idx in seq | |
if not with_pad or idx != self.pad_index] | |
return " ".join(words) if join else words | |
def load_vocab(vocab_path: str) -> 'WordVocab': | |
with open(vocab_path, "rb") as f: | |
return pickle.load(f) | |