Freak-ppa's picture
Update ComfyUI/custom_nodes/ComfyUI-BrushNet/brushnet_nodes.py
7f1aaab verified
raw
history blame
45.8 kB
import os
import types
from typing import Tuple
import torch
import torchvision.transforms as T
import torch.nn.functional as F
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import math
import comfy
import folder_paths
from .model_patch import add_model_patch_option, patch_model_function_wrapper
from .brushnet.brushnet import BrushNetModel
from .brushnet.brushnet_ca import BrushNetModel as PowerPaintModel
from .brushnet.powerpaint_utils import TokenizerWrapper, add_tokens
current_directory = os.path.dirname(os.path.abspath(__file__))
brushnet_config_file = os.path.join(current_directory, 'brushnet', 'brushnet.json')
brushnet_xl_config_file = os.path.join(current_directory, 'brushnet', 'brushnet_xl.json')
powerpaint_config_file = os.path.join(current_directory,'brushnet', 'powerpaint.json')
sd15_scaling_factor = 0.18215
sdxl_scaling_factor = 0.13025
ModelsToUnload = [comfy.sd1_clip.SD1ClipModel,
comfy.ldm.models.autoencoder.AutoencoderKL
]
class BrushNetLoader:
@classmethod
def INPUT_TYPES(self):
self.inpaint_files = get_files_with_extension('inpaint')
return {"required":
{
"brushnet": ([file for file in self.inpaint_files], ),
"dtype": (['float16', 'bfloat16', 'float32', 'float64'], ),
},
}
CATEGORY = "inpaint"
RETURN_TYPES = ("BRMODEL",)
RETURN_NAMES = ("brushnet",)
FUNCTION = "brushnet_loading"
def brushnet_loading(self, brushnet, dtype):
brushnet_file = os.path.join(self.inpaint_files[brushnet], brushnet)
is_SDXL = False
is_PP = False
sd = comfy.utils.load_torch_file(brushnet_file)
brushnet_down_block, brushnet_mid_block, brushnet_up_block, keys = brushnet_blocks(sd)
del sd
if brushnet_down_block == 24 and brushnet_mid_block == 2 and brushnet_up_block == 30:
is_SDXL = False
if keys == 322:
is_PP = False
print('BrushNet model type: SD1.5')
else:
is_PP = True
print('PowerPaint model type: SD1.5')
elif brushnet_down_block == 18 and brushnet_mid_block == 2 and brushnet_up_block == 22:
print('BrushNet model type: Loading SDXL')
is_SDXL = True
is_PP = False
else:
raise Exception("Unknown BrushNet model")
with init_empty_weights():
if is_SDXL:
brushnet_config = BrushNetModel.load_config(brushnet_xl_config_file)
brushnet_model = BrushNetModel.from_config(brushnet_config)
elif is_PP:
brushnet_config = PowerPaintModel.load_config(powerpaint_config_file)
brushnet_model = PowerPaintModel.from_config(brushnet_config)
else:
brushnet_config = BrushNetModel.load_config(brushnet_config_file)
brushnet_model = BrushNetModel.from_config(brushnet_config)
if is_PP:
print("PowerPaint model file:", brushnet_file)
else:
print("BrushNet model file:", brushnet_file)
if dtype == 'float16':
torch_dtype = torch.float16
elif dtype == 'bfloat16':
torch_dtype = torch.bfloat16
elif dtype == 'float32':
torch_dtype = torch.float32
else:
torch_dtype = torch.float64
brushnet_model = load_checkpoint_and_dispatch(
brushnet_model,
brushnet_file,
device_map="sequential",
max_memory=None,
offload_folder=None,
offload_state_dict=False,
dtype=torch_dtype,
force_hooks=False,
)
if is_PP:
print("PowerPaint model is loaded")
elif is_SDXL:
print("BrushNet SDXL model is loaded")
else:
print("BrushNet SD1.5 model is loaded")
return ({"brushnet": brushnet_model, "SDXL": is_SDXL, "PP": is_PP, "dtype": torch_dtype}, )
class PowerPaintCLIPLoader:
@classmethod
def INPUT_TYPES(self):
self.inpaint_files = get_files_with_extension('inpaint', ['.bin'])
self.clip_files = get_files_with_extension('clip')
return {"required":
{
"base": ([file for file in self.clip_files], ),
"powerpaint": ([file for file in self.inpaint_files], ),
},
}
CATEGORY = "inpaint"
RETURN_TYPES = ("CLIP",)
RETURN_NAMES = ("clip",)
FUNCTION = "ppclip_loading"
def ppclip_loading(self, base, powerpaint):
base_CLIP_file = os.path.join(self.clip_files[base], base)
pp_CLIP_file = os.path.join(self.inpaint_files[powerpaint], powerpaint)
pp_clip = comfy.sd.load_clip(ckpt_paths=[base_CLIP_file])
print('PowerPaint base CLIP file: ', base_CLIP_file)
pp_tokenizer = TokenizerWrapper(pp_clip.tokenizer.clip_l.tokenizer)
pp_text_encoder = pp_clip.patcher.model.clip_l.transformer
add_tokens(
tokenizer = pp_tokenizer,
text_encoder = pp_text_encoder,
placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"],
initialize_tokens = ["a", "a", "a"],
num_vectors_per_token = 10,
)
pp_text_encoder.load_state_dict(comfy.utils.load_torch_file(pp_CLIP_file), strict=False)
print('PowerPaint CLIP file: ', pp_CLIP_file)
pp_clip.tokenizer.clip_l.tokenizer = pp_tokenizer
pp_clip.patcher.model.clip_l.transformer = pp_text_encoder
return (pp_clip,)
class PowerPaint:
@classmethod
def INPUT_TYPES(s):
return {"required":
{
"model": ("MODEL",),
"vae": ("VAE", ),
"image": ("IMAGE",),
"mask": ("MASK",),
"powerpaint": ("BRMODEL", ),
"clip": ("CLIP", ),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"fitting" : ("FLOAT", {"default": 1.0, "min": 0.3, "max": 1.0}),
"function": (['text guided', 'shape guided', 'object removal', 'context aware', 'image outpainting'], ),
"scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
"start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
"end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
"save_memory": (['none', 'auto', 'max'], ),
},
}
CATEGORY = "inpaint"
RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",)
RETURN_NAMES = ("model","positive","negative","latent",)
FUNCTION = "model_update"
def model_update(self, model, vae, image, mask, powerpaint, clip, positive, negative, fitting, function, scale, start_at, end_at, save_memory):
is_SDXL, is_PP = check_compatibilty(model, powerpaint)
if not is_PP:
raise Exception("BrushNet model was loaded, please use BrushNet node")
# Make a copy of the model so that we're not patching it everywhere in the workflow.
model = model.clone()
# prepare image and mask
# no batches for original image and mask
masked_image, mask = prepare_image(image, mask)
batch = masked_image.shape[0]
#width = masked_image.shape[2]
#height = masked_image.shape[1]
if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'):
scaling_factor = model.model.model_config.latent_format.scale_factor
else:
scaling_factor = sd15_scaling_factor
torch_dtype = powerpaint['dtype']
# prepare conditioning latents
conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor)
conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
# prepare embeddings
if function == "object removal":
promptA = "P_ctxt"
promptB = "P_ctxt"
negative_promptA = "P_obj"
negative_promptB = "P_obj"
print('You should add to positive prompt: "empty scene blur"')
#positive = positive + " empty scene blur"
elif function == "context aware":
promptA = "P_ctxt"
promptB = "P_ctxt"
negative_promptA = ""
negative_promptB = ""
#positive = positive + " empty scene"
print('You should add to positive prompt: "empty scene"')
elif function == "shape guided":
promptA = "P_shape"
promptB = "P_ctxt"
negative_promptA = "P_shape"
negative_promptB = "P_ctxt"
elif function == "image outpainting":
promptA = "P_ctxt"
promptB = "P_ctxt"
negative_promptA = "P_obj"
negative_promptB = "P_obj"
#positive = positive + " empty scene"
print('You should add to positive prompt: "empty scene"')
else:
promptA = "P_obj"
promptB = "P_obj"
negative_promptA = "P_obj"
negative_promptB = "P_obj"
tokens = clip.tokenize(promptA)
prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
tokens = clip.tokenize(negative_promptA)
negative_prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
tokens = clip.tokenize(promptB)
prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
tokens = clip.tokenize(negative_promptB)
negative_prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
prompt_embeds_pp = (prompt_embedsA * fitting + (1.0 - fitting) * prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
negative_prompt_embeds_pp = (negative_prompt_embedsA * fitting + (1.0 - fitting) * negative_prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
# unload vae and CLIPs
del vae
del clip
for loaded_model in comfy.model_management.current_loaded_models:
if type(loaded_model.model.model) in ModelsToUnload:
comfy.model_management.current_loaded_models.remove(loaded_model)
loaded_model.model_unload()
del loaded_model
# apply patch to model
brushnet_conditioning_scale = scale
control_guidance_start = start_at
control_guidance_end = end_at
if save_memory != 'none':
powerpaint['brushnet'].set_attention_slice(save_memory)
add_brushnet_patch(model,
powerpaint['brushnet'],
torch_dtype,
conditioning_latents,
(brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
negative_prompt_embeds_pp, prompt_embeds_pp,
None, None, None,
False)
latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=powerpaint['brushnet'].device)
return (model, positive, negative, {"samples":latent},)
class BrushNet:
@classmethod
def INPUT_TYPES(s):
return {"required":
{
"model": ("MODEL",),
"vae": ("VAE", ),
"image": ("IMAGE",),
"mask": ("MASK",),
"brushnet": ("BRMODEL", ),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
"start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
"end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
},
}
CATEGORY = "inpaint"
RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",)
RETURN_NAMES = ("model","positive","negative","latent",)
FUNCTION = "model_update"
def model_update(self, model, vae, image, mask, brushnet, positive, negative, scale, start_at, end_at):
is_SDXL, is_PP = check_compatibilty(model, brushnet)
if is_PP:
raise Exception("PowerPaint model was loaded, please use PowerPaint node")
# Make a copy of the model so that we're not patching it everywhere in the workflow.
model = model.clone()
# prepare image and mask
# no batches for original image and mask
masked_image, mask = prepare_image(image, mask)
batch = masked_image.shape[0]
width = masked_image.shape[2]
height = masked_image.shape[1]
if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'):
scaling_factor = model.model.model_config.latent_format.scale_factor
elif is_SDXL:
scaling_factor = sdxl_scaling_factor
else:
scaling_factor = sd15_scaling_factor
torch_dtype = brushnet['dtype']
# prepare conditioning latents
conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor)
conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
# unload vae
del vae
for loaded_model in comfy.model_management.current_loaded_models:
if type(loaded_model.model.model) in ModelsToUnload:
comfy.model_management.current_loaded_models.remove(loaded_model)
loaded_model.model_unload()
del loaded_model
# prepare embeddings
prompt_embeds = positive[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
negative_prompt_embeds = negative[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
max_tokens = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
if prompt_embeds.shape[1] < max_tokens:
multiplier = max_tokens // 77 - prompt_embeds.shape[1] // 77
prompt_embeds = torch.concat([prompt_embeds] + [prompt_embeds[:,-77:,:]] * multiplier, dim=1)
print('BrushNet: negative prompt more than 75 tokens:', negative_prompt_embeds.shape, 'multiplying prompt_embeds')
if negative_prompt_embeds.shape[1] < max_tokens:
multiplier = max_tokens // 77 - negative_prompt_embeds.shape[1] // 77
negative_prompt_embeds = torch.concat([negative_prompt_embeds] + [negative_prompt_embeds[:,-77:,:]] * multiplier, dim=1)
print('BrushNet: positive prompt more than 75 tokens:', prompt_embeds.shape, 'multiplying negative_prompt_embeds')
if len(positive[0]) > 1 and 'pooled_output' in positive[0][1] and positive[0][1]['pooled_output'] is not None:
pooled_prompt_embeds = positive[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
else:
print('BrushNet: positive conditioning has not pooled_output')
if is_SDXL:
print('BrushNet will not produce correct results')
pooled_prompt_embeds = torch.empty([2, 1280], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
if len(negative[0]) > 1 and 'pooled_output' in negative[0][1] and negative[0][1]['pooled_output'] is not None:
negative_pooled_prompt_embeds = negative[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
else:
print('BrushNet: negative conditioning has not pooled_output')
if is_SDXL:
print('BrushNet will not produce correct results')
negative_pooled_prompt_embeds = torch.empty([1, pooled_prompt_embeds.shape[1]], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
time_ids = torch.FloatTensor([[height, width, 0., 0., height, width]]).to(dtype=torch_dtype).to(brushnet['brushnet'].device)
if not is_SDXL:
pooled_prompt_embeds = None
negative_pooled_prompt_embeds = None
time_ids = None
# apply patch to model
brushnet_conditioning_scale = scale
control_guidance_start = start_at
control_guidance_end = end_at
add_brushnet_patch(model,
brushnet['brushnet'],
torch_dtype,
conditioning_latents,
(brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
prompt_embeds, negative_prompt_embeds,
pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
False)
latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=brushnet['brushnet'].device)
return (model, positive, negative, {"samples":latent},)
class BlendInpaint:
@classmethod
def INPUT_TYPES(s):
return {"required":
{
"inpaint": ("IMAGE",),
"original": ("IMAGE",),
"mask": ("MASK",),
"kernel": ("INT", {"default": 10, "min": 1, "max": 1000}),
"sigma": ("FLOAT", {"default": 10.0, "min": 0.01, "max": 1000}),
},
"optional":
{
"origin": ("VECTOR",),
},
}
CATEGORY = "inpaint"
RETURN_TYPES = ("IMAGE","MASK",)
RETURN_NAMES = ("image","MASK",)
FUNCTION = "blend_inpaint"
def blend_inpaint(self, inpaint: torch.Tensor, original: torch.Tensor, mask, kernel: int, sigma:float, origin=None):
original, mask = check_image_mask(original, mask, 'Blend Inpaint')
if len(inpaint.shape) < 4:
inpaint = inpaint[None,:,:,:]
if inpaint.shape[0] < original.shape[0]:
original = original[:inpaint.shape[0],:,:]
mask = mask[:inpaint.shape[0],:,:]
if inpaint.shape[0] > original.shape[0]:
count = 0
original_list = []
mask_list = []
origin_list = []
while (count < inpaint.shape[0]):
for i in range(original.shape[0]):
original_list.append(original[i][None,:,:,:])
mask_list.append(mask[i][None,:,:])
if origin is not None:
origin_list.append(origin[i][None,:])
count += 1
if count >= inpaint.shape[0]:
break
original = torch.concat(original_list, dim=0)
mask = torch.concat(mask_list, dim=0)
if origin is not None:
origin = torch.concat(origin_list, dim=0)
if kernel % 2 == 0:
kernel += 1
transform = T.GaussianBlur(kernel_size=(kernel, kernel), sigma=(sigma, sigma))
ret = []
blurred = []
for i in range(inpaint.shape[0]):
height, width, _ = original[i].shape
if origin is not None:
x0, y0, cut_width, cut_height = origin[i]
else:
x0, y0 = 0, 0
cut_width, cut_height = width, height
# Ensure cut dimensions don't exceed original image dimensions
cut_width = min(cut_width, width - x0)
cut_height = min(cut_height, height - y0)
# Scale inpainted image to match the cut size
scaled_inpaint = F.interpolate(inpaint[i].permute(2, 0, 1).unsqueeze(0), size=(cut_height, cut_width), mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0)
# Create a mask for the inpainted region
inpaint_mask = torch.zeros((height, width), device=mask.device, dtype=mask.dtype)
inpaint_mask[y0:y0+cut_height, x0:x0+cut_width] = F.interpolate(mask[i][None, None, :, :], size=(cut_height, cut_width), mode='nearest').squeeze()
# Apply Gaussian blur to the inpaint mask
blurred_mask = transform(inpaint_mask.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
blurred.append(blurred_mask)
# Create the result by blending only the masked area
result = original[i].clone()
result[y0:y0+cut_height, x0:x0+cut_width] = (
original[i][y0:y0+cut_height, x0:x0+cut_width] * (1 - blurred_mask[y0:y0+cut_height, x0:x0+cut_width, None]) +
scaled_inpaint * blurred_mask[y0:y0+cut_height, x0:x0+cut_width, None]
)
ret.append(result)
return (torch.stack(ret), torch.stack(blurred))
def scale_mask_and_image(image, mask, side_margin, min_side, max_side):
min_area = min_side * min_side
max_area = max_side * max_side
h0, w0 = mask.shape
iy, ix = (mask == 1).nonzero(as_tuple=True)
if iy.numel() == 0:
x_c, y_c = w0 / 2.0, h0 / 2.0
mask_width, mask_height = 1, 1
else:
x_min, x_max = ix.min().item(), ix.max().item()
y_min, y_max = iy.min().item(), iy.max().item()
x_c, y_c = (x_min + x_max) / 2.0, (y_min + y_max) / 2.0
mask_width, mask_height = x_max - x_min + 1, y_max - y_min + 1
mask_aspect_ratio = mask_width / mask_height
if mask_aspect_ratio > 1:
new_mask_width = mask_width
new_mask_height = mask_width
else:
new_mask_height = mask_height
new_mask_width = mask_height
margin = side_margin/100.0
cut_width = int(new_mask_width * (1 + 2 * margin))
cut_height = int(new_mask_height * (1 + 2 * margin))
x0 = max(0, min(w0 - cut_width, int(x_c - cut_width / 2)))
y0 = max(0, min(h0 - cut_height, int(y_c - cut_height / 2)))
cut_width = min(cut_width, w0 - x0)
cut_height = min(cut_height, h0 - y0)
cut_image = image[y0:y0+cut_height, x0:x0+cut_width]
cut_mask = mask[y0:y0+cut_height, x0:x0+cut_width]
current_area = cut_width * cut_height
print(f"current_area: {current_area} min_area: {min_area} max_area: {max_area}")
if current_area > max_area or current_area <= min_area:
if current_area > max_area:
print("current_area > max_area")
scale_factor = math.sqrt(max_area / current_area)
elif current_area <= min_area:
print("current_area <= min_area")
scale_factor = math.sqrt(min_area / current_area)
new_width = int(cut_width * scale_factor)
new_height = int(cut_height * scale_factor)
scaled_image = F.interpolate(cut_image.permute(2, 0, 1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0)
scaled_mask = F.interpolate(cut_mask.unsqueeze(0).unsqueeze(0).float(), size=(new_height, new_width), mode='nearest').squeeze(0).squeeze(0)
return scaled_image, scaled_mask, (x0, y0, cut_width, cut_height)
else:
print("original size mask")
return cut_image, cut_mask, (x0, y0, cut_width, cut_height)
class CutForInpaint:
@classmethod
def INPUT_TYPES(s):
return {"required":
{
"image": ("IMAGE",),
"mask": ("MASK",),
"side_margin_percent": ("INT", {"default": 10, "min": 0, "max": 1000}),
"min_side": ("INT", {"default": 512, "min": 128, "max": 4096}),
"max_side": ("INT", {"default": 1536, "min": 128, "max": 4096})
},
}
CATEGORY = "inpaint"
RETURN_TYPES = ("IMAGE","MASK","VECTOR",)
RETURN_NAMES = ("image","mask","origin",)
FUNCTION = "cut_for_inpaint"
def cut_for_inpaint(self, image: torch.Tensor, mask: torch.Tensor, side_margin_percent: int, min_side: int, max_side: int):
ret = []
msk = []
org = []
for i in range(image.shape[0]):
cut_image, cut_mask, (x0, y0, cut_width, cut_height) = scale_mask_and_image(image[i], mask[i], side_margin_percent, min_side, max_side)
ret.append(cut_image)
msk.append(cut_mask)
org.append(torch.IntTensor([x0, y0, cut_width, cut_height]))
return (torch.stack(ret), torch.stack(msk), torch.stack(org))
#### Utility function
def get_files_with_extension(folder_name, extension=['.safetensors']):
try:
folders = folder_paths.get_folder_paths(folder_name)
except:
folders = []
if not folders:
folders = [os.path.join(folder_paths.models_dir, folder_name)]
if not os.path.isdir(folders[0]):
folders = [os.path.join(folder_paths.base_path, folder_name)]
if not os.path.isdir(folders[0]):
return {}
filtered_folders = []
for x in folders:
if not os.path.isdir(x):
continue
the_same = False
for y in filtered_folders:
if os.path.samefile(x, y):
the_same = True
break
if not the_same:
filtered_folders.append(x)
if not filtered_folders:
return {}
output = {}
for x in filtered_folders:
files, folders_all = folder_paths.recursive_search(x, excluded_dir_names=[".git"])
filtered_files = folder_paths.filter_files_extensions(files, extension)
for f in filtered_files:
output[f] = x
return output
# get blocks from state_dict so we could know which model it is
def brushnet_blocks(sd):
brushnet_down_block = 0
brushnet_mid_block = 0
brushnet_up_block = 0
for key in sd:
if 'brushnet_down_block' in key:
brushnet_down_block += 1
if 'brushnet_mid_block' in key:
brushnet_mid_block += 1
if 'brushnet_up_block' in key:
brushnet_up_block += 1
return (brushnet_down_block, brushnet_mid_block, brushnet_up_block, len(sd))
# Check models compatibility
def check_compatibilty(model, brushnet):
is_SDXL = False
is_PP = False
if isinstance(model.model.model_config, comfy.supported_models.SD15):
print('Base model type: SD1.5')
is_SDXL = False
if brushnet["SDXL"]:
raise Exception("Base model is SD15, but BrushNet is SDXL type")
if brushnet["PP"]:
is_PP = True
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
print('Base model type: SDXL')
is_SDXL = True
if not brushnet["SDXL"]:
raise Exception("Base model is SDXL, but BrushNet is SD15 type")
else:
print('Base model type: ', type(model.model.model_config))
raise Exception("Unsupported model type: " + str(type(model.model.model_config)))
return (is_SDXL, is_PP)
def check_image_mask(image, mask, name):
if len(image.shape) < 4:
# image tensor shape should be [B, H, W, C], but batch somehow is missing
image = image[None,:,:,:]
if len(mask.shape) > 3:
# mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be?
# take first mask, red channel
mask = (mask[:,:,:,0])[:,:,:]
elif len(mask.shape) < 3:
# mask tensor shape should be [B, H, W] but batch somehow is missing
mask = mask[None,:,:]
if image.shape[0] > mask.shape[0]:
print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0]))
if mask.shape[0] == 1:
print(name, "will copy the mask to fill batch")
mask = torch.cat([mask] * image.shape[0], dim=0)
else:
print(name, "will add empty masks to fill batch")
empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]])
mask = torch.cat([mask, empty_mask], dim=0)
elif image.shape[0] < mask.shape[0]:
print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0]))
mask = mask[:image.shape[0],:,:]
return (image, mask)
# Prepare image and mask
def prepare_image(image, mask):
image, mask = check_image_mask(image, mask, 'BrushNet')
print("BrushNet image.shape =", image.shape, "mask.shape =", mask.shape)
if mask.shape[2] != image.shape[2] or mask.shape[1] != image.shape[1]:
raise Exception("Image and mask should be the same size")
# As a suggestion of inferno46n2 (https://github.com/nullquant/ComfyUI-BrushNet/issues/64)
mask = mask.round()
masked_image = image * (1.0 - mask[:,:,:,None])
return (masked_image, mask)
# Prepare conditioning_latents
@torch.inference_mode()
def get_image_latents(masked_image, mask, vae, scaling_factor):
processed_image = masked_image.to(vae.device)
image_latents = vae.encode(processed_image[:,:,:,:3]) * scaling_factor
processed_mask = 1. - mask[:,None,:,:]
interpolated_mask = torch.nn.functional.interpolate(
processed_mask,
size=(
image_latents.shape[-2],
image_latents.shape[-1]
)
)
interpolated_mask = interpolated_mask.to(image_latents.device)
conditioning_latents = [image_latents, interpolated_mask]
print('BrushNet CL: image_latents shape =', image_latents.shape, 'interpolated_mask shape =', interpolated_mask.shape)
return conditioning_latents
# Main function where magic happens
@torch.inference_mode()
def brushnet_inference(x, timesteps, transformer_options, debug):
if 'model_patch' not in transformer_options:
print('BrushNet inference: there is no model_patch key in transformer_options')
return ([], 0, [])
mp = transformer_options['model_patch']
if 'brushnet' not in mp:
print('BrushNet inference: there is no brushnet key in mdel_patch')
return ([], 0, [])
bo = mp['brushnet']
if 'model' not in bo:
print('BrushNet inference: there is no model key in brushnet')
return ([], 0, [])
brushnet = bo['model']
if not (isinstance(brushnet, BrushNetModel) or isinstance(brushnet, PowerPaintModel)):
print('BrushNet model is not a BrushNetModel class')
return ([], 0, [])
torch_dtype = bo['dtype']
cl_list = bo['latents']
brushnet_conditioning_scale, control_guidance_start, control_guidance_end = bo['controls']
pe = bo['prompt_embeds']
npe = bo['negative_prompt_embeds']
ppe, nppe, time_ids = bo['add_embeds']
#do_classifier_free_guidance = mp['free_guidance']
do_classifier_free_guidance = len(transformer_options['cond_or_uncond']) > 1
x = x.detach().clone()
x = x.to(torch_dtype).to(brushnet.device)
timesteps = timesteps.detach().clone()
timesteps = timesteps.to(torch_dtype).to(brushnet.device)
total_steps = mp['total_steps']
step = mp['step']
added_cond_kwargs = {}
if do_classifier_free_guidance and step == 0:
print('BrushNet inference: do_classifier_free_guidance is True')
sub_idx = None
if 'ad_params' in transformer_options and 'sub_idxs' in transformer_options['ad_params']:
sub_idx = transformer_options['ad_params']['sub_idxs']
# we have batch input images
batch = cl_list[0].shape[0]
# we have incoming latents
latents_incoming = x.shape[0]
# and we already got some
latents_got = bo['latent_id']
if step == 0 or batch > 1:
print('BrushNet inference, step = %d: image batch = %d, got %d latents, starting from %d' \
% (step, batch, latents_incoming, latents_got))
image_latents = []
masks = []
prompt_embeds = []
negative_prompt_embeds = []
pooled_prompt_embeds = []
negative_pooled_prompt_embeds = []
if sub_idx:
# AnimateDiff indexes detected
if step == 0:
print('BrushNet inference: AnimateDiff indexes detected and applied')
batch = len(sub_idx)
if do_classifier_free_guidance:
for i in sub_idx:
image_latents.append(cl_list[0][i][None,:,:,:])
masks.append(cl_list[1][i][None,:,:,:])
prompt_embeds.append(pe)
negative_prompt_embeds.append(npe)
pooled_prompt_embeds.append(ppe)
negative_pooled_prompt_embeds.append(nppe)
for i in sub_idx:
image_latents.append(cl_list[0][i][None,:,:,:])
masks.append(cl_list[1][i][None,:,:,:])
else:
for i in sub_idx:
image_latents.append(cl_list[0][i][None,:,:,:])
masks.append(cl_list[1][i][None,:,:,:])
prompt_embeds.append(pe)
pooled_prompt_embeds.append(ppe)
else:
# do_classifier_free_guidance = 2 passes, 1st pass is cond, 2nd is uncond
continue_batch = True
for i in range(latents_incoming):
number = latents_got + i
if number < batch:
# 1st pass, cond
image_latents.append(cl_list[0][number][None,:,:,:])
masks.append(cl_list[1][number][None,:,:,:])
prompt_embeds.append(pe)
pooled_prompt_embeds.append(ppe)
elif do_classifier_free_guidance and number < batch * 2:
# 2nd pass, uncond
image_latents.append(cl_list[0][number-batch][None,:,:,:])
masks.append(cl_list[1][number-batch][None,:,:,:])
negative_prompt_embeds.append(npe)
negative_pooled_prompt_embeds.append(nppe)
else:
# latent batch
image_latents.append(cl_list[0][0][None,:,:,:])
masks.append(cl_list[1][0][None,:,:,:])
prompt_embeds.append(pe)
pooled_prompt_embeds.append(ppe)
latents_got = -i
continue_batch = False
if continue_batch:
# we don't have full batch yet
if do_classifier_free_guidance:
if number < batch * 2 - 1:
bo['latent_id'] = number + 1
else:
bo['latent_id'] = 0
else:
if number < batch - 1:
bo['latent_id'] = number + 1
else:
bo['latent_id'] = 0
else:
bo['latent_id'] = 0
cl = []
for il, m in zip(image_latents, masks):
cl.append(torch.concat([il, m], dim=1))
cl2apply = torch.concat(cl, dim=0)
conditioning_latents = cl2apply.to(torch_dtype).to(brushnet.device)
prompt_embeds.extend(negative_prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
if ppe is not None:
added_cond_kwargs = {}
added_cond_kwargs['time_ids'] = torch.concat([time_ids] * latents_incoming, dim = 0).to(torch_dtype).to(brushnet.device)
pooled_prompt_embeds.extend(negative_pooled_prompt_embeds)
pooled_prompt_embeds = torch.concat(pooled_prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
added_cond_kwargs['text_embeds'] = pooled_prompt_embeds
else:
added_cond_kwargs = None
if x.shape[2] != conditioning_latents.shape[2] or x.shape[3] != conditioning_latents.shape[3]:
if step == 0:
print('BrushNet inference: image', conditioning_latents.shape, 'and latent', x.shape, 'have different size, resizing image')
conditioning_latents = torch.nn.functional.interpolate(
conditioning_latents, size=(
x.shape[2],
x.shape[3],
), mode='bicubic',
).to(torch_dtype).to(brushnet.device)
if step == 0:
print('BrushNet inference: sample', x.shape, ', CL', conditioning_latents.shape, 'dtype', torch_dtype)
if debug: print('BrushNet: step =', step)
if step < control_guidance_start or step > control_guidance_end:
cond_scale = 0.0
else:
cond_scale = brushnet_conditioning_scale
return brushnet(x,
encoder_hidden_states=prompt_embeds,
brushnet_cond=conditioning_latents,
timestep = timesteps,
conditioning_scale=cond_scale,
guess_mode=False,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
debug=debug,
)
# This is main patch function
def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
controls,
prompt_embeds, negative_prompt_embeds,
pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
debug):
is_SDXL = isinstance(model.model.model_config, comfy.supported_models.SDXL)
if is_SDXL:
input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
[1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[4, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.attention.SpatialTransformer],
[6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[7, comfy.ldm.modules.attention.SpatialTransformer],
[8, comfy.ldm.modules.attention.SpatialTransformer]]
middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
output_blocks = [[0, comfy.ldm.modules.attention.SpatialTransformer],
[1, comfy.ldm.modules.attention.SpatialTransformer],
[2, comfy.ldm.modules.attention.SpatialTransformer],
[2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[3, comfy.ldm.modules.attention.SpatialTransformer],
[4, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[6, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[7, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[8, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
else:
input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
[1, comfy.ldm.modules.attention.SpatialTransformer],
[2, comfy.ldm.modules.attention.SpatialTransformer],
[3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[4, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.attention.SpatialTransformer],
[6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[7, comfy.ldm.modules.attention.SpatialTransformer],
[8, comfy.ldm.modules.attention.SpatialTransformer],
[9, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
[10, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[11, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
output_blocks = [[0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
[2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[3, comfy.ldm.modules.attention.SpatialTransformer],
[4, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.attention.SpatialTransformer],
[5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[6, comfy.ldm.modules.attention.SpatialTransformer],
[7, comfy.ldm.modules.attention.SpatialTransformer],
[8, comfy.ldm.modules.attention.SpatialTransformer],
[8, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
[9, comfy.ldm.modules.attention.SpatialTransformer],
[10, comfy.ldm.modules.attention.SpatialTransformer],
[11, comfy.ldm.modules.attention.SpatialTransformer]]
def last_layer_index(block, tp):
layer_list = []
for layer in block:
layer_list.append(type(layer))
layer_list.reverse()
if tp not in layer_list:
return -1, layer_list.reverse()
return len(layer_list) - 1 - layer_list.index(tp), layer_list
def brushnet_forward(model, x, timesteps, transformer_options, control):
if 'brushnet' not in transformer_options['model_patch']:
input_samples = []
mid_sample = 0
output_samples = []
else:
# brushnet inference
input_samples, mid_sample, output_samples = brushnet_inference(x, timesteps, transformer_options, debug)
# give additional samples to blocks
for i, tp in input_blocks:
idx, layer_list = last_layer_index(model.input_blocks[i], tp)
if idx < 0:
print("BrushNet can't find", tp, "layer in", i,"input block:", layer_list)
continue
model.input_blocks[i][idx].add_sample_after = input_samples.pop(0) if input_samples else 0
idx, layer_list = last_layer_index(model.middle_block, middle_block[1])
if idx < 0:
print("BrushNet can't find", middle_block[1], "layer in middle block", layer_list)
model.middle_block[idx].add_sample_after = mid_sample
for i, tp in output_blocks:
idx, layer_list = last_layer_index(model.output_blocks[i], tp)
if idx < 0:
print("BrushNet can't find", tp, "layer in", i,"outnput block:", layer_list)
continue
model.output_blocks[i][idx].add_sample_after = output_samples.pop(0) if output_samples else 0
patch_model_function_wrapper(model, brushnet_forward)
to = add_model_patch_option(model)
mp = to['model_patch']
if 'brushnet' not in mp:
mp['brushnet'] = {}
bo = mp['brushnet']
bo['model'] = brushnet
bo['dtype'] = torch_dtype
bo['latents'] = conditioning_latents
bo['controls'] = controls
bo['prompt_embeds'] = prompt_embeds
bo['negative_prompt_embeds'] = negative_prompt_embeds
bo['add_embeds'] = (pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids)
bo['latent_id'] = 0
# patch layers `forward` so we can apply brushnet
def forward_patched_by_brushnet(self, x, *args, **kwargs):
h = self.original_forward(x, *args, **kwargs)
if hasattr(self, 'add_sample_after') and type(self):
to_add = self.add_sample_after
if torch.is_tensor(to_add):
# interpolate due to RAUNet
if h.shape[2] != to_add.shape[2] or h.shape[3] != to_add.shape[3]:
to_add = torch.nn.functional.interpolate(to_add, size=(h.shape[2], h.shape[3]), mode='bicubic')
h += to_add.to(h.dtype).to(h.device)
else:
h += self.add_sample_after
self.add_sample_after = 0
return h
for i, block in enumerate(model.model.diffusion_model.input_blocks):
for j, layer in enumerate(block):
if not hasattr(layer, 'original_forward'):
layer.original_forward = layer.forward
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
layer.add_sample_after = 0
for j, layer in enumerate(model.model.diffusion_model.middle_block):
if not hasattr(layer, 'original_forward'):
layer.original_forward = layer.forward
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
layer.add_sample_after = 0
for i, block in enumerate(model.model.diffusion_model.output_blocks):
for j, layer in enumerate(block):
if not hasattr(layer, 'original_forward'):
layer.original_forward = layer.forward
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
layer.add_sample_after = 0