import re import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class MotionEmbedding(nn.Module): def __init__(self, embed_dim: int = None, max_seq_length: int = 32, wh: int = 1): super().__init__() self.embed = nn.Parameter(torch.zeros(wh, max_seq_length, embed_dim)) # print('register spatial motion embedding with', wh) self.scale = 1.0 self.trained_length = -1 def set_scale(self, scale: float): self.scale = scale def set_lengths(self, trained_length: int): if trained_length > self.embed.shape[1] or trained_length <= 0: raise ValueError("Trained length is out of bounds") self.trained_length = trained_length def forward(self, x): _, seq_length, _ = x.shape # seq_length here is the target sequence length for x # print('seq_length',seq_length) # Assuming self.embed is [batch, frames, dim] embeddings = self.embed[:, :seq_length] # Initial slice, may not be necessary depending on the interpolation logic # Check if interpolation is needed if self.trained_length != -1 and seq_length != self.trained_length: # Interpolate embeddings to match x's sequence length # Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames embeddings = embeddings.permute(0, 2, 1) # Now [batch, dim, frames] embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False) embeddings = embeddings.permute(0, 2, 1) # Revert to [batch, frames, dim] # Ensure the interpolated embeddings match the sequence length of x if embeddings.shape[1] != seq_length: raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}") if x.shape[0] != embeddings.shape[0]: x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale else: # Now embeddings should have the shape [batch, seq_length, dim] matching x x = x + embeddings * self.scale # Assuming broadcasting is desired over the batch and dim dimensions return x def forward_average(self, x): _, seq_length, _ = x.shape # seq_length here is the target sequence length for x # print('seq_length',seq_length) # Assuming self.embed is [batch, frames, dim] embeddings = self.embed[:, :seq_length] # Initial slice, may not be necessary depending on the interpolation logic # Check if interpolation is needed if self.trained_length != -1 and seq_length != self.trained_length: # Interpolate embeddings to match x's sequence length # Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames embeddings = embeddings.permute(0, 2, 1) # Now [batch, dim, frames] embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False) embeddings = embeddings.permute(0, 2, 1) # Revert to [batch, frames, dim] # Ensure the interpolated embeddings match the sequence length of x if embeddings.shape[1] != seq_length: raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}") embeddings_mean = embeddings.mean(dim=1, keepdim=True) embeddings = embeddings - embeddings_mean if x.shape[0] != embeddings.shape[0]: x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale else: # Now embeddings should have the shape [batch, seq_length, dim] matching x x = x + embeddings * self.scale # Assuming broadcasting is desired over the batch and dim dimensions return x def forward_frameSubtraction(self, x): _, seq_length, _ = x.shape # seq_length here is the target sequence length for x # print('seq_length',seq_length) # Assuming self.embed is [batch, frames, dim] embeddings = self.embed[:, :seq_length] # Initial slice, may not be necessary depending on the interpolation logic # Check if interpolation is needed if self.trained_length != -1 and seq_length != self.trained_length: # Interpolate embeddings to match x's sequence length # Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames embeddings = embeddings.permute(0, 2, 1) # Now [batch, dim, frames] embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False) embeddings = embeddings.permute(0, 2, 1) # Revert to [batch, frames, dim] # Ensure the interpolated embeddings match the sequence length of x if embeddings.shape[1] != seq_length: raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}") embeddings_subtraction = embeddings[:,1:] - embeddings[:,:-1] embeddings = embeddings.clone().detach() embeddings[:,1:] = embeddings_subtraction # first frame minus mean # embeddings[:,0:1] = embeddings[:,0:1] - embeddings.mean(dim=1, keepdim=True) if x.shape[0] != embeddings.shape[0]: x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale else: # Now embeddings should have the shape [batch, seq_length, dim] matching x x = x + embeddings * self.scale # Assuming broadcasting is desired over the batch and dim dimensions return x class MotionEmbeddingSpatial(nn.Module): def __init__(self, h: int = None, w: int = None, embed_dim: int = None, max_seq_length: int = 32): super().__init__() self.embed = nn.Parameter(torch.zeros(h*w, max_seq_length, embed_dim)) self.scale = 1.0 self.trained_length = -1 def set_scale(self, scale: float): self.scale = scale def set_lengths(self, trained_length: int): if trained_length > self.embed.shape[1] or trained_length <= 0: raise ValueError("Trained length is out of bounds") self.trained_length = trained_length def forward(self, x): _, seq_length, _ = x.shape # seq_length here is the target sequence length for x # Assuming self.embed is [batch, frames, dim] embeddings = self.embed[:, :seq_length] # Initial slice, may not be necessary depending on the interpolation logic # Check if interpolation is needed if self.trained_length != -1 and seq_length != self.trained_length: # Interpolate embeddings to match x's sequence length # Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames embeddings = embeddings.permute(0, 2, 1) # Now [batch, dim, frames] embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False) embeddings = embeddings.permute(0, 2, 1) # Revert to [batch, frames, dim] # Ensure the interpolated embeddings match the sequence length of x if embeddings.shape[1] != seq_length: raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}") if x.shape[0] != embeddings.shape[0]: x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale else: # Now embeddings should have the shape [batch, seq_length, dim] matching x x = x + embeddings * self.scale # Assuming broadcasting is desired over the batch and dim dimensions return x def inject_motion_embeddings(model, combinations=None, config=None): spatial_shape=np.array([config.dataset.height,config.dataset.width]) shape32 = np.ceil(spatial_shape/32).astype(int) shape16 = np.ceil(spatial_shape/16).astype(int) spatial_name = 'vSpatial' replacement_dict = {} # support for 32 frames max_seq_length = 32 inject_layers = [] for name, module in model.named_modules(): # check if the module is temp_attention PETemporal = '.temp_attentions.' in name if not(PETemporal and re.search(r'transformer_blocks\.\d+$', name)): continue if not ([name.split('_')[0], module.norm1.normalized_shape[0]] in combinations): continue replacement_dict[f'{name}.pos_embed'] = MotionEmbedding(max_seq_length=max_seq_length, embed_dim=module.norm1.normalized_shape[0]).to(dtype=model.dtype, device=model.device) replacement_keys = list(set(replacement_dict.keys())) temp_attn_list = [name.replace('pos_embed','attn1') for name in replacement_keys] + \ [name.replace('pos_embed','attn2') for name in replacement_keys] embed_dims = [replacement_dict[replacement_keys[i]].embed.shape[2] for i in range(len(replacement_keys))] for temp_attn_index,temp_attn in enumerate(temp_attn_list): place_in_net = temp_attn.split('_')[0] pattern = r'(\d+)\.temp_attentions' match = re.search(pattern, temp_attn) place_in_net = temp_attn.split('_')[0] index_in_net = match.group(1) h,w = None,None if place_in_net == 'up': if index_in_net == "1": h, w = shape32 elif index_in_net == "2": h, w = shape16 elif place_in_net == 'down': if index_in_net == "1": h, w = shape16 elif index_in_net == "2": h, w = shape32 replacement_dict[temp_attn+'.'+spatial_name] = \ MotionEmbedding( wh=h*w, embed_dim=embed_dims[temp_attn_index%len(replacement_keys)] ).to(dtype=model.dtype, device=model.device) for name, new_module in replacement_dict.items(): parent_name = name.rsplit('.', 1)[0] if '.' in name else '' module_name = name.rsplit('.', 1)[-1] parent_module = model if parent_name: parent_module = dict(model.named_modules())[parent_name] if [parent_name.split('_')[0], new_module.embed.shape[-1]] in combinations: inject_layers.append(name) setattr(parent_module, module_name, new_module) inject_layers = list(set(inject_layers)) # for name in inject_layers: # print(f"Injecting motion embedding at {name}") parameters_list = [] for name, para in model.named_parameters(): if 'pos_embed' in name or spatial_name in name: parameters_list.append(para) para.requires_grad = True else: para.requires_grad = False return parameters_list, inject_layers def save_motion_embeddings(model, file_path): # Extract motion embedding from all instances of MotionEmbedding motion_embeddings = { name: module.embed for name, module in model.named_modules() if isinstance(module, MotionEmbedding) or isinstance(module, MotionEmbeddingSpatial) } # Save the motion embeddings to the specified file path torch.save(motion_embeddings, file_path) def load_motion_embeddings(model, saved_embeddings): for key, embedding in saved_embeddings.items(): # Extract parent module and module name from the key parent_name = key.rsplit('.', 1)[0] if '.' in key else '' module_name = key.rsplit('.', 1)[-1] # Retrieve the parent module parent_module = model if parent_name: parent_module = dict(model.named_modules())[parent_name] # Create a new MotionEmbedding instance with the correct dimensions new_module = MotionEmbedding(wh = embedding.shape[0],embed_dim=embedding.shape[-1], max_seq_length=embedding.shape[-2]) # Properly assign the loaded embeddings to the 'embed' parameter wrapped in nn.Parameter # Ensure the embedding is on the correct device and has the correct dtype new_module.embed = nn.Parameter(embedding.to(dtype=model.dtype, device=model.device)) # Replace the corresponding module in the model with the new MotionEmbedding instance setattr(parent_module, module_name, new_module) def set_motion_embedding_scale(model, scale_value): # Iterate over all modules in the model for _, module in model.named_modules(): # Check if the module is an instance of MotionEmbedding if isinstance(module, MotionEmbedding): # Set the scale attribute to the specified value module.scale = scale_value def set_motion_embedding_length(model, trained_length): # Iterate over all modules in the model for _, module in model.named_modules(): # Check if the module is an instance of MotionEmbedding if isinstance(module, MotionEmbedding): # Set the length to the specified value module.trained_length = trained_length