from datasets import load_dataset from collections import Counter from random import choices, shuffle import os parti_prompt_results = [] ORG = "diffusers-parti-prompts" SUBMISSIONS = { "Stable Diffusion 1-5": load_dataset(os.path.join(ORG, "sd-v1-5"))["train"], # "Stable Diffusion 2-1": load_dataset(os.path.join(ORG, "sd-v2-1")), # "IF-1-0": None, # "Karlo": None, # "Kadinsky": } NUM_QUESTIONS = 25 submission_names = list(SUBMISSIONS.keys()) num_images = len(SUBMISSIONS[submission_names[0]]) def start(): ids = {id: 0 for id in range(num_images)} # submissions = load_dataset(os.path.join(ORG, "submissions")) # submitted_ids = Counter(submissions["ids"]) submitted_ids = {} ids = {**ids, **submitted_ids} # sort by count ids = sorted(ids) # get lowest count ids id_candidates = ids[:(8 * NUM_QUESTIONS)] # get random `NUM_QUESTIONS` ids to check image_ids = choices(id_candidates, k=NUM_QUESTIONS) images = {} for i in range(NUM_QUESTIONS): order = list(range(len(SUBMISSIONS))) shuffle(order) id = image_ids[i] images[i] = { "prompt": SUBMISSIONS[submission_names[0]][id]["Prompt"], "id": id, "choices": {i: SUBMISSIONS[submission_names[i]][id]["images"] for i in order}, } return images hello = start()