Spaces:
Running
Running
File size: 7,967 Bytes
aac338f 0824d60 aac338f c61580a 0824d60 c61580a 0824d60 aac338f c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 aac338f c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a bbc9212 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a bbc9212 c61580a 0824d60 bbc9212 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 9db2d10 c61580a aac338f c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a e7c4130 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 c61580a 0824d60 aac338f c61580a 0824d60 c61580a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import gradio as gr
from random import randint
from all_models import models
from externalmod import gr_Interface_load, randomize_seed
import asyncio
import os
from threading import RLock
# Create a lock to ensure thread safety when accessing shared resources
lock = RLock()
# Load Hugging Face token from environment variable, if available
HF_TOKEN = os.environ.get("HF_TOKEN") if os.environ.get("HF_TOKEN") else None # If private or gated models aren't used, ENV setting is unnecessary.
# Function to load all models specified in the 'models' list
def load_fn(models):
global models_load
models_load = {}
# Iterate through all models to load them
for model in models:
if model not in models_load.keys():
try:
# Log model loading attempt
print(f"Attempting to load model: {model}")
# Load model interface using externalmod function
m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
print(f"Successfully loaded model: {model}")
except Exception as error:
# In case of an error, print it and create a placeholder interface
print(f"Error loading model {model}: {error}")
m = gr.Interface(lambda: None, ['text'], ['image'])
# Update the models_load dictionary with the loaded model
models_load.update({model: m})
# Load all models defined in the 'models' list
print("Loading models...")
load_fn(models)
print("Models loaded successfully.")
num_models = 6
# Set the default models to use for inference
default_models = models[:num_models]
inference_timeout = 600
MAX_SEED = 3999999999
# Generate a starting seed randomly between 1941 and 2024
starting_seed = randint(1941, 2024)
print(f"Starting seed: {starting_seed}")
# Extend the choices list to ensure it contains 'num_models' elements
def extend_choices(choices):
print(f"Extending choices: {choices}")
extended = choices[:num_models] + (num_models - len(choices[:num_models])) * ['NA']
print(f"Extended choices: {extended}")
return extended
# Update the image boxes based on selected models
def update_imgbox(choices):
print(f"Updating image boxes with choices: {choices}")
choices_plus = extend_choices(choices[:num_models])
imgboxes = [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus]
print(f"Updated image boxes: {imgboxes}")
return imgboxes
# Asynchronous function to perform inference on a given model
async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
from pathlib import Path
kwargs = {}
noise = ""
kwargs["seed"] = seed
# Create an asynchronous task to run the model inference
print(f"Starting inference for model: {model_str} with prompt: '{prompt}' and seed: {seed}")
task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn,
prompt=f'{prompt} {noise}', **kwargs, token=HF_TOKEN))
await asyncio.sleep(0) # Allow other tasks to run
try:
# Wait for the task to complete within the specified timeout
result = await asyncio.wait_for(task, timeout=timeout)
print(f"Inference completed for model: {model_str}")
except (Exception, asyncio.TimeoutError) as e:
# Handle any exceptions or timeout errors
print(f"Error during inference for model {model_str}: {e}")
if not task.done():
task.cancel()
print(f"Task cancelled for model: {model_str}")
result = None
# If the task completed successfully, save the result as an image
if task.done() and result is not None:
with lock:
png_path = "image.png"
result.save(png_path)
image = str(Path(png_path).resolve())
print(f"Result saved as image: {image}")
return image
print(f"No result for model: {model_str}")
return None
# Function to generate an image based on the given model, prompt, and seed
def gen_fnseed(model_str, prompt, seed=1):
if model_str == 'NA':
print(f"Model is 'NA', skipping generation.")
return None
try:
# Create a new event loop to run the asynchronous inference function
print(f"Generating image for model: {model_str} with prompt: '{prompt}' and seed: {seed}")
loop = asyncio.new_event_loop()
result = loop.run_until_complete(infer(model_str, prompt, seed, inference_timeout))
except (Exception, asyncio.CancelledError) as e:
# Handle any exceptions or cancelled tasks
print(f"Error during generation for model {model_str}: {e}")
result = None
finally:
# Close the event loop
loop.close()
print(f"Event loop closed for model: {model_str}")
return result
# Create the Gradio Blocks interface with a custom theme
print("Creating Gradio interface...")
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
gr.HTML("<center><h1>Compare-6</h1></center>")
with gr.Tab('Compare-6'):
# Text input for user prompt
txt_input = gr.Textbox(label='Your prompt:', lines=4)
# Button to generate images
gen_button = gr.Button('Generate up to 6 images in up to 3 minutes total')
with gr.Row():
# Slider to select a seed for reproducibility
seed = gr.Slider(label="Use a seed to replicate the same image later (maximum 3999999999)", minimum=0, maximum=MAX_SEED, step=1, value=starting_seed, scale=3)
# Button to randomize the seed
seed_rand = gr.Button("Randomize Seed 🎲", size="sm", variant="secondary", scale=1)
# Set up click event to randomize the seed
seed_rand.click(randomize_seed, None, [seed], queue=False)
print("Seed randomization button set up.")
# Button click to start generation
gen_button.click(lambda s: gr.update(interactive=True), None)
print("Generation button set up.")
with gr.Row():
# Create image output components for each model
output = [gr.Image(label=m, min_width=480) for m in default_models]
# Create hidden textboxes to store the current models
current_models = [gr.Textbox(m, visible=False) for m in default_models]
# Set up generation events for each model and output image
for m, o in zip(current_models, output):
print(f"Setting up generation event for model: {m.value}")
gen_event = gr.on(triggers=[gen_button.click, txt_input.submit], fn=gen_fnseed,
inputs=[m, txt_input, seed], outputs=[o], concurrency_limit=None, queue=False)
# The commented stop button could be used to cancel the generation event
#stop_button.click(lambda s: gr.update(interactive=False), None, stop_button, cancels=[gen_event])
# Accordion to allow model selection
with gr.Accordion('Model selection'):
# Checkbox group to select up to 'num_models' different models
model_choice = gr.CheckboxGroup(models, label=f'Choose up to {int(num_models)} different models from the {len(models)} available!', value=default_models, interactive=True)
# Update image boxes and current models based on model selection
model_choice.change(update_imgbox, model_choice, output)
model_choice.change(extend_choices, model_choice, current_models)
print("Model selection setup complete.")
with gr.Row():
# Placeholder HTML to add additional UI elements if needed
gr.HTML(
)
# Queue settings for handling multiple concurrent requests
print("Setting up queue...")
demo.queue(default_concurrency_limit=200, max_size=200)
print("Launching Gradio interface...")
demo.launch(show_api=False, max_threads=400)
print("Gradio interface launched successfully.") |