Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import einsum | |
import torch.nn.functional as F | |
import math | |
from einops import rearrange, repeat | |
import os | |
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION | |
import comfy.samplers | |
# from comfy/ldm/modules/attention.py | |
# but modified to return attention scores as well as output | |
def attention_basic_with_sim(q, k, v, heads, mask=None): | |
b, _, dim_head = q.shape | |
dim_head //= heads | |
scale = dim_head ** -0.5 | |
h = heads | |
q, k, v = map( | |
lambda t: t.unsqueeze(3) | |
.reshape(b, -1, heads, dim_head) | |
.permute(0, 2, 1, 3) | |
.reshape(b * heads, -1, dim_head) | |
.contiguous(), | |
(q, k, v), | |
) | |
# force cast to fp32 to avoid overflowing | |
if _ATTN_PRECISION =="fp32": | |
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale | |
else: | |
sim = einsum('b i d, b j d -> b i j', q, k) * scale | |
del q, k | |
if mask is not None: | |
mask = rearrange(mask, 'b ... -> b (...)') | |
max_neg_value = -torch.finfo(sim.dtype).max | |
mask = repeat(mask, 'b j -> (b h) () j', h=h) | |
sim.masked_fill_(~mask, max_neg_value) | |
# attention, what we cannot get enough of | |
sim = sim.softmax(dim=-1) | |
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) | |
out = ( | |
out.unsqueeze(0) | |
.reshape(b, heads, -1, dim_head) | |
.permute(0, 2, 1, 3) | |
.reshape(b, -1, heads * dim_head) | |
) | |
return (out, sim) | |
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): | |
# reshape and GAP the attention map | |
_, hw1, hw2 = attn.shape | |
b, _, lh, lw = x0.shape | |
attn = attn.reshape(b, -1, hw1, hw2) | |
# Global Average Pool | |
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold | |
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length() | |
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] | |
# Reshape | |
mask = ( | |
mask.reshape(b, *mid_shape) | |
.unsqueeze(1) | |
.type(attn.dtype) | |
) | |
# Upsample | |
mask = F.interpolate(mask, (lh, lw)) | |
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) | |
blurred = blurred * mask + x0 * (1 - mask) | |
return blurred | |
def gaussian_blur_2d(img, kernel_size, sigma): | |
ksize_half = (kernel_size - 1) * 0.5 | |
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) | |
pdf = torch.exp(-0.5 * (x / sigma).pow(2)) | |
x_kernel = pdf / pdf.sum() | |
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) | |
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) | |
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) | |
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] | |
img = F.pad(img, padding, mode="reflect") | |
img = F.conv2d(img, kernel2d, groups=img.shape[-3]) | |
return img | |
def get_denoised_ranges(latent, measure="hard", top_k=0.25): | |
chans = [] | |
for x in range(len(latent)): | |
max_values = torch.topk(latent[x] - latent[x].mean() if measure == "range" else latent[x], k=int(len(latent[x])*top_k), largest=True).values | |
min_values = torch.topk(latent[x] - latent[x].mean() if measure == "range" else latent[x], k=int(len(latent[x])*top_k), largest=False).values | |
max_val = torch.mean(max_values).item() | |
min_val = torch.mean(torch.abs(min_values)).item() if (measure == "hard" or measure == "range") else abs(torch.mean(min_values).item()) | |
denoised_range = (max_val + min_val) / 2 | |
chans.append(denoised_range) | |
return chans | |
class SelfAttentionGuidanceCustom: | |
def INPUT_TYPES(s): | |
return {"required": { "model": ("MODEL",), | |
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 100.0, "step": 0.1}), | |
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), | |
"sigma_start": ("FLOAT", {"default": 15.0, "min": 0.0, "max": 1000.0, "step": 0.1, "round": 0.1}), | |
"sigma_end": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1, "round": 0.1}), | |
"auto_scale" : ("BOOLEAN", {"default": False}), | |
}} | |
RETURN_TYPES = ("MODEL",) | |
FUNCTION = "patch" | |
CATEGORY = "model_patches" | |
def patch(self, model, scale, blur_sigma, sigma_start, sigma_end, auto_scale): | |
m = model.clone() | |
attn_scores = None | |
# TODO: make this work properly with chunked batches | |
# currently, we can only save the attn from one UNet call | |
def attn_and_record(q, k, v, extra_options): | |
nonlocal attn_scores | |
# if uncond, save the attention scores | |
heads = extra_options["n_heads"] | |
cond_or_uncond = extra_options["cond_or_uncond"] | |
b = q.shape[0] // len(cond_or_uncond) | |
if 1 in cond_or_uncond: | |
uncond_index = cond_or_uncond.index(1) | |
# do the entire attention operation, but save the attention scores to attn_scores | |
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads) | |
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn] | |
n_slices = heads * b | |
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)] | |
return out | |
else: | |
return optimized_attention(q, k, v, heads=heads) | |
def post_cfg_function(args): | |
nonlocal attn_scores | |
uncond_attn = attn_scores | |
sag_scale = scale | |
sag_sigma = blur_sigma | |
sag_threshold = 1.0 | |
model = args["model"] | |
uncond_pred = args["uncond_denoised"] | |
uncond = args["uncond"] | |
cfg_result = args["denoised"] | |
sigma = args["sigma"] | |
model_options = args["model_options"] | |
x = args["input"] | |
if uncond_pred is None or uncond is None or uncond_attn is None: | |
return cfg_result | |
if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding | |
return cfg_result | |
if sigma[0] > sigma_start or sigma[0] < sigma_end: | |
return cfg_result | |
# create the adversarially blurred image | |
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) | |
degraded_noised = degraded + x - uncond_pred | |
# call into the UNet | |
(sag, _) = comfy.samplers.calc_cond_batch(model, [uncond, None], degraded_noised, sigma, model_options) | |
# comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) | |
if auto_scale: | |
denoised_tmp = cfg_result + (degraded - sag) * 8 | |
for b in range(len(denoised_tmp)): | |
denoised_ranges = get_denoised_ranges(denoised_tmp[b]) | |
for c in range(len(denoised_tmp[b])): | |
fixed_scale = (sag_scale / 10) / denoised_ranges[c] | |
denoised_tmp[b][c] = cfg_result[b][c] + (degraded[b][c] - sag[b][c]) * fixed_scale | |
return denoised_tmp | |
return cfg_result + (degraded - sag) * sag_scale | |
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=False) | |
# from diffusers: | |
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch | |
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0) | |
return (m, ) |