|
from __future__ import annotations |
|
|
|
from argparse import ArgumentParser |
|
|
|
import datasets |
|
import gradio as gr |
|
import numpy as np |
|
import openai |
|
|
|
from dataset_creation.generate_txt_dataset import generate |
|
|
|
|
|
def main(openai_model: str): |
|
dataset = datasets.load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus", split="train") |
|
captions = dataset[np.random.permutation(len(dataset))]["TEXT"] |
|
index = 0 |
|
|
|
def click_random(): |
|
nonlocal index |
|
output = captions[index] |
|
index = (index + 1) % len(captions) |
|
return output |
|
|
|
def click_generate(input: str): |
|
if input == "": |
|
raise gr.Error("Input caption is missing!") |
|
edit_output = generate(openai_model, input) |
|
if edit_output is None: |
|
return "Failed :(", "Failed :(" |
|
return edit_output |
|
|
|
with gr.Blocks(css="footer {visibility: hidden}") as demo: |
|
txt_input = gr.Textbox(lines=3, label="Input Caption", interactive=True, placeholder="Type image caption here...") |
|
txt_edit = gr.Textbox(lines=1, label="GPT-3 Instruction", interactive=False) |
|
txt_output = gr.Textbox(lines=3, label="GPT3 Edited Caption", interactive=False) |
|
|
|
with gr.Row(): |
|
clear_btn = gr.Button("Clear") |
|
random_btn = gr.Button("Random Input") |
|
generate_btn = gr.Button("Generate Instruction + Edited Caption") |
|
|
|
clear_btn.click(fn=lambda: ("", "", ""), inputs=[], outputs=[txt_input, txt_edit, txt_output]) |
|
random_btn.click(fn=click_random, inputs=[], outputs=[txt_input]) |
|
generate_btn.click(fn=click_generate, inputs=[txt_input], outputs=[txt_edit, txt_output]) |
|
|
|
demo.launch(share=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = ArgumentParser() |
|
parser.add_argument("--openai-api-key", required=True, type=str) |
|
parser.add_argument("--openai-model", required=True, type=str) |
|
args = parser.parse_args() |
|
openai.api_key = args.openai_api_key |
|
main(args.openai_model) |
|
|