Spaces:
Sleeping
Sleeping
import os | |
from collections import Counter | |
from nltk.tokenize import RegexpTokenizer | |
from source.config import Config | |
class Vocab: | |
""" | |
Offers word2index and index2word functionality after counting words in input sentences. | |
Allows choosing the size of the vocabulary by taking the most common words. Explicitly reserves four indices: | |
<pad>, <sos>, <eos> and <unk>. | |
""" | |
def __init__(self, sentence_splitter=None): | |
""" | |
Args: | |
sentence_splitter: tokenizing function | |
""" | |
self.config = Config() | |
self.counter = Counter() | |
self.word2index = dict() | |
self.index2word = dict() | |
self.size = 0 | |
# predefined tokens | |
self.PADDING_INDEX = 0 | |
self.SOS = 1 | |
self.EOS = 2 | |
self.UNKNOWN_WORD_INDEX = 3 | |
if sentence_splitter is None: | |
# matches sequences of characters including ones between < > | |
word_regex = r'(?:\w+|<\w+>)' | |
# tokenize the string into words | |
sentence_splitter = RegexpTokenizer(word_regex).tokenize | |
self.splitter = sentence_splitter | |
def add_sentence(self, sentence: str): | |
""" | |
Update word counts from sentence after tokenizing it into words | |
""" | |
self.counter.update(self.splitter(sentence)) | |
def build_vocab(self, vocab_size: int, file_name: str): | |
""" Build vocabulary dictionaries word2index and index2word from a text file at config.ROOT path | |
Args: | |
vocab_size (int): size of vocabulary (including 4 predefined tokens: <pad>, <sos>, <eos>, <unk>) | |
file_name (str): name of the text file from which the vocabulary will be built. | |
Note: the lines in file are assumed to be in form: 'word SPACE index' and | |
it asssumes a header line (for example: 'captions.txt') | |
""" | |
filepath = os.path.join(self.config.ROOT, file_name) | |
try: | |
with open(filepath, 'r', encoding='utf-8') as file: | |
for i, line in enumerate(file): | |
# ignore header line | |
if i == 0: | |
continue | |
caption = line.strip().lower().split(",", 1)[1] # id=0, caption=1 | |
self.add_sentence(caption) | |
except Exception as e: | |
print(f"Error processing file {filepath}: {e}") | |
return | |
# adding predefined tokens in the vocabulary | |
self._add_predefined_tokens() | |
words = self.counter.most_common(vocab_size - 4) | |
# (index + 4) because first 4 tokens are the predefined ones | |
for index, (word, _) in enumerate(words, start=4): | |
self.word2index[word] = index | |
self.index2word[index] = word | |
self.size = len(self.word2index) | |
# adding predefined tokens in the vocabulary | |
self.index2word[self.PADDING_INDEX] = '<pad>' | |
self.word2index['<pad>'] = self.PADDING_INDEX | |
self.index2word[self.SOS] = '<sos>' | |
self.word2index['<sos>'] = self.SOS | |
self.index2word[self.EOS] = '<eos>' | |
self.word2index['<eos>'] = self.EOS | |
self.index2word[self.UNKNOWN_WORD_INDEX] = '<unk>' | |
self.word2index['<unk>'] = self.UNKNOWN_WORD_INDEX | |
def word_to_index(self, word: str) -> int: | |
""" Map word to index from word2index dictionary in vocabulary | |
Args: | |
word (str): word to be mapped | |
Returns: | |
int: index matched to the word | |
""" | |
try: | |
return self.word2index[word] | |
except KeyError: | |
return self.UNKNOWN_WORD_INDEX | |
def index_to_word(self, index: int) -> str: | |
""" Map word to index from index2word dictionary in vocabulary | |
Args: | |
word (str): index to be mapped | |
Returns: | |
str: word matched to the index | |
""" | |
try: | |
return self.index2word[index] | |
except KeyError: | |
return self.index2word[self.UNKNOWN_WORD_INDEX] | |
def load_vocab(self, file_name: str): | |
""" Load the word2index and index2word dictionaries from a text file at config.ROOT path | |
Args: | |
file_name (str): name of the text file where the vocabulary is saved (i.e 'word2index.txt') | |
Note: the lines in file are assumed to be in form: 'word SPACE index' and it asssumes a header line | |
""" | |
filepath = os.path.join(self.config.ROOT, file_name) | |
self.word2index = dict() | |
self.index2word = dict() | |
try: | |
with open(filepath) as file: | |
for line in file: | |
line = line.strip().split(' ') | |
word, index = line[0], line[1] | |
self.word2index[word] = int(index) | |
self.index2word[int(index)] = word | |
except Exception as e: | |
print(f"Error loading vocabulary from file {filepath}: {e}") | |