tonyassi's picture
Update app.py
00486f9 verified
raw
history blame
No virus
4.63 kB
import spaces
import gradio as gr
from diffusers import AutoPipelineForInpainting, AutoencoderKL
import torch
from PIL import Image, ImageOps
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipeline = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", vae=vae, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda")
def get_select_index(evt: gr.SelectData):
return evt.index
@spaces.GPU()
def squarify_image(img):
if(img.height > img.width): bg_size = img.height
else: bg_size = img.width
bg = Image.new(mode="RGB", size=(bg_size,bg_size), color="white")
bg.paste(img, ( int((bg.width - bg.width)/2), 0) )
return bg
@spaces.GPU()
def divisible_by_8(image):
width, height = image.size
# Calculate the new width and height that are divisible by 8
new_width = (width // 8) * 8
new_height = (height // 8) * 8
# Resize the image
resized_image = image.resize((new_width, new_height))
return resized_image
@spaces.GPU()
def restore_version(index, versions):
print('restore version:', index)
final_dict = {'background': versions[index][0], 'layers': None, 'composite': versions[index][0]}
return final_dict
@spaces.GPU()
def generate(image_editor, prompt, neg_prompt, versions):
image = image_editor['background'].convert('RGB')
# Resize image
image.thumbnail((1024, 1024))
image = divisible_by_8(image)
original_image_size = image.size
# Mask layer
layer = image_editor["layers"][0].resize(image.size)
# Make image a square
image = squarify_image(image)
# Make sure mask is white with a black background
mask = Image.new("RGBA", image.size, "WHITE")
mask.paste(layer, (0, 0), layer)
mask = ImageOps.invert(mask.convert('L'))
# Inpaint
pipeline.to("cuda")
final_image = pipeline(prompt=prompt,
image=image,
mask_image=mask).images[0]
# Make sure the longest side of image is 1024
if (original_image_size[0] > original_image_size[1]):
original_image_size = ( original_image_size[0] * (1024/original_image_size[0]) , original_image_size[1] * (1024/original_image_size[0]))
else:
original_image_size = (original_image_size[0] * (1024/original_image_size[1]), original_image_size[1] * (1024/original_image_size[1]))
# Crop image to original aspect ratio
final_image = final_image.crop((0, 0, original_image_size[0], original_image_size[1]))
# gradio.ImageEditor requires a diction
final_dict = {'background': final_image, 'layers': None, 'composite': final_image}
# Add generated image to version gallery
if(versions==None):
final_gallery = [image_editor['background'] ,final_image]
else:
final_gallery = versions
final_gallery.append(final_image)
return final_dict, gr.Gallery(value=final_gallery, visible=True), gr.update(visible=True)
with gr.Blocks() as demo:
gr.Markdown("""
# Inpainting SDXL Sketch Pad
by [Tony Assi](https://www.tonyassi.com/)
Please ❤️ this Space. I build custom AI apps for companies. <a href="mailto: [email protected]">Email me</a> for business inquiries.
""")
with gr.Row():
with gr.Column():
sketch_pad = gr.ImageMask(type='pil', label='Inpaint')
prompt = gr.Textbox(label="Prompt")
generate_button = gr.Button("Generate")
with gr.Accordion("Advanced Settings", open=False):
neg_prompt = gr.Textbox(label='Negative Prompt', value='ugly, deformed')
with gr.Column():
version_gallery = gr.Gallery(label="Versions", type="pil", object_fit='contain', visible=False)
restore_button = gr.Button("Restore Version", visible=False)
selected = gr.Number(show_label=False, visible=False)
gr.Examples(
[[{'background':'./tony.jpg', 'layers':['./tony-mask.jpg'], 'composite':'./tony.jpg'}, 'tuexedo', 'ugly', None]],
[sketch_pad, prompt, neg_prompt, version_gallery],
[sketch_pad, version_gallery, restore_button],
generate,
cache_examples=True,
)
version_gallery.select(get_select_index, None, selected)
generate_button.click(fn=generate, inputs=[sketch_pad,prompt, neg_prompt, version_gallery], outputs=[sketch_pad, version_gallery, restore_button])
restore_button.click(fn=restore_version, inputs=[selected, version_gallery], outputs=sketch_pad)
demo.launch()