Spaces:
Running
on
Zero
Running
on
Zero
import abc | |
LOW_RESOURCE = False | |
import torch | |
import cv2 | |
import torch | |
import os | |
import numpy as np | |
from collections import defaultdict | |
from functools import partial | |
from typing import Any, Dict, Optional | |
def register_attention_control(unet, config=None): | |
def BasicTransformerBlock_forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
timestep: Optional[torch.LongTensor] = None, | |
cross_attention_kwargs: Dict[str, Any] = None, | |
class_labels: Optional[torch.LongTensor] = None, | |
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, | |
) -> torch.FloatTensor: | |
# Notice that normalization is always applied before the real computation in the following blocks. | |
# 0. Self-Attention | |
batch_size = hidden_states.shape[0] | |
if self.norm_type == "ada_norm": | |
norm_hidden_states = self.norm1(hidden_states, timestep) | |
elif self.norm_type == "ada_norm_zero": | |
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( | |
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype | |
) | |
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: | |
norm_hidden_states = self.norm1(hidden_states) | |
elif self.norm_type == "ada_norm_continuous": | |
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) | |
elif self.norm_type == "ada_norm_single": | |
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( | |
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) | |
).chunk(6, dim=1) | |
norm_hidden_states = self.norm1(hidden_states) | |
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa | |
norm_hidden_states = norm_hidden_states.squeeze(1) | |
else: | |
raise ValueError("Incorrect norm used") | |
# save the origin_hidden_states w/o pos_embed, for the use of motion v embedding | |
origin_hidden_states = None | |
if self.pos_embed is not None or hasattr(self.attn1,'vSpatial'): | |
origin_hidden_states = norm_hidden_states.clone() | |
if cross_attention_kwargs is None: | |
cross_attention_kwargs = {} | |
cross_attention_kwargs["origin_hidden_states"] = origin_hidden_states | |
if self.pos_embed is not None: | |
norm_hidden_states = self.pos_embed(norm_hidden_states) | |
# 1. Retrieve lora scale. | |
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 | |
# 2. Prepare GLIGEN inputs | |
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} | |
gligen_kwargs = cross_attention_kwargs.pop("gligen", None) | |
attn_output = self.attn1( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, | |
attention_mask=attention_mask, | |
**cross_attention_kwargs, | |
) | |
if self.norm_type == "ada_norm_zero": | |
attn_output = gate_msa.unsqueeze(1) * attn_output | |
elif self.norm_type == "ada_norm_single": | |
attn_output = gate_msa * attn_output | |
hidden_states = attn_output + hidden_states | |
if hidden_states.ndim == 4: | |
hidden_states = hidden_states.squeeze(1) | |
# 2.5 GLIGEN Control | |
if gligen_kwargs is not None: | |
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) | |
# 3. Cross-Attention | |
if self.attn2 is not None: | |
if self.norm_type == "ada_norm": | |
norm_hidden_states = self.norm2(hidden_states, timestep) | |
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: | |
norm_hidden_states = self.norm2(hidden_states) | |
elif self.norm_type == "ada_norm_single": | |
# For PixArt norm2 isn't applied here: | |
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 | |
norm_hidden_states = hidden_states | |
elif self.norm_type == "ada_norm_continuous": | |
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) | |
else: | |
raise ValueError("Incorrect norm") | |
if self.pos_embed is not None and self.norm_type != "ada_norm_single": | |
# save the origin_hidden_states | |
origin_hidden_states = norm_hidden_states.clone() | |
norm_hidden_states = self.pos_embed(norm_hidden_states) | |
cross_attention_kwargs["origin_hidden_states"] = origin_hidden_states | |
attn_output = self.attn2( | |
norm_hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=encoder_attention_mask, | |
**cross_attention_kwargs, | |
) | |
hidden_states = attn_output + hidden_states | |
# delete the origin_hidden_states | |
if cross_attention_kwargs is not None and "origin_hidden_states" in cross_attention_kwargs: | |
cross_attention_kwargs.pop("origin_hidden_states") | |
# 4. Feed-forward | |
# i2vgen doesn't have this norm 🤷♂️ | |
if self.norm_type == "ada_norm_continuous": | |
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) | |
elif not self.norm_type == "ada_norm_single": | |
norm_hidden_states = self.norm3(hidden_states) | |
if self.norm_type == "ada_norm_zero": | |
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] | |
if self.norm_type == "ada_norm_single": | |
norm_hidden_states = self.norm2(hidden_states) | |
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp | |
if self._chunk_size is not None: | |
# "feed_forward_chunk_size" can be used to save memory | |
ff_output = _chunked_feed_forward( | |
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale | |
) | |
else: | |
ff_output = self.ff(norm_hidden_states, scale=lora_scale) | |
if self.norm_type == "ada_norm_zero": | |
ff_output = gate_mlp.unsqueeze(1) * ff_output | |
elif self.norm_type == "ada_norm_single": | |
ff_output = gate_mlp * ff_output | |
hidden_states = ff_output + hidden_states | |
if hidden_states.ndim == 4: | |
hidden_states = hidden_states.squeeze(1) | |
return hidden_states | |
def temp_attn_forward(self, additional_info=None): | |
to_out = self.to_out | |
if type(to_out) is torch.nn.modules.container.ModuleList: | |
to_out = self.to_out[0] | |
else: | |
to_out = self.to_out | |
def forward(hidden_states, encoder_hidden_states=None, attention_mask=None,temb=None,origin_hidden_states=None): | |
residual = hidden_states | |
if self.spatial_norm is not None: | |
hidden_states = self.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
if self.group_norm is not None: | |
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif self.norm_cross: | |
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) | |
query = self.to_q(hidden_states) | |
key = self.to_k(encoder_hidden_states) | |
# strategies to manipulate the motion value embedding | |
if additional_info is not None: | |
# empirically, in the inference stage of camera motion | |
# discarding the motion value embedding improves the text similarity of the generated video | |
if additional_info['removeMFromV']: | |
value = self.to_v(origin_hidden_states) | |
elif hasattr(self,'vSpatial'): | |
# during inference, the debiasing operation helps to generate more diverse videos | |
# refer to the 'Figure.3 Right' in the paper for more details | |
if additional_info['vSpatial_frameSubtraction']: | |
value = self.to_v(self.vSpatial.forward_frameSubtraction(origin_hidden_states)) | |
# during training, do not apply debias operation for motion learning | |
else: | |
value = self.to_v(self.vSpatial(origin_hidden_states)) | |
else: | |
value = self.to_v(origin_hidden_states) | |
else: | |
value = self.to_v(encoder_hidden_states) | |
query = self.head_to_batch_dim(query) | |
key = self.head_to_batch_dim(key) | |
value = self.head_to_batch_dim(value) | |
attention_probs = self.get_attention_scores(query, key, attention_mask) | |
hidden_states = torch.bmm(attention_probs, value) | |
hidden_states = self.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = to_out(hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if self.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / self.rescale_output_factor | |
return hidden_states | |
return forward | |
def register_recr(net_, count, name, config=None): | |
if net_.__class__.__name__ == 'BasicTransformerBlock': | |
BasicTransformerBlock_forward_ = partial(BasicTransformerBlock_forward, net_) | |
net_.forward = BasicTransformerBlock_forward_ | |
if net_.__class__.__name__ == 'Attention': | |
block_name = name.split('.attn')[0] | |
if config is not None and block_name in set([l.split('.attn')[0].split('.pos_embed')[0] for l in config.model.embedding_layers]): | |
additional_info = {} | |
additional_info['layer_name'] = name | |
additional_info['removeMFromV'] = config.strategy.get('removeMFromV', False) | |
additional_info['vSpatial_frameSubtraction'] = config.strategy.get('vSpatial_frameSubtraction', False) | |
net_.forward = temp_attn_forward(net_, additional_info) | |
print('register Motion V embedding at ', block_name) | |
return count + 1 | |
else: | |
return count | |
elif hasattr(net_, 'children'): | |
for net_name, net__ in dict(net_.named_children()).items(): | |
count = register_recr(net__, count, name = name + '.' + net_name, config=config) | |
return count | |
sub_nets = unet.named_children() | |
for net in sub_nets: | |
register_recr(net[1], 0,name = net[0], config=config) | |