Spaces:
Running
Running
import torch | |
from torch import nn, einsum | |
from .ldm.modules.attention import CrossAttention | |
from inspect import isfunction | |
def exists(val): | |
return val is not None | |
def uniq(arr): | |
return{el: True for el in arr}.keys() | |
def default(val, d): | |
if exists(val): | |
return val | |
return d() if isfunction(d) else d | |
# feedforward | |
class GEGLU(nn.Module): | |
def __init__(self, dim_in, dim_out): | |
super().__init__() | |
self.proj = nn.Linear(dim_in, dim_out * 2) | |
def forward(self, x): | |
x, gate = self.proj(x).chunk(2, dim=-1) | |
return x * torch.nn.functional.gelu(gate) | |
class FeedForward(nn.Module): | |
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): | |
super().__init__() | |
inner_dim = int(dim * mult) | |
dim_out = default(dim_out, dim) | |
project_in = nn.Sequential( | |
nn.Linear(dim, inner_dim), | |
nn.GELU() | |
) if not glu else GEGLU(dim, inner_dim) | |
self.net = nn.Sequential( | |
project_in, | |
nn.Dropout(dropout), | |
nn.Linear(inner_dim, dim_out) | |
) | |
def forward(self, x): | |
return self.net(x) | |
class GatedCrossAttentionDense(nn.Module): | |
def __init__(self, query_dim, context_dim, n_heads, d_head): | |
super().__init__() | |
self.attn = CrossAttention( | |
query_dim=query_dim, | |
context_dim=context_dim, | |
heads=n_heads, | |
dim_head=d_head) | |
self.ff = FeedForward(query_dim, glu=True) | |
self.norm1 = nn.LayerNorm(query_dim) | |
self.norm2 = nn.LayerNorm(query_dim) | |
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) | |
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) | |
# this can be useful: we can externally change magnitude of tanh(alpha) | |
# for example, when it is set to 0, then the entire model is same as | |
# original one | |
self.scale = 1 | |
def forward(self, x, objs): | |
x = x + self.scale * \ | |
torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs) | |
x = x + self.scale * \ | |
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) | |
return x | |
class GatedSelfAttentionDense(nn.Module): | |
def __init__(self, query_dim, context_dim, n_heads, d_head): | |
super().__init__() | |
# we need a linear projection since we need cat visual feature and obj | |
# feature | |
self.linear = nn.Linear(context_dim, query_dim) | |
self.attn = CrossAttention( | |
query_dim=query_dim, | |
context_dim=query_dim, | |
heads=n_heads, | |
dim_head=d_head) | |
self.ff = FeedForward(query_dim, glu=True) | |
self.norm1 = nn.LayerNorm(query_dim) | |
self.norm2 = nn.LayerNorm(query_dim) | |
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) | |
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) | |
# this can be useful: we can externally change magnitude of tanh(alpha) | |
# for example, when it is set to 0, then the entire model is same as | |
# original one | |
self.scale = 1 | |
def forward(self, x, objs): | |
N_visual = x.shape[1] | |
objs = self.linear(objs) | |
x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn( | |
self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :] | |
x = x + self.scale * \ | |
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) | |
return x | |
class GatedSelfAttentionDense2(nn.Module): | |
def __init__(self, query_dim, context_dim, n_heads, d_head): | |
super().__init__() | |
# we need a linear projection since we need cat visual feature and obj | |
# feature | |
self.linear = nn.Linear(context_dim, query_dim) | |
self.attn = CrossAttention( | |
query_dim=query_dim, context_dim=query_dim, dim_head=d_head) | |
self.ff = FeedForward(query_dim, glu=True) | |
self.norm1 = nn.LayerNorm(query_dim) | |
self.norm2 = nn.LayerNorm(query_dim) | |
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) | |
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) | |
# this can be useful: we can externally change magnitude of tanh(alpha) | |
# for example, when it is set to 0, then the entire model is same as | |
# original one | |
self.scale = 1 | |
def forward(self, x, objs): | |
B, N_visual, _ = x.shape | |
B, N_ground, _ = objs.shape | |
objs = self.linear(objs) | |
# sanity check | |
size_v = math.sqrt(N_visual) | |
size_g = math.sqrt(N_ground) | |
assert int(size_v) == size_v, "Visual tokens must be square rootable" | |
assert int(size_g) == size_g, "Grounding tokens must be square rootable" | |
size_v = int(size_v) | |
size_g = int(size_g) | |
# select grounding token and resize it to visual token size as residual | |
out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[ | |
:, N_visual:, :] | |
out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g) | |
out = torch.nn.functional.interpolate( | |
out, (size_v, size_v), mode='bicubic') | |
residual = out.reshape(B, -1, N_visual).permute(0, 2, 1) | |
# add residual to visual feature | |
x = x + self.scale * torch.tanh(self.alpha_attn) * residual | |
x = x + self.scale * \ | |
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) | |
return x | |
class FourierEmbedder(): | |
def __init__(self, num_freqs=64, temperature=100): | |
self.num_freqs = num_freqs | |
self.temperature = temperature | |
self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs) | |
def __call__(self, x, cat_dim=-1): | |
"x: arbitrary shape of tensor. dim: cat dim" | |
out = [] | |
for freq in self.freq_bands: | |
out.append(torch.sin(freq * x)) | |
out.append(torch.cos(freq * x)) | |
return torch.cat(out, cat_dim) | |
class PositionNet(nn.Module): | |
def __init__(self, in_dim, out_dim, fourier_freqs=8): | |
super().__init__() | |
self.in_dim = in_dim | |
self.out_dim = out_dim | |
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) | |
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy | |
self.linears = nn.Sequential( | |
nn.Linear(self.in_dim + self.position_dim, 512), | |
nn.SiLU(), | |
nn.Linear(512, 512), | |
nn.SiLU(), | |
nn.Linear(512, out_dim), | |
) | |
self.null_positive_feature = torch.nn.Parameter( | |
torch.zeros([self.in_dim])) | |
self.null_position_feature = torch.nn.Parameter( | |
torch.zeros([self.position_dim])) | |
def forward(self, boxes, masks, positive_embeddings): | |
B, N, _ = boxes.shape | |
dtype = self.linears[0].weight.dtype | |
masks = masks.unsqueeze(-1).to(dtype) | |
positive_embeddings = positive_embeddings.to(dtype) | |
# embedding position (it may includes padding as placeholder) | |
xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C | |
# learnable null embedding | |
positive_null = self.null_positive_feature.view(1, 1, -1) | |
xyxy_null = self.null_position_feature.view(1, 1, -1) | |
# replace padding with learnable null embedding | |
positive_embeddings = positive_embeddings * \ | |
masks + (1 - masks) * positive_null | |
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null | |
objs = self.linears( | |
torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) | |
assert objs.shape == torch.Size([B, N, self.out_dim]) | |
return objs | |
class Gligen(nn.Module): | |
def __init__(self, modules, position_net, key_dim): | |
super().__init__() | |
self.module_list = nn.ModuleList(modules) | |
self.position_net = position_net | |
self.key_dim = key_dim | |
self.max_objs = 30 | |
self.current_device = torch.device("cpu") | |
def _set_position(self, boxes, masks, positive_embeddings): | |
objs = self.position_net(boxes, masks, positive_embeddings) | |
def func(x, extra_options): | |
key = extra_options["transformer_index"] | |
module = self.module_list[key] | |
return module(x, objs) | |
return func | |
def set_position(self, latent_image_shape, position_params, device): | |
batch, c, h, w = latent_image_shape | |
masks = torch.zeros([self.max_objs], device="cpu") | |
boxes = [] | |
positive_embeddings = [] | |
for p in position_params: | |
x1 = (p[4]) / w | |
y1 = (p[3]) / h | |
x2 = (p[4] + p[2]) / w | |
y2 = (p[3] + p[1]) / h | |
masks[len(boxes)] = 1.0 | |
boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)] | |
positive_embeddings += [p[0]] | |
append_boxes = [] | |
append_conds = [] | |
if len(boxes) < self.max_objs: | |
append_boxes = [torch.zeros( | |
[self.max_objs - len(boxes), 4], device="cpu")] | |
append_conds = [torch.zeros( | |
[self.max_objs - len(boxes), self.key_dim], device="cpu")] | |
box_out = torch.cat( | |
boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1) | |
masks = masks.unsqueeze(0).repeat(batch, 1) | |
conds = torch.cat(positive_embeddings + | |
append_conds).unsqueeze(0).repeat(batch, 1, 1) | |
return self._set_position( | |
box_out.to(device), | |
masks.to(device), | |
conds.to(device)) | |
def set_empty(self, latent_image_shape, device): | |
batch, c, h, w = latent_image_shape | |
masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1) | |
box_out = torch.zeros([self.max_objs, 4], | |
device="cpu").repeat(batch, 1, 1) | |
conds = torch.zeros([self.max_objs, self.key_dim], | |
device="cpu").repeat(batch, 1, 1) | |
return self._set_position( | |
box_out.to(device), | |
masks.to(device), | |
conds.to(device)) | |
def load_gligen(sd): | |
sd_k = sd.keys() | |
output_list = [] | |
key_dim = 768 | |
for a in ["input_blocks", "middle_block", "output_blocks"]: | |
for b in range(20): | |
k_temp = filter(lambda k: "{}.{}.".format(a, b) | |
in k and ".fuser." in k, sd_k) | |
k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp) | |
n_sd = {} | |
for k in k_temp: | |
n_sd[k[1]] = sd[k[0]] | |
if len(n_sd) > 0: | |
query_dim = n_sd["linear.weight"].shape[0] | |
key_dim = n_sd["linear.weight"].shape[1] | |
if key_dim == 768: # SD1.x | |
n_heads = 8 | |
d_head = query_dim // n_heads | |
else: | |
d_head = 64 | |
n_heads = query_dim // d_head | |
gated = GatedSelfAttentionDense( | |
query_dim, key_dim, n_heads, d_head) | |
gated.load_state_dict(n_sd, strict=False) | |
output_list.append(gated) | |
if "position_net.null_positive_feature" in sd_k: | |
in_dim = sd["position_net.null_positive_feature"].shape[0] | |
out_dim = sd["position_net.linears.4.weight"].shape[0] | |
class WeightsLoader(torch.nn.Module): | |
pass | |
w = WeightsLoader() | |
w.position_net = PositionNet(in_dim, out_dim) | |
w.load_state_dict(sd, strict=False) | |
gligen = Gligen(output_list, w.position_net, key_dim) | |
return gligen | |