3DGen-Arena / serve /gradio_web_t2i_single.py
ZhangYuhan's picture
add serve
7c1eee1
raw
history blame
5.22 kB
import json
from functools import partial
from .utils import *
from .vote_utils import (
upvote_last_response_t2s as upvote_last_response,
downvote_last_response_t2s as downvote_last_response,
flag_last_response_t2s as flag_last_response,
)
from .inference import(
sample_prompt,
generate_t2s
)
from .constants import TEXT_PROMPT_PATH
with open(TEXT_PROMPT_PATH, 'r') as f:
prompt_list = json.load(f)
def build_single_model_ui(models):
notice_markdown = """
# πŸ”οΈ Play with Image Generation Models
{promotion}
## πŸ€– Choose any model to generate
"""
model_list = models.get_t2s_models()
gen_func = partial(generate_t2s, models.inference_parallel, models.render_parallel)
gr.Markdown(notice_markdown, elem_id="notice_markdown")
with gr.Row(elem_id="model_selector_row"):
model_selector = gr.Dropdown(
choices=model_list,
value=model_list[0] if len(model_list) > 0 else "",
interactive=True,
show_label=False
)
with gr.Row():
with gr.Accordion("πŸ” Expand to see all Arena players", open=False):
model_description_md = get_model_description_md(model_list)
gr.Markdown(model_description_md, elem_id="model_description_markdown")
with gr.Row():
textbox = gr.Textbox(
show_label=False,
placeholder="πŸ‘‰ Enter your prompt or Sample a random prompt, and press ENTER",
container=True,
elem_id="input_box",
)
sample_btn = gr.Button(value="🎲 Sample", variant="primary", scale=0)
send_btn = gr.Button(value="πŸ“€ Send", variant="primary", scale=0)
with gr.Row():
normal = gr.Image(width=512, label = "Normal", show_copy_button=True)
rgb = gr.Image(width=512, label = "RGB", show_copy_button=True,)
with gr.Row():
clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False)
regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False)
with gr.Row(elem_id="Geometry Quality"):
geo_upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
geo_downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
geo_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
with gr.Row(elem_id="Texture Quality"):
text_upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
text_downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
text_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
with gr.Row(elem_id="Alignment Quality"):
align_upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
align_downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
align_flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
state = gr.State()
geo_btn_list = [geo_upvote_btn, geo_downvote_btn, geo_flag_btn]
text_btn_list = [text_upvote_btn, text_downvote_btn, text_flag_btn]
align_btn_list = [align_upvote_btn, align_downvote_btn, align_flag_btn]
for btn_list in [geo_btn_list, text_btn_list, align_btn_list]:
upvote_btn, downvote_btn, flag_btn = btn_list
upvote_btn.click(
upvote_last_response,
[state, model_selector],
[textbox] + btn_list
)
downvote_btn.click(
downvote_last_response,
[state, model_selector],
[textbox] + btn_list
)
flag_btn.click(
flag_last_response,
[state, model_selector],
[textbox] + btn_list
)
sample_btn.click(
sample_prompt,
[state, model_selector, prompt_list],
state + [textbox],
api_name="sample_btn_single"
)
textbox.submit(
gen_func,
[state, textbox, model_selector, prompt_list],
[state, normal, rgb],
api_name="submit_btn_single",
show_progress = "full"
).then(
enable_buttons,
None,
geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
)
send_btn.click(
gen_func,
[state, textbox, model_selector, prompt_list],
[state, normal, rgb],
api_name="send_btn_single",
show_progress = "full"
).then(
enable_buttons,
None,
geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
)
clear_btn.click(
clear_history,
None,
[state, textbox, normal, rgb],
api_name="clear_history_single",
show_progress="full"
).then(
disable_buttons,
None,
geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
)
regenerate_btn.click(
gen_func,
[state, textbox, model_selector, prompt_list],
[state, normal, rgb],
api_name="regenerate_btn_single",
show_progress = "full"
).then(
enable_buttons,
None,
geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
)