Spaces:
Runtime error
Runtime error
from utils import * | |
#### SVD | |
from dragnuwa.svd.modules.diffusionmodules.video_model_flow import VideoUNet_flow, VideoResBlock_Embed | |
from dragnuwa.svd.modules.diffusionmodules.denoiser import Denoiser | |
from dragnuwa.svd.modules.diffusionmodules.denoiser_scaling import VScalingWithEDMcNoise | |
from dragnuwa.svd.modules.encoders.modules import * | |
from dragnuwa.svd.models.autoencoder import AutoencodingEngine | |
from dragnuwa.svd.modules.diffusionmodules.wrappers import OpenAIWrapper | |
from dragnuwa.svd.modules.diffusionmodules.sampling import EulerEDMSampler | |
from dragnuwa.lora import inject_trainable_lora, inject_trainable_lora_extended, extract_lora_ups_down, _find_modules | |
def get_gaussian_kernel(kernel_size, sigma, channels): | |
print('parameters of gaussian kernel: kernel_size: {}, sigma: {}, channels: {}'.format(kernel_size, sigma, channels)) | |
x_coord = torch.arange(kernel_size) | |
x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) | |
y_grid = x_grid.t() | |
xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() | |
mean = (kernel_size - 1)/2. | |
variance = sigma**2. | |
gaussian_kernel = torch.exp( | |
-torch.sum((xy_grid - mean)**2., dim=-1) /\ | |
(2*variance) | |
) | |
gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) | |
gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1) | |
gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels,kernel_size=kernel_size, groups=channels, bias=False, padding=kernel_size//2) | |
gaussian_filter.weight.data = gaussian_kernel | |
gaussian_filter.weight.requires_grad = False | |
return gaussian_filter | |
def inject_lora(use_lora, model, replace_modules, is_extended=False, dropout=0.0, r=16): | |
injector = ( | |
inject_trainable_lora if not is_extended | |
else | |
inject_trainable_lora_extended | |
) | |
params = None | |
negation = None | |
if use_lora: | |
REPLACE_MODULES = replace_modules | |
injector_args = { | |
"model": model, | |
"target_replace_module": REPLACE_MODULES, | |
"r": r | |
} | |
if not is_extended: injector_args['dropout_p'] = dropout | |
params, negation = injector(**injector_args) | |
for _up, _down in extract_lora_ups_down( | |
model, | |
target_replace_module=REPLACE_MODULES): | |
if all(x is not None for x in [_up, _down]): | |
print(f"Lora successfully injected into {model.__class__.__name__}.") | |
break | |
return params, negation | |
class Args: | |
### basic | |
fps = 4 | |
height = 320 | |
width = 576 | |
### lora | |
unet_lora_rank = 32 | |
### gaussian filter parameters | |
kernel_size = 199 | |
sigma = 20 | |
# model | |
denoiser_config = { | |
'scaling_config':{ | |
'target': 'dragnuwa.svd.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise', | |
} | |
} | |
network_config = { | |
'adm_in_channels': 768, 'num_classes': 'sequential', 'use_checkpoint': True, 'in_channels': 8, 'out_channels': 4, 'model_channels': 320, 'attention_resolutions': [4, 2, 1], 'num_res_blocks': 2, 'channel_mult': [1, 2, 4, 4], 'num_head_channels': 64, 'use_linear_in_transformer': True, 'transformer_depth': 1, 'context_dim': 1024, 'spatial_transformer_attn_type': 'softmax-xformers', 'extra_ff_mix_layer': True, 'use_spatial_context': True, 'merge_strategy': 'learned_with_images', 'video_kernel_size': [3, 1, 1], 'flow_dim_scale': 1, | |
} | |
conditioner_emb_models = [ | |
{'is_trainable': False, | |
'input_key': 'cond_frames_without_noise', # crossattn | |
'ucg_rate': 0.1, | |
'target': 'dragnuwa.svd.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder', | |
'params':{ | |
'n_cond_frames': 1, | |
'n_copies': 1, | |
'open_clip_embedding_config': { | |
'target': 'dragnuwa.svd.modules.encoders.modules.FrozenOpenCLIPImageEmbedder', | |
'params': { | |
'freeze':True, | |
} | |
} | |
} | |
}, | |
{'input_key': 'fps_id', # vector | |
'is_trainable': False, | |
'ucg_rate': 0.1, | |
'target': 'dragnuwa.svd.modules.encoders.modules.ConcatTimestepEmbedderND', | |
'params': { | |
'outdim': 256, | |
} | |
}, | |
{'input_key': 'motion_bucket_id', # vector | |
'ucg_rate': 0.1, | |
'is_trainable': False, | |
'target': 'dragnuwa.svd.modules.encoders.modules.ConcatTimestepEmbedderND', | |
'params': { | |
'outdim': 256, | |
} | |
}, | |
{'input_key': 'cond_frames', # concat | |
'is_trainable': False, | |
'ucg_rate': 0.1, | |
'target': 'dragnuwa.svd.modules.encoders.modules.VideoPredictionEmbedderWithEncoder', | |
'params': { | |
'en_and_decode_n_samples_a_time': 1, | |
'disable_encoder_autocast': True, | |
'n_cond_frames': 1, | |
'n_copies': 1, | |
'is_ae': True, | |
'encoder_config': { | |
'target': 'dragnuwa.svd.models.autoencoder.AutoencoderKLModeOnly', | |
'params': { | |
'embed_dim': 4, | |
'monitor': 'val/rec_loss', | |
'ddconfig': { | |
'attn_type': 'vanilla-xformers', | |
'double_z': True, | |
'z_channels': 4, | |
'resolution': 256, | |
'in_channels': 3, | |
'out_ch': 3, | |
'ch': 128, | |
'ch_mult': [1, 2, 4, 4], | |
'num_res_blocks': 2, | |
'attn_resolutions': [], | |
'dropout': 0.0, | |
}, | |
'lossconfig': { | |
'target': 'torch.nn.Identity', | |
} | |
} | |
} | |
} | |
}, | |
{'input_key': 'cond_aug', # vector | |
'ucg_rate': 0.1, | |
'is_trainable': False, | |
'target': 'dragnuwa.svd.modules.encoders.modules.ConcatTimestepEmbedderND', | |
'params': { | |
'outdim': 256, | |
} | |
} | |
] | |
first_stage_config = { | |
'loss_config': {'target': 'torch.nn.Identity'}, | |
'regularizer_config': {'target': 'dragnuwa.svd.modules.autoencoding.regularizers.DiagonalGaussianRegularizer'}, | |
'encoder_config':{'target': 'dragnuwa.svd.modules.diffusionmodules.model.Encoder', | |
'params': { 'attn_type':'vanilla', | |
'double_z': True, | |
'z_channels': 4, | |
'resolution': 256, | |
'in_channels': 3, | |
'out_ch': 3, | |
'ch': 128, | |
'ch_mult': [1, 2, 4, 4], | |
'num_res_blocks': 2, | |
'attn_resolutions': [], | |
'dropout': 0.0, | |
} | |
}, | |
'decoder_config':{'target': 'dragnuwa.svd.modules.autoencoding.temporal_ae.VideoDecoder', | |
'params': {'attn_type': 'vanilla', | |
'double_z': True, | |
'z_channels': 4, | |
'resolution': 256, | |
'in_channels': 3, | |
'out_ch': 3, | |
'ch': 128, | |
'ch_mult': [1, 2, 4, 4], | |
'num_res_blocks': 2, | |
'attn_resolutions': [], | |
'dropout': 0.0, | |
'video_kernel_size': [3, 1, 1], | |
} | |
}, | |
} | |
sampler_config = { | |
'discretization_config': {'target': 'dragnuwa.svd.modules.diffusionmodules.discretizer.EDMDiscretization', | |
'params': {'sigma_max': 700.0,}, | |
}, | |
'guider_config': {'target': 'dragnuwa.svd.modules.diffusionmodules.guiders.LinearPredictionGuider', | |
'params': {'max_scale':2.5, | |
'min_scale':1.0, | |
'num_frames':14}, | |
}, | |
'num_steps': 25, | |
} | |
scale_factor = 0.18215 | |
num_frames = 14 | |
### others | |
seed = 42 | |
os.environ["PL_GLOBAL_SEED"] = str(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
args = Args() | |
def quick_freeze(model): | |
for name, param in model.named_parameters(): | |
param.requires_grad = False | |
return model | |
class Net(nn.Module): | |
def __init__(self, args): | |
super(Net, self).__init__() | |
self.args = args | |
self.device = 'cpu' | |
### unet | |
model = VideoUNet_flow(**args.network_config) | |
self.model = OpenAIWrapper(model) | |
### denoiser and sampler | |
self.denoiser = Denoiser(**args.denoiser_config) | |
self.sampler = EulerEDMSampler(**args.sampler_config) | |
### conditioner | |
self.conditioner = GeneralConditioner(args.conditioner_emb_models) | |
### first stage model | |
self.first_stage_model = AutoencodingEngine(**args.first_stage_config).eval() | |
self.scale_factor = args.scale_factor | |
self.en_and_decode_n_samples_a_time = 1 # decode 1 frame each time to save GPU memory | |
self.num_frames = args.num_frames | |
self.guassian_filter = quick_freeze(get_gaussian_kernel(kernel_size=args.kernel_size, sigma=args.sigma, channels=2)) | |
unet_lora_params, unet_negation = inject_lora( | |
True, self, ['OpenAIWrapper'], is_extended=False, r=args.unet_lora_rank | |
) | |
def to(self, *args, **kwargs): | |
model_converted = super().to(*args, **kwargs) | |
self.device = next(self.parameters()).device | |
self.sampler.device = self.device | |
for embedder in self.conditioner.embedders: | |
if hasattr(embedder, "device"): | |
embedder.device = self.device | |
return model_converted | |
def train(self, *args): | |
super().train(*args) | |
self.conditioner.eval() | |
self.first_stage_model.eval() | |
def apply_gaussian_filter_on_drag(self, drag): | |
b, l, h, w, c = drag.shape | |
drag = rearrange(drag, 'b l h w c -> (b l) c h w') | |
drag = self.guassian_filter(drag) | |
drag = rearrange(drag, '(b l) c h w -> b l h w c', b=b) | |
return drag | |
def decode_first_stage(self, z): | |
z = 1.0 / self.scale_factor * z | |
n_samples = self.en_and_decode_n_samples_a_time # 1 | |
n_rounds = math.ceil(z.shape[0] / n_samples) | |
all_out = [] | |
for n in range(n_rounds): | |
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} | |
out = self.first_stage_model.decode( | |
z[n * n_samples : (n + 1) * n_samples], **kwargs | |
) | |
all_out.append(out) | |
out = torch.cat(all_out, dim=0) | |
return out |