File size: 4,087 Bytes
3d5837a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
store = {}

# ==================== Hook into sampling functions for ControlNet ====================

import comfy.samplers

def patch1(fn_name):
    def calc_cond_batch(*args, **kwargs):
        x_in = kwargs['x_in'] if 'x_in' in kwargs else args[2]
        model_options = kwargs['model_options'] if 'model_options' in kwargs else args[4]
        if not hasattr(x_in, 'model_options'):
            x_in.model_options = model_options
        return store[fn_name](*args, **kwargs)
    return calc_cond_batch

def patch2(fn_name):
    def get_area_and_mult(*args, **kwargs):
        x_in = kwargs['x_in'] if 'x_in' in kwargs else args[1]
        conds = kwargs['conds'] if 'conds' in kwargs else args[0]
        if (model_options:=getattr(x_in, 'model_options', None)) is not None and 'tiled_diffusion' in model_options:
            if 'control' in conds:
                control = conds['control']
                if not hasattr(control, 'get_control_orig'):
                    control.get_control_orig = control.get_control
                control.get_control = lambda *a, **kw: control
        else:
            if 'control' in conds:
                control = conds['control']
                if hasattr(control, 'get_control_orig') and control.get_control != control.get_control_orig:
                    control.get_control = control.get_control_orig
        return store[fn_name](*args, **kwargs)
    return get_area_and_mult

patches = [
    (comfy.samplers, 'calc_cond_batch', patch1),
    (comfy.samplers, 'get_area_and_mult', patch2),
]
for parent, fn_name, create_patch in patches:
    store[fn_name] = getattr(parent, fn_name)
    setattr(parent, fn_name, create_patch(fn_name))

# ==================== Patch pre_run_control ====================

# Is this necessary anymore?
def pre_run_control(model, conds):
    s = model.model_sampling
    for t in range(len(conds)):
        x = conds[t]

        timestep_start = None
        timestep_end = None
        percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
        if 'control' in x:
            try: x['control'].cleanup()
            except Exception: ...
            x['control'].pre_run(model, percent_to_timestep_function)
comfy.samplers.pre_run_control = pre_run_control

# ==================== Patch SAG ====================

import math
import torch.nn.functional as F
import comfy_extras.nodes_sag
from comfy_extras.nodes_sag import gaussian_blur_2d 
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
    # reshape and GAP the attention map
    _, hw1, hw2 = attn.shape
    b, _, lh, lw = x0.shape
    attn = attn.reshape(b, -1, hw1, hw2)
    # Global Average Pool
    mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
    def calc_closest_factors(a):
        for b in range(int(math.sqrt(a)), 0, -1):
            if a % b == 0:
                c = a // b
                return (b,c)
    m = calc_closest_factors(hw1)
    mh = max(m) if lh > lw else min(m)
    mw = m[1] if mh == m[0] else m[0]
    mid_shape = mh, mw

    # Reshape
    mask = (
        mask.reshape(b, *mid_shape)
        .unsqueeze(1)
        .type(attn.dtype)
    )
    # Upsample
    mask = F.interpolate(mask, (lh, lw))

    blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
    blurred = blurred * mask + x0 * (1 - mask)
    return blurred
comfy_extras.nodes_sag.create_blur_map = create_blur_map

# ==================== Patch Gligen ====================

def _set_position(self, boxes, masks, positive_embeddings):
    objs = self.position_net(boxes, masks, positive_embeddings)
    def func(x, extra_options):
        key = extra_options["transformer_index"]
        module = self.module_list[key]
        nonlocal objs
        _objs = objs.repeat(-(x.shape[0] // -objs.shape[0]),1,1) if x.shape[0] > objs.shape[0] else objs
        return module(x, _objs.to(device=x.device, dtype=x.dtype))
    return func

import comfy.gligen
comfy.gligen.Gligen._set_position = _set_position