import torch import numpy as np from PIL import Image def resize_image(image)->Image.Image: pixel_number = 1024*1024 granularity_val = 64 ratio = image.size[0] / image.size[1] width = int((pixel_number * ratio) ** 0.5) width = width - (width % granularity_val) height = int(pixel_number / width) height = height - (height % granularity_val) return image.resize((width, height)) def get_masked_background_image(image, image_mask)->tuple: image_mask_pil = image_mask.resize(image.size) # fg is white image = np.array(image.convert("RGB")).transpose(2, 0, 1).astype(np.float32) / 255.0 image_mask = np.array(image_mask_pil.convert("L")).astype(np.float32) / 255.0 image[:,image_mask < 0.5] = 0 # mask background return image, image_mask def get_control_image_tensor(vae, image, mask)->torch.Tensor: masked_image, image_mask = get_masked_background_image(image, mask) masked_image_tensor = torch.from_numpy(masked_image) masked_image_tensor = (masked_image_tensor - 0.5) / 0.5 # normalize for vae masked_image_tensor = masked_image_tensor.unsqueeze(0).to(device="cuda:0") # encode the image to get the control latents control_latents = vae.encode( masked_image_tensor[:, :3, :, :].to(vae.dtype) ).latent_dist.sample() control_latents = control_latents * vae.config.scaling_factor mask_tensor = torch.tensor(image_mask, dtype=torch.float32)[None, None, ...].to(device="cuda:0") mask_tensor = torch.where(mask_tensor > 0.5, 1.0, 0) # binarize the mask mask_resized = torch.nn.functional.interpolate(mask_tensor, size=(control_latents.shape[2], control_latents.shape[3]), mode='nearest') control_tensor = torch.cat([control_latents, mask_resized], dim=1) return control_tensor def remove_bg_from_image(image)->Image.Image: from transformers import pipeline pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True) mask = pipe(image, return_mask = True) # outputs a pillow mask return mask def paste_fg_over_image(gen_image: Image.Image, orig_image: Image.Image, fg_mask: Image.Image)->Image.Image: fg_mask = fg_mask.convert("L") fg_mask = fg_mask.resize(orig_image.size, Image.NEAREST) gen_image = gen_image.convert("RGBA") orig_image = orig_image.convert("RGBA") gen_image.paste(orig_image, (0, 0), fg_mask) return gen_image.convert("RGB")