File size: 5,427 Bytes
8133f69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import math
import os
import json
import pickle
import random
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
EMPTY_DATA_PATH =  "tangram_pngs/"
CLIP_FOLDER = "clip_similarities"

def generate_complete_game():
    # First get corpus and clip model
    curr_corpus = get_data()

    clip_files = os.listdir(CLIP_FOLDER)
    clip_model = {}
    for filename in clip_files:
        # Get values
        with open(os.path.join(CLIP_FOLDER, filename), 'rb') as f:
            curr_similarities = pickle.load(f)
        
        # Get keys
        tangram_name = '-'.join(filename.split('-')[:2])
        clip_model[tangram_name] = curr_similarities

    # Next get the pragmatic context
    context_dict = get_pragmatic_context(curr_corpus, clip_model)
    return context_dict

def get_pragmatic_context(curr_corpus, clip_model):
    # Initialize the lists needed for generation
    overall_context = []
    base_tangrams = []
    individual_blocks = []

    # Initialize the parameters for generation
    block_sizes = evenly_spread_values(10, 3)
    
    for i in range(3):
        # Sample the base tangram
        base_tangram = sample_similarity_block_base(curr_corpus, clip_model, overall_context)
        base_tangrams.append(base_tangram)

        # Sample the similarity block
        similarity_block = sample_similarity_block(curr_corpus, base_tangram, block_sizes[i], clip_model) # TODO
        individual_blocks.append(similarity_block)
        overall_context.extend(similarity_block)

        # Filter out the corpus
        curr_corpus = [tangram for tangram in curr_corpus if tangram not in overall_context]

    # Sample the targets at random
    targets = random.sample(overall_context, 3)

    # Construct the dictionary
    speaker_order = list(range(len(overall_context)))
    random.shuffle(speaker_order)
    speaker_images = [overall_context[i] for i in speaker_order]

    listener_order = list(range(len(overall_context)))
    random.shuffle(listener_order)
    listener_images = [overall_context[i] for i in listener_order]

    context_dict = {
        "speaker_context" : speaker_images,
        "listener_context" : listener_images,
        "targets" : targets,        
    }

    return context_dict

def evenly_spread_values(block_size, num_similarity_blocks):
    sim_block_sizes = [0 for _ in range(num_similarity_blocks)]
    for i in range(block_size):
        idx = i % num_similarity_blocks
        sim_block_sizes[idx] += 1
    return sim_block_sizes

def sample_similarity_block_base(curr_corpus, clip_model, overall_context):
    # Get list of candidate tangrams
    candidate_base_tangrams = get_candidate_base_tangrams(curr_corpus, clip_model,
                                                          overall_context)

    base_tangram = random.sample(candidate_base_tangrams, 1)[0]
    return base_tangram

def get_candidate_base_tangrams(curr_corpus, clip_model, overall_context):
    candidate_base_tangrams = []
    for tangram in curr_corpus:
        if valid_base_tangram(overall_context, tangram, clip_model):
            candidate_base_tangrams.append(tangram)
    return candidate_base_tangrams

def valid_base_tangram(overall_context, tangram, clip_model):
    for context_tangram in overall_context:
        if clip_model[context_tangram[:-4]][tangram[:-4]] > 1:
            return False
    return True

def sample_similarity_block(curr_corpus, base_tangram, similarity_block_size,
                            clip_model):
    # Get the most similar tangrams to the base tangram
    base_similarities = clip_model[base_tangram[:-4]]
    sorted_similarities = sorted(base_similarities.items(), reverse=True, key=lambda x: x[1])
    sorted_similarities = [sim for sim in sorted_similarities if sim[0] + ".png" in curr_corpus]

    # Separate out the tangrams and the scores
    sorted_tangrams = [sim[0] + ".png" for sim in sorted_similarities]
    sorted_scores = [sim[1] for sim in sorted_similarities]
    k = similarity_block_size - 1

    distribution = get_similarity_distribution(sorted_scores, 0.055)
    sampled_indices = sample_without_replacement(distribution, k)
    similarity_block = [base_tangram] + [sorted_tangrams[i] for i in sampled_indices]
    return similarity_block

def get_similarity_distribution(scores, temperature):
    logits = torch.Tensor([score / temperature for score in scores])
    probs = F.softmax(logits, dim=0)
    return probs

def sample_without_replacement(distribution, K):
    new_distribution = torch.clone(distribution)

    samples = []
    for i in range(K):
        current_sample = torch.multinomial(new_distribution, num_samples=1).item()
        samples.append(current_sample)

        new_distribution[current_sample] = 0
        new_distribution = new_distribution / torch.sum(new_distribution)

    return samples

def get_data(restricted_dataset=""):
    # Get the list of all paths
    if restricted_dataset == "":
        paths = os.listdir(EMPTY_DATA_PATH)
    else:
        with open(restricted_dataset, 'rb') as f:
            paths = pickle.load(f)
        paths = [path + ".svg" for path in paths]
    paths = [path for path in paths if ".DS_Store" not in path]
    random.shuffle(paths)

    # Remove duplicates
    for duplicate in ['page6-51.png', 'page6-66.png', 'page4-170.png']:
        if duplicate in paths:
            paths.remove(duplicate)
    
    return paths