|
import gradio as gr |
|
from random import randint |
|
from all_models import models |
|
|
|
|
|
def load_models(models): |
|
models_load = {} |
|
for model in models: |
|
if model not in models_load: |
|
try: |
|
m = gr.load(f'models/{model}') |
|
except Exception as error: |
|
m = gr.Interface(lambda txt: None, ['text'], ['image']) |
|
models_load[model] = m |
|
return models_load |
|
|
|
models_load = load_models(models) |
|
|
|
num_models = 6 |
|
default_models = models[:num_models] |
|
|
|
|
|
def extend_choices(choices): |
|
return choices + ['NA'] * (num_models - len(choices)) |
|
|
|
|
|
def update_imgbox(choices): |
|
extended_choices = extend_choices(choices) |
|
return [gr.Image(None, label=m, visible=(m != 'NA')) for m in extended_choices] |
|
|
|
|
|
def generate_image(model_str, prompt): |
|
if model_str == 'NA': |
|
return None |
|
noise = str(randint(0, 99999999999)) |
|
return models_load[model_str](f'{prompt} {noise}') |
|
|
|
|
|
with gr.Blocks() as demo: |
|
model_dropdown = gr.Dropdown(models, label='Choose model', value=models[0], filterable=False) |
|
text_input = gr.Textbox(label='Prompt text') |
|
|
|
max_images = 6 |
|
num_images_slider = gr.Slider(1, max_images, value=max_images, step=1, label='Number of images') |
|
|
|
generate_button = gr.Button('Generate') |
|
stop_button = gr.Button('Stop', variant='secondary', interactive=False) |
|
|
|
|
|
generate_button.click(lambda: gr.update(interactive=True), None, stop_button) |
|
|
|
with gr.Row(): |
|
output_images = [gr.Image(label='') for _ in range(max_images)] |
|
|
|
for i, output in enumerate(output_images): |
|
img_index = gr.Number(i, visible=False) |
|
num_images_slider.change( |
|
lambda idx, n: gr.update(visible=(idx < n)), |
|
[img_index, num_images_slider], output |
|
) |
|
generate_event = generate_button.click( |
|
lambda idx, n, model, prompt: generate_image(model, prompt) if idx < n else None, |
|
[img_index, num_images_slider, model_dropdown, text_input], output |
|
) |
|
|
|
|
|
stop_button.click(lambda: gr.update(interactive=False), None, stop_button, cancels=[generate_event]) |
|
|
|
|
|
demo.launch() |
|
|