Spaces:
Running
on
Zero
Running
on
Zero
store = {} | |
# ==================== Hook into sampling functions for ControlNet ==================== | |
import comfy.samplers | |
def patch1(fn_name): | |
def calc_cond_batch(*args, **kwargs): | |
x_in = kwargs['x_in'] if 'x_in' in kwargs else args[2] | |
model_options = kwargs['model_options'] if 'model_options' in kwargs else args[4] | |
if not hasattr(x_in, 'model_options'): | |
x_in.model_options = model_options | |
return store[fn_name](*args, **kwargs) | |
return calc_cond_batch | |
def patch2(fn_name): | |
def get_area_and_mult(*args, **kwargs): | |
x_in = kwargs['x_in'] if 'x_in' in kwargs else args[1] | |
conds = kwargs['conds'] if 'conds' in kwargs else args[0] | |
if (model_options:=getattr(x_in, 'model_options', None)) is not None and 'tiled_diffusion' in model_options: | |
if 'control' in conds: | |
control = conds['control'] | |
if not hasattr(control, 'get_control_orig'): | |
control.get_control_orig = control.get_control | |
control.get_control = lambda *a, **kw: control | |
else: | |
if 'control' in conds: | |
control = conds['control'] | |
if hasattr(control, 'get_control_orig') and control.get_control != control.get_control_orig: | |
control.get_control = control.get_control_orig | |
return store[fn_name](*args, **kwargs) | |
return get_area_and_mult | |
patches = [ | |
(comfy.samplers, 'calc_cond_batch', patch1), | |
(comfy.samplers, 'get_area_and_mult', patch2), | |
] | |
for parent, fn_name, create_patch in patches: | |
store[fn_name] = getattr(parent, fn_name) | |
setattr(parent, fn_name, create_patch(fn_name)) | |
# ==================== Patch pre_run_control ==================== | |
# Is this necessary anymore? | |
def pre_run_control(model, conds): | |
s = model.model_sampling | |
for t in range(len(conds)): | |
x = conds[t] | |
timestep_start = None | |
timestep_end = None | |
percent_to_timestep_function = lambda a: s.percent_to_sigma(a) | |
if 'control' in x: | |
try: x['control'].cleanup() | |
except Exception: ... | |
x['control'].pre_run(model, percent_to_timestep_function) | |
comfy.samplers.pre_run_control = pre_run_control | |
# ==================== Patch SAG ==================== | |
import math | |
import torch.nn.functional as F | |
import comfy_extras.nodes_sag | |
from comfy_extras.nodes_sag import gaussian_blur_2d | |
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): | |
# reshape and GAP the attention map | |
_, hw1, hw2 = attn.shape | |
b, _, lh, lw = x0.shape | |
attn = attn.reshape(b, -1, hw1, hw2) | |
# Global Average Pool | |
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold | |
def calc_closest_factors(a): | |
for b in range(int(math.sqrt(a)), 0, -1): | |
if a % b == 0: | |
c = a // b | |
return (b,c) | |
m = calc_closest_factors(hw1) | |
mh = max(m) if lh > lw else min(m) | |
mw = m[1] if mh == m[0] else m[0] | |
mid_shape = mh, mw | |
# Reshape | |
mask = ( | |
mask.reshape(b, *mid_shape) | |
.unsqueeze(1) | |
.type(attn.dtype) | |
) | |
# Upsample | |
mask = F.interpolate(mask, (lh, lw)) | |
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) | |
blurred = blurred * mask + x0 * (1 - mask) | |
return blurred | |
comfy_extras.nodes_sag.create_blur_map = create_blur_map | |
# ==================== Patch Gligen ==================== | |
def _set_position(self, boxes, masks, positive_embeddings): | |
objs = self.position_net(boxes, masks, positive_embeddings) | |
def func(x, extra_options): | |
key = extra_options["transformer_index"] | |
module = self.module_list[key] | |
nonlocal objs | |
_objs = objs.repeat(-(x.shape[0] // -objs.shape[0]),1,1) if x.shape[0] > objs.shape[0] else objs | |
return module(x, _objs.to(device=x.device, dtype=x.dtype)) | |
return func | |
import comfy.gligen | |
comfy.gligen.Gligen._set_position = _set_position | |