zxl
first commit
07c6a04
raw
history blame
30.8 kB
# Adapted from Open-Sora-Plan
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
# --------------------------------------------------------
import glob
import importlib
import os
from typing import Optional, Tuple, Union
import numpy as np
import torch
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from einops import rearrange
from torch import nn
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
def tensor_to_video(x):
x = x.detach().cpu()
x = torch.clamp(x, -1, 1)
x = (x + 1) / 2
x = x.permute(1, 0, 2, 3).float().numpy() # c t h w ->
x = (255 * x).astype(np.uint8)
return x
def nonlinearity(x):
return x * torch.sigmoid(x)
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
def mode(self):
return self.mean
def resolve_str_to_obj(str_val, append=True):
if append:
str_val = "videosys.models.open_sora_plan.modules." + str_val
if "opensora.models.ae.videobase." in str_val:
str_val = str_val.replace("opensora.models.ae.videobase.", "videosys.models.open_sora_plan.")
module_name, class_name = str_val.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
class VideoBaseAE_PL(ModelMixin, ConfigMixin):
config_name = "config.json"
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def encode(self, x: torch.Tensor, *args, **kwargs):
pass
def decode(self, encoding: torch.Tensor, *args, **kwargs):
pass
@property
def num_training_steps(self) -> int:
"""Total training steps inferred from datamodule and devices."""
if self.trainer.max_steps:
return self.trainer.max_steps
limit_batches = self.trainer.limit_train_batches
batches = len(self.train_dataloader())
batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches)
num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
if self.trainer.tpu_cores:
num_devices = max(num_devices, self.trainer.tpu_cores)
effective_accum = self.trainer.accumulate_grad_batches * num_devices
return (batches // effective_accum) * self.trainer.max_epochs
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
ckpt_files = glob.glob(os.path.join(pretrained_model_name_or_path, "*.ckpt"))
if ckpt_files:
# Adapt to PyTorch Lightning
last_ckpt_file = ckpt_files[-1]
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
model = cls.from_config(config_file)
print("init from {}".format(last_ckpt_file))
model.init_from_ckpt(last_ckpt_file)
return model
else:
print(f"Loading model from {pretrained_model_name_or_path}")
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
class Encoder(nn.Module):
def __init__(
self,
z_channels: int,
hidden_size: int,
hidden_size_mult: Tuple[int] = (1, 2, 4, 4),
attn_resolutions: Tuple[int] = (16,),
conv_in: str = "Conv2d",
conv_out: str = "CasualConv3d",
attention: str = "AttnBlock",
resnet_blocks: Tuple[str] = (
"ResnetBlock2D",
"ResnetBlock2D",
"ResnetBlock2D",
"ResnetBlock3D",
),
spatial_downsample: Tuple[str] = (
"Downsample",
"Downsample",
"Downsample",
"",
),
temporal_downsample: Tuple[str] = ("", "", "TimeDownsampleRes2x", ""),
mid_resnet: str = "ResnetBlock3D",
dropout: float = 0.0,
resolution: int = 256,
num_res_blocks: int = 2,
double_z: bool = True,
) -> None:
super().__init__()
assert len(resnet_blocks) == len(hidden_size_mult), print(hidden_size_mult, resnet_blocks)
# ---- Config ----
self.num_resolutions = len(hidden_size_mult)
self.resolution = resolution
self.num_res_blocks = num_res_blocks
# ---- In ----
self.conv_in = resolve_str_to_obj(conv_in)(3, hidden_size, kernel_size=3, stride=1, padding=1)
# ---- Downsample ----
curr_res = resolution
in_ch_mult = (1,) + tuple(hidden_size_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = hidden_size * in_ch_mult[i_level]
block_out = hidden_size * hidden_size_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
resolve_str_to_obj(resnet_blocks[i_level])(
in_channels=block_in,
out_channels=block_out,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(resolve_str_to_obj(attention)(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if spatial_downsample[i_level]:
down.downsample = resolve_str_to_obj(spatial_downsample[i_level])(block_in, block_in)
curr_res = curr_res // 2
if temporal_downsample[i_level]:
down.time_downsample = resolve_str_to_obj(temporal_downsample[i_level])(block_in, block_in)
self.down.append(down)
# ---- Mid ----
self.mid = nn.Module()
self.mid.block_1 = resolve_str_to_obj(mid_resnet)(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
)
self.mid.attn_1 = resolve_str_to_obj(attention)(block_in)
self.mid.block_2 = resolve_str_to_obj(mid_resnet)(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
)
# ---- Out ----
self.norm_out = Normalize(block_in)
self.conv_out = resolve_str_to_obj(conv_out)(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1,
)
def forward(self, x):
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if hasattr(self.down[i_level], "downsample"):
hs.append(self.down[i_level].downsample(hs[-1]))
if hasattr(self.down[i_level], "time_downsample"):
hs_down = self.down[i_level].time_downsample(hs[-1])
hs.append(hs_down)
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
z_channels: int,
hidden_size: int,
hidden_size_mult: Tuple[int] = (1, 2, 4, 4),
attn_resolutions: Tuple[int] = (16,),
conv_in: str = "Conv2d",
conv_out: str = "CasualConv3d",
attention: str = "AttnBlock",
resnet_blocks: Tuple[str] = (
"ResnetBlock3D",
"ResnetBlock3D",
"ResnetBlock3D",
"ResnetBlock3D",
),
spatial_upsample: Tuple[str] = (
"",
"SpatialUpsample2x",
"SpatialUpsample2x",
"SpatialUpsample2x",
),
temporal_upsample: Tuple[str] = ("", "", "", "TimeUpsampleRes2x"),
mid_resnet: str = "ResnetBlock3D",
dropout: float = 0.0,
resolution: int = 256,
num_res_blocks: int = 2,
):
super().__init__()
# ---- Config ----
self.num_resolutions = len(hidden_size_mult)
self.resolution = resolution
self.num_res_blocks = num_res_blocks
# ---- In ----
block_in = hidden_size * hidden_size_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.conv_in = resolve_str_to_obj(conv_in)(z_channels, block_in, kernel_size=3, padding=1)
# ---- Mid ----
self.mid = nn.Module()
self.mid.block_1 = resolve_str_to_obj(mid_resnet)(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
)
self.mid.attn_1 = resolve_str_to_obj(attention)(block_in)
self.mid.block_2 = resolve_str_to_obj(mid_resnet)(
in_channels=block_in,
out_channels=block_in,
dropout=dropout,
)
# ---- Upsample ----
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = hidden_size * hidden_size_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
resolve_str_to_obj(resnet_blocks[i_level])(
in_channels=block_in,
out_channels=block_out,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(resolve_str_to_obj(attention)(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if spatial_upsample[i_level]:
up.upsample = resolve_str_to_obj(spatial_upsample[i_level])(block_in, block_in)
curr_res = curr_res * 2
if temporal_upsample[i_level]:
up.time_upsample = resolve_str_to_obj(temporal_upsample[i_level])(block_in, block_in)
self.up.insert(0, up)
# ---- Out ----
self.norm_out = Normalize(block_in)
self.conv_out = resolve_str_to_obj(conv_out)(block_in, 3, kernel_size=3, padding=1)
def forward(self, z):
h = self.conv_in(z)
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if hasattr(self.up[i_level], "upsample"):
h = self.up[i_level].upsample(h)
if hasattr(self.up[i_level], "time_upsample"):
h = self.up[i_level].time_upsample(h)
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class CausalVAEModel(VideoBaseAE_PL):
@register_to_config
def __init__(
self,
lr: float = 1e-5,
hidden_size: int = 128,
z_channels: int = 4,
hidden_size_mult: Tuple[int] = (1, 2, 4, 4),
attn_resolutions: Tuple[int] = [],
dropout: float = 0.0,
resolution: int = 256,
double_z: bool = True,
embed_dim: int = 4,
num_res_blocks: int = 2,
loss_type: str = "opensora.models.ae.videobase.losses.LPIPSWithDiscriminator",
loss_params: dict = {
"kl_weight": 0.000001,
"logvar_init": 0.0,
"disc_start": 2001,
"disc_weight": 0.5,
},
q_conv: str = "CausalConv3d",
encoder_conv_in: str = "CausalConv3d",
encoder_conv_out: str = "CausalConv3d",
encoder_attention: str = "AttnBlock3D",
encoder_resnet_blocks: Tuple[str] = (
"ResnetBlock3D",
"ResnetBlock3D",
"ResnetBlock3D",
"ResnetBlock3D",
),
encoder_spatial_downsample: Tuple[str] = (
"SpatialDownsample2x",
"SpatialDownsample2x",
"SpatialDownsample2x",
"",
),
encoder_temporal_downsample: Tuple[str] = (
"",
"TimeDownsample2x",
"TimeDownsample2x",
"",
),
encoder_mid_resnet: str = "ResnetBlock3D",
decoder_conv_in: str = "CausalConv3d",
decoder_conv_out: str = "CausalConv3d",
decoder_attention: str = "AttnBlock3D",
decoder_resnet_blocks: Tuple[str] = (
"ResnetBlock3D",
"ResnetBlock3D",
"ResnetBlock3D",
"ResnetBlock3D",
),
decoder_spatial_upsample: Tuple[str] = (
"",
"SpatialUpsample2x",
"SpatialUpsample2x",
"SpatialUpsample2x",
),
decoder_temporal_upsample: Tuple[str] = ("", "", "TimeUpsample2x", "TimeUpsample2x"),
decoder_mid_resnet: str = "ResnetBlock3D",
) -> None:
super().__init__()
self.tile_sample_min_size = 256
self.tile_sample_min_size_t = 65
self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1)))
t_down_ratio = [i for i in encoder_temporal_downsample if len(i) > 0]
self.tile_latent_min_size_t = int((self.tile_sample_min_size_t - 1) / (2 ** len(t_down_ratio))) + 1
self.tile_overlap_factor = 0.25
self.use_tiling = False
self.learning_rate = lr
self.lr_g_factor = 1.0
self.loss = resolve_str_to_obj(loss_type, append=False)(**loss_params)
self.encoder = Encoder(
z_channels=z_channels,
hidden_size=hidden_size,
hidden_size_mult=hidden_size_mult,
attn_resolutions=attn_resolutions,
conv_in=encoder_conv_in,
conv_out=encoder_conv_out,
attention=encoder_attention,
resnet_blocks=encoder_resnet_blocks,
spatial_downsample=encoder_spatial_downsample,
temporal_downsample=encoder_temporal_downsample,
mid_resnet=encoder_mid_resnet,
dropout=dropout,
resolution=resolution,
num_res_blocks=num_res_blocks,
double_z=double_z,
)
self.decoder = Decoder(
z_channels=z_channels,
hidden_size=hidden_size,
hidden_size_mult=hidden_size_mult,
attn_resolutions=attn_resolutions,
conv_in=decoder_conv_in,
conv_out=decoder_conv_out,
attention=decoder_attention,
resnet_blocks=decoder_resnet_blocks,
spatial_upsample=decoder_spatial_upsample,
temporal_upsample=decoder_temporal_upsample,
mid_resnet=decoder_mid_resnet,
dropout=dropout,
resolution=resolution,
num_res_blocks=num_res_blocks,
)
quant_conv_cls = resolve_str_to_obj(q_conv)
self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1)
self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1)
if hasattr(self.loss, "discriminator"):
self.automatic_optimization = False
def encode(self, x):
if self.use_tiling and (
x.shape[-1] > self.tile_sample_min_size
or x.shape[-2] > self.tile_sample_min_size
or x.shape[-3] > self.tile_sample_min_size_t
):
return self.tiled_encode(x)
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
if self.use_tiling and (
z.shape[-1] > self.tile_latent_min_size
or z.shape[-2] > self.tile_latent_min_size
or z.shape[-3] > self.tile_latent_min_size_t
):
return self.tiled_decode(z)
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx):
if hasattr(self.loss, "discriminator"):
return self._training_step_gan(batch, batch_idx=batch_idx)
else:
return self._training_step(batch, batch_idx=batch_idx)
def _training_step(self, batch, batch_idx):
inputs = self.get_input(batch, "video")
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
split="train",
)
self.log(
"aeloss",
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss
def _training_step_gan(self, batch, batch_idx):
inputs = self.get_input(batch, "video")
reconstructions, posterior = self(inputs)
opt1, opt2 = self.optimizers()
# ---- AE Loss ----
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log(
"aeloss",
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
opt1.zero_grad()
self.manual_backward(aeloss)
self.clip_gradients(opt1, gradient_clip_val=1, gradient_clip_algorithm="norm")
opt1.step()
# ---- GAN Loss ----
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log(
"discloss",
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
opt2.zero_grad()
self.manual_backward(discloss)
self.clip_gradients(opt2, gradient_clip_val=1, gradient_clip_algorithm="norm")
opt2.step()
self.log_dict(
{**log_dict_ae, **log_dict_disc},
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False,
)
def configure_optimizers(self):
from itertools import chain
lr = self.learning_rate
modules_to_train = [
self.encoder.named_parameters(),
self.decoder.named_parameters(),
self.post_quant_conv.named_parameters(),
self.quant_conv.named_parameters(),
]
params_with_time = []
params_without_time = []
for name, param in chain(*modules_to_train):
if "time" in name:
params_with_time.append(param)
else:
params_without_time.append(param)
optimizers = []
opt_ae = torch.optim.Adam(
[
{"params": params_with_time, "lr": lr},
{"params": params_without_time, "lr": lr},
],
lr=lr,
betas=(0.5, 0.9),
)
optimizers.append(opt_ae)
if hasattr(self.loss, "discriminator"):
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9))
optimizers.append(opt_disc)
return optimizers, []
def get_last_layer(self):
if hasattr(self.decoder.conv_out, "conv"):
return self.decoder.conv_out.conv.weight
else:
return self.decoder.conv_out.weight
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def tiled_encode(self, x):
t = x.shape[2]
t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t - 1)]
if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:
t_chunk_start_end = [[0, t]]
else:
t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)]
if t_chunk_start_end[-1][-1] > t:
t_chunk_start_end[-1][-1] = t
elif t_chunk_start_end[-1][-1] < t:
last_start_end = [t_chunk_idx[-1], t]
t_chunk_start_end.append(last_start_end)
moments = []
for idx, (start, end) in enumerate(t_chunk_start_end):
chunk_x = x[:, :, start:end]
if idx != 0:
moment = self.tiled_encode2d(chunk_x, return_moments=True)[:, :, 1:]
else:
moment = self.tiled_encode2d(chunk_x, return_moments=True)
moments.append(moment)
moments = torch.cat(moments, dim=2)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def tiled_decode(self, x):
t = x.shape[2]
t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t - 1)]
if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:
t_chunk_start_end = [[0, t]]
else:
t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)]
if t_chunk_start_end[-1][-1] > t:
t_chunk_start_end[-1][-1] = t
elif t_chunk_start_end[-1][-1] < t:
last_start_end = [t_chunk_idx[-1], t]
t_chunk_start_end.append(last_start_end)
dec_ = []
for idx, (start, end) in enumerate(t_chunk_start_end):
chunk_x = x[:, :, start:end]
if idx != 0:
dec = self.tiled_decode2d(chunk_x)[:, :, 1:]
else:
dec = self.tiled_decode2d(chunk_x)
dec_.append(dec)
dec_ = torch.cat(dec_, dim=2)
return dec_
def tiled_encode2d(self, x, return_moments=False):
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[3], overlap_size):
row = []
for j in range(0, x.shape[4], overlap_size):
tile = x[
:,
:,
:,
i : i + self.tile_sample_min_size,
j : j + self.tile_sample_min_size,
]
tile = self.encoder(tile)
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=4))
moments = torch.cat(result_rows, dim=3)
posterior = DiagonalGaussianDistribution(moments)
if return_moments:
return moments
return posterior
def tiled_decode2d(self, z):
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[3], overlap_size):
row = []
for j in range(0, z.shape[4], overlap_size):
tile = z[
:,
:,
:,
i : i + self.tile_latent_min_size,
j : j + self.tile_latent_min_size,
]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3)
return dec
def enable_tiling(self, use_tiling: bool = True):
self.use_tiling = use_tiling
def disable_tiling(self):
self.enable_tiling(False)
def init_from_ckpt(self, path, ignore_keys=list(), remove_loss=False):
sd = torch.load(path, map_location="cpu")
print("init from " + path)
if "state_dict" in sd:
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, "video")
latents = self.encode(inputs).sample()
video_recon = self.decode(latents)
for idx in range(len(video_recon)):
self.logger.log_video(f"recon {batch_idx} {idx}", [tensor_to_video(video_recon[idx])], fps=[10])
class CausalVAEModelWrapper(nn.Module):
def __init__(self, model_path, subfolder=None, cache_dir=None, **kwargs):
super(CausalVAEModelWrapper, self).__init__()
# if os.path.exists(ckpt):
# self.vae = CausalVAEModel.load_from_checkpoint(ckpt)
self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs)
def encode(self, x): # b c t h w
# x = self.vae.encode(x).sample()
x = self.vae.encode(x).sample().mul_(0.18215)
return x
def decode(self, x):
# x = self.vae.decode(x)
x = self.vae.decode(x / 0.18215)
x = rearrange(x, "b c t h w -> b t c h w").contiguous()
return x
def dtype(self):
return self.vae.dtype
#
# def device(self):
# return self.vae.device
videobase_ae_stride = {
"CausalVAEModel_4x8x8": [4, 8, 8],
}
videobase_ae_channel = {
"CausalVAEModel_4x8x8": 4,
}
videobase_ae = {
"CausalVAEModel_4x8x8": CausalVAEModelWrapper,
}
ae_stride_config = {}
ae_stride_config.update(videobase_ae_stride)
ae_channel_config = {}
ae_channel_config.update(videobase_ae_channel)
def getae_wrapper(ae):
"""deprecation"""
ae = videobase_ae.get(ae, None)
assert ae is not None
return ae