Spaces:
Running
Running
File size: 5,219 Bytes
7c1eee1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import json
from functools import partial
from .utils import *
from .vote_utils import (
upvote_last_response_t2s as upvote_last_response,
downvote_last_response_t2s as downvote_last_response,
flag_last_response_t2s as flag_last_response,
)
from .inference import(
sample_prompt,
generate_t2s
)
from .constants import TEXT_PROMPT_PATH
with open(TEXT_PROMPT_PATH, 'r') as f:
prompt_list = json.load(f)
def build_single_model_ui(models):
notice_markdown = """
# ποΈ Play with Image Generation Models
{promotion}
## π€ Choose any model to generate
"""
model_list = models.get_t2s_models()
gen_func = partial(generate_t2s, models.inference_parallel, models.render_parallel)
gr.Markdown(notice_markdown, elem_id="notice_markdown")
with gr.Row(elem_id="model_selector_row"):
model_selector = gr.Dropdown(
choices=model_list,
value=model_list[0] if len(model_list) > 0 else "",
interactive=True,
show_label=False
)
with gr.Row():
with gr.Accordion("π Expand to see all Arena players", open=False):
model_description_md = get_model_description_md(model_list)
gr.Markdown(model_description_md, elem_id="model_description_markdown")
with gr.Row():
textbox = gr.Textbox(
show_label=False,
placeholder="π Enter your prompt or Sample a random prompt, and press ENTER",
container=True,
elem_id="input_box",
)
sample_btn = gr.Button(value="π² Sample", variant="primary", scale=0)
send_btn = gr.Button(value="π€ Send", variant="primary", scale=0)
with gr.Row():
normal = gr.Image(width=512, label = "Normal", show_copy_button=True)
rgb = gr.Image(width=512, label = "RGB", show_copy_button=True,)
with gr.Row():
clear_btn = gr.Button(value="ποΈ Clear", interactive=False)
regenerate_btn = gr.Button(value="π Regenerate", interactive=False)
with gr.Row(elem_id="Geometry Quality"):
geo_upvote_btn = gr.Button(value="π Upvote", interactive=False)
geo_downvote_btn = gr.Button(value="π Downvote", interactive=False)
geo_flag_btn = gr.Button(value="β οΈ Flag", interactive=False)
with gr.Row(elem_id="Texture Quality"):
text_upvote_btn = gr.Button(value="π Upvote", interactive=False)
text_downvote_btn = gr.Button(value="π Downvote", interactive=False)
text_flag_btn = gr.Button(value="β οΈ Flag", interactive=False)
with gr.Row(elem_id="Alignment Quality"):
align_upvote_btn = gr.Button(value="π Upvote", interactive=False)
align_downvote_btn = gr.Button(value="π Downvote", interactive=False)
align_flag_btn = gr.Button(value="β οΈ Flag", interactive=False)
gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
state = gr.State()
geo_btn_list = [geo_upvote_btn, geo_downvote_btn, geo_flag_btn]
text_btn_list = [text_upvote_btn, text_downvote_btn, text_flag_btn]
align_btn_list = [align_upvote_btn, align_downvote_btn, align_flag_btn]
for btn_list in [geo_btn_list, text_btn_list, align_btn_list]:
upvote_btn, downvote_btn, flag_btn = btn_list
upvote_btn.click(
upvote_last_response,
[state, model_selector],
[textbox] + btn_list
)
downvote_btn.click(
downvote_last_response,
[state, model_selector],
[textbox] + btn_list
)
flag_btn.click(
flag_last_response,
[state, model_selector],
[textbox] + btn_list
)
sample_btn.click(
sample_prompt,
[state, model_selector, prompt_list],
state + [textbox],
api_name="sample_btn_single"
)
textbox.submit(
gen_func,
[state, textbox, model_selector, prompt_list],
[state, normal, rgb],
api_name="submit_btn_single",
show_progress = "full"
).then(
enable_buttons,
None,
geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
)
send_btn.click(
gen_func,
[state, textbox, model_selector, prompt_list],
[state, normal, rgb],
api_name="send_btn_single",
show_progress = "full"
).then(
enable_buttons,
None,
geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
)
clear_btn.click(
clear_history,
None,
[state, textbox, normal, rgb],
api_name="clear_history_single",
show_progress="full"
).then(
disable_buttons,
None,
geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
)
regenerate_btn.click(
gen_func,
[state, textbox, model_selector, prompt_list],
[state, normal, rgb],
api_name="regenerate_btn_single",
show_progress = "full"
).then(
enable_buttons,
None,
geo_btn_list + text_btn_list + align_btn_list + [regenerate_btn, clear_btn]
)
|