patrickvonplaten's picture
uP
e82dda8
raw
history blame
1.38 kB
from datasets import load_dataset
from collections import Counter
from random import choices, shuffle
import os
parti_prompt_results = []
ORG = "diffusers-parti-prompts"
SUBMISSIONS = {
"Stable Diffusion 1-5": load_dataset(os.path.join(ORG, "sd-v1-5"))["train"],
# "Stable Diffusion 2-1": load_dataset(os.path.join(ORG, "sd-v2-1")),
# "IF-1-0": None,
# "Karlo": None,
# "Kadinsky":
}
NUM_QUESTIONS = 25
submission_names = list(SUBMISSIONS.keys())
num_images = len(SUBMISSIONS[submission_names[0]])
def start():
ids = {id: 0 for id in range(num_images)}
# submissions = load_dataset(os.path.join(ORG, "submissions"))
# submitted_ids = Counter(submissions["ids"])
submitted_ids = {}
ids = {**ids, **submitted_ids}
# sort by count
ids = sorted(ids)
# get lowest count ids
id_candidates = ids[:(8 * NUM_QUESTIONS)]
# get random `NUM_QUESTIONS` ids to check
image_ids = choices(id_candidates, k=NUM_QUESTIONS)
images = {}
for i in range(NUM_QUESTIONS):
order = list(range(len(SUBMISSIONS)))
shuffle(order)
id = image_ids[i]
images[i] = {
"prompt": SUBMISSIONS[submission_names[0]][id]["Prompt"],
"id": id,
"choices": {i: SUBMISSIONS[submission_names[i]][id]["images"] for i in order},
}
return images
hello = start()