import pickle from collections import Counter class TorchVocab(object): """ :property freqs: collections.Counter, コーパス中の単語の出現頻度を保持するオブジェクト :property stoi: collections.defaultdict, string → id の対応を示す辞書 :property itos: collections.defaultdict, id → string の対応を示す辞書 """ def __init__(self, counter, max_size=None, min_freq=1, specials=['', ''], vectors=None, unk_init=None, vectors_cache=None): """ :param counter: collections.Counter, データ中に含まれる単語の頻度を計測するためのcounter :param max_size: int, vocabularyの最大のサイズ. Noneの場合は最大値なし. defaultはNone :param min_freq: int, vocabulary中の単語の最低出現頻度. この数以下の出現回数の単語はvocabularyに加えられない. :param specials: list of str, vocabularyにあらかじめ登録するtoken :param vectors: list of vectors, 事前学習済みのベクトル. ex)Vocab.load_vectors """ self.freqs = counter counter = counter.copy() min_freq = max(min_freq, 1) self.itos = list(specials) # special tokensの出現頻度はvocabulary作成の際にカウントされない for tok in specials: del counter[tok] max_size = None if max_size is None else max_size + len(self.itos) # まず頻度でソートし、次に文字順で並び替える words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) # 出現頻度がmin_freq未満のものはvocabに加えない for word, freq in words_and_frequencies: if freq < min_freq or len(self.itos) == max_size: break self.itos.append(word) # dictのk,vをいれかえてstoiを作成する self.stoi = {tok: i for i, tok in enumerate(self.itos)} self.vectors = None if vectors is not None: self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache) else: assert unk_init is None and vectors_cache is None def __eq__(self, other): if self.freqs != other.freqs: return False if self.stoi != other.stoi: return False if self.itos != other.itos: return False if self.vectors != other.vectors: return False return True def __len__(self): return len(self.itos) def vocab_rerank(self): self.stoi = {word: i for i, word in enumerate(self.itos)} def extend(self, v, sort=False): words = sorted(v.itos) if sort else v.itos for w in words: if w not in self.stoi: self.itos.append(w) self.stoi[w] = len(self.itos) - 1 class Vocab(TorchVocab): def __init__(self, counter, max_size=None, min_freq=1): self.pad_index = 0 self.unk_index = 1 self.eos_index = 2 self.sos_index = 3 self.mask_index = 4 super().__init__(counter, specials=["", "", "", "", ""], max_size=max_size, min_freq=min_freq) # override用 def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list: pass # override用 def from_seq(self, seq, join=False, with_pad=False): pass def load_vocab(vocab_path: str) -> 'Vocab': with open(vocab_path, "rb") as f: return pickle.load(f) def save_vocab(self, vocab_path): with open(vocab_path, "wb") as f: pickle.dump(self, f) # テキストファイルからvocabを作成する class WordVocab(Vocab): def __init__(self, texts, max_size=None, min_freq=1): print("Building Vocab") counter = Counter() for line in texts: if isinstance(line, list): words = line else: words = line.replace("\n", "").replace("\t", "").split() for word in words: counter[word] += 1 super().__init__(counter, max_size=max_size, min_freq=min_freq) def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False): if isinstance(sentence, str): sentence = sentence.split() seq = [self.stoi.get(word, self.unk_index) for word in sentence] if with_eos: seq += [self.eos_index] # this would be index 1 if with_sos: seq = [self.sos_index] + seq origin_seq_len = len(seq) if seq_len is None: pass elif len(seq) <= seq_len: seq += [self.pad_index for _ in range(seq_len - len(seq))] else: seq = seq[:seq_len] return (seq, origin_seq_len) if with_len else seq def from_seq(self, seq, join=False, with_pad=False): words = [self.itos[idx] if idx < len(self.itos) else "<%d>" % idx for idx in seq if not with_pad or idx != self.pad_index] return " ".join(words) if join else words def load_vocab(vocab_path: str) -> 'WordVocab': with open(vocab_path, "rb") as f: return pickle.load(f)