MotionInversion / utils /func_utils.py
ziyangmai's picture
page demo
113884e
import torch
import random
import cv2
import fnmatch
import torch.nn.functional as F
from torchvision import transforms
import torchvision.transforms.functional as TF
from diffusers.optimization import get_scheduler
from einops import rearrange, repeat
from omegaconf import OmegaConf
from dataset import *
from models.unet.motion_embeddings import *
from .lora import *
from .lora_handler import *
def find_videos(directory, extensions=('.mp4', '.mkv', '.avi', '.mov', '.flv', '.wmv', '.gif')):
video_files = []
for root, dirs, files in os.walk(directory):
for extension in extensions:
for filename in fnmatch.filter(files, '*' + extension):
video_files.append(os.path.join(root, filename))
return video_files
def param_optim(model, condition, extra_params=None, is_lora=False, negation=None):
extra_params = extra_params if len(extra_params.keys()) > 0 else None
return {
"model": model,
"condition": condition,
'extra_params': extra_params,
'is_lora': is_lora,
"negation": negation
}
def create_optim_params(name='param', params=None, lr=5e-6, extra_params=None):
params = {
"name": name,
"params": params,
"lr": lr
}
if extra_params is not None:
for k, v in extra_params.items():
params[k] = v
return params
def create_optimizer_params(model_list, lr):
import itertools
optimizer_params = []
for optim in model_list:
model, condition, extra_params, is_lora, negation = optim.values()
# Check if we are doing LoRA training.
if is_lora and condition and isinstance(model, list):
params = create_optim_params(
params=itertools.chain(*model),
extra_params=extra_params
)
optimizer_params.append(params)
continue
if is_lora and condition and not isinstance(model, list):
for n, p in model.named_parameters():
if 'lora' in n:
params = create_optim_params(n, p, lr, extra_params)
optimizer_params.append(params)
continue
# If this is true, we can train it.
if condition:
for n, p in model.named_parameters():
should_negate = 'lora' in n and not is_lora
if should_negate: continue
params = create_optim_params(n, p, lr, extra_params)
optimizer_params.append(params)
return optimizer_params
def get_optimizer(use_8bit_adam):
if use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
return bnb.optim.AdamW8bit
else:
return torch.optim.AdamW
# Initialize the optimizer
def prepare_optimizers(params, config, **extra_params):
optimizer_cls = get_optimizer(config.train.use_8bit_adam)
optimizer_temporal = optimizer_cls(
params,
lr=config.loss.learning_rate
)
lr_scheduler_temporal = get_scheduler(
config.loss.lr_scheduler,
optimizer=optimizer_temporal,
num_warmup_steps=config.loss.lr_warmup_steps * config.train.gradient_accumulation_steps,
num_training_steps=config.train.max_train_steps * config.train.gradient_accumulation_steps,
)
# Insert Spatial LoRAs
if config.loss.type == 'DebiasedHybrid':
unet_lora_params_spatial_list = extra_params.get('unet_lora_params_spatial_list', [])
spatial_lora_num = extra_params.get('spatial_lora_num', 1)
optimizer_spatial_list = []
lr_scheduler_spatial_list = []
for i in range(spatial_lora_num):
unet_lora_params_spatial = unet_lora_params_spatial_list[i]
optimizer_spatial = optimizer_cls(
create_optimizer_params(
[
param_optim(
unet_lora_params_spatial,
config.loss.use_unet_lora,
is_lora=True,
extra_params={**{"lr": config.loss.learning_rate_spatial}}
)
],
config.loss.learning_rate_spatial
),
lr=config.loss.learning_rate_spatial
)
optimizer_spatial_list.append(optimizer_spatial)
# Scheduler
lr_scheduler_spatial = get_scheduler(
config.loss.lr_scheduler,
optimizer=optimizer_spatial,
num_warmup_steps=config.loss.lr_warmup_steps * config.train.gradient_accumulation_steps,
num_training_steps=config.train.max_train_steps * config.train.gradient_accumulation_steps,
)
lr_scheduler_spatial_list.append(lr_scheduler_spatial)
else:
optimizer_spatial_list = []
lr_scheduler_spatial_list = []
return [optimizer_temporal] + optimizer_spatial_list, [lr_scheduler_temporal] + lr_scheduler_spatial_list
def sample_noise(latents, noise_strength, use_offset_noise=False):
b, c, f, *_ = latents.shape
noise_latents = torch.randn_like(latents, device=latents.device)
if use_offset_noise:
offset_noise = torch.randn(b, c, f, 1, 1, device=latents.device)
noise_latents = noise_latents + noise_strength * offset_noise
return noise_latents
@torch.no_grad()
def tensor_to_vae_latent(t, vae):
video_length = t.shape[1]
t = rearrange(t, "b f c h w -> (b f) c h w")
latents = vae.encode(t).latent_dist.sample()
latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
latents = latents * 0.18215
return latents
def prepare_data(config, tokenizer):
# Get the training dataset based on types (json, single_video, image)
# Assuming config.dataset is a DictConfig object
dataset_params_dict = OmegaConf.to_container(config.dataset, resolve=True)
# Remove the 'type' key
dataset_params_dict.pop('type', None) # 'None' ensures no error if 'type' key doesn't exist
train_datasets = []
# Loop through all available datasets, get the name, then add to list of data to process.
for DataSet in [VideoJsonDataset, SingleVideoDataset, ImageDataset, VideoFolderDataset]:
for dataset in config.dataset.type:
if dataset == DataSet.__getname__():
train_datasets.append(DataSet(**dataset_params_dict, tokenizer=tokenizer))
if len(train_datasets) < 0:
raise ValueError("Dataset type not found: 'json', 'single_video', 'folder', 'image'")
train_dataset = train_datasets[0]
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.train.train_batch_size,
shuffle=True
)
return train_dataloader, train_dataset
# create parameters for optimziation
def prepare_params(unet, config, train_dataset):
extra_params = {}
params,embedding_layers = inject_motion_embeddings(
unet,
combinations=config.model.motion_embeddings.combinations,
config=config
)
config.model.embedding_layers = embedding_layers
if config.loss.type == "DebiasedHybrid":
if config.loss.spatial_lora_num == -1:
config.loss.spatial_lora_num = train_dataset.__len__()
lora_managers_spatial, unet_lora_params_spatial_list, unet_negation_all = inject_spatial_loras(
unet=unet,
use_unet_lora=True,
lora_unet_dropout=0.1,
lora_path='',
lora_rank=32,
spatial_lora_num=1,
)
extra_params['lora_managers_spatial'] = lora_managers_spatial
extra_params['unet_lora_params_spatial_list'] = unet_lora_params_spatial_list
extra_params['unet_negation_all'] = unet_negation_all
return params, extra_params