from PIL import Image, ImageFilter import torch import math from nodes import common_ksampler, VAEEncode, VAEDecode, VAEDecodeTiled from comfy_extras.nodes_custom_sampler import SamplerCustom from utils import pil_to_tensor, tensor_to_pil, get_crop_region, expand_crop, crop_cond from modules import shared if (not hasattr(Image, 'Resampling')): # For older versions of Pillow Image.Resampling = Image class StableDiffusionProcessing: def __init__( self, init_img, model, positive, negative, vae, seed, steps, cfg, sampler_name, scheduler, denoise, upscale_by, uniform_tile_mode, tiled_decode, custom_sampler=None, custom_sigmas=None ): # Variables used by the USDU script self.init_images = [init_img] self.image_mask = None self.mask_blur = 0 self.inpaint_full_res_padding = 0 self.width = init_img.width self.height = init_img.height # ComfyUI Sampler inputs self.model = model self.positive = positive self.negative = negative self.vae = vae self.seed = seed self.steps = steps self.cfg = cfg self.sampler_name = sampler_name self.scheduler = scheduler self.denoise = denoise # Optional custom sampler and sigmas self.custom_sampler = custom_sampler self.custom_sigmas = custom_sigmas if (custom_sampler is not None) ^ (custom_sigmas is not None): print("[USDU] Both custom sampler and custom sigmas must be provided, defaulting to widget sampler and sigmas") # Variables used only by this script self.init_size = init_img.width, init_img.height self.upscale_by = upscale_by self.uniform_tile_mode = uniform_tile_mode self.tiled_decode = tiled_decode self.vae_decoder = VAEDecode() self.vae_encoder = VAEEncode() self.vae_decoder_tiled = VAEDecodeTiled() # Other required A1111 variables for the USDU script that is currently unused in this script self.extra_generation_params = {} class Processed: def __init__(self, p: StableDiffusionProcessing, images: list, seed: int, info: str): self.images = images self.seed = seed self.info = info def infotext(self, p: StableDiffusionProcessing, index): return None def fix_seed(p: StableDiffusionProcessing): pass def sample(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise, custom_sampler, custom_sigmas): # Choose way to sample based on given inputs # Custom sampler and sigmas if custom_sampler is not None and custom_sigmas is not None: custom_sample = SamplerCustom() (samples, _) = getattr(custom_sample, custom_sample.FUNCTION)( model=model, add_noise=True, noise_seed=seed, cfg=cfg, positive=positive, negative=negative, sampler=custom_sampler, sigmas=custom_sigmas, latent_image=latent ) return samples # Default (samples,) = common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=denoise) return samples def process_images(p: StableDiffusionProcessing) -> Processed: # Where the main image generation happens in A1111 # Setup image_mask = p.image_mask.convert('L') init_image = p.init_images[0] # Locate the white region of the mask outlining the tile and add padding crop_region = get_crop_region(image_mask, p.inpaint_full_res_padding) if p.uniform_tile_mode: # Expand the crop region to match the processing size ratio and then resize it to the processing size x1, y1, x2, y2 = crop_region crop_width = x2 - x1 crop_height = y2 - y1 crop_ratio = crop_width / crop_height p_ratio = p.width / p.height if crop_ratio > p_ratio: target_width = crop_width target_height = round(crop_width / p_ratio) else: target_width = round(crop_height * p_ratio) target_height = crop_height crop_region, _ = expand_crop(crop_region, image_mask.width, image_mask.height, target_width, target_height) tile_size = p.width, p.height else: # Uses the minimal size that can fit the mask, minimizes tile size but may lead to image sizes that the model is not trained on x1, y1, x2, y2 = crop_region crop_width = x2 - x1 crop_height = y2 - y1 target_width = math.ceil(crop_width / 8) * 8 target_height = math.ceil(crop_height / 8) * 8 crop_region, tile_size = expand_crop(crop_region, image_mask.width, image_mask.height, target_width, target_height) # Blur the mask if p.mask_blur > 0: image_mask = image_mask.filter(ImageFilter.GaussianBlur(p.mask_blur)) # Crop the images to get the tiles that will be used for generation tiles = [img.crop(crop_region) for img in shared.batch] # Assume the same size for all images in the batch initial_tile_size = tiles[0].size # Resize if necessary for i, tile in enumerate(tiles): if tile.size != tile_size: tiles[i] = tile.resize(tile_size, Image.Resampling.LANCZOS) # Crop conditioning positive_cropped = crop_cond(p.positive, crop_region, p.init_size, init_image.size, tile_size) negative_cropped = crop_cond(p.negative, crop_region, p.init_size, init_image.size, tile_size) # Encode the image batched_tiles = torch.cat([pil_to_tensor(tile) for tile in tiles], dim=0) (latent,) = p.vae_encoder.encode(p.vae, batched_tiles) # Generate samples samples = sample(p.model, p.seed, p.steps, p.cfg, p.sampler_name, p.scheduler, positive_cropped, negative_cropped, latent, p.denoise, p.custom_sampler, p.custom_sigmas) # Decode the sample if not p.tiled_decode: (decoded,) = p.vae_decoder.decode(p.vae, samples) else: print("[USDU] Using tiled decode") (decoded,) = p.vae_decoder_tiled.decode(p.vae, samples, 512) # Default tile size is 512 # Convert the sample to a PIL image tiles_sampled = [tensor_to_pil(decoded, i) for i in range(len(decoded))] for i, tile_sampled in enumerate(tiles_sampled): init_image = shared.batch[i] # Resize back to the original size if tile_sampled.size != initial_tile_size: tile_sampled = tile_sampled.resize(initial_tile_size, Image.Resampling.LANCZOS) # Put the tile into position image_tile_only = Image.new('RGBA', init_image.size) image_tile_only.paste(tile_sampled, crop_region[:2]) # Add the mask as an alpha channel # Must make a copy due to the possibility of an edge becoming black temp = image_tile_only.copy() temp.putalpha(image_mask) image_tile_only.paste(temp, image_tile_only) # Add back the tile to the initial image according to the mask in the alpha channel result = init_image.convert('RGBA') result.alpha_composite(image_tile_only) # Convert back to RGB result = result.convert('RGB') shared.batch[i] = result processed = Processed(p, [shared.batch[0]], p.seed, None) return processed