import torch.nn as nn from ldm.modules.encoders.modules import OpenCLIPImageEmbedder, FrozenOpenCLIPEmbedder from ldm.util import instantiate_from_config import torch from taming.models.vqgan import VQModelInterfaceEncoder, VQModel from ldm.modules.attention import SpatialTransformer from ldm.modules.attention import Normalize, BasicTransformerBlock#, exists from ldm.modules.diffusionmodules.util import zero_module, identity_init_fc, conv_nd from einops import rearrange # from ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self def make_zero_conv(in_channels, out_channels, kernel_size, stride=1, padding=0): return zero_module(conv_nd(2, in_channels, out_channels, kernel_size, stride=stride, padding=padding)) class SpatialTransformer_v2(nn.Module): """ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply standard transformer action. Finally, reshape to image NEW: use_linear for more efficiency instead of the 1x1 convs """ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, disable_self_attn=False, use_linear=False, use_checkpoint=True): super().__init__() # change: # if exists(context_dim) and not isinstance(context_dim, list): if not isinstance(context_dim, list): context_dim = [context_dim] self.in_channels = in_channels inner_dim = n_heads * d_head self.norm = Normalize(in_channels) if not use_linear: self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) else: self.proj_in = nn.Linear(in_channels, inner_dim) self.transformer_blocks = nn.ModuleList( [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) for d in range(depth)] ) if not use_linear: self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) else: self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) # change: switch self.use_linear = use_linear def forward(self, x, context=None): # note: if no context is given, cross-attention defaults to self-attention if not isinstance(context, list): context = [context] b, c, h, w = x.shape x_in = x x = self.norm(x) if not self.use_linear: x = self.proj_in(x) x = rearrange(x, 'b c h w -> b (h w) c').contiguous() if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): x = block(x, context=context[i]) if self.use_linear: x = self.proj_out(x) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() if not self.use_linear: x = self.proj_out(x) return x + x_in class trans_glyph_emb(nn.Module): def __init__(self, type = "fc", # "conv", "attn" input_dim = 256, out_dim = 1024, # fc fc_init = "zero", # conv/attn conv_ks = 3, conv_pad = 1, conv_stride = 1, # attn ch = 512, # 1024 num_heads = 8, # 16 dim_head = 64, use_linear_in_transformer = True, use_checkpoint = False, #True, ): super().__init__() if type == "fc": self.model = torch.nn.Linear(input_dim, out_dim) if fc_init == "zero": self.model = zero_module(self.model) elif fc_init == "identity": self.model = identity_init_fc(self.model) elif type == "conv": self.model = make_zero_conv(input_dim, out_dim, conv_ks, stride = conv_stride, padding = conv_pad) elif type == "attn": model = [ # nn.Conv2d(input_dim, ch, 3, stride = 1, padding = 1), nn.Conv2d(input_dim, ch, conv_ks, stride = conv_stride, padding = conv_pad), SpatialTransformer_v2( #SpatialTransformer( ch, num_heads, dim_head, depth=1, context_dim=None, #ch, disable_self_attn=False, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint, # False if the context is None ), make_zero_conv(ch, out_dim, 1, stride = 1, padding = 0) # make_zero_conv(ch, out_dim, conv_ks, stride = conv_stride, padding = conv_pad) ] self.model = nn.Sequential(*model) self.model_type = type def forward(self, x): if self.model_type == "fc": # b, c, h, w = x.shape x = rearrange(x, 'b c h w -> b (h w) c').contiguous() x = self.model(x) # x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() # return x else: x = self.model(x) x = rearrange(x, 'b c h w -> b (h w) c').contiguous() return x class glyph_control(nn.Module): def __init__(self, image_encoder = "CLIP", # "VQGAN" image_encoder_config = None, fuse_way = "concat", load_text_encoder = False, text_encoder_config = None, freeze_image_encoder = True, trans_emb = False, trans_emb_config = None, # use_fp16 = False, ): super().__init__() if image_encoder_config is not None: image_encoder_config.params.freeze = freeze_image_encoder self.image_encoder = instantiate_from_config(image_encoder_config) else: if image_encoder == "CLIP": self.image_encoder = OpenCLIPImageEmbedder(freeze=freeze_image_encoder) elif image_encoder == "VQGAN": print("VQGAN glyph image encoder is missing config") raise ValueError else: print("Other types of glyph image encoder are not supported") raise ValueError if freeze_image_encoder: self.freeze_imenc() self.freeze_image_encoder = freeze_image_encoder self.image_encoder_type = image_encoder if load_text_encoder: if text_encoder_config is None: self.text_encoder = FrozenOpenCLIPEmbedder() else: self.text_encoder = instantiate_from_config(text_encoder_config) self.fuse_way = fuse_way # self.dtype = torch.float16 if use_fp16 else torch.float32 if trans_emb: if trans_emb_config is not None: self.trans_glyph_emb_model = instantiate_from_config(trans_emb_config) else: self.trans_glyph_emb_model = trans_glyph_emb() else: self.trans_glyph_emb_model = None def freeze_imenc(self): self.image_encoder = self.image_encoder.eval() self.image_encoder.train = disabled_train for param in self.image_encoder.parameters(): param.requires_grad = False def forward(self, glyph_image, text = None, text_embed = None): clgim_num_list = [img.shape[0] for img in glyph_image] # image_embeds = self.image_encoder(torch.concat(glyph_image, dim=0)) gim_concat = torch.concat(glyph_image, dim=0) image_embeds = self.image_encoder(gim_concat) if self.trans_glyph_emb_model is not None: image_embeds = self.trans_glyph_emb_model(image_embeds) image_embeds = torch.split(image_embeds, clgim_num_list) max_image_tokens = max(clgim_num_list) pad_image_embeds = [] for image_embed in image_embeds: if image_embed.shape[0] < max_image_tokens: image_embed = torch.concat([ image_embed, torch.zeros( (max_image_tokens - image_embed.shape[0], *image_embed.shape[1:]), device=image_embed.device, dtype=image_embed.dtype, # add dtype )], dim=0 ) pad_image_embeds.append(image_embed) pad_image_embeds = torch.stack(pad_image_embeds, dim = 0) if text_embed is None: assert self.text_encoder, text is not None text_embed = self.text_encoder(text) if self.fuse_way == "concat": assert pad_image_embeds.shape[-1] == text_embed.shape[-1] if len(pad_image_embeds.shape) == 4: b, _, _ , embdim = pad_image_embeds.shape pad_image_embeds = pad_image_embeds.view(b, -1, embdim) out_embed = torch.concat([text_embed, pad_image_embeds], dim= 1) print("concat glyph_embed with text_embed:", out_embed.shape) return out_embed else: raise ValueError("Not support other fuse ways for now!")