Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import annotations | |
import os | |
import cv2 | |
import abc | |
from typing import Dict, List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from diffusers.models.attention import Attention | |
from PIL import Image | |
import random | |
import matplotlib.pyplot as plt | |
import pdb | |
import math | |
from PIL import Image | |
class P2PCrossAttnProcessor: | |
def __init__(self, controller, place_in_unet): | |
super().__init__() | |
self.controller = controller | |
self.place_in_unet = place_in_unet | |
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): | |
batch_size, sequence_length, _ = hidden_states.shape | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
query = attn.to_q(hidden_states) | |
is_cross = encoder_hidden_states is not None | |
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
query = attn.head_to_batch_dim(query) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
# one line change | |
self.controller(attention_probs, is_cross, self.place_in_unet) | |
hidden_states = torch.bmm(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
return hidden_states | |
def create_controller( | |
prompts: List[str], cross_attention_kwargs: Dict, num_inference_steps: int, tokenizer, device, attn_res | |
) -> AttentionControl: | |
edit_type = cross_attention_kwargs.get("edit_type", None) | |
local_blend_words = cross_attention_kwargs.get("local_blend_words", None) | |
equalizer_words = cross_attention_kwargs.get("equalizer_words", None) | |
equalizer_strengths = cross_attention_kwargs.get("equalizer_strengths", None) | |
n_cross_replace = cross_attention_kwargs.get("n_cross_replace", 0.4) | |
n_self_replace = cross_attention_kwargs.get("n_self_replace", 0.4) | |
if edit_type == 'visualize': | |
return AttentionStore(device=device) | |
# only replace | |
if edit_type == "replace" and local_blend_words is None: | |
return AttentionReplace( | |
prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device, attn_res=attn_res | |
) | |
# replace + localblend | |
if edit_type == "replace" and local_blend_words is not None: | |
lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device, attn_res=attn_res) | |
return AttentionReplace( | |
prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device, attn_res=attn_res | |
) | |
# only refine | |
if edit_type == "refine" and local_blend_words is None: | |
return AttentionRefine( | |
prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device, attn_res=attn_res | |
) | |
# refine + localblend | |
if edit_type == "refine" and local_blend_words is not None: | |
lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device, attn_res=attn_res) | |
return AttentionRefine( | |
prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device, attn_res=attn_res | |
) | |
# only reweight | |
if edit_type == "reweight" and local_blend_words is None: | |
assert ( | |
equalizer_words is not None and equalizer_strengths is not None | |
), "To use reweight edit, please specify equalizer_words and equalizer_strengths." | |
assert len(equalizer_words) == len( | |
equalizer_strengths | |
), "equalizer_words and equalizer_strengths must be of same length." | |
equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer) | |
return AttentionReweight( | |
prompts, | |
num_inference_steps, | |
n_cross_replace, | |
n_self_replace, | |
tokenizer=tokenizer, | |
device=device, | |
equalizer=equalizer, | |
attn_res=attn_res, | |
) | |
# reweight and localblend | |
if edit_type == "reweight" and local_blend_words: | |
assert ( | |
equalizer_words is not None and equalizer_strengths is not None | |
), "To use reweight edit, please specify equalizer_words and equalizer_strengths." | |
assert len(equalizer_words) == len( | |
equalizer_strengths | |
), "equalizer_words and equalizer_strengths must be of same length." | |
equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer) | |
lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device, attn_res=attn_res) | |
return AttentionReweight( | |
prompts, | |
num_inference_steps, | |
n_cross_replace, | |
n_self_replace, | |
tokenizer=tokenizer, | |
device=device, | |
equalizer=equalizer, | |
attn_res=attn_res, | |
local_blend=lb, | |
) | |
raise ValueError(f"Edit type {edit_type} not recognized. Use one of: replace, refine, reweight.") | |
class AttentionControl(abc.ABC): | |
def step_callback(self, x_t): | |
return x_t | |
def between_steps(self): | |
return | |
def num_uncond_att_layers(self): | |
return 0 | |
def forward(self, attn, is_cross: bool, place_in_unet: str): | |
raise NotImplementedError | |
def __call__(self, attn, is_cross: bool, place_in_unet: str): | |
if self.cur_att_layer >= self.num_uncond_att_layers: | |
h = attn.shape[0] | |
attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet) | |
self.cur_att_layer += 1 | |
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: | |
self.cur_att_layer = 0 | |
self.cur_step += 1 | |
self.between_steps() | |
return attn | |
def reset(self): | |
self.cur_step = 0 | |
self.cur_att_layer = 0 | |
def __init__(self, attn_res=None): | |
self.cur_step = 0 | |
self.num_att_layers = -1 | |
self.cur_att_layer = 0 | |
self.attn_res = attn_res | |
class EmptyControl(AttentionControl): | |
def forward(self, attn, is_cross: bool, place_in_unet: str): | |
return attn | |
class AttentionStore(AttentionControl): | |
def get_empty_store(): | |
return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} | |
def forward(self, attn, is_cross: bool, place_in_unet: str): | |
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" | |
if attn.shape[1] <= 32**2: # avoid memory overhead | |
if self.device.type != 'cuda': | |
attn = attn.cpu() | |
self.step_store[key].append(attn) | |
return attn | |
def between_steps(self): | |
if len(self.attention_store) == 0: | |
self.attention_store = self.step_store | |
else: | |
for key in self.attention_store: | |
for i in range(len(self.attention_store[key])): | |
self.attention_store[key][i] += self.step_store[key][i] | |
self.step_store = self.get_empty_store() | |
def get_average_attention(self): | |
average_attention = { | |
key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store | |
} | |
return average_attention | |
def reset(self): | |
super(AttentionStore, self).reset() | |
self.step_store = self.get_empty_store() | |
self.attention_store = {} | |
def __init__(self, attn_res=None, device='cuda'): | |
super(AttentionStore, self).__init__(attn_res) | |
self.step_store = self.get_empty_store() | |
self.attention_store = {} | |
self.device = device | |
class LocalBlend: | |
def __call__(self, x_t, attention_store): | |
# note that this code works on the latent level! | |
k = 1 | |
# maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] # These are the numbers because we want to take layers that are 256 x 256, I think this can be changed to something smarter...like, get all attentions where thesecond dim is self.attn_res[0] * self.attn_res[1] in up and down cross. | |
maps = [m for m in attention_store["down_cross"] + attention_store["mid_cross"] + attention_store["up_cross"] if m.shape[1] == self.attn_res[0] * self.attn_res[1]] | |
maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, self.attn_res[0], self.attn_res[1], self.max_num_words) for item in maps] | |
maps = torch.cat(maps, dim=1) | |
maps = (maps * self.alpha_layers).sum(-1).mean(1) # since alpha_layers is all 0s except where we edit, the product zeroes out all but what we change. Then, the sum adds the values of the original and what we edit. Then, we average across dim=1, which is the number of layers. | |
mask = F.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k)) | |
mask = F.interpolate(mask, size=(x_t.shape[2:])) | |
mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] | |
mask = mask.gt(self.threshold) | |
mask = mask[:1] + mask[1:] | |
mask = mask.to(torch.float16) | |
x_t = x_t[:1] + mask * (x_t - x_t[:1]) # x_t[:1] is the original image. mask*(x_t - x_t[:1]) zeroes out the original image and removes the difference between the original and each image we are generating (mostly just one). Then, it applies the mask on the image. That is, it's only keeping the cells we want to generate. | |
return x_t | |
def __init__( | |
self, prompts: List[str], words: [List[List[str]]], tokenizer, device, threshold=0.3, attn_res=None | |
): | |
self.max_num_words = 77 | |
self.attn_res = attn_res | |
alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words) | |
for i, (prompt, words_) in enumerate(zip(prompts, words)): | |
if isinstance(words_, str): | |
words_ = [words_] | |
for word in words_: | |
ind = get_word_inds(prompt, word, tokenizer) | |
alpha_layers[i, :, :, :, :, ind] = 1 | |
self.alpha_layers = alpha_layers.to(device) # a one-hot vector where the 1s are the words we modify (source and target) | |
self.threshold = threshold | |
class AttentionControlEdit(AttentionStore, abc.ABC): | |
def step_callback(self, x_t): | |
if self.local_blend is not None: | |
x_t = self.local_blend(x_t, self.attention_store) | |
return x_t | |
def replace_self_attention(self, attn_base, att_replace): | |
if att_replace.shape[2] <= self.attn_res[0]**2: | |
return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) | |
else: | |
return att_replace | |
def replace_cross_attention(self, attn_base, att_replace): | |
raise NotImplementedError | |
def forward(self, attn, is_cross: bool, place_in_unet: str): | |
super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) | |
if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): | |
h = attn.shape[0] // (self.batch_size) | |
attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) | |
attn_base, attn_replace = attn[0], attn[1:] | |
if is_cross: | |
alpha_words = self.cross_replace_alpha[self.cur_step] | |
attn_replace_new = ( | |
self.replace_cross_attention(attn_base, attn_replace) * alpha_words | |
+ (1 - alpha_words) * attn_replace | |
) | |
attn[1:] = attn_replace_new | |
else: | |
attn[1:] = self.replace_self_attention(attn_base, attn_replace) | |
attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) | |
return attn | |
def __init__( | |
self, | |
prompts, | |
num_steps: int, | |
cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], | |
self_replace_steps: Union[float, Tuple[float, float]], | |
local_blend: Optional[LocalBlend], | |
tokenizer, | |
device, | |
attn_res=None, | |
): | |
super(AttentionControlEdit, self).__init__(attn_res=attn_res) | |
# add tokenizer and device here | |
self.tokenizer = tokenizer | |
self.device = device | |
self.batch_size = len(prompts) | |
self.cross_replace_alpha = get_time_words_attention_alpha( | |
prompts, num_steps, cross_replace_steps, self.tokenizer | |
).to(self.device) | |
if isinstance(self_replace_steps, float): | |
self_replace_steps = 0, self_replace_steps | |
self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) | |
self.local_blend = local_blend | |
class AttentionReplace(AttentionControlEdit): | |
def replace_cross_attention(self, attn_base, att_replace): | |
return torch.einsum("hpw,bwn->bhpn", attn_base, self.mapper) | |
def __init__( | |
self, | |
prompts, | |
num_steps: int, | |
cross_replace_steps: float, | |
self_replace_steps: float, | |
local_blend: Optional[LocalBlend] = None, | |
tokenizer=None, | |
device=None, | |
attn_res=None, | |
): | |
super(AttentionReplace, self).__init__( | |
prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device, attn_res | |
) | |
self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device) | |
class AttentionRefine(AttentionControlEdit): | |
def replace_cross_attention(self, attn_base, att_replace): | |
attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) | |
attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) | |
return attn_replace | |
def __init__( | |
self, | |
prompts, | |
num_steps: int, | |
cross_replace_steps: float, | |
self_replace_steps: float, | |
local_blend: Optional[LocalBlend] = None, | |
tokenizer=None, | |
device=None, | |
attn_res=None | |
): | |
super(AttentionRefine, self).__init__( | |
prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device, attn_res | |
) | |
self.mapper, alphas = get_refinement_mapper(prompts, self.tokenizer) | |
self.mapper, alphas = self.mapper.to(self.device), alphas.to(self.device) | |
self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) | |
class AttentionReweight(AttentionControlEdit): | |
def replace_cross_attention(self, attn_base, att_replace): | |
if self.prev_controller is not None: | |
attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) | |
attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] | |
return attn_replace | |
def __init__( | |
self, | |
prompts, | |
num_steps: int, | |
cross_replace_steps: float, | |
self_replace_steps: float, | |
equalizer, | |
local_blend: Optional[LocalBlend] = None, | |
controller: Optional[AttentionControlEdit] = None, | |
tokenizer=None, | |
device=None, | |
attn_res=None, | |
): | |
super(AttentionReweight, self).__init__( | |
prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device, attn_res | |
) | |
self.equalizer = equalizer.to(self.device) | |
self.prev_controller = controller | |
### util functions for all Edits | |
def update_alpha_time_word( | |
alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor] = None | |
): | |
if isinstance(bounds, float): | |
bounds = 0, bounds | |
start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) | |
if word_inds is None: | |
word_inds = torch.arange(alpha.shape[2]) | |
alpha[:start, prompt_ind, word_inds] = 0 | |
alpha[start:end, prompt_ind, word_inds] = 1 | |
alpha[end:, prompt_ind, word_inds] = 0 | |
return alpha | |
def get_time_words_attention_alpha( | |
prompts, num_steps, cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], tokenizer, max_num_words=77 | |
): | |
if not isinstance(cross_replace_steps, dict): | |
cross_replace_steps = {"default_": cross_replace_steps} | |
if "default_" not in cross_replace_steps: | |
cross_replace_steps["default_"] = (0.0, 1.0) | |
alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) | |
for i in range(len(prompts) - 1): | |
alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], i) | |
for key, item in cross_replace_steps.items(): | |
if key != "default_": | |
inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] | |
for i, ind in enumerate(inds): | |
if len(ind) > 0: | |
alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) | |
alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) | |
return alpha_time_words | |
### util functions for LocalBlend and ReplacementEdit | |
def get_word_inds(text: str, word_place: int, tokenizer): | |
split_text = text.split(" ") | |
if isinstance(word_place, str): | |
word_place = [i for i, word in enumerate(split_text) if word_place == word] | |
elif isinstance(word_place, int): | |
word_place = [word_place] | |
out = [] | |
if len(word_place) > 0: | |
words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] | |
cur_len, ptr = 0, 0 | |
for i in range(len(words_encode)): | |
cur_len += len(words_encode[i]) | |
if ptr in word_place: | |
out.append(i + 1) | |
if cur_len >= len(split_text[ptr]): | |
ptr += 1 | |
cur_len = 0 | |
return np.array(out) | |
### util functions for ReplacementEdit | |
def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): | |
words_x = x.split(" ") | |
words_y = y.split(" ") | |
if len(words_x) != len(words_y): | |
raise ValueError( | |
f"attention replacement edit can only be applied on prompts with the same length" | |
f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words." | |
) | |
inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] | |
inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] | |
inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] | |
mapper = np.zeros((max_len, max_len)) | |
i = j = 0 | |
cur_inds = 0 | |
while i < max_len and j < max_len: | |
if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: | |
inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] | |
if len(inds_source_) == len(inds_target_): | |
mapper[inds_source_, inds_target_] = 1 | |
else: | |
ratio = 1 / len(inds_target_) | |
for i_t in inds_target_: | |
mapper[inds_source_, i_t] = ratio | |
cur_inds += 1 | |
i += len(inds_source_) | |
j += len(inds_target_) | |
elif cur_inds < len(inds_source): | |
mapper[i, j] = 1 | |
i += 1 | |
j += 1 | |
else: | |
mapper[j, j] = 1 | |
i += 1 | |
j += 1 | |
# return torch.from_numpy(mapper).float() | |
return torch.from_numpy(mapper).to(torch.float16) | |
def get_replacement_mapper(prompts, tokenizer, max_len=77): | |
x_seq = prompts[0] | |
mappers = [] | |
for i in range(1, len(prompts)): | |
mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) | |
mappers.append(mapper) | |
return torch.stack(mappers) | |
### util functions for ReweightEdit | |
def get_equalizer( | |
text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]], tokenizer | |
): | |
if isinstance(word_select, (int, str)): | |
word_select = (word_select,) | |
equalizer = torch.ones(len(values), 77) | |
values = torch.tensor(values, dtype=torch.float32) | |
for i, word in enumerate(word_select): | |
inds = get_word_inds(text, word, tokenizer) | |
equalizer[:, inds] = torch.FloatTensor(values[i]) | |
return equalizer | |
### util functions for RefinementEdit | |
class ScoreParams: | |
def __init__(self, gap, match, mismatch): | |
self.gap = gap | |
self.match = match | |
self.mismatch = mismatch | |
def mis_match_char(self, x, y): | |
if x != y: | |
return self.mismatch | |
else: | |
return self.match | |
def get_matrix(size_x, size_y, gap): | |
matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) | |
matrix[0, 1:] = (np.arange(size_y) + 1) * gap | |
matrix[1:, 0] = (np.arange(size_x) + 1) * gap | |
return matrix | |
def get_traceback_matrix(size_x, size_y): | |
matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) | |
matrix[0, 1:] = 1 | |
matrix[1:, 0] = 2 | |
matrix[0, 0] = 4 | |
return matrix | |
def global_align(x, y, score): | |
matrix = get_matrix(len(x), len(y), score.gap) | |
trace_back = get_traceback_matrix(len(x), len(y)) | |
for i in range(1, len(x) + 1): | |
for j in range(1, len(y) + 1): | |
left = matrix[i, j - 1] + score.gap | |
up = matrix[i - 1, j] + score.gap | |
diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) | |
matrix[i, j] = max(left, up, diag) | |
if matrix[i, j] == left: | |
trace_back[i, j] = 1 | |
elif matrix[i, j] == up: | |
trace_back[i, j] = 2 | |
else: | |
trace_back[i, j] = 3 | |
return matrix, trace_back | |
def get_aligned_sequences(x, y, trace_back): | |
x_seq = [] | |
y_seq = [] | |
i = len(x) | |
j = len(y) | |
mapper_y_to_x = [] | |
while i > 0 or j > 0: | |
if trace_back[i, j] == 3: | |
x_seq.append(x[i - 1]) | |
y_seq.append(y[j - 1]) | |
i = i - 1 | |
j = j - 1 | |
mapper_y_to_x.append((j, i)) | |
elif trace_back[i][j] == 1: | |
x_seq.append("-") | |
y_seq.append(y[j - 1]) | |
j = j - 1 | |
mapper_y_to_x.append((j, -1)) | |
elif trace_back[i][j] == 2: | |
x_seq.append(x[i - 1]) | |
y_seq.append("-") | |
i = i - 1 | |
elif trace_back[i][j] == 4: | |
break | |
mapper_y_to_x.reverse() | |
return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64) | |
def get_mapper(x: str, y: str, tokenizer, max_len=77): | |
x_seq = tokenizer.encode(x) | |
y_seq = tokenizer.encode(y) | |
score = ScoreParams(0, 1, -1) | |
matrix, trace_back = global_align(x_seq, y_seq, score) | |
mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] | |
alphas = torch.ones(max_len) | |
alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float() | |
mapper = torch.zeros(max_len, dtype=torch.int64) | |
mapper[: mapper_base.shape[0]] = mapper_base[:, 1] | |
mapper[mapper_base.shape[0] :] = len(y_seq) + torch.arange(max_len - len(y_seq)) | |
return mapper, alphas | |
def get_refinement_mapper(prompts, tokenizer, max_len=77): | |
x_seq = prompts[0] | |
mappers, alphas = [], [] | |
for i in range(1, len(prompts)): | |
mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len) | |
mappers.append(mapper) | |
alphas.append(alpha) | |
return torch.stack(mappers), torch.stack(alphas) | |
def aggregate_attention(prompts, attention_store: AttentionStore, height: int, width: int, from_where: List[str], is_cross: bool, select: int): | |
out = [] | |
attention_maps = attention_store.get_average_attention() | |
attention_map_height = height // 32 | |
attention_map_width = width // 32 | |
num_pixels = attention_map_height * attention_map_width | |
for location in from_where: | |
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: | |
if item.shape[1] == num_pixels: | |
cross_maps = item.reshape(len(prompts), -1, attention_map_width, attention_map_height, item.shape[-1])[select] | |
out.append(cross_maps) | |
out = torch.cat(out, dim=0) | |
out = out.sum(0) / out.shape[0] | |
return out.cpu() | |
def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0, t=0): | |
tokens = tokenizer.encode(prompts[select]) | |
decoder = tokenizer.decode | |
attention_maps = aggregate_attention(prompts, attention_store, res, from_where, True, select) | |
images = [] | |
for i in range(len(tokens)): | |
image = attention_maps[:, :, i] | |
image = 255 * image / image.max() | |
image = image.unsqueeze(-1).expand(*image.shape, 3) | |
image = image.numpy().astype(np.uint8) | |
image = np.array(Image.fromarray(image).resize((256, 256))) | |
image = text_under_image(image, decoder(int(tokens[i]))) | |
images.append(image) | |
view_images(np.stack(images, axis=0), t=t, from_where=from_where) | |
def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str], | |
max_com=10, select: int = 0): | |
attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2)) | |
u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True)) | |
images = [] | |
for i in range(max_com): | |
image = vh[i].reshape(res, res) | |
image = image - image.min() | |
image = 255 * image / image.max() | |
image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8) | |
image = Image.fromarray(image).resize((256, 256)) | |
image = np.array(image) | |
images.append(image) | |
view_images(np.concatenate(images, axis=1),from_where=from_where) | |
def view_images(images, num_rows=1, offset_ratio=0.02, t=0, from_where= List[str]): | |
if type(images) is list: | |
num_empty = len(images) % num_rows | |
elif images.ndim == 4: | |
num_empty = images.shape[0] % num_rows | |
else: | |
images = [images] | |
num_empty = 0 | |
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 | |
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty | |
num_items = len(images) | |
h, w, c = images[0].shape | |
offset = int(h * offset_ratio) | |
num_cols = num_items // num_rows | |
image_ = np.ones((h * num_rows + offset * (num_rows - 1), | |
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 | |
for i in range(num_rows): | |
for j in range(num_cols): | |
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ | |
i * num_cols + j] | |
pil_img = Image.fromarray(image_) | |
if len(from_where) > 1: | |
from_where = '_'.join(from_where) | |
save_path = f'./visualization/{from_where}' | |
if not os.path.exists(save_path): | |
os.mkdir(save_path) | |
pil_img.save(f"{save_path}/{t}.png") | |
def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): | |
h, w, c = image.shape | |
offset = int(h * .2) | |
img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
img[:h] = image | |
textsize = cv2.getTextSize(text, font, 1, 2)[0] | |
text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 | |
cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) | |
return img | |
def get_views(height, width, window_size=32, stride=16, random_jitter=False): | |
num_blocks_height = int((height - window_size) / stride - 1e-6) + 2 if height > window_size else 1 | |
num_blocks_width = int((width - window_size) / stride - 1e-6) + 2 if width > window_size else 1 | |
total_num_blocks = int(num_blocks_height * num_blocks_width) | |
views = [] | |
for i in range(total_num_blocks): | |
h_start = int((i // num_blocks_width) * stride) | |
h_end = h_start + window_size | |
w_start = int((i % num_blocks_width) * stride) | |
w_end = w_start + window_size | |
if h_end > height: | |
h_start = int(h_start + height - h_end) | |
h_end = int(height) | |
if w_end > width: | |
w_start = int(w_start + width - w_end) | |
w_end = int(width) | |
if h_start < 0: | |
h_end = int(h_end - h_start) | |
h_start = 0 | |
if w_start < 0: | |
w_end = int(w_end - w_start) | |
w_start = 0 | |
if random_jitter: | |
jitter_range = (window_size - stride) // 4 | |
w_jitter = 0 | |
h_jitter = 0 | |
if (w_start != 0) and (w_end != width): | |
w_jitter = random.randint(-jitter_range, jitter_range) | |
elif (w_start == 0) and (w_end != width): | |
w_jitter = random.randint(-jitter_range, 0) | |
elif (w_start != 0) and (w_end == width): | |
w_jitter = random.randint(0, jitter_range) | |
if (h_start != 0) and (h_end != height): | |
h_jitter = random.randint(-jitter_range, jitter_range) | |
elif (h_start == 0) and (h_end != height): | |
h_jitter = random.randint(-jitter_range, 0) | |
elif (h_start != 0) and (h_end == height): | |
h_jitter = random.randint(0, jitter_range) | |
h_start += (h_jitter + jitter_range) | |
h_end += (h_jitter + jitter_range) | |
w_start += (w_jitter + jitter_range) | |
w_end += (w_jitter + jitter_range) | |
views.append((int(h_start), int(h_end), int(w_start), int(w_end))) | |
return views | |
def get_multidiffusion_prompts(tokenizer, prompts, threthod, attention_store:AttentionStore, height:int, width:int, from_where: List[str], scale_num=4, random_jitter=False): | |
tokens = tokenizer.encode(prompts[0]) | |
decoder = tokenizer.decode | |
# get cross_attention_maps | |
attention_maps = aggregate_attention(prompts, attention_store, height, width, from_where, True, 0) | |
# view cross_attention_maps | |
images = [] | |
for i in range(len(tokens)): | |
image = attention_maps[:, :, i] | |
image = 255 * image / image.max() | |
image = image.unsqueeze(-1).expand(*image.shape, 3).numpy().astype(np.uint8) | |
image = np.array(Image.fromarray(image).resize((256, 256))) | |
image = text_under_image(image, decoder(int(tokens[i]))) | |
images.append(image) | |
# get high attention regions | |
masks = [] | |
for i in range(len(tokens)): | |
attention_map = attention_maps[:, :, i] | |
attention_map = attention_map.to(torch.float32) | |
words = decoder(int(tokens[i])) | |
mask = torch.where(attention_map > attention_map.mean(), 1, 0).numpy().astype(np.uint8) | |
mask = mask * 255 | |
# process mask | |
kernel = np.ones((3, 3), np.uint8) | |
eroded_mask = cv2.erode(mask, kernel, iterations=mask.shape[0]//16) | |
dilated_mask = cv2.dilate(eroded_mask, kernel, iterations=mask.shape[0]//16) | |
masks.append(dilated_mask) | |
# dict for prompts and views | |
prompt_dict = {} | |
view_dict = {} | |
ori_w, ori_h = mask.shape | |
window_size = max(ori_h, ori_w) | |
for scale in range(2, scale_num+1): | |
# current height and width | |
cur_w = ori_w * scale | |
cur_h = ori_h * scale | |
views = get_views(height=cur_h, width=cur_w, window_size=window_size, stride=window_size/2, random_jitter=random_jitter) | |
words_in_patch = [] | |
for i, mask in enumerate(masks): | |
# skip endoftext and beginof text masks | |
if i == 0 or i == len(masks) - 1: | |
continue | |
# upscale masks | |
mask = cv2.resize(mask, (cur_w, cur_h), interpolation=cv2.INTER_NEAREST) | |
if random_jitter: | |
jitter_range = int((ori_h - ori_h/2) // 4) | |
mask = np.pad(mask, ((jitter_range, jitter_range), (jitter_range, jitter_range)), 'constant', constant_values=(0, 0)) | |
word_in_patch =[] | |
word = decoder(int(tokens[i])) | |
for i, view in enumerate(views): | |
h_start, h_end, w_start, w_end = view | |
view_mask = mask[h_start:h_end, w_start:w_end] | |
if (view_mask/255).sum() / (ori_h * ori_w) >= threthod: | |
word_in_patch.append(word) # word in patch | |
else: | |
word_in_patch.append('') # word not in patch | |
words_in_patch.append(word_in_patch) | |
# get prompts for each view | |
result = [] | |
prompts_for_each_views = [' '.join(strings) for strings in zip(*words_in_patch)] | |
for prompt in prompts_for_each_views: | |
prompt = prompt.split() | |
result.append(" ".join(prompt)) | |
# save prompts and views in each scale | |
prompt_dict[scale] = result | |
view_dict[scale] = views | |
return prompt_dict, view_dict | |
class ScaledAttnProcessor: | |
r""" | |
Default processor for performing attention-related computations. | |
""" | |
def __init__(self, processor, test_res, train_res): | |
self.processor = processor | |
self.test_res = test_res | |
self.train_res = train_res | |
def __call__( | |
self, | |
attn, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
): | |
input_ndim = hidden_states.ndim | |
# print(f"cross attention: {not encoder_hidden_states is None}") | |
# if encoder_hidden_states is None: | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
sequence_length = height * width | |
else: | |
batch_size, sequence_length, _ = hidden_states.shape | |
test_train_ratio = (self.test_res ** 2.0) / (self.train_res ** 2.0) | |
# test_train_ratio = float(self.test_res / self.train_res) | |
# print(f"test_train_ratio: {test_train_ratio}") | |
train_sequence_length = sequence_length / test_train_ratio | |
scale_factor = math.log(sequence_length, train_sequence_length) ** 0.5 | |
# else: | |
# scale_factor = 1 | |
# print(f"scale factor: {scale_factor}") | |
original_scale = attn.scale | |
attn.scale = attn.scale * scale_factor | |
hidden_states = self.processor(attn, hidden_states, encoder_hidden_states, attention_mask, temb, scale = attn.scale ) | |
# hidden_states = super(ScaledAttnProcessor, self).__call__( | |
# attn, hidden_states, encoder_hidden_states, attention_mask, temb) | |
attn.scale = original_scale | |
return hidden_states | |