Spaces:
Runtime error
Runtime error
File size: 5,408 Bytes
c7272f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import pickle
from collections import Counter
class TorchVocab(object):
"""
:property freqs: collections.Counter, コーパス中の単語の出現頻度を保持するオブジェクト
:property stoi: collections.defaultdict, string → id の対応を示す辞書
:property itos: collections.defaultdict, id → string の対応を示す辞書
"""
def __init__(self, counter, max_size=None, min_freq=1, specials=['<pad>', '<oov>'],
vectors=None, unk_init=None, vectors_cache=None):
"""
:param counter: collections.Counter, データ中に含まれる単語の頻度を計測するためのcounter
:param max_size: int, vocabularyの最大のサイズ. Noneの場合は最大値なし. defaultはNone
:param min_freq: int, vocabulary中の単語の最低出現頻度. この数以下の出現回数の単語はvocabularyに加えられない.
:param specials: list of str, vocabularyにあらかじめ登録するtoken
:param vectors: list of vectors, 事前学習済みのベクトル. ex)Vocab.load_vectors
"""
self.freqs = counter
counter = counter.copy()
min_freq = max(min_freq, 1)
self.itos = list(specials)
# special tokensの出現頻度はvocabulary作成の際にカウントされない
for tok in specials:
del counter[tok]
max_size = None if max_size is None else max_size + len(self.itos)
# まず頻度でソートし、次に文字順で並び替える
words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
# 出現頻度がmin_freq未満のものはvocabに加えない
for word, freq in words_and_frequencies:
if freq < min_freq or len(self.itos) == max_size:
break
self.itos.append(word)
# dictのk,vをいれかえてstoiを作成する
self.stoi = {tok: i for i, tok in enumerate(self.itos)}
self.vectors = None
if vectors is not None:
self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
else:
assert unk_init is None and vectors_cache is None
def __eq__(self, other):
if self.freqs != other.freqs:
return False
if self.stoi != other.stoi:
return False
if self.itos != other.itos:
return False
if self.vectors != other.vectors:
return False
return True
def __len__(self):
return len(self.itos)
def vocab_rerank(self):
self.stoi = {word: i for i, word in enumerate(self.itos)}
def extend(self, v, sort=False):
words = sorted(v.itos) if sort else v.itos
for w in words:
if w not in self.stoi:
self.itos.append(w)
self.stoi[w] = len(self.itos) - 1
class Vocab(TorchVocab):
def __init__(self, counter, max_size=None, min_freq=1):
self.pad_index = 0
self.unk_index = 1
self.eos_index = 2
self.sos_index = 3
self.mask_index = 4
super().__init__(counter, specials=["<pad>", "<unk>", "<eos>", "<sos>", "<mask>"], max_size=max_size, min_freq=min_freq)
# override用
def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list:
pass
# override用
def from_seq(self, seq, join=False, with_pad=False):
pass
def load_vocab(vocab_path: str) -> 'Vocab':
with open(vocab_path, "rb") as f:
return pickle.load(f)
def save_vocab(self, vocab_path):
with open(vocab_path, "wb") as f:
pickle.dump(self, f)
# テキストファイルからvocabを作成する
class WordVocab(Vocab):
def __init__(self, texts, max_size=None, min_freq=1):
print("Building Vocab")
counter = Counter()
for line in texts:
if isinstance(line, list):
words = line
else:
words = line.replace("\n", "").replace("\t", "").split()
for word in words:
counter[word] += 1
super().__init__(counter, max_size=max_size, min_freq=min_freq)
def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False):
if isinstance(sentence, str):
sentence = sentence.split()
seq = [self.stoi.get(word, self.unk_index) for word in sentence]
if with_eos:
seq += [self.eos_index] # this would be index 1
if with_sos:
seq = [self.sos_index] + seq
origin_seq_len = len(seq)
if seq_len is None:
pass
elif len(seq) <= seq_len:
seq += [self.pad_index for _ in range(seq_len - len(seq))]
else:
seq = seq[:seq_len]
return (seq, origin_seq_len) if with_len else seq
def from_seq(self, seq, join=False, with_pad=False):
words = [self.itos[idx]
if idx < len(self.itos)
else "<%d>" % idx
for idx in seq
if not with_pad or idx != self.pad_index]
return " ".join(words) if join else words
def load_vocab(vocab_path: str) -> 'WordVocab':
with open(vocab_path, "rb") as f:
return pickle.load(f)
|