import math import os import json import pickle import random import torch import torch.nn.functional as F import torch.nn as nn import numpy as np EMPTY_DATA_PATH = "tangram_pngs/" CLIP_FOLDER = "clip_similarities" def generate_complete_game(): # First get corpus and clip model curr_corpus = get_data() clip_files = os.listdir(CLIP_FOLDER) clip_model = {} for filename in clip_files: # Get values with open(os.path.join(CLIP_FOLDER, filename), 'rb') as f: curr_similarities = pickle.load(f) # Get keys tangram_name = '-'.join(filename.split('-')[:2]) clip_model[tangram_name] = curr_similarities # Next get the pragmatic context context_dict = get_pragmatic_context(curr_corpus, clip_model) return context_dict def get_pragmatic_context(curr_corpus, clip_model): # Initialize the lists needed for generation overall_context = [] base_tangrams = [] individual_blocks = [] # Initialize the parameters for generation block_sizes = evenly_spread_values(10, 3) for i in range(3): # Sample the base tangram base_tangram = sample_similarity_block_base(curr_corpus, clip_model, overall_context) base_tangrams.append(base_tangram) # Sample the similarity block similarity_block = sample_similarity_block(curr_corpus, base_tangram, block_sizes[i], clip_model) # TODO individual_blocks.append(similarity_block) overall_context.extend(similarity_block) # Filter out the corpus curr_corpus = [tangram for tangram in curr_corpus if tangram not in overall_context] # Sample the targets at random targets = random.sample(overall_context, 3) # Construct the dictionary speaker_order = list(range(len(overall_context))) random.shuffle(speaker_order) speaker_images = [overall_context[i] for i in speaker_order] listener_order = list(range(len(overall_context))) random.shuffle(listener_order) listener_images = [overall_context[i] for i in listener_order] context_dict = { "speaker_context" : speaker_images, "listener_context" : listener_images, "targets" : targets, } return context_dict def evenly_spread_values(block_size, num_similarity_blocks): sim_block_sizes = [0 for _ in range(num_similarity_blocks)] for i in range(block_size): idx = i % num_similarity_blocks sim_block_sizes[idx] += 1 return sim_block_sizes def sample_similarity_block_base(curr_corpus, clip_model, overall_context): # Get list of candidate tangrams candidate_base_tangrams = get_candidate_base_tangrams(curr_corpus, clip_model, overall_context) base_tangram = random.sample(candidate_base_tangrams, 1)[0] return base_tangram def get_candidate_base_tangrams(curr_corpus, clip_model, overall_context): candidate_base_tangrams = [] for tangram in curr_corpus: if valid_base_tangram(overall_context, tangram, clip_model): candidate_base_tangrams.append(tangram) return candidate_base_tangrams def valid_base_tangram(overall_context, tangram, clip_model): for context_tangram in overall_context: if clip_model[context_tangram[:-4]][tangram[:-4]] > 1: return False return True def sample_similarity_block(curr_corpus, base_tangram, similarity_block_size, clip_model): # Get the most similar tangrams to the base tangram base_similarities = clip_model[base_tangram[:-4]] sorted_similarities = sorted(base_similarities.items(), reverse=True, key=lambda x: x[1]) sorted_similarities = [sim for sim in sorted_similarities if sim[0] + ".png" in curr_corpus] # Separate out the tangrams and the scores sorted_tangrams = [sim[0] + ".png" for sim in sorted_similarities] sorted_scores = [sim[1] for sim in sorted_similarities] k = similarity_block_size - 1 distribution = get_similarity_distribution(sorted_scores, 0.055) sampled_indices = sample_without_replacement(distribution, k) similarity_block = [base_tangram] + [sorted_tangrams[i] for i in sampled_indices] return similarity_block def get_similarity_distribution(scores, temperature): logits = torch.Tensor([score / temperature for score in scores]) probs = F.softmax(logits, dim=0) return probs def sample_without_replacement(distribution, K): new_distribution = torch.clone(distribution) samples = [] for i in range(K): current_sample = torch.multinomial(new_distribution, num_samples=1).item() samples.append(current_sample) new_distribution[current_sample] = 0 new_distribution = new_distribution / torch.sum(new_distribution) return samples def get_data(restricted_dataset=""): # Get the list of all paths if restricted_dataset == "": paths = os.listdir(EMPTY_DATA_PATH) else: with open(restricted_dataset, 'rb') as f: paths = pickle.load(f) paths = [path + ".svg" for path in paths] paths = [path for path in paths if ".DS_Store" not in path] random.shuffle(paths) # Remove duplicates for duplicate in ['page6-51.png', 'page6-66.png', 'page4-170.png']: if duplicate in paths: paths.remove(duplicate) return paths