cogen / config_generator.py
momergul
Initial commit
8133f69
raw
history blame
5.43 kB
import math
import os
import json
import pickle
import random
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
EMPTY_DATA_PATH = "tangram_pngs/"
CLIP_FOLDER = "clip_similarities"
def generate_complete_game():
# First get corpus and clip model
curr_corpus = get_data()
clip_files = os.listdir(CLIP_FOLDER)
clip_model = {}
for filename in clip_files:
# Get values
with open(os.path.join(CLIP_FOLDER, filename), 'rb') as f:
curr_similarities = pickle.load(f)
# Get keys
tangram_name = '-'.join(filename.split('-')[:2])
clip_model[tangram_name] = curr_similarities
# Next get the pragmatic context
context_dict = get_pragmatic_context(curr_corpus, clip_model)
return context_dict
def get_pragmatic_context(curr_corpus, clip_model):
# Initialize the lists needed for generation
overall_context = []
base_tangrams = []
individual_blocks = []
# Initialize the parameters for generation
block_sizes = evenly_spread_values(10, 3)
for i in range(3):
# Sample the base tangram
base_tangram = sample_similarity_block_base(curr_corpus, clip_model, overall_context)
base_tangrams.append(base_tangram)
# Sample the similarity block
similarity_block = sample_similarity_block(curr_corpus, base_tangram, block_sizes[i], clip_model) # TODO
individual_blocks.append(similarity_block)
overall_context.extend(similarity_block)
# Filter out the corpus
curr_corpus = [tangram for tangram in curr_corpus if tangram not in overall_context]
# Sample the targets at random
targets = random.sample(overall_context, 3)
# Construct the dictionary
speaker_order = list(range(len(overall_context)))
random.shuffle(speaker_order)
speaker_images = [overall_context[i] for i in speaker_order]
listener_order = list(range(len(overall_context)))
random.shuffle(listener_order)
listener_images = [overall_context[i] for i in listener_order]
context_dict = {
"speaker_context" : speaker_images,
"listener_context" : listener_images,
"targets" : targets,
}
return context_dict
def evenly_spread_values(block_size, num_similarity_blocks):
sim_block_sizes = [0 for _ in range(num_similarity_blocks)]
for i in range(block_size):
idx = i % num_similarity_blocks
sim_block_sizes[idx] += 1
return sim_block_sizes
def sample_similarity_block_base(curr_corpus, clip_model, overall_context):
# Get list of candidate tangrams
candidate_base_tangrams = get_candidate_base_tangrams(curr_corpus, clip_model,
overall_context)
base_tangram = random.sample(candidate_base_tangrams, 1)[0]
return base_tangram
def get_candidate_base_tangrams(curr_corpus, clip_model, overall_context):
candidate_base_tangrams = []
for tangram in curr_corpus:
if valid_base_tangram(overall_context, tangram, clip_model):
candidate_base_tangrams.append(tangram)
return candidate_base_tangrams
def valid_base_tangram(overall_context, tangram, clip_model):
for context_tangram in overall_context:
if clip_model[context_tangram[:-4]][tangram[:-4]] > 1:
return False
return True
def sample_similarity_block(curr_corpus, base_tangram, similarity_block_size,
clip_model):
# Get the most similar tangrams to the base tangram
base_similarities = clip_model[base_tangram[:-4]]
sorted_similarities = sorted(base_similarities.items(), reverse=True, key=lambda x: x[1])
sorted_similarities = [sim for sim in sorted_similarities if sim[0] + ".png" in curr_corpus]
# Separate out the tangrams and the scores
sorted_tangrams = [sim[0] + ".png" for sim in sorted_similarities]
sorted_scores = [sim[1] for sim in sorted_similarities]
k = similarity_block_size - 1
distribution = get_similarity_distribution(sorted_scores, 0.055)
sampled_indices = sample_without_replacement(distribution, k)
similarity_block = [base_tangram] + [sorted_tangrams[i] for i in sampled_indices]
return similarity_block
def get_similarity_distribution(scores, temperature):
logits = torch.Tensor([score / temperature for score in scores])
probs = F.softmax(logits, dim=0)
return probs
def sample_without_replacement(distribution, K):
new_distribution = torch.clone(distribution)
samples = []
for i in range(K):
current_sample = torch.multinomial(new_distribution, num_samples=1).item()
samples.append(current_sample)
new_distribution[current_sample] = 0
new_distribution = new_distribution / torch.sum(new_distribution)
return samples
def get_data(restricted_dataset=""):
# Get the list of all paths
if restricted_dataset == "":
paths = os.listdir(EMPTY_DATA_PATH)
else:
with open(restricted_dataset, 'rb') as f:
paths = pickle.load(f)
paths = [path + ".svg" for path in paths]
paths = [path for path in paths if ".DS_Store" not in path]
random.shuffle(paths)
# Remove duplicates
for duplicate in ['page6-51.png', 'page6-66.png', 'page4-170.png']:
if duplicate in paths:
paths.remove(duplicate)
return paths