Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import random | |
import cv2 | |
import fnmatch | |
import torch.nn.functional as F | |
from torchvision import transforms | |
import torchvision.transforms.functional as TF | |
from diffusers.optimization import get_scheduler | |
from einops import rearrange, repeat | |
from omegaconf import OmegaConf | |
from dataset import * | |
from models.unet.motion_embeddings import * | |
from .lora import * | |
from .lora_handler import * | |
def find_videos(directory, extensions=('.mp4', '.mkv', '.avi', '.mov', '.flv', '.wmv', '.gif')): | |
video_files = [] | |
for root, dirs, files in os.walk(directory): | |
for extension in extensions: | |
for filename in fnmatch.filter(files, '*' + extension): | |
video_files.append(os.path.join(root, filename)) | |
return video_files | |
def param_optim(model, condition, extra_params=None, is_lora=False, negation=None): | |
extra_params = extra_params if len(extra_params.keys()) > 0 else None | |
return { | |
"model": model, | |
"condition": condition, | |
'extra_params': extra_params, | |
'is_lora': is_lora, | |
"negation": negation | |
} | |
def create_optim_params(name='param', params=None, lr=5e-6, extra_params=None): | |
params = { | |
"name": name, | |
"params": params, | |
"lr": lr | |
} | |
if extra_params is not None: | |
for k, v in extra_params.items(): | |
params[k] = v | |
return params | |
def create_optimizer_params(model_list, lr): | |
import itertools | |
optimizer_params = [] | |
for optim in model_list: | |
model, condition, extra_params, is_lora, negation = optim.values() | |
# Check if we are doing LoRA training. | |
if is_lora and condition and isinstance(model, list): | |
params = create_optim_params( | |
params=itertools.chain(*model), | |
extra_params=extra_params | |
) | |
optimizer_params.append(params) | |
continue | |
if is_lora and condition and not isinstance(model, list): | |
for n, p in model.named_parameters(): | |
if 'lora' in n: | |
params = create_optim_params(n, p, lr, extra_params) | |
optimizer_params.append(params) | |
continue | |
# If this is true, we can train it. | |
if condition: | |
for n, p in model.named_parameters(): | |
should_negate = 'lora' in n and not is_lora | |
if should_negate: continue | |
params = create_optim_params(n, p, lr, extra_params) | |
optimizer_params.append(params) | |
return optimizer_params | |
def get_optimizer(use_8bit_adam): | |
if use_8bit_adam: | |
try: | |
import bitsandbytes as bnb | |
except ImportError: | |
raise ImportError( | |
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" | |
) | |
return bnb.optim.AdamW8bit | |
else: | |
return torch.optim.AdamW | |
# Initialize the optimizer | |
def prepare_optimizers(params, config, **extra_params): | |
optimizer_cls = get_optimizer(config.train.use_8bit_adam) | |
optimizer_temporal = optimizer_cls( | |
params, | |
lr=config.loss.learning_rate | |
) | |
lr_scheduler_temporal = get_scheduler( | |
config.loss.lr_scheduler, | |
optimizer=optimizer_temporal, | |
num_warmup_steps=config.loss.lr_warmup_steps * config.train.gradient_accumulation_steps, | |
num_training_steps=config.train.max_train_steps * config.train.gradient_accumulation_steps, | |
) | |
# Insert Spatial LoRAs | |
if config.loss.type == 'DebiasedHybrid': | |
unet_lora_params_spatial_list = extra_params.get('unet_lora_params_spatial_list', []) | |
spatial_lora_num = extra_params.get('spatial_lora_num', 1) | |
optimizer_spatial_list = [] | |
lr_scheduler_spatial_list = [] | |
for i in range(spatial_lora_num): | |
unet_lora_params_spatial = unet_lora_params_spatial_list[i] | |
optimizer_spatial = optimizer_cls( | |
create_optimizer_params( | |
[ | |
param_optim( | |
unet_lora_params_spatial, | |
config.loss.use_unet_lora, | |
is_lora=True, | |
extra_params={**{"lr": config.loss.learning_rate_spatial}} | |
) | |
], | |
config.loss.learning_rate_spatial | |
), | |
lr=config.loss.learning_rate_spatial | |
) | |
optimizer_spatial_list.append(optimizer_spatial) | |
# Scheduler | |
lr_scheduler_spatial = get_scheduler( | |
config.loss.lr_scheduler, | |
optimizer=optimizer_spatial, | |
num_warmup_steps=config.loss.lr_warmup_steps * config.train.gradient_accumulation_steps, | |
num_training_steps=config.train.max_train_steps * config.train.gradient_accumulation_steps, | |
) | |
lr_scheduler_spatial_list.append(lr_scheduler_spatial) | |
else: | |
optimizer_spatial_list = [] | |
lr_scheduler_spatial_list = [] | |
return [optimizer_temporal] + optimizer_spatial_list, [lr_scheduler_temporal] + lr_scheduler_spatial_list | |
def sample_noise(latents, noise_strength, use_offset_noise=False): | |
b, c, f, *_ = latents.shape | |
noise_latents = torch.randn_like(latents, device=latents.device) | |
if use_offset_noise: | |
offset_noise = torch.randn(b, c, f, 1, 1, device=latents.device) | |
noise_latents = noise_latents + noise_strength * offset_noise | |
return noise_latents | |
def tensor_to_vae_latent(t, vae): | |
video_length = t.shape[1] | |
t = rearrange(t, "b f c h w -> (b f) c h w") | |
latents = vae.encode(t).latent_dist.sample() | |
latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) | |
latents = latents * 0.18215 | |
return latents | |
def prepare_data(config, tokenizer): | |
# Get the training dataset based on types (json, single_video, image) | |
# Assuming config.dataset is a DictConfig object | |
dataset_params_dict = OmegaConf.to_container(config.dataset, resolve=True) | |
# Remove the 'type' key | |
dataset_params_dict.pop('type', None) # 'None' ensures no error if 'type' key doesn't exist | |
train_datasets = [] | |
# Loop through all available datasets, get the name, then add to list of data to process. | |
for DataSet in [VideoJsonDataset, SingleVideoDataset, ImageDataset, VideoFolderDataset]: | |
for dataset in config.dataset.type: | |
if dataset == DataSet.__getname__(): | |
train_datasets.append(DataSet(**dataset_params_dict, tokenizer=tokenizer)) | |
if len(train_datasets) < 0: | |
raise ValueError("Dataset type not found: 'json', 'single_video', 'folder', 'image'") | |
train_dataset = train_datasets[0] | |
train_dataloader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=config.train.train_batch_size, | |
shuffle=True | |
) | |
return train_dataloader, train_dataset | |
# create parameters for optimziation | |
def prepare_params(unet, config, train_dataset): | |
extra_params = {} | |
params,embedding_layers = inject_motion_embeddings( | |
unet, | |
combinations=config.model.motion_embeddings.combinations, | |
config=config | |
) | |
config.model.embedding_layers = embedding_layers | |
if config.loss.type == "DebiasedHybrid": | |
if config.loss.spatial_lora_num == -1: | |
config.loss.spatial_lora_num = train_dataset.__len__() | |
lora_managers_spatial, unet_lora_params_spatial_list, unet_negation_all = inject_spatial_loras( | |
unet=unet, | |
use_unet_lora=True, | |
lora_unet_dropout=0.1, | |
lora_path='', | |
lora_rank=32, | |
spatial_lora_num=1, | |
) | |
extra_params['lora_managers_spatial'] = lora_managers_spatial | |
extra_params['unet_lora_params_spatial_list'] = unet_lora_params_spatial_list | |
extra_params['unet_negation_all'] = unet_negation_all | |
return params, extra_params |