Spaces:
Sleeping
Sleeping
# References: https://github.com/optas/shapeglot | |
# https://github.com/63days/PartGlot. | |
from six.moves import cPickle | |
def unpickle_data(file_name, python2_to_3=False): | |
"""Restore data previously saved with pickle_data(). | |
:param file_name: file holding the pickled data. | |
:param python2_to_3: (boolean), if True, pickle happened under python2x, unpickling under python3x. | |
:return: a generator over the un-pickled items. | |
Note, about implementing the python2_to_3 see | |
https://stackoverflow.com/questions/28218466/unpickling-a-python-2-object-with-python-3 | |
""" | |
in_file = open(file_name, "rb") | |
if python2_to_3: | |
size = cPickle.load(in_file, encoding="latin1") | |
else: | |
size = cPickle.load(in_file) | |
for _ in range(size): | |
if python2_to_3: | |
yield cPickle.load(in_file, encoding="latin1") | |
else: | |
yield cPickle.load(in_file) | |
in_file.close() | |
def get_mask_of_game_data( | |
game_data: DataFrame, | |
word2int: Dict, | |
only_correct: bool, | |
only_easy_context: bool, | |
max_seq_len: int, | |
only_one_part_name: bool, | |
): | |
""" | |
only_correct (if True): mask will be 1 in location iff human listener predicted correctly. | |
only_easy (if True): uses only easy context examples (more dissimilar triplet chairs) | |
max_seq_len: drops examples with len(utterance) > max_seq_len | |
only_one_part_name (if True): uses only utterances describing only one part in the give set. | |
""" | |
mask = np.array(game_data.correct) | |
if not only_correct: | |
mask = np.ones_like(mask, dtype=np.bool) | |
if only_easy_context: | |
context_mask = np.array(game_data.context_condition == "easy", dtype=np.bool) | |
mask = np.logical_and(mask, context_mask) | |
short_mask = np.array( | |
game_data.text.apply(lambda x: len(x)) <= max_seq_len, dtype=np.bool | |
) | |
mask = np.logical_and(mask, short_mask) | |
part_indicator, part_mask = get_part_indicator(game_data.text, word2int) | |
if only_one_part_name: | |
mask = np.logical_and(mask, part_mask) | |
return mask, part_indicator | |