Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
from einops import rearrange | |
from .masactrl_utils import AttentionBase | |
from torchvision.utils import save_image | |
class MutualSelfAttentionControl(AttentionBase): | |
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50): | |
""" | |
Mutual self-attention control for Stable-Diffusion model | |
Args: | |
start_step: the step to start mutual self-attention control | |
start_layer: the layer to start mutual self-attention control | |
layer_idx: list of the layers to apply mutual self-attention control | |
step_idx: list the steps to apply mutual self-attention control | |
total_steps: the total number of steps | |
""" | |
super().__init__() | |
self.total_steps = total_steps | |
self.start_step = start_step | |
self.start_layer = start_layer | |
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, 16)) | |
self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps)) | |
print("step_idx: ", self.step_idx) | |
print("layer_idx: ", self.layer_idx) | |
def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): | |
b = q.shape[0] // num_heads | |
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) | |
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads) | |
v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads) | |
sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale") | |
attn = sim.softmax(-1) | |
out = torch.einsum("h i j, h j d -> h i d", attn, v) | |
out = rearrange(out, "h (b n) d -> b n (h d)", b=b) | |
return out | |
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): | |
""" | |
Attention forward function | |
""" | |
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: | |
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) | |
qu, qc = q.chunk(2) | |
ku, kc = k.chunk(2) | |
vu, vc = v.chunk(2) | |
attnu, attnc = attn.chunk(2) | |
out_u = self.attn_batch(qu, ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs) | |
out_c = self.attn_batch(qc, kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs) | |
out = torch.cat([out_u, out_c], dim=0) | |
return out | |
class MutualSelfAttentionControlMask(MutualSelfAttentionControl): | |
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, mask_s=None, mask_t=None, mask_save_dir=None): | |
""" | |
Maske-guided MasaCtrl to alleviate the problem of fore- and background confusion | |
Args: | |
start_step: the step to start mutual self-attention control | |
start_layer: the layer to start mutual self-attention control | |
layer_idx: list of the layers to apply mutual self-attention control | |
step_idx: list the steps to apply mutual self-attention control | |
total_steps: the total number of steps | |
mask_s: source mask with shape (h, w) | |
mask_t: target mask with same shape as source mask | |
""" | |
super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps) | |
self.mask_s = mask_s # source mask with shape (h, w) | |
self.mask_t = mask_t # target mask with same shape as source mask | |
print("Using mask-guided MasaCtrl") | |
if mask_save_dir is not None: | |
os.makedirs(mask_save_dir, exist_ok=True) | |
save_image(self.mask_s.unsqueeze(0).unsqueeze(0), os.path.join(mask_save_dir, "mask_s.png")) | |
save_image(self.mask_t.unsqueeze(0).unsqueeze(0), os.path.join(mask_save_dir, "mask_t.png")) | |
def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): | |
B = q.shape[0] // num_heads | |
H = W = int(np.sqrt(q.shape[1])) | |
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) | |
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads) | |
v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads) | |
sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale") | |
if kwargs.get("is_mask_attn") and self.mask_s is not None: | |
print("masked attention") | |
mask = self.mask_s.unsqueeze(0).unsqueeze(0) | |
mask = F.interpolate(mask, (H, W)).flatten(0).unsqueeze(0) | |
mask = mask.flatten() | |
# background | |
sim_bg = sim + mask.masked_fill(mask == 1, torch.finfo(sim.dtype).min) | |
# object | |
sim_fg = sim + mask.masked_fill(mask == 0, torch.finfo(sim.dtype).min) | |
sim = torch.cat([sim_fg, sim_bg], dim=0) | |
attn = sim.softmax(-1) | |
if len(attn) == 2 * len(v): | |
v = torch.cat([v] * 2) | |
out = torch.einsum("h i j, h j d -> h i d", attn, v) | |
out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads) | |
return out | |
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): | |
""" | |
Attention forward function | |
""" | |
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: | |
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) | |
B = q.shape[0] // num_heads // 2 | |
H = W = int(np.sqrt(q.shape[1])) | |
qu, qc = q.chunk(2) | |
ku, kc = k.chunk(2) | |
vu, vc = v.chunk(2) | |
attnu, attnc = attn.chunk(2) | |
out_u_source = self.attn_batch(qu[:num_heads], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs) | |
out_c_source = self.attn_batch(qc[:num_heads], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs) | |
out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, is_mask_attn=True, **kwargs) | |
out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, is_mask_attn=True, **kwargs) | |
if self.mask_s is not None and self.mask_t is not None: | |
out_u_target_fg, out_u_target_bg = out_u_target.chunk(2, 0) | |
out_c_target_fg, out_c_target_bg = out_c_target.chunk(2, 0) | |
mask = F.interpolate(self.mask_t.unsqueeze(0).unsqueeze(0), (H, W)) | |
mask = mask.reshape(-1, 1) # (hw, 1) | |
out_u_target = out_u_target_fg * mask + out_u_target_bg * (1 - mask) | |
out_c_target = out_c_target_fg * mask + out_c_target_bg * (1 - mask) | |
out = torch.cat([out_u_source, out_u_target, out_c_source, out_c_target], dim=0) | |
return out | |
class MutualSelfAttentionControlMaskAuto(MutualSelfAttentionControl): | |
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, thres=0.1, ref_token_idx=[1], cur_token_idx=[1], mask_save_dir=None): | |
""" | |
MasaCtrl with mask auto generation from cross-attention map | |
Args: | |
start_step: the step to start mutual self-attention control | |
start_layer: the layer to start mutual self-attention control | |
layer_idx: list of the layers to apply mutual self-attention control | |
step_idx: list the steps to apply mutual self-attention control | |
total_steps: the total number of steps | |
thres: the thereshold for mask thresholding | |
ref_token_idx: the token index list for cross-attention map aggregation | |
cur_token_idx: the token index list for cross-attention map aggregation | |
mask_save_dir: the path to save the mask image | |
""" | |
super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps) | |
print("using MutualSelfAttentionControlMaskAuto") | |
self.thres = thres | |
self.ref_token_idx = ref_token_idx | |
self.cur_token_idx = cur_token_idx | |
self.self_attns = [] | |
self.cross_attns = [] | |
self.cross_attns_mask = None | |
self.self_attns_mask = None | |
self.mask_save_dir = mask_save_dir | |
if self.mask_save_dir is not None: | |
os.makedirs(self.mask_save_dir, exist_ok=True) | |
def after_step(self): | |
self.self_attns = [] | |
self.cross_attns = [] | |
def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): | |
B = q.shape[0] // num_heads | |
H = W = int(np.sqrt(q.shape[1])) | |
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) | |
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads) | |
v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads) | |
sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale") | |
if self.self_attns_mask is not None: | |
# binarize the mask | |
mask = self.self_attns_mask | |
thres = self.thres | |
mask[mask >= thres] = 1 | |
mask[mask < thres] = 0 | |
sim_fg = sim + mask.masked_fill(mask == 0, torch.finfo(sim.dtype).min) | |
sim_bg = sim + mask.masked_fill(mask == 1, torch.finfo(sim.dtype).min) | |
sim = torch.cat([sim_fg, sim_bg]) | |
attn = sim.softmax(-1) | |
if len(attn) == 2 * len(v): | |
v = torch.cat([v] * 2) | |
out = torch.einsum("h i j, h j d -> h i d", attn, v) | |
out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads) | |
return out | |
def aggregate_cross_attn_map(self, idx): | |
attn_map = torch.stack(self.cross_attns, dim=1).mean(1) # (B, N, dim) | |
B = attn_map.shape[0] | |
res = int(np.sqrt(attn_map.shape[-2])) | |
attn_map = attn_map.reshape(-1, res, res, attn_map.shape[-1]) | |
image = attn_map[..., idx] | |
if isinstance(idx, list): | |
image = image.sum(-1) | |
image_min = image.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0] | |
image_max = image.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0] | |
image = (image - image_min) / (image_max - image_min) | |
return image | |
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs): | |
""" | |
Attention forward function | |
""" | |
if is_cross: | |
# save cross attention map with res 16 * 16 | |
if attn.shape[1] == 16 * 16: | |
self.cross_attns.append(attn.reshape(-1, num_heads, *attn.shape[-2:]).mean(1)) | |
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: | |
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) | |
B = q.shape[0] // num_heads // 2 | |
H = W = int(np.sqrt(q.shape[1])) | |
qu, qc = q.chunk(2) | |
ku, kc = k.chunk(2) | |
vu, vc = v.chunk(2) | |
attnu, attnc = attn.chunk(2) | |
out_u_source = self.attn_batch(qu[:num_heads], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs) | |
out_c_source = self.attn_batch(qc[:num_heads], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs) | |
if len(self.cross_attns) == 0: | |
self.self_attns_mask = None | |
out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs) | |
out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs) | |
else: | |
mask = self.aggregate_cross_attn_map(idx=self.ref_token_idx) # (2, H, W) | |
mask_source = mask[-2] # (H, W) | |
res = int(np.sqrt(q.shape[1])) | |
self.self_attns_mask = F.interpolate(mask_source.unsqueeze(0).unsqueeze(0), (res, res)).flatten() | |
if self.mask_save_dir is not None: | |
H = W = int(np.sqrt(self.self_attns_mask.shape[0])) | |
mask_image = self.self_attns_mask.reshape(H, W).unsqueeze(0) | |
save_image(mask_image, os.path.join(self.mask_save_dir, f"mask_s_{self.cur_step}_{self.cur_att_layer}.png")) | |
out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs) | |
out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs) | |
if self.self_attns_mask is not None: | |
mask = self.aggregate_cross_attn_map(idx=self.cur_token_idx) # (2, H, W) | |
mask_target = mask[-1] # (H, W) | |
res = int(np.sqrt(q.shape[1])) | |
spatial_mask = F.interpolate(mask_target.unsqueeze(0).unsqueeze(0), (res, res)).reshape(-1, 1) | |
if self.mask_save_dir is not None: | |
H = W = int(np.sqrt(spatial_mask.shape[0])) | |
mask_image = spatial_mask.reshape(H, W).unsqueeze(0) | |
save_image(mask_image, os.path.join(self.mask_save_dir, f"mask_t_{self.cur_step}_{self.cur_att_layer}.png")) | |
# binarize the mask | |
thres = self.thres | |
spatial_mask[spatial_mask >= thres] = 1 | |
spatial_mask[spatial_mask < thres] = 0 | |
out_u_target_fg, out_u_target_bg = out_u_target.chunk(2) | |
out_c_target_fg, out_c_target_bg = out_c_target.chunk(2) | |
out_u_target = out_u_target_fg * spatial_mask + out_u_target_bg * (1 - spatial_mask) | |
out_c_target = out_c_target_fg * spatial_mask + out_c_target_bg * (1 - spatial_mask) | |
# set self self-attention mask to None | |
self.self_attns_mask = None | |
out = torch.cat([out_u_source, out_u_target, out_c_source, out_c_target], dim=0) | |
return out | |