import gradio as gr from PIL import Image import requests from io import BytesIO import torch from torchvision import transforms from diffusers import AutoencoderKL, LCMScheduler from pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline from controlnet import ControlNetModel # Define helper functions def download_image(url): response = requests.get(url) return Image.open(BytesIO(response.content)).convert("RGB") def load_model(): # Load model components controlnet = ControlNetModel().from_pretrained("briaai/DEV-ControlNetInpaintingFast", torch_dtype=torch.float16) vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) pipe = StableDiffusionXLControlNetPipeline.from_pretrained("briaai/BRIA-2.3", controlnet=controlnet.to(dtype=torch.float16), torch_dtype=torch.float16, vae=vae) pipe.to('cuda') return pipe pipe = load_model() # Define the inpainting function def inpaint(image, mask): # Process image and mask image = image.resize((1024, 1024)).convert("RGB") mask = mask.resize((1024, 1024)).convert("L") # Transform to tensor image_transform = transforms.ToTensor() image_tensor = image_transform(image).unsqueeze(0).to('cuda') mask_tensor = image_transform(mask).unsqueeze(0).to('cuda') mask_tensor = (mask_tensor > 0.5).float() # binarize mask # Generate image with torch.no_grad(): result = pipe(prompt="A park bench", init_image=image_tensor, mask_image=mask_tensor, num_inference_steps=50).images[0] return transforms.ToPILImage()(result.squeeze(0)) # Define the interface interface = gr.Interface(fn=inpaint, inputs=[gr.inputs.Image(type="pil", label="Original Image"), gr.inputs.Image(type="pil", label="Mask Image")], outputs=gr.outputs.Image(type="pil", label="Inpainted Image"), title="Stable Diffusion XL ControlNet Inpainting", description="Upload an image and its corresponding mask to inpaint the specified area.") if __name__ == "__main__": interface.launch()