BrushNET / raunet_nodes.py
erikayurika's picture
Upload 94 files
ede32b0 verified
import torch.nn.functional as F
import comfy
from .model_patch import add_model_patch_option, patch_model_function_wrapper
class RAUNet:
@classmethod
def INPUT_TYPES(s):
return {"required":
{
"model": ("MODEL",),
"du_start": ("INT", {"default": 0, "min": 0, "max": 10000}),
"du_end": ("INT", {"default": 4, "min": 0, "max": 10000}),
"xa_start": ("INT", {"default": 4, "min": 0, "max": 10000}),
"xa_end": ("INT", {"default": 10, "min": 0, "max": 10000}),
},
}
CATEGORY = "inpaint"
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "model_update"
def model_update(self, model, du_start, du_end, xa_start, xa_end):
model = model.clone()
add_raunet_patch(model,
du_start,
du_end,
xa_start,
xa_end)
return (model,)
# This is main patch function
def add_raunet_patch(model, du_start, du_end, xa_start, xa_end):
def raunet_forward(model, x, timesteps, transformer_options, control):
if 'model_patch' not in transformer_options:
print("RAUNet: 'model_patch' not in transformer_options, skip")
return
mp = transformer_options['model_patch']
is_SDXL = mp['SDXL']
if is_SDXL and type(model.input_blocks[6][0]) != comfy.ldm.modules.diffusionmodules.openaimodel.Downsample:
print('RAUNet: model is SDXL, but input[6] != Downsample, skip')
return
if not is_SDXL and type(model.input_blocks[3][0]) != comfy.ldm.modules.diffusionmodules.openaimodel.Downsample:
print('RAUNet: model is not SDXL, but input[3] != Downsample, skip')
return
if 'raunet' not in mp:
print('RAUNet: "raunet" not in model_patch options, skip')
return
if is_SDXL:
block = model.input_blocks[6][0]
else:
block = model.input_blocks[3][0]
total_steps = mp['total_steps']
step = mp['step']
ro = mp['raunet']
du_start = ro['du_start']
du_end = ro['du_end']
if step >= du_start and step < du_end:
block.op.stride = (4, 4)
block.op.padding = (2, 2)
block.op.dilation = (2, 2)
else:
block.op.stride = (2, 2)
block.op.padding = (1, 1)
block.op.dilation = (1, 1)
patch_model_function_wrapper(model, raunet_forward)
model.set_model_input_block_patch(in_xattn_patch)
model.set_model_output_block_patch(out_xattn_patch)
to = add_model_patch_option(model)
mp = to['model_patch']
if 'raunet' not in mp:
mp['raunet'] = {}
ro = mp['raunet']
ro['du_start'] = du_start
ro['du_end'] = du_end
ro['xa_start'] = xa_start
ro['xa_end'] = xa_end
def in_xattn_patch(h, transformer_options):
# both SDXL and SD15 = (input,4)
if transformer_options["block"] != ("input", 4):
# wrong block
return h
if 'model_patch' not in transformer_options:
print("RAUNet (i-x-p): 'model_patch' not in transformer_options")
return h
mp = transformer_options['model_patch']
if 'raunet' not in mp:
print("RAUNet (i-x-p): 'raunet' not in model_patch options")
return h
step = mp['step']
ro = mp['raunet']
xa_start = ro['xa_start']
xa_end = ro['xa_end']
if step < xa_start or step >= xa_end:
return h
h = F.avg_pool2d(h, kernel_size=(2,2))
return h
def out_xattn_patch(h, hsp, transformer_options):
if 'model_patch' not in transformer_options:
print("RAUNet (o-x-p): 'model_patch' not in transformer_options")
return h, hsp
mp = transformer_options['model_patch']
if 'raunet' not in mp:
print("RAUNet (o-x-p): 'raunet' not in model_patch options")
return h
step = mp['step']
is_SDXL = mp['SDXL']
ro = mp['raunet']
xa_start = ro['xa_start']
xa_end = ro['xa_end']
if is_SDXL:
if transformer_options["block"] != ("output", 5):
# wrong block
return h, hsp
else:
if transformer_options["block"] != ("output", 8):
# wrong block
return h, hsp
if step < xa_start or step >= xa_end:
return h, hsp
#error in hidiffusion codebase, size * 2 for particular sizes only
#re_size = (int(h.shape[-2] * 2), int(h.shape[-1] * 2))
re_size = (hsp.shape[-2], hsp.shape[-1])
h = F.interpolate(h, size=re_size, mode='bicubic')
return h, hsp