patrickvonplaten's picture
update
fc365d2
raw
history blame
No virus
13.9 kB
from datasets import load_dataset
from collections import Counter, defaultdict
from random import sample, shuffle
from collections import Counter
import datasets
from pandas import DataFrame
from huggingface_hub import list_datasets
import os
import gradio as gr
import secrets
parti_prompt_results = []
ORG = "diffusers-parti-prompts"
SUBMISSIONS = {
"kand2": load_dataset(os.path.join(ORG, "kandinsky-2-2"))["train"],
"sdxl": load_dataset(os.path.join(ORG, "sdxl-1.0-refiner"))["train"],
"wuerst": load_dataset(os.path.join(ORG, "wuerstchen"))["train"],
"karlo": load_dataset(os.path.join(ORG, "karlo-v1"))["train"],
}
LINKS = {
"kand2": "https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder",
"sdxl": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0",
"wuerst": "https://huggingface.co/warp-ai/wuerstchen",
"karlo": "https://huggingface.co/kakaobrain/karlo-v1-alpha",
}
KANDINSKY = """
"## The creative one 🎨!
![img](https://aeiljuispo.cloudimg.io/v7/https://cdn-uploads.huggingface.co/production/uploads/5dfcb1aada6d0311fd3d5448/rETvCyoUD5Mr9wm6OxUhe.png?w=200&h=200&f=face)
\n You mostly resonate with **Kandinsky 2.2** released by AI Forever.
\n Kandinsky 2.2 has a similar architecture to DALLE-2 and works extremely well for artistic, colorful generations.
\n Check out your soulmate [here](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder).
"""
SDXL_RESULT = """
## The powerful one ⚑!
![img](https://huggingface.co/datasets/OpenGenAI/logos/resolve/main/7vmYr2XwVcPtkLzac_jxQ.png)
\n You mostly resonate with **Stable Diffusion XL** released by Stability AI.
\n Stable Diffusion XL consists of a two diffusion models that are chained together, a base model and a refiner model. Together, the system contains roughly 5 billion parameters.
\n It's the latest open-source release of Stable Diffusion and allows to render stunning images of much larger sizes than Stable Diffusion v1.
Try it out [here](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).
"""
WUERSTCHEN = """
## The innovative one βš—οΈ !
![img](https://www.gravatar.com/avatar/3219846609129e84790fb83793998d61?d=retro&size=100)
\n You mostly resonate with **Wuerstchen** released by the WARP team.
\n Wuerstchen is a three stage diffusion model that proposed a very novel, innovative model architecture.
\n Wuerstchen is able to generate very large images (up to 1024x2048) in just a few seconds.
\n The model has an amazing image quality vs. speed trade-off.
\n Check out your new best friend [here](https://huggingface.co/warp-ai/wuerstchen).
"""
KARLO = """
## The precise one 🎯!
![img](https://huggingface.co/datasets/OpenGenAI/logos/resolve/main/1670220967262-615ed619c807b26d117a49bd.png)
\n You mostly resonate with **Karlo** released by KakaoBrain.
\n Karlo is based on the same architecture as DALLE-2 and has been trained on the [well curated COYO dataset](https://huggingface.co/datasets/kakaobrain/coyo-700m).
\n Play around with it [here]("https://huggingface.co/kakaobrain/karlo-v1-alpha").
"""
RESULT = {
"kand2": KANDINSKY,
"wuerst": WUERSTCHEN,
"sdxl": SDXL_RESULT,
"karlo": KARLO,
}
NUM_QUESTIONS = 10
MODEL_KEYS = "-".join(SUBMISSIONS.keys())
SUBMISSION_ORG = f"results-{MODEL_KEYS}"
PROMPT_FORMAT = " Select all pictures that correctly match the prompt and click on 'Submit'. Remember that if no image matches the prompt, no image shall be selected."
submission_names = list(SUBMISSIONS.keys())
num_images = len(SUBMISSIONS[submission_names[0]])
def load_submissions():
all_datasets = list_datasets(author=SUBMISSION_ORG)
relevant_ids = [d.id for d in all_datasets]
submitted_ids = []
for _id in relevant_ids:
ds = load_dataset(_id)["train"]
submitted_ids += ds["id"]
submitted_ids = Counter(submitted_ids)
return submitted_ids
SUBMITTED_IDS = load_submissions()
def generate_random_hash(length=8):
"""
Generates a random hash of specified length.
Args:
length (int): The length of the hash to generate.
Returns:
str: A random hash of specified length.
"""
if length % 2 != 0:
raise ValueError("Length should be an even number.")
num_bytes = length // 2
random_bytes = secrets.token_bytes(num_bytes)
random_hash = secrets.token_hex(num_bytes)
return random_hash
def refresh(row_number, dataframe):
if row_number == NUM_QUESTIONS:
submitted_ids = load_submissions()
return start(submitted_ids)
else:
return dataframe
def start():
ids = {id: 0 for id in range(num_images)}
ids = {**ids, **SUBMITTED_IDS}
# sort by count
ids = sorted(ids.items(), key=lambda x: x[1])
freq_ids = defaultdict(list)
for k, v in ids:
freq_ids[v].append(k)
# shuffle in-between categories
for k, v_list in freq_ids.items():
shuffle(v_list)
freq_ids[k] = v_list
shuffled_ids = sum(list(freq_ids.values()), [])
# get lowest count ids
id_candidates = shuffled_ids[: (10 * NUM_QUESTIONS)]
# get random `NUM_QUESTIONS` ids to check
image_ids = sample(id_candidates, k=NUM_QUESTIONS)
images = {}
for i in range(NUM_QUESTIONS):
order = list(range(len(SUBMISSIONS)))
shuffle(order)
id = image_ids[i]
row = SUBMISSIONS[submission_names[0]][id]
images[i] = {
"prompt": row["Prompt"],
"result": "",
"id": id,
"Challenge": row["Challenge"],
"Category": row["Category"],
"Note": row["Note"],
}
for n, m in enumerate(order):
images[i][f"choice_{n}"] = m
images_frame = DataFrame.from_dict(images, orient="index")
return images_frame
def process(dataframe, row_number=0):
if row_number == NUM_QUESTIONS:
nones = len(RESULT) * [None]
falses = len(RESULT) * [False]
return *nones, *falses, "", ""
image_id = dataframe.iloc[row_number]["id"]
choices = [
submission_names[dataframe.iloc[row_number][f"choice_{i}"]]
for i in range(len(SUBMISSIONS))
]
images = [SUBMISSIONS[c][int(image_id)]["images"] for c in choices]
prompt = SUBMISSIONS[choices[0]][int(image_id)]["Prompt"]
prompt = f'# "{prompt}"'
counter = f"***{row_number + 1}/{NUM_QUESTIONS} {PROMPT_FORMAT}***"
image_buttons = len(images) * [False]
return *images, *image_buttons, prompt, counter
def write_result(user_choice, row_number, dataframe):
if row_number == NUM_QUESTIONS:
return row_number, dataframe
user_choices = []
for i, b in enumerate(str(user_choice)):
if bool(int(b)):
user_choices.append(i)
chosen_models = []
for user_choice in user_choices:
chosen_models.append(submission_names[dataframe.iloc[row_number][f"choice_{user_choice}"]])
print(chosen_models)
dataframe.loc[row_number, "result"] = ",".join(chosen_models)
return row_number + 1, dataframe
def get_index(evt: gr.SelectData) -> int:
return evt.index
def change_view(row_number, dataframe):
if row_number == NUM_QUESTIONS:
results = sum([x.split(",") for x in dataframe["result"].values], [])
results = [r for r in results if len(r) > 0]
favorite_model = Counter(results).most_common(1)[0][0]
dataset = datasets.Dataset.from_pandas(dataframe)
dataset = dataset.remove_columns(set(dataset.column_names) - set(["id", "result"]))
hash = generate_random_hash()
repo_id = os.path.join(SUBMISSION_ORG, hash)
dataset.push_to_hub(repo_id, token=os.getenv("HF_TOKEN"))
return {
intro_view: gr.update(visible=False),
result_view: gr.update(visible=True),
gallery_view: gr.update(visible=False),
start_view: gr.update(visible=True),
result: RESULT[favorite_model],
}
else:
return {
intro_view: gr.update(visible=False),
result_view: gr.update(visible=False),
gallery_view: gr.update(visible=True),
start_view: gr.update(visible=False),
result: "",
}
TITLE = "# What AI model is best for you? πŸ‘©β€βš•οΈ"
DESCRIPTION = """
***How it works*** πŸ“– \n\n
- Upon clicking start, you are shown image descriptions alongside four AI generated images.
\n- Select all images that match the prompt. **Note** that you should leave all images *unchecked* if no image matches the prompt.
\n- Answer **10** questions to find out what AI generator most resonates with you.
\n- Your submissions contribute to [**Open Parti Prompts Leaderboard**](https://huggingface.co/spaces/OpenGenAI/parti-prompts-leaderboard) ❀️.
\n\n
"""
NOTE = """\n\n\n\n
The prompts you are shown originate from the [Parti Prompts](https://huggingface.co/datasets/nateraw/parti-prompts) dataset.
Parti Prompts is designed to test text-to-image AI models on 1600+ prompts of varying difficulty and categories.
The images you are shown have been pre-generated with 4 state-of-the-art open-sourced text-to-image models.
You answers will be used to contribute to the official [**Open Parti Prompts Leaderboard**](https://huggingface.co/spaces/OpenGenAI/parti-prompts-leaderboard).
Every couple months, the generated images will be updated with possibly improved models. The current models and code that was used to generate the images can be verified here:\n
- [kandinsky-2-2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder) \n
- [wuerstchen](https://huggingface.co/warp-ai/wuerstchen) \n
- [sdxl-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) \n
- [karlo](https://huggingface.co/datasets/diffusers-parti-prompts/karlo-v1) \n
"""
GALLERY_COLUMN_NUM = len(SUBMISSIONS)
with gr.Blocks() as demo:
gr.Markdown(TITLE)
with gr.Column(visible=True) as intro_view:
gr.Markdown(DESCRIPTION)
headers = ["prompt", "result", "id", "Challenge", "Category", "Note"] + [
f"choice_{i}" for i in range(len(SUBMISSIONS))
]
datatype = ["str", "str", "number", "str", "str", "str"] + len(SUBMISSIONS) * [
"number"
]
with gr.Column(visible=False):
row_number = gr.Number(
label="Current row selection index",
value=0,
precision=0,
interactive=False,
)
# Create Data Frame
with gr.Column(visible=False) as result_view:
result = gr.Markdown("")
dataframe = gr.Dataframe(
headers=headers,
datatype=datatype,
row_count=NUM_QUESTIONS,
col_count=(6 + len(SUBMISSIONS), "fixed"),
interactive=False,
)
gr.Markdown("Click on start to play again!")
with gr.Column(visible=True) as start_view:
start_button = gr.Button("Start").style(full_width=True)
gr.Markdown(NOTE)
with gr.Column(visible=False):
selected_image = gr.Number(label="Selected index", value=-1, precision=0)
with gr.Column(visible=False) as gallery_view:
with gr.Row():
counter = gr.Markdown(f"***1/{NUM_QUESTIONS} {PROMPT_FORMAT}***")
with gr.Row():
prompt = gr.Markdown("")
with gr.Blocks():
with gr.Row():
with gr.Column() as c1:
image_1 = gr.Image(interactive=False)
image_1_button = gr.Checkbox(False, label="Image 1").style(full_width=True)
with gr.Column() as c2:
image_2 = gr.Image(interactive=False)
image_2_button = gr.Checkbox(False, label="Image 2").style(full_width=True)
with gr.Column() as c3:
image_3 = gr.Image(interactive=False)
image_3_button = gr.Checkbox(False, label="Image 3").style(full_width=True)
with gr.Column() as c4:
image_4 = gr.Image(interactive=False)
image_4_button = gr.Checkbox(False, label="Image 4").style(full_width=True)
with gr.Row():
submit_button = gr.Button("Submit").style(full_width=True)
start_button.click(
fn=start,
inputs=[],
outputs=dataframe,
show_progress=True
).then(
fn=lambda x: 0 if x == NUM_QUESTIONS else x,
inputs=[row_number],
outputs=[row_number],
).then(
fn=change_view,
inputs=[row_number, dataframe],
outputs=[intro_view, result_view, gallery_view, start_view, result],
).then(
fn=process,
inputs=[dataframe],
outputs=[image_1, image_2, image_3, image_4, image_1_button, image_2_button, image_3_button, image_4_button, prompt, counter]
)
def integerize(x1, x2, x3, x4):
number = f"{int(x1)}{int(x2)}{int(x3)}{int(x4)}"
return int(number)
submit_button.click(
fn=integerize,
inputs=[image_1_button, image_2_button, image_3_button, image_4_button],
outputs=[selected_image],
).then(
fn=write_result,
inputs=[selected_image, row_number, dataframe],
outputs=[row_number, dataframe],
).then(
fn=change_view,
inputs=[row_number, dataframe],
outputs=[intro_view, result_view, gallery_view, start_view, result]
).then(
fn=process,
inputs=[dataframe, row_number],
outputs=[image_1, image_2, image_3, image_4, image_1_button, image_2_button, image_3_button, image_4_button, prompt, counter],
).then(
fn=lambda x: 0 if x == NUM_QUESTIONS else x,
inputs=[row_number],
outputs=[row_number],
).then(
fn=refresh,
inputs=[row_number, dataframe],
outputs=[dataframe],
)
demo.launch()