I modified app.py so that it could generate up to 100 images sequentially.
import gradio as gr
from PIL import Image
import os
import argparse
import random
from datetime import datetime
from OmniGen import OmniGenPipeline
Initialize the pipeline
pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")
def generate_image(
text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale,
inference_steps, seed, separate_cfg_infer, offload_model,
use_input_image_size_as_output, max_input_image_size, randomize_seed,
save_images, num_generated_images
):
input_images = [img1, img2, img3]
# Remove None values
input_images = [img for img in input_images if img is not None]
if len(input_images) == 0:
input_images = None
images = []
for i in range(int(num_generated_images)):
if randomize_seed:
current_seed = random.randint(0, 10000000)
else:
current_seed = int(seed) + i
output = pipe(
prompt=text,
input_images=input_images,
height=height,
width=width,
guidance_scale=guidance_scale,
img_guidance_scale=img_guidance_scale,
num_inference_steps=inference_steps,
separate_cfg_infer=separate_cfg_infer,
use_kv_cache=True,
offload_kv_cache=True,
offload_model=offload_model,
use_input_image_size_as_output=use_input_image_size_as_output,
seed=current_seed,
max_input_image_size=max_input_image_size,
)
img = output[0]
if save_images:
# Save each generated image
os.makedirs('outputs', exist_ok=True)
timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
output_path = os.path.join('outputs', f'{timestamp}_{i}.png')
img.save(output_path)
images.append(img)
return images # Return the list of generated images
Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# OmniGen: Unified Image Generation paper code")
# You can include the description here if needed
with gr.Row():
with gr.Column():
# Text prompt
prompt_input = gr.Textbox(
label="Enter your prompt (use <img><|image_i|></img>
to represent i-th input image)",
placeholder="Type your prompt here..."
)
with gr.Row(equal_height=True):
# Input images (keep type as "filepath")
image_input_1 = gr.Image(label="<img><|image_1|></img>", type="filepath")
image_input_2 = gr.Image(label="<img><|image_2|></img>", type="filepath")
image_input_3 = gr.Image(label="<img><|image_3|></img>", type="filepath")
# Sliders and inputs
height_input = gr.Slider(
label="Height", minimum=128, maximum=2048, value=1024, step=16
)
width_input = gr.Slider(
label="Width", minimum=128, maximum=2048, value=1024, step=16
)
guidance_scale_input = gr.Slider(
label="Guidance Scale", minimum=1.0, maximum=5.0, value=2.5, step=0.1
)
img_guidance_scale_input = gr.Slider(
label="Image Guidance Scale", minimum=1.0, maximum=2.0, value=1.6, step=0.1
)
num_inference_steps = gr.Slider(
label="Inference Steps", minimum=1, maximum=100, value=50, step=1
)
seed_input = gr.Slider(
label="Seed", minimum=0, maximum=2147483647, value=42, step=1
)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
max_input_image_size = gr.Slider(
label="Max Input Image Size", minimum=128, maximum=2048, value=1024, step=16
)
separate_cfg_infer = gr.Checkbox(
label="Separate CFG Inference",
info="Use separate inference processes for different guidance to reduce memory cost.",
value=True,
)
offload_model = gr.Checkbox(
label="Offload Model",
info="Offload model to CPU to reduce memory cost but slow down generation.",
value=False,
)
use_input_image_size_as_output = gr.Checkbox(
label="Use Input Image Size as Output",
info="Automatically adjust the output image size to match input image size.",
value=False,
)
save_images = gr.Checkbox(label="Save Generated Images", value=True)
# Added slider for the number of images to generate
num_generated_images_input = gr.Slider(
label="Number of Generated Images", minimum=1, maximum=100, value=1, step=1
)
# Generate button
generate_button = gr.Button("Generate Images")
with gr.Column():
# Output gallery to display multiple images
output_gallery = gr.Gallery(label="Output Images")
# Generate button click event
generate_button.click(
generate_image,
inputs=[
prompt_input,
image_input_1,
image_input_2,
image_input_3,
height_input,
width_input,
guidance_scale_input,
img_guidance_scale_input,
num_inference_steps,
seed_input,
separate_cfg_infer,
offload_model,
use_input_image_size_as_output,
max_input_image_size,
randomize_seed,
save_images,
num_generated_images_input, # Include the new input
],
outputs=output_gallery, # Output the list of images to the gallery
)
if name == "main":
parser = argparse.ArgumentParser(description='Run the OmniGen app')
parser.add_argument('--share', action='store_true', help='Share the Gradio app publicly')
args = parser.parse_args()
# Launch the app
demo.launch(share=args.share)