flatcherlee's picture
Upload 2334 files
3d5837a verified
raw
history blame
4.09 kB
store = {}
# ==================== Hook into sampling functions for ControlNet ====================
import comfy.samplers
def patch1(fn_name):
def calc_cond_batch(*args, **kwargs):
x_in = kwargs['x_in'] if 'x_in' in kwargs else args[2]
model_options = kwargs['model_options'] if 'model_options' in kwargs else args[4]
if not hasattr(x_in, 'model_options'):
x_in.model_options = model_options
return store[fn_name](*args, **kwargs)
return calc_cond_batch
def patch2(fn_name):
def get_area_and_mult(*args, **kwargs):
x_in = kwargs['x_in'] if 'x_in' in kwargs else args[1]
conds = kwargs['conds'] if 'conds' in kwargs else args[0]
if (model_options:=getattr(x_in, 'model_options', None)) is not None and 'tiled_diffusion' in model_options:
if 'control' in conds:
control = conds['control']
if not hasattr(control, 'get_control_orig'):
control.get_control_orig = control.get_control
control.get_control = lambda *a, **kw: control
else:
if 'control' in conds:
control = conds['control']
if hasattr(control, 'get_control_orig') and control.get_control != control.get_control_orig:
control.get_control = control.get_control_orig
return store[fn_name](*args, **kwargs)
return get_area_and_mult
patches = [
(comfy.samplers, 'calc_cond_batch', patch1),
(comfy.samplers, 'get_area_and_mult', patch2),
]
for parent, fn_name, create_patch in patches:
store[fn_name] = getattr(parent, fn_name)
setattr(parent, fn_name, create_patch(fn_name))
# ==================== Patch pre_run_control ====================
# Is this necessary anymore?
def pre_run_control(model, conds):
s = model.model_sampling
for t in range(len(conds)):
x = conds[t]
timestep_start = None
timestep_end = None
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x:
try: x['control'].cleanup()
except Exception: ...
x['control'].pre_run(model, percent_to_timestep_function)
comfy.samplers.pre_run_control = pre_run_control
# ==================== Patch SAG ====================
import math
import torch.nn.functional as F
import comfy_extras.nodes_sag
from comfy_extras.nodes_sag import gaussian_blur_2d
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
# reshape and GAP the attention map
_, hw1, hw2 = attn.shape
b, _, lh, lw = x0.shape
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
def calc_closest_factors(a):
for b in range(int(math.sqrt(a)), 0, -1):
if a % b == 0:
c = a // b
return (b,c)
m = calc_closest_factors(hw1)
mh = max(m) if lh > lw else min(m)
mw = m[1] if mh == m[0] else m[0]
mid_shape = mh, mw
# Reshape
mask = (
mask.reshape(b, *mid_shape)
.unsqueeze(1)
.type(attn.dtype)
)
# Upsample
mask = F.interpolate(mask, (lh, lw))
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
blurred = blurred * mask + x0 * (1 - mask)
return blurred
comfy_extras.nodes_sag.create_blur_map = create_blur_map
# ==================== Patch Gligen ====================
def _set_position(self, boxes, masks, positive_embeddings):
objs = self.position_net(boxes, masks, positive_embeddings)
def func(x, extra_options):
key = extra_options["transformer_index"]
module = self.module_list[key]
nonlocal objs
_objs = objs.repeat(-(x.shape[0] // -objs.shape[0]),1,1) if x.shape[0] > objs.shape[0] else objs
return module(x, _objs.to(device=x.device, dtype=x.dtype))
return func
import comfy.gligen
comfy.gligen.Gligen._set_position = _set_position