import dataclasses import functools import logging import os import pickle import pprint import random from typing import List EMPTY_DATA_PATH = "tangram_pngs/" SPLIT_PATH = "dataset_splits/" @dataclasses.dataclass(frozen=True) class GameConfig: speaker_context: List[str] listener_context: List[str] targets: List[str] def generate_game_config() -> GameConfig: corpus = _get_data() context = random.sample(corpus, 10) num_targets = random.randint(3, 5) targets = random.sample(context, num_targets) listener_order = list(range(10)) random.shuffle(listener_order) config = GameConfig( speaker_context=context, listener_context=[context[i] for i in listener_order], targets=targets, ) logging.info(f"context_dict: {pprint.pformat(dataclasses.asdict(config))}") return config @functools.cache def _get_data(hb_split: bool=True): if not hb_split: # 1013 images paths = os.listdir(EMPTY_DATA_PATH) else: # 912 images with open(os.path.join(SPLIT_PATH, "test_imgs.pkl"), 'rb') as f: paths = pickle.load(f) with open(os.path.join(SPLIT_PATH, "train_imgs.pkl"), 'rb') as f: paths += pickle.load(f) paths = [path + ".png" for path in paths] dup_images = ["page6-51.png", "page6-66.png", "page4-170.png"] paths = [path for path in paths if path != ".DS_Store" and path not in dup_images] return paths