respect / config_generator.py
chenzizhao's picture
cosmetics
87b7a45
import dataclasses
import functools
import logging
import os
import pickle
import pprint
import random
from typing import List
EMPTY_DATA_PATH = "tangram_pngs/"
SPLIT_PATH = "dataset_splits/"
@dataclasses.dataclass(frozen=True)
class GameConfig:
speaker_context: List[str]
listener_context: List[str]
targets: List[str]
def generate_game_config() -> GameConfig:
corpus = _get_data()
context = random.sample(corpus, 10)
num_targets = random.randint(3, 5)
targets = random.sample(context, num_targets)
listener_order = list(range(10))
random.shuffle(listener_order)
config = GameConfig(
speaker_context=context,
listener_context=[context[i] for i in listener_order],
targets=targets,
)
logging.info(f"context_dict: {pprint.pformat(dataclasses.asdict(config))}")
return config
@functools.cache
def _get_data(hb_split: bool=True):
if not hb_split:
# 1013 images
paths = os.listdir(EMPTY_DATA_PATH)
else:
# 912 images
with open(os.path.join(SPLIT_PATH, "test_imgs.pkl"), 'rb') as f:
paths = pickle.load(f)
with open(os.path.join(SPLIT_PATH, "train_imgs.pkl"), 'rb') as f:
paths += pickle.load(f)
paths = [path + ".png" for path in paths]
dup_images = ["page6-51.png", "page6-66.png", "page4-170.png"]
paths = [path for path in paths if path != ".DS_Store" and path not in dup_images]
return paths