|
from datasets import load_dataset |
|
from collections import Counter, defaultdict |
|
from random import sample, shuffle |
|
from collections import Counter |
|
import datasets |
|
from pandas import DataFrame |
|
from huggingface_hub import list_datasets |
|
import os |
|
import gradio as gr |
|
|
|
import secrets |
|
|
|
|
|
parti_prompt_results = [] |
|
ORG = "diffusers-parti-prompts" |
|
SUBMISSIONS = { |
|
"kand2": load_dataset(os.path.join(ORG, "kandinsky-2-2"))["train"], |
|
"sdxl": load_dataset(os.path.join(ORG, "sdxl-1.0-refiner"))["train"], |
|
"wuerst": load_dataset(os.path.join(ORG, "wuerstchen"))["train"], |
|
"karlo": load_dataset(os.path.join(ORG, "karlo-v1"))["train"], |
|
} |
|
|
|
LINKS = { |
|
"kand2": "https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder", |
|
"sdxl": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0", |
|
"wuerst": "https://huggingface.co/warp-ai/wuerstchen", |
|
"karlo": "https://huggingface.co/kakaobrain/karlo-v1-alpha", |
|
} |
|
KANDINSKY = """ |
|
"## The creative one π¨! |
|
![img](https://aeiljuispo.cloudimg.io/v7/https://cdn-uploads.huggingface.co/production/uploads/5dfcb1aada6d0311fd3d5448/rETvCyoUD5Mr9wm6OxUhe.png?w=200&h=200&f=face) |
|
\n You mostly resonate with **Kandinsky 2.2** released by AI Forever. |
|
\n Kandinsky 2.2 has a similar architecture to DALLE-2 and works extremely well for artistic, colorful generations. |
|
\n Check out your soulmate [here](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder). |
|
""" |
|
SDXL_RESULT = """ |
|
## The powerful one β‘! |
|
![img](https://huggingface.co/datasets/OpenGenAI/logos/resolve/main/7vmYr2XwVcPtkLzac_jxQ.png) |
|
\n You mostly resonate with **Stable Diffusion XL** released by Stability AI. |
|
\n Stable Diffusion XL consists of a two diffusion models that are chained together, a base model and a refiner model. Together, the system contains roughly 5 billion parameters. |
|
\n It's the latest open-source release of Stable Diffusion and allows to render stunning images of much larger sizes than Stable Diffusion v1. |
|
Try it out [here](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0). |
|
""" |
|
WUERSTCHEN = """ |
|
## The innovative one βοΈ ! |
|
![img](https://www.gravatar.com/avatar/3219846609129e84790fb83793998d61?d=retro&size=100) |
|
\n You mostly resonate with **Wuerstchen** released by the WARP team. |
|
\n Wuerstchen is a three stage diffusion model that proposed a very novel, innovative model architecture. |
|
\n Wuerstchen is able to generate very large images (up to 1024x2048) in just a few seconds. |
|
\n The model has an amazing image quality vs. speed trade-off. |
|
\n Check out your new best friend [here](https://huggingface.co/warp-ai/wuerstchen). |
|
""" |
|
KARLO = """ |
|
## The precise one π―! |
|
![img](https://huggingface.co/datasets/OpenGenAI/logos/resolve/main/1670220967262-615ed619c807b26d117a49bd.png) |
|
\n You mostly resonate with **Karlo** released by KakaoBrain. |
|
\n Karlo is based on the same architecture as DALLE-2 and has been trained on the [well curated COYO dataset](https://huggingface.co/datasets/kakaobrain/coyo-700m). |
|
\n Play around with it [here]("https://huggingface.co/kakaobrain/karlo-v1-alpha"). |
|
""" |
|
|
|
RESULT = { |
|
"kand2": KANDINSKY, |
|
"wuerst": WUERSTCHEN, |
|
"sdxl": SDXL_RESULT, |
|
"karlo": KARLO, |
|
} |
|
NUM_QUESTIONS = 10 |
|
MODEL_KEYS = "-".join(SUBMISSIONS.keys()) |
|
SUBMISSION_ORG = f"results-{MODEL_KEYS}" |
|
PROMPT_FORMAT = " Select all pictures that correctly match the prompt and click on 'Submit'. Remember that if no image matches the prompt, no image shall be selected." |
|
|
|
submission_names = list(SUBMISSIONS.keys()) |
|
num_images = len(SUBMISSIONS[submission_names[0]]) |
|
|
|
|
|
def load_submissions(): |
|
all_datasets = list_datasets(author=SUBMISSION_ORG) |
|
relevant_ids = [d.id for d in all_datasets] |
|
|
|
submitted_ids = [] |
|
for _id in relevant_ids: |
|
ds = load_dataset(_id)["train"] |
|
submitted_ids += ds["id"] |
|
|
|
submitted_ids = Counter(submitted_ids) |
|
return submitted_ids |
|
|
|
|
|
SUBMITTED_IDS = load_submissions() |
|
|
|
|
|
def generate_random_hash(length=8): |
|
""" |
|
Generates a random hash of specified length. |
|
|
|
Args: |
|
length (int): The length of the hash to generate. |
|
|
|
Returns: |
|
str: A random hash of specified length. |
|
""" |
|
if length % 2 != 0: |
|
raise ValueError("Length should be an even number.") |
|
|
|
num_bytes = length // 2 |
|
random_bytes = secrets.token_bytes(num_bytes) |
|
random_hash = secrets.token_hex(num_bytes) |
|
|
|
return random_hash |
|
|
|
|
|
def refresh(row_number, dataframe): |
|
if row_number == NUM_QUESTIONS: |
|
submitted_ids = load_submissions() |
|
return start(submitted_ids) |
|
else: |
|
return dataframe |
|
|
|
def start(): |
|
ids = {id: 0 for id in range(num_images)} |
|
ids = {**ids, **SUBMITTED_IDS} |
|
|
|
|
|
ids = sorted(ids.items(), key=lambda x: x[1]) |
|
freq_ids = defaultdict(list) |
|
for k, v in ids: |
|
freq_ids[v].append(k) |
|
|
|
|
|
for k, v_list in freq_ids.items(): |
|
shuffle(v_list) |
|
freq_ids[k] = v_list |
|
|
|
shuffled_ids = sum(list(freq_ids.values()), []) |
|
|
|
|
|
id_candidates = shuffled_ids[: (10 * NUM_QUESTIONS)] |
|
|
|
|
|
image_ids = sample(id_candidates, k=NUM_QUESTIONS) |
|
images = {} |
|
|
|
for i in range(NUM_QUESTIONS): |
|
order = list(range(len(SUBMISSIONS))) |
|
shuffle(order) |
|
|
|
id = image_ids[i] |
|
row = SUBMISSIONS[submission_names[0]][id] |
|
images[i] = { |
|
"prompt": row["Prompt"], |
|
"result": "", |
|
"id": id, |
|
"Challenge": row["Challenge"], |
|
"Category": row["Category"], |
|
"Note": row["Note"], |
|
} |
|
for n, m in enumerate(order): |
|
images[i][f"choice_{n}"] = m |
|
|
|
images_frame = DataFrame.from_dict(images, orient="index") |
|
return images_frame |
|
|
|
|
|
def process(dataframe, row_number=0): |
|
if row_number == NUM_QUESTIONS: |
|
nones = len(RESULT) * [None] |
|
falses = len(RESULT) * [False] |
|
return *nones, *falses, "", "" |
|
|
|
image_id = dataframe.iloc[row_number]["id"] |
|
choices = [ |
|
submission_names[dataframe.iloc[row_number][f"choice_{i}"]] |
|
for i in range(len(SUBMISSIONS)) |
|
] |
|
images = [SUBMISSIONS[c][int(image_id)]["images"] for c in choices] |
|
|
|
prompt = SUBMISSIONS[choices[0]][int(image_id)]["Prompt"] |
|
prompt = f'# "{prompt}"' |
|
counter = f"***{row_number + 1}/{NUM_QUESTIONS} {PROMPT_FORMAT}***" |
|
image_buttons = len(images) * [False] |
|
|
|
return *images, *image_buttons, prompt, counter |
|
|
|
|
|
def write_result(user_choice, row_number, dataframe): |
|
if row_number == NUM_QUESTIONS: |
|
return row_number, dataframe |
|
|
|
user_choices = [] |
|
for i, b in enumerate(str(user_choice)): |
|
if bool(int(b)): |
|
user_choices.append(i) |
|
|
|
chosen_models = [] |
|
for user_choice in user_choices: |
|
chosen_models.append(submission_names[dataframe.iloc[row_number][f"choice_{user_choice}"]]) |
|
|
|
print(chosen_models) |
|
dataframe.loc[row_number, "result"] = ",".join(chosen_models) |
|
return row_number + 1, dataframe |
|
|
|
|
|
def get_index(evt: gr.SelectData) -> int: |
|
return evt.index |
|
|
|
|
|
def change_view(row_number, dataframe): |
|
if row_number == NUM_QUESTIONS: |
|
|
|
results = sum([x.split(",") for x in dataframe["result"].values], []) |
|
results = [r for r in results if len(r) > 0] |
|
favorite_model = Counter(results).most_common(1)[0][0] |
|
|
|
dataset = datasets.Dataset.from_pandas(dataframe) |
|
dataset = dataset.remove_columns(set(dataset.column_names) - set(["id", "result"])) |
|
hash = generate_random_hash() |
|
repo_id = os.path.join(SUBMISSION_ORG, hash) |
|
|
|
dataset.push_to_hub(repo_id, token=os.getenv("HF_TOKEN")) |
|
return { |
|
intro_view: gr.update(visible=False), |
|
result_view: gr.update(visible=True), |
|
gallery_view: gr.update(visible=False), |
|
start_view: gr.update(visible=True), |
|
result: RESULT[favorite_model], |
|
} |
|
else: |
|
return { |
|
intro_view: gr.update(visible=False), |
|
result_view: gr.update(visible=False), |
|
gallery_view: gr.update(visible=True), |
|
start_view: gr.update(visible=False), |
|
result: "", |
|
} |
|
|
|
|
|
TITLE = "# What AI model is best for you? π©ββοΈ" |
|
|
|
DESCRIPTION = """ |
|
***How it works*** π \n\n |
|
- Upon clicking start, you are shown image descriptions alongside four AI generated images. |
|
\n- Select all images that match the prompt. **Note** that you should leave all images *unchecked* if no image matches the prompt. |
|
\n- Answer **10** questions to find out what AI generator most resonates with you. |
|
\n- Your submissions contribute to [**Open Parti Prompts Leaderboard**](https://huggingface.co/spaces/OpenGenAI/parti-prompts-leaderboard) β€οΈ. |
|
\n\n |
|
""" |
|
|
|
NOTE = """\n\n\n\n |
|
The prompts you are shown originate from the [Parti Prompts](https://huggingface.co/datasets/nateraw/parti-prompts) dataset. |
|
Parti Prompts is designed to test text-to-image AI models on 1600+ prompts of varying difficulty and categories. |
|
The images you are shown have been pre-generated with 4 state-of-the-art open-sourced text-to-image models. |
|
You answers will be used to contribute to the official [**Open Parti Prompts Leaderboard**](https://huggingface.co/spaces/OpenGenAI/parti-prompts-leaderboard). |
|
Every couple months, the generated images will be updated with possibly improved models. The current models and code that was used to generate the images can be verified here:\n |
|
- [kandinsky-2-2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder) \n |
|
- [wuerstchen](https://huggingface.co/warp-ai/wuerstchen) \n |
|
- [sdxl-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) \n |
|
- [karlo](https://huggingface.co/datasets/diffusers-parti-prompts/karlo-v1) \n |
|
""" |
|
|
|
GALLERY_COLUMN_NUM = len(SUBMISSIONS) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(TITLE) |
|
with gr.Column(visible=True) as intro_view: |
|
gr.Markdown(DESCRIPTION) |
|
|
|
headers = ["prompt", "result", "id", "Challenge", "Category", "Note"] + [ |
|
f"choice_{i}" for i in range(len(SUBMISSIONS)) |
|
] |
|
datatype = ["str", "str", "number", "str", "str", "str"] + len(SUBMISSIONS) * [ |
|
"number" |
|
] |
|
|
|
with gr.Column(visible=False): |
|
row_number = gr.Number( |
|
label="Current row selection index", |
|
value=0, |
|
precision=0, |
|
interactive=False, |
|
) |
|
|
|
|
|
with gr.Column(visible=False) as result_view: |
|
result = gr.Markdown("") |
|
dataframe = gr.Dataframe( |
|
headers=headers, |
|
datatype=datatype, |
|
row_count=NUM_QUESTIONS, |
|
col_count=(6 + len(SUBMISSIONS), "fixed"), |
|
interactive=False, |
|
) |
|
gr.Markdown("Click on start to play again!") |
|
|
|
with gr.Column(visible=True) as start_view: |
|
start_button = gr.Button("Start").style(full_width=True) |
|
gr.Markdown(NOTE) |
|
|
|
with gr.Column(visible=False): |
|
selected_image = gr.Number(label="Selected index", value=-1, precision=0) |
|
|
|
with gr.Column(visible=False) as gallery_view: |
|
with gr.Row(): |
|
counter = gr.Markdown(f"***1/{NUM_QUESTIONS} {PROMPT_FORMAT}***") |
|
with gr.Row(): |
|
prompt = gr.Markdown("") |
|
with gr.Blocks(): |
|
with gr.Row(): |
|
with gr.Column(min_width=200) as c1: |
|
image_1 = gr.Image(interactive=False) |
|
image_1_button = gr.Checkbox(False, label="Image 1").style(full_width=True) |
|
with gr.Column(min_width=200) as c2: |
|
image_2 = gr.Image(interactive=False) |
|
image_2_button = gr.Checkbox(False, label="Image 2").style(full_width=True) |
|
with gr.Column(min_width=200) as c3: |
|
image_3 = gr.Image(interactive=False) |
|
image_3_button = gr.Checkbox(False, label="Image 3").style(full_width=True) |
|
with gr.Column(min_width=200) as c4: |
|
image_4 = gr.Image(interactive=False) |
|
image_4_button = gr.Checkbox(False, label="Image 4").style(full_width=True) |
|
with gr.Row(): |
|
submit_button = gr.Button("Submit").style(full_width=True) |
|
|
|
start_button.click( |
|
fn=start, |
|
inputs=[], |
|
outputs=dataframe, |
|
show_progress=True |
|
).then( |
|
fn=lambda x: 0 if x == NUM_QUESTIONS else x, |
|
inputs=[row_number], |
|
outputs=[row_number], |
|
).then( |
|
fn=change_view, |
|
inputs=[row_number, dataframe], |
|
outputs=[intro_view, result_view, gallery_view, start_view, result], |
|
).then( |
|
fn=process, |
|
inputs=[dataframe], |
|
outputs=[image_1, image_2, image_3, image_4, image_1_button, image_2_button, image_3_button, image_4_button, prompt, counter] |
|
) |
|
|
|
def integerize(x1, x2, x3, x4): |
|
number = f"{int(x1)}{int(x2)}{int(x3)}{int(x4)}" |
|
return int(number) |
|
|
|
submit_button.click( |
|
fn=integerize, |
|
inputs=[image_1_button, image_2_button, image_3_button, image_4_button], |
|
outputs=[selected_image], |
|
).then( |
|
fn=write_result, |
|
inputs=[selected_image, row_number, dataframe], |
|
outputs=[row_number, dataframe], |
|
).then( |
|
fn=change_view, |
|
inputs=[row_number, dataframe], |
|
outputs=[intro_view, result_view, gallery_view, start_view, result] |
|
).then( |
|
fn=process, |
|
inputs=[dataframe, row_number], |
|
outputs=[image_1, image_2, image_3, image_4, image_1_button, image_2_button, image_3_button, image_4_button, prompt, counter], |
|
).then( |
|
fn=lambda x: 0 if x == NUM_QUESTIONS else x, |
|
inputs=[row_number], |
|
outputs=[row_number], |
|
).then( |
|
fn=refresh, |
|
inputs=[row_number, dataframe], |
|
outputs=[dataframe], |
|
) |
|
|
|
demo.launch() |
|
|