ImageCaption / source /vocab.py
nssharmaofficial's picture
Update code
7cb93ae
raw
history blame contribute delete
No virus
2.91 kB
from collections import Counter
from nltk.tokenize import RegexpTokenizer
from source.config import Config
class Vocab:
"""
Offers word2index and index2word functionality after counting words in input sentences.
Allows choosing the size of the vocabulary by taking the most common words. Explicitly reserves four indices:
<pad>, <sos>, <eos> and <unk>.
"""
def __init__(self, sentence_splitter=None):
"""
Args:
sentence_splitter: tokenizing function
"""
self.config = Config()
self.counter = Counter()
self.word2index = dict()
self.index2word = dict()
self.size = 0
# predefined tokens
self.PADDING_INDEX = 0
self.SOS = 1
self.EOS = 2
self.UNKNOWN_WORD_INDEX = 3
if sentence_splitter is None:
# matches sequences of characters including ones between < >
word_regex = r'(?:\w+|<\w+>)'
# tokenize the string into words
sentence_splitter = RegexpTokenizer(word_regex).tokenize
self.splitter = sentence_splitter
def add_sentence(self, sentence: str):
"""
Update word counts from sentence after tokenizing it into words
"""
self.counter.update(self.splitter(sentence))
def word_to_index(self, word: str) -> int:
""" Map word to index from word2index dictionary in vocabulary
Args:
word (str): word to be mapped
Returns:
int: index matched to the word
"""
try:
return self.word2index[word]
except KeyError:
return self.UNKNOWN_WORD_INDEX
def index_to_word(self, index: int) -> str:
""" Map word to index from index2word dictionary in vocabulary
Args:
word (str): index to be mapped
Returns:
str: word matched to the index
"""
try:
return self.index2word[index]
except KeyError:
return self.index2word[self.UNKNOWN_WORD_INDEX]
def load_vocab(self, filepath: str):
""" Load the word2index and index2word dictionaries from a text file.
Args:
file_name (str): name of the text file where the vocabulary is saved (i.e 'word2index.txt')
Note: the lines in file are assumed to be in form: 'word SPACE index' and it asssumes a header line
"""
self.word2index = dict()
self.index2word = dict()
try:
with open(filepath) as file:
for line in file:
line = line.strip().split(' ')
word, index = line[0], line[1]
self.word2index[word] = int(index)
self.index2word[int(index)] = word
except Exception as e:
print(f"Error loading vocabulary from file {filepath}: {e}")