patrickvonplaten's picture
Update app.py
d2b128f
raw
history blame contribute delete
No virus
14 kB
from datasets import load_dataset
from collections import Counter, defaultdict
from random import sample, shuffle
from collections import Counter
import datasets
from pandas import DataFrame
from huggingface_hub import list_datasets
import os
import gradio as gr
import secrets
parti_prompt_results = []
ORG = "diffusers-parti-prompts"
SUBMISSIONS = {
"kand2": load_dataset(os.path.join(ORG, "kandinsky-2-2"))["train"],
"sdxl": load_dataset(os.path.join(ORG, "sdxl-1.0-refiner"))["train"],
"wuerst": load_dataset(os.path.join(ORG, "wuerstchen"))["train"],
"karlo": load_dataset(os.path.join(ORG, "karlo-v1"))["train"],
}
LINKS = {
"kand2": "https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder",
"sdxl": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0",
"wuerst": "https://huggingface.co/warp-ai/wuerstchen",
"karlo": "https://huggingface.co/kakaobrain/karlo-v1-alpha",
}
KANDINSKY = """
"## The creative one 🎨!
![img](https://aeiljuispo.cloudimg.io/v7/https://cdn-uploads.huggingface.co/production/uploads/5dfcb1aada6d0311fd3d5448/rETvCyoUD5Mr9wm6OxUhe.png?w=200&h=200&f=face)
\n You mostly resonate with **Kandinsky 2.2** released by AI Forever.
\n Kandinsky 2.2 has a similar architecture to DALLE-2 and works extremely well for artistic, colorful generations.
\n Check out your soulmate [here](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder).
"""
SDXL_RESULT = """
## The powerful one ⚑!
![img](https://huggingface.co/datasets/OpenGenAI/logos/resolve/main/7vmYr2XwVcPtkLzac_jxQ.png)
\n You mostly resonate with **Stable Diffusion XL** released by Stability AI.
\n Stable Diffusion XL consists of a two diffusion models that are chained together, a base model and a refiner model. Together, the system contains roughly 5 billion parameters.
\n It's the latest open-source release of Stable Diffusion and allows to render stunning images of much larger sizes than Stable Diffusion v1.
Try it out [here](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).
"""
WUERSTCHEN = """
## The innovative one βš—οΈ !
![img](https://www.gravatar.com/avatar/3219846609129e84790fb83793998d61?d=retro&size=100)
\n You mostly resonate with **Wuerstchen** released by the WARP team.
\n Wuerstchen is a three stage diffusion model that proposed a very novel, innovative model architecture.
\n Wuerstchen is able to generate very large images (up to 1024x2048) in just a few seconds.
\n The model has an amazing image quality vs. speed trade-off.
\n Check out your new best friend [here](https://huggingface.co/warp-ai/wuerstchen).
"""
KARLO = """
## The precise one 🎯!
![img](https://huggingface.co/datasets/OpenGenAI/logos/resolve/main/1670220967262-615ed619c807b26d117a49bd.png)
\n You mostly resonate with **Karlo** released by KakaoBrain.
\n Karlo is based on the same architecture as DALLE-2 and has been trained on the [well curated COYO dataset](https://huggingface.co/datasets/kakaobrain/coyo-700m).
\n Play around with it [here]("https://huggingface.co/kakaobrain/karlo-v1-alpha").
"""
RESULT = {
"kand2": KANDINSKY,
"wuerst": WUERSTCHEN,
"sdxl": SDXL_RESULT,
"karlo": KARLO,
}
NUM_QUESTIONS = 10
MODEL_KEYS = "-".join(SUBMISSIONS.keys())
SUBMISSION_ORG = f"result-{MODEL_KEYS}"
PROMPT_FORMAT = " Select the image that best matches the prompt and click on 'Submit'. Remember that if multiple images match the prompt equally well, select them all. If no image matches the prompt, no image shall be selected."
submission_names = list(SUBMISSIONS.keys())
num_images = len(SUBMISSIONS[submission_names[0]])
def load_submissions():
all_datasets = list_datasets(author=SUBMISSION_ORG)
relevant_ids = [d.id for d in all_datasets]
submitted_ids = []
for _id in relevant_ids:
ds = load_dataset(_id)["train"]
submitted_ids += ds["id"]
submitted_ids = Counter(submitted_ids)
return submitted_ids
SUBMITTED_IDS = load_submissions()
def generate_random_hash(length=8):
"""
Generates a random hash of specified length.
Args:
length (int): The length of the hash to generate.
Returns:
str: A random hash of specified length.
"""
if length % 2 != 0:
raise ValueError("Length should be an even number.")
num_bytes = length // 2
random_bytes = secrets.token_bytes(num_bytes)
random_hash = secrets.token_hex(num_bytes)
return random_hash
def refresh(row_number, dataframe):
if row_number == NUM_QUESTIONS:
submitted_ids = load_submissions()
return start(submitted_ids)
else:
return dataframe
def start():
ids = {id: 0 for id in range(num_images)}
ids = {**ids, **SUBMITTED_IDS}
# sort by count
ids = sorted(ids.items(), key=lambda x: x[1])
freq_ids = defaultdict(list)
for k, v in ids:
freq_ids[v].append(k)
# shuffle in-between categories
for k, v_list in freq_ids.items():
shuffle(v_list)
freq_ids[k] = v_list
shuffled_ids = sum(list(freq_ids.values()), [])
# get lowest count ids
id_candidates = shuffled_ids[: (10 * NUM_QUESTIONS)]
# get random `NUM_QUESTIONS` ids to check
image_ids = sample(id_candidates, k=NUM_QUESTIONS)
images = {}
for i in range(NUM_QUESTIONS):
order = list(range(len(SUBMISSIONS)))
shuffle(order)
id = image_ids[i]
row = SUBMISSIONS[submission_names[0]][id]
images[i] = {
"prompt": row["Prompt"],
"result": "",
"id": id,
"Challenge": row["Challenge"],
"Category": row["Category"],
"Note": row["Note"],
}
for n, m in enumerate(order):
images[i][f"choice_{n}"] = m
images_frame = DataFrame.from_dict(images, orient="index")
return images_frame
def process(dataframe, row_number=0):
if row_number == NUM_QUESTIONS:
nones = len(RESULT) * [None]
falses = len(RESULT) * [False]
return *nones, *falses, "", ""
image_id = dataframe.iloc[row_number]["id"]
choices = [
submission_names[dataframe.iloc[row_number][f"choice_{i}"]]
for i in range(len(SUBMISSIONS))
]
images = [SUBMISSIONS[c][int(image_id)]["images"] for c in choices]
prompt = SUBMISSIONS[choices[0]][int(image_id)]["Prompt"]
prompt = f'# "{prompt}"'
counter = f"***{row_number + 1}/{NUM_QUESTIONS} {PROMPT_FORMAT}***"
image_buttons = len(images) * [False]
return *images, *image_buttons, prompt, counter
def write_result(user_choice, row_number, dataframe):
if row_number == NUM_QUESTIONS:
return row_number, dataframe
user_choices = []
for i, b in enumerate(str(user_choice)):
if bool(int(b)):
user_choices.append(i)
chosen_models = []
for user_choice in user_choices:
chosen_models.append(submission_names[dataframe.iloc[row_number][f"choice_{user_choice}"]])
print(chosen_models)
dataframe.loc[row_number, "result"] = ",".join(chosen_models)
return row_number + 1, dataframe
def get_index(evt: gr.SelectData) -> int:
return evt.index
def change_view(row_number, dataframe):
if row_number == NUM_QUESTIONS:
results = sum([x.split(",") for x in dataframe["result"].values], [])
results = [r for r in results if len(r) > 0]
favorite_model = Counter(results).most_common(1)[0][0]
dataset = datasets.Dataset.from_pandas(dataframe)
dataset = dataset.remove_columns(set(dataset.column_names) - set(["id", "result"]))
hash = generate_random_hash()
repo_id = os.path.join(SUBMISSION_ORG, hash)
dataset.push_to_hub(repo_id, token=os.getenv("HF_TOKEN"))
return {
intro_view: gr.update(visible=False),
result_view: gr.update(visible=True),
gallery_view: gr.update(visible=False),
start_view: gr.update(visible=True),
result: RESULT[favorite_model],
}
else:
return {
intro_view: gr.update(visible=False),
result_view: gr.update(visible=False),
gallery_view: gr.update(visible=True),
start_view: gr.update(visible=False),
result: "",
}
TITLE = "# What AI model is best for you? πŸ‘©β€βš•οΈ"
DESCRIPTION = """
***How it works*** πŸ“– \n\n
- Upon clicking start, you are shown image descriptions alongside four AI generated images.
\n- Select the image that best matches the prompt. If multiple images match the prompt equally well, select all images. If no image matches the prompt, leave all images unchecked.
\n- Answer **10** questions to find out what AI generator most resonates with you.
\n- Your submissions contribute to [**Open Parti Prompts Leaderboard**](https://huggingface.co/spaces/OpenGenAI/parti-prompts-leaderboard) ❀️.
\n\n
"""
NOTE = """\n\n\n\n
The prompts you are shown originate from the [Parti Prompts](https://huggingface.co/datasets/nateraw/parti-prompts) dataset.
Parti Prompts is designed to test text-to-image AI models on 1600+ prompts of varying difficulty and categories.
The images you are shown have been pre-generated with 4 state-of-the-art open-sourced text-to-image models.
You answers will be used to contribute to the official [**Open Parti Prompts Leaderboard**](https://huggingface.co/spaces/OpenGenAI/parti-prompts-leaderboard).
Every couple months, the generated images will be updated with possibly improved models. The current models and code that was used to generate the images can be verified here:\n
- [kandinsky-2-2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder) \n
- [wuerstchen](https://huggingface.co/warp-ai/wuerstchen) \n
- [sdxl-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) \n
- [karlo](https://huggingface.co/datasets/diffusers-parti-prompts/karlo-v1) \n
"""
GALLERY_COLUMN_NUM = len(SUBMISSIONS)
with gr.Blocks() as demo:
gr.Markdown(TITLE)
with gr.Column(visible=True) as intro_view:
gr.Markdown(DESCRIPTION)
headers = ["prompt", "result", "id", "Challenge", "Category", "Note"] + [
f"choice_{i}" for i in range(len(SUBMISSIONS))
]
datatype = ["str", "str", "number", "str", "str", "str"] + len(SUBMISSIONS) * [
"number"
]
with gr.Column(visible=False):
row_number = gr.Number(
label="Current row selection index",
value=0,
precision=0,
interactive=False,
)
# Create Data Frame
with gr.Column(visible=False) as result_view:
result = gr.Markdown("")
dataframe = gr.Dataframe(
headers=headers,
datatype=datatype,
row_count=NUM_QUESTIONS,
col_count=(6 + len(SUBMISSIONS), "fixed"),
interactive=False,
)
gr.Markdown("Click on start to play again!")
with gr.Column(visible=True) as start_view:
start_button = gr.Button("Start").style(full_width=True)
gr.Markdown(NOTE)
with gr.Column(visible=False):
selected_image = gr.Textbox(label="Selected indexes")
with gr.Column(visible=False) as gallery_view:
with gr.Row():
counter = gr.Markdown(f"***1/{NUM_QUESTIONS} {PROMPT_FORMAT}***")
with gr.Row():
prompt = gr.Markdown("")
with gr.Blocks():
with gr.Row():
with gr.Column(min_width=200) as c1:
image_1 = gr.Image(interactive=False)
image_1_button = gr.Checkbox(False, label="Image 1").style(full_width=True)
with gr.Column(min_width=200) as c2:
image_2 = gr.Image(interactive=False)
image_2_button = gr.Checkbox(False, label="Image 2").style(full_width=True)
with gr.Column(min_width=200) as c3:
image_3 = gr.Image(interactive=False)
image_3_button = gr.Checkbox(False, label="Image 3").style(full_width=True)
with gr.Column(min_width=200) as c4:
image_4 = gr.Image(interactive=False)
image_4_button = gr.Checkbox(False, label="Image 4").style(full_width=True)
with gr.Row():
submit_button = gr.Button("Submit").style(full_width=True)
start_button.click(
fn=start,
inputs=[],
outputs=dataframe,
show_progress=True
).then(
fn=lambda x: 0 if x == NUM_QUESTIONS else x,
inputs=[row_number],
outputs=[row_number],
).then(
fn=change_view,
inputs=[row_number, dataframe],
outputs=[intro_view, result_view, gallery_view, start_view, result],
).then(
fn=process,
inputs=[dataframe],
outputs=[image_1, image_2, image_3, image_4, image_1_button, image_2_button, image_3_button, image_4_button, prompt, counter]
)
def integerize(x1, x2, x3, x4):
number = f"{int(x1)}{int(x2)}{int(x3)}{int(x4)}"
return number
submit_button.click(
fn=integerize,
inputs=[image_1_button, image_2_button, image_3_button, image_4_button],
outputs=[selected_image],
).then(
fn=write_result,
inputs=[selected_image, row_number, dataframe],
outputs=[row_number, dataframe],
).then(
fn=change_view,
inputs=[row_number, dataframe],
outputs=[intro_view, result_view, gallery_view, start_view, result]
).then(
fn=process,
inputs=[dataframe, row_number],
outputs=[image_1, image_2, image_3, image_4, image_1_button, image_2_button, image_3_button, image_4_button, prompt, counter],
).then(
fn=lambda x: 0 if x == NUM_QUESTIONS else x,
inputs=[row_number],
outputs=[row_number],
).then(
fn=refresh,
inputs=[row_number, dataframe],
outputs=[dataframe],
)
demo.launch()