yonishafir's picture
Create app.py
884e760 verified
raw
history blame
2.14 kB
import gradio as gr
from PIL import Image
import requests
from io import BytesIO
import torch
from torchvision import transforms
from diffusers import AutoencoderKL, LCMScheduler
from pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
from controlnet import ControlNetModel
# Define helper functions
def download_image(url):
response = requests.get(url)
return Image.open(BytesIO(response.content)).convert("RGB")
def load_model():
# Load model components
controlnet = ControlNetModel().from_pretrained("briaai/DEV-ControlNetInpaintingFast", torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained("briaai/BRIA-2.3", controlnet=controlnet.to(dtype=torch.float16), torch_dtype=torch.float16, vae=vae)
pipe.to('cuda')
return pipe
pipe = load_model()
# Define the inpainting function
def inpaint(image, mask):
# Process image and mask
image = image.resize((1024, 1024)).convert("RGB")
mask = mask.resize((1024, 1024)).convert("L")
# Transform to tensor
image_transform = transforms.ToTensor()
image_tensor = image_transform(image).unsqueeze(0).to('cuda')
mask_tensor = image_transform(mask).unsqueeze(0).to('cuda')
mask_tensor = (mask_tensor > 0.5).float() # binarize mask
# Generate image
with torch.no_grad():
result = pipe(prompt="A park bench", init_image=image_tensor, mask_image=mask_tensor, num_inference_steps=50).images[0]
return transforms.ToPILImage()(result.squeeze(0))
# Define the interface
interface = gr.Interface(fn=inpaint,
inputs=[gr.inputs.Image(type="pil", label="Original Image"), gr.inputs.Image(type="pil", label="Mask Image")],
outputs=gr.outputs.Image(type="pil", label="Inpainted Image"),
title="Stable Diffusion XL ControlNet Inpainting",
description="Upload an image and its corresponding mask to inpaint the specified area.")
if __name__ == "__main__":
interface.launch()