DragNUWA / DragNUWA_net.py
yinshengming
init
ab85cf9
from utils import *
#### SVD
from dragnuwa.svd.modules.diffusionmodules.video_model_flow import VideoUNet_flow, VideoResBlock_Embed
from dragnuwa.svd.modules.diffusionmodules.denoiser import Denoiser
from dragnuwa.svd.modules.diffusionmodules.denoiser_scaling import VScalingWithEDMcNoise
from dragnuwa.svd.modules.encoders.modules import *
from dragnuwa.svd.models.autoencoder import AutoencodingEngine
from dragnuwa.svd.modules.diffusionmodules.wrappers import OpenAIWrapper
from dragnuwa.svd.modules.diffusionmodules.sampling import EulerEDMSampler
from dragnuwa.lora import inject_trainable_lora, inject_trainable_lora_extended, extract_lora_ups_down, _find_modules
def get_gaussian_kernel(kernel_size, sigma, channels):
print('parameters of gaussian kernel: kernel_size: {}, sigma: {}, channels: {}'.format(kernel_size, sigma, channels))
x_coord = torch.arange(kernel_size)
x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
y_grid = x_grid.t()
xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
mean = (kernel_size - 1)/2.
variance = sigma**2.
gaussian_kernel = torch.exp(
-torch.sum((xy_grid - mean)**2., dim=-1) /\
(2*variance)
)
gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1)
gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels,kernel_size=kernel_size, groups=channels, bias=False, padding=kernel_size//2)
gaussian_filter.weight.data = gaussian_kernel
gaussian_filter.weight.requires_grad = False
return gaussian_filter
def inject_lora(use_lora, model, replace_modules, is_extended=False, dropout=0.0, r=16):
injector = (
inject_trainable_lora if not is_extended
else
inject_trainable_lora_extended
)
params = None
negation = None
if use_lora:
REPLACE_MODULES = replace_modules
injector_args = {
"model": model,
"target_replace_module": REPLACE_MODULES,
"r": r
}
if not is_extended: injector_args['dropout_p'] = dropout
params, negation = injector(**injector_args)
for _up, _down in extract_lora_ups_down(
model,
target_replace_module=REPLACE_MODULES):
if all(x is not None for x in [_up, _down]):
print(f"Lora successfully injected into {model.__class__.__name__}.")
break
return params, negation
class Args:
### basic
fps = 4
height = 320
width = 576
### lora
unet_lora_rank = 32
### gaussian filter parameters
kernel_size = 199
sigma = 20
# model
denoiser_config = {
'scaling_config':{
'target': 'dragnuwa.svd.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise',
}
}
network_config = {
'adm_in_channels': 768, 'num_classes': 'sequential', 'use_checkpoint': True, 'in_channels': 8, 'out_channels': 4, 'model_channels': 320, 'attention_resolutions': [4, 2, 1], 'num_res_blocks': 2, 'channel_mult': [1, 2, 4, 4], 'num_head_channels': 64, 'use_linear_in_transformer': True, 'transformer_depth': 1, 'context_dim': 1024, 'spatial_transformer_attn_type': 'softmax-xformers', 'extra_ff_mix_layer': True, 'use_spatial_context': True, 'merge_strategy': 'learned_with_images', 'video_kernel_size': [3, 1, 1], 'flow_dim_scale': 1,
}
conditioner_emb_models = [
{'is_trainable': False,
'input_key': 'cond_frames_without_noise', # crossattn
'ucg_rate': 0.1,
'target': 'dragnuwa.svd.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder',
'params':{
'n_cond_frames': 1,
'n_copies': 1,
'open_clip_embedding_config': {
'target': 'dragnuwa.svd.modules.encoders.modules.FrozenOpenCLIPImageEmbedder',
'params': {
'freeze':True,
}
}
}
},
{'input_key': 'fps_id', # vector
'is_trainable': False,
'ucg_rate': 0.1,
'target': 'dragnuwa.svd.modules.encoders.modules.ConcatTimestepEmbedderND',
'params': {
'outdim': 256,
}
},
{'input_key': 'motion_bucket_id', # vector
'ucg_rate': 0.1,
'is_trainable': False,
'target': 'dragnuwa.svd.modules.encoders.modules.ConcatTimestepEmbedderND',
'params': {
'outdim': 256,
}
},
{'input_key': 'cond_frames', # concat
'is_trainable': False,
'ucg_rate': 0.1,
'target': 'dragnuwa.svd.modules.encoders.modules.VideoPredictionEmbedderWithEncoder',
'params': {
'en_and_decode_n_samples_a_time': 1,
'disable_encoder_autocast': True,
'n_cond_frames': 1,
'n_copies': 1,
'is_ae': True,
'encoder_config': {
'target': 'dragnuwa.svd.models.autoencoder.AutoencoderKLModeOnly',
'params': {
'embed_dim': 4,
'monitor': 'val/rec_loss',
'ddconfig': {
'attn_type': 'vanilla-xformers',
'double_z': True,
'z_channels': 4,
'resolution': 256,
'in_channels': 3,
'out_ch': 3,
'ch': 128,
'ch_mult': [1, 2, 4, 4],
'num_res_blocks': 2,
'attn_resolutions': [],
'dropout': 0.0,
},
'lossconfig': {
'target': 'torch.nn.Identity',
}
}
}
}
},
{'input_key': 'cond_aug', # vector
'ucg_rate': 0.1,
'is_trainable': False,
'target': 'dragnuwa.svd.modules.encoders.modules.ConcatTimestepEmbedderND',
'params': {
'outdim': 256,
}
}
]
first_stage_config = {
'loss_config': {'target': 'torch.nn.Identity'},
'regularizer_config': {'target': 'dragnuwa.svd.modules.autoencoding.regularizers.DiagonalGaussianRegularizer'},
'encoder_config':{'target': 'dragnuwa.svd.modules.diffusionmodules.model.Encoder',
'params': { 'attn_type':'vanilla',
'double_z': True,
'z_channels': 4,
'resolution': 256,
'in_channels': 3,
'out_ch': 3,
'ch': 128,
'ch_mult': [1, 2, 4, 4],
'num_res_blocks': 2,
'attn_resolutions': [],
'dropout': 0.0,
}
},
'decoder_config':{'target': 'dragnuwa.svd.modules.autoencoding.temporal_ae.VideoDecoder',
'params': {'attn_type': 'vanilla',
'double_z': True,
'z_channels': 4,
'resolution': 256,
'in_channels': 3,
'out_ch': 3,
'ch': 128,
'ch_mult': [1, 2, 4, 4],
'num_res_blocks': 2,
'attn_resolutions': [],
'dropout': 0.0,
'video_kernel_size': [3, 1, 1],
}
},
}
sampler_config = {
'discretization_config': {'target': 'dragnuwa.svd.modules.diffusionmodules.discretizer.EDMDiscretization',
'params': {'sigma_max': 700.0,},
},
'guider_config': {'target': 'dragnuwa.svd.modules.diffusionmodules.guiders.LinearPredictionGuider',
'params': {'max_scale':2.5,
'min_scale':1.0,
'num_frames':14},
},
'num_steps': 25,
}
scale_factor = 0.18215
num_frames = 14
### others
seed = 42
os.environ["PL_GLOBAL_SEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
args = Args()
def quick_freeze(model):
for name, param in model.named_parameters():
param.requires_grad = False
return model
class Net(nn.Module):
def __init__(self, args):
super(Net, self).__init__()
self.args = args
self.device = 'cpu'
### unet
model = VideoUNet_flow(**args.network_config)
self.model = OpenAIWrapper(model)
### denoiser and sampler
self.denoiser = Denoiser(**args.denoiser_config)
self.sampler = EulerEDMSampler(**args.sampler_config)
### conditioner
self.conditioner = GeneralConditioner(args.conditioner_emb_models)
### first stage model
self.first_stage_model = AutoencodingEngine(**args.first_stage_config).eval()
self.scale_factor = args.scale_factor
self.en_and_decode_n_samples_a_time = 1 # decode 1 frame each time to save GPU memory
self.num_frames = args.num_frames
self.guassian_filter = quick_freeze(get_gaussian_kernel(kernel_size=args.kernel_size, sigma=args.sigma, channels=2))
unet_lora_params, unet_negation = inject_lora(
True, self, ['OpenAIWrapper'], is_extended=False, r=args.unet_lora_rank
)
def to(self, *args, **kwargs):
model_converted = super().to(*args, **kwargs)
self.device = next(self.parameters()).device
self.sampler.device = self.device
for embedder in self.conditioner.embedders:
if hasattr(embedder, "device"):
embedder.device = self.device
return model_converted
def train(self, *args):
super().train(*args)
self.conditioner.eval()
self.first_stage_model.eval()
def apply_gaussian_filter_on_drag(self, drag):
b, l, h, w, c = drag.shape
drag = rearrange(drag, 'b l h w c -> (b l) c h w')
drag = self.guassian_filter(drag)
drag = rearrange(drag, '(b l) c h w -> b l h w c', b=b)
return drag
@torch.no_grad()
def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
n_samples = self.en_and_decode_n_samples_a_time # 1
n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = []
for n in range(n_rounds):
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
out = self.first_stage_model.decode(
z[n * n_samples : (n + 1) * n_samples], **kwargs
)
all_out.append(out)
out = torch.cat(all_out, dim=0)
return out