|
__author__ = 'Taneem Jan, taneemishere.github.io' |
|
|
|
import sys |
|
import numpy as np |
|
|
|
START_TOKEN = "<START>" |
|
END_TOKEN = "<END>" |
|
PLACEHOLDER = " " |
|
SEPARATOR = '->' |
|
|
|
|
|
class Vocabulary: |
|
def __init__(self): |
|
self.binary_vocabulary = {} |
|
self.vocabulary = {} |
|
self.token_lookup = {} |
|
self.size = 0 |
|
|
|
self.append(START_TOKEN) |
|
self.append(END_TOKEN) |
|
self.append(PLACEHOLDER) |
|
|
|
def append(self, token): |
|
if token not in self.vocabulary: |
|
self.vocabulary[token] = self.size |
|
self.token_lookup[self.size] = token |
|
self.size += 1 |
|
|
|
def create_binary_representation(self): |
|
if sys.version_info >= (3,): |
|
items = self.vocabulary.items() |
|
else: |
|
items = self.vocabulary.iteritems() |
|
for key, value in items: |
|
binary = np.zeros(self.size) |
|
binary[value] = 1 |
|
self.binary_vocabulary[key] = binary |
|
|
|
def get_serialized_binary_representation(self): |
|
if len(self.binary_vocabulary) == 0: |
|
self.create_binary_representation() |
|
|
|
string = "" |
|
if sys.version_info >= (3,): |
|
items = self.binary_vocabulary.items() |
|
else: |
|
items = self.binary_vocabulary.iteritems() |
|
for key, value in items: |
|
array_as_string = np.array2string(value, separator=',', max_line_width=self.size * self.size) |
|
string += "{}{}{}\n".format(key, SEPARATOR, array_as_string[1:len(array_as_string) - 1]) |
|
return string |
|
|
|
def save(self, path): |
|
output_file_name = "{}/words.vocab".format(path) |
|
output_file = open(output_file_name, 'w') |
|
output_file.write(self.get_serialized_binary_representation()) |
|
output_file.close() |
|
|
|
def retrieve(self, path): |
|
input_file = open("{}/words.vocab".format(path), 'r') |
|
buffer = "" |
|
for line in input_file: |
|
try: |
|
separator_position = len(buffer) + line.index(SEPARATOR) |
|
buffer += line |
|
key = buffer[:separator_position] |
|
value = buffer[separator_position + len(SEPARATOR):] |
|
value = np.fromstring(value, sep=',') |
|
|
|
self.binary_vocabulary[key] = value |
|
self.vocabulary[key] = np.where(value == 1)[0][0] |
|
self.token_lookup[np.where(value == 1)[0][0]] = key |
|
|
|
buffer = "" |
|
except ValueError: |
|
buffer += line |
|
input_file.close() |
|
self.size = len(self.vocabulary) |
|
|