File size: 7,967 Bytes
aac338f
0824d60
 
 
 
 
 
aac338f
 
c61580a
 
0824d60
c61580a
0824d60
aac338f
c61580a
0824d60
 
 
 
c61580a
0824d60
 
 
c61580a
 
 
0824d60
c61580a
0824d60
c61580a
 
0824d60
c61580a
0824d60
aac338f
c61580a
 
0824d60
c61580a
0824d60
 
 
c61580a
0824d60
 
c61580a
 
0824d60
c61580a
0824d60
c61580a
0824d60
c61580a
 
 
 
0824d60
c61580a
0824d60
c61580a
0824d60
c61580a
 
 
0824d60
c61580a
0824d60
 
 
 
 
c61580a
 
0824d60
 
c61580a
bbc9212
c61580a
0824d60
c61580a
0824d60
c61580a
 
 
 
 
0824d60
c61580a
0824d60
 
 
 
 
c61580a
bbc9212
c61580a
0824d60
bbc9212
c61580a
0824d60
 
c61580a
0824d60
 
c61580a
 
0824d60
 
 
c61580a
 
0824d60
 
c61580a
0824d60
c61580a
0824d60
 
c61580a
 
0824d60
9db2d10
c61580a
 
aac338f
c61580a
0824d60
 
c61580a
0824d60
c61580a
0824d60
c61580a
0824d60
c61580a
 
 
 
e7c4130
0824d60
c61580a
 
 
 
 
 
0824d60
c61580a
0824d60
c61580a
 
 
 
0824d60
c61580a
 
 
0824d60
 
c61580a
0824d60
c61580a
0824d60
 
aac338f
c61580a
 
0824d60
c61580a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import gradio as gr
from random import randint
from all_models import models

from externalmod import gr_Interface_load, randomize_seed

import asyncio
import os
from threading import RLock

# Create a lock to ensure thread safety when accessing shared resources
lock = RLock()
# Load Hugging Face token from environment variable, if available
HF_TOKEN = os.environ.get("HF_TOKEN") if os.environ.get("HF_TOKEN") else None # If private or gated models aren't used, ENV setting is unnecessary.

# Function to load all models specified in the 'models' list
def load_fn(models):
    global models_load
    models_load = {}
    
    # Iterate through all models to load them
    for model in models:
        if model not in models_load.keys():
            try:
                # Log model loading attempt
                print(f"Attempting to load model: {model}")
                # Load model interface using externalmod function
                m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
                print(f"Successfully loaded model: {model}")
            except Exception as error:
                # In case of an error, print it and create a placeholder interface
                print(f"Error loading model {model}: {error}")
                m = gr.Interface(lambda: None, ['text'], ['image'])
            # Update the models_load dictionary with the loaded model
            models_load.update({model: m})

# Load all models defined in the 'models' list
print("Loading models...")
load_fn(models)
print("Models loaded successfully.")

num_models = 6

# Set the default models to use for inference
default_models = models[:num_models]
inference_timeout = 600
MAX_SEED = 3999999999
# Generate a starting seed randomly between 1941 and 2024
starting_seed = randint(1941, 2024)
print(f"Starting seed: {starting_seed}")

# Extend the choices list to ensure it contains 'num_models' elements
def extend_choices(choices):
    print(f"Extending choices: {choices}")
    extended = choices[:num_models] + (num_models - len(choices[:num_models])) * ['NA']
    print(f"Extended choices: {extended}")
    return extended

# Update the image boxes based on selected models
def update_imgbox(choices):
    print(f"Updating image boxes with choices: {choices}")
    choices_plus = extend_choices(choices[:num_models])
    imgboxes = [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus]
    print(f"Updated image boxes: {imgboxes}")
    return imgboxes

# Asynchronous function to perform inference on a given model
async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
    from pathlib import Path
    kwargs = {}
    noise = ""
    kwargs["seed"] = seed
    # Create an asynchronous task to run the model inference
    print(f"Starting inference for model: {model_str} with prompt: '{prompt}' and seed: {seed}")
    task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn,
                               prompt=f'{prompt} {noise}', **kwargs, token=HF_TOKEN))
    await asyncio.sleep(0)  # Allow other tasks to run
    try:
        # Wait for the task to complete within the specified timeout
        result = await asyncio.wait_for(task, timeout=timeout)
        print(f"Inference completed for model: {model_str}")
    except (Exception, asyncio.TimeoutError) as e:
        # Handle any exceptions or timeout errors
        print(f"Error during inference for model {model_str}: {e}")
        if not task.done():
            task.cancel()
            print(f"Task cancelled for model: {model_str}")
        result = None
    # If the task completed successfully, save the result as an image
    if task.done() and result is not None:
        with lock:
            png_path = "image.png"
            result.save(png_path)
            image = str(Path(png_path).resolve())
            print(f"Result saved as image: {image}")
        return image
    print(f"No result for model: {model_str}")
    return None

# Function to generate an image based on the given model, prompt, and seed
def gen_fnseed(model_str, prompt, seed=1):
    if model_str == 'NA':
        print(f"Model is 'NA', skipping generation.")
        return None
    try:
        # Create a new event loop to run the asynchronous inference function
        print(f"Generating image for model: {model_str} with prompt: '{prompt}' and seed: {seed}")
        loop = asyncio.new_event_loop()
        result = loop.run_until_complete(infer(model_str, prompt, seed, inference_timeout))
    except (Exception, asyncio.CancelledError) as e:
        # Handle any exceptions or cancelled tasks
        print(f"Error during generation for model {model_str}: {e}")
        result = None
    finally:
        # Close the event loop
        loop.close()
        print(f"Event loop closed for model: {model_str}")
    return result

# Create the Gradio Blocks interface with a custom theme
print("Creating Gradio interface...")
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
    gr.HTML("<center><h1>Compare-6</h1></center>")
    with gr.Tab('Compare-6'):
        # Text input for user prompt
        txt_input = gr.Textbox(label='Your prompt:', lines=4)
        # Button to generate images
        gen_button = gr.Button('Generate up to 6 images in up to 3 minutes total')
        with gr.Row():
            # Slider to select a seed for reproducibility
            seed = gr.Slider(label="Use a seed to replicate the same image later (maximum 3999999999)", minimum=0, maximum=MAX_SEED, step=1, value=starting_seed, scale=3)
            # Button to randomize the seed
            seed_rand = gr.Button("Randomize Seed 🎲", size="sm", variant="secondary", scale=1)    
        # Set up click event to randomize the seed
        seed_rand.click(randomize_seed, None, [seed], queue=False)
        print("Seed randomization button set up.")
        # Button click to start generation
        gen_button.click(lambda s: gr.update(interactive=True), None)
        print("Generation button set up.")

        with gr.Row():
            # Create image output components for each model
            output = [gr.Image(label=m, min_width=480) for m in default_models]
            # Create hidden textboxes to store the current models
            current_models = [gr.Textbox(m, visible=False) for m in default_models]
            
            # Set up generation events for each model and output image
            for m, o in zip(current_models, output):
                print(f"Setting up generation event for model: {m.value}")
                gen_event = gr.on(triggers=[gen_button.click, txt_input.submit], fn=gen_fnseed,
                                  inputs=[m, txt_input, seed], outputs=[o], concurrency_limit=None, queue=False)
                # The commented stop button could be used to cancel the generation event
                #stop_button.click(lambda s: gr.update(interactive=False), None, stop_button, cancels=[gen_event])
        # Accordion to allow model selection
        with gr.Accordion('Model selection'):
            # Checkbox group to select up to 'num_models' different models
            model_choice = gr.CheckboxGroup(models, label=f'Choose up to {int(num_models)} different models from the {len(models)} available!', value=default_models, interactive=True)
            # Update image boxes and current models based on model selection
            model_choice.change(update_imgbox, model_choice, output)
            model_choice.change(extend_choices, model_choice, current_models)
            print("Model selection setup complete.")
        with gr.Row():
            # Placeholder HTML to add additional UI elements if needed
            gr.HTML(
)

# Queue settings for handling multiple concurrent requests
print("Setting up queue...")
demo.queue(default_concurrency_limit=200, max_size=200)
print("Launching Gradio interface...")
demo.launch(show_api=False, max_threads=400)
print("Gradio interface launched successfully.")