# Adapted from Open-Sora-Plan # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # References: # Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan # -------------------------------------------------------- import glob import importlib import os from typing import Optional, Tuple, Union import numpy as np import torch from diffusers import ConfigMixin, ModelMixin from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin from einops import rearrange from torch import nn def Normalize(in_channels, num_groups=32): return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) def tensor_to_video(x): x = x.detach().cpu() x = torch.clamp(x, -1, 1) x = (x + 1) / 2 x = x.permute(1, 0, 2, 3).float().numpy() # c t h w -> x = (255 * x).astype(np.uint8) return x def nonlinearity(x): return x * torch.sigmoid(x) class DiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) def sample(self): x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) return x def kl(self, other=None): if self.deterministic: return torch.Tensor([0.0]) else: if other is None: return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[1, 2, 3], ) def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) def mode(self): return self.mean def resolve_str_to_obj(str_val, append=True): if append: str_val = "videosys.models.open_sora_plan.modules." + str_val if "opensora.models.ae.videobase." in str_val: str_val = str_val.replace("opensora.models.ae.videobase.", "videosys.models.open_sora_plan.") module_name, class_name = str_val.rsplit(".", 1) module = importlib.import_module(module_name) return getattr(module, class_name) class VideoBaseAE_PL(ModelMixin, ConfigMixin): config_name = "config.json" def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def encode(self, x: torch.Tensor, *args, **kwargs): pass def decode(self, encoding: torch.Tensor, *args, **kwargs): pass @property def num_training_steps(self) -> int: """Total training steps inferred from datamodule and devices.""" if self.trainer.max_steps: return self.trainer.max_steps limit_batches = self.trainer.limit_train_batches batches = len(self.train_dataloader()) batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches) num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) if self.trainer.tpu_cores: num_devices = max(num_devices, self.trainer.tpu_cores) effective_accum = self.trainer.accumulate_grad_batches * num_devices return (batches // effective_accum) * self.trainer.max_epochs @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): ckpt_files = glob.glob(os.path.join(pretrained_model_name_or_path, "*.ckpt")) if ckpt_files: # Adapt to PyTorch Lightning last_ckpt_file = ckpt_files[-1] config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) model = cls.from_config(config_file) print("init from {}".format(last_ckpt_file)) model.init_from_ckpt(last_ckpt_file) return model else: print(f"Loading model from {pretrained_model_name_or_path}") return super().from_pretrained(pretrained_model_name_or_path, **kwargs) class Encoder(nn.Module): def __init__( self, z_channels: int, hidden_size: int, hidden_size_mult: Tuple[int] = (1, 2, 4, 4), attn_resolutions: Tuple[int] = (16,), conv_in: str = "Conv2d", conv_out: str = "CasualConv3d", attention: str = "AttnBlock", resnet_blocks: Tuple[str] = ( "ResnetBlock2D", "ResnetBlock2D", "ResnetBlock2D", "ResnetBlock3D", ), spatial_downsample: Tuple[str] = ( "Downsample", "Downsample", "Downsample", "", ), temporal_downsample: Tuple[str] = ("", "", "TimeDownsampleRes2x", ""), mid_resnet: str = "ResnetBlock3D", dropout: float = 0.0, resolution: int = 256, num_res_blocks: int = 2, double_z: bool = True, ) -> None: super().__init__() assert len(resnet_blocks) == len(hidden_size_mult), print(hidden_size_mult, resnet_blocks) # ---- Config ---- self.num_resolutions = len(hidden_size_mult) self.resolution = resolution self.num_res_blocks = num_res_blocks # ---- In ---- self.conv_in = resolve_str_to_obj(conv_in)(3, hidden_size, kernel_size=3, stride=1, padding=1) # ---- Downsample ---- curr_res = resolution in_ch_mult = (1,) + tuple(hidden_size_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = hidden_size * in_ch_mult[i_level] block_out = hidden_size * hidden_size_mult[i_level] for i_block in range(self.num_res_blocks): block.append( resolve_str_to_obj(resnet_blocks[i_level])( in_channels=block_in, out_channels=block_out, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(resolve_str_to_obj(attention)(block_in)) down = nn.Module() down.block = block down.attn = attn if spatial_downsample[i_level]: down.downsample = resolve_str_to_obj(spatial_downsample[i_level])(block_in, block_in) curr_res = curr_res // 2 if temporal_downsample[i_level]: down.time_downsample = resolve_str_to_obj(temporal_downsample[i_level])(block_in, block_in) self.down.append(down) # ---- Mid ---- self.mid = nn.Module() self.mid.block_1 = resolve_str_to_obj(mid_resnet)( in_channels=block_in, out_channels=block_in, dropout=dropout, ) self.mid.attn_1 = resolve_str_to_obj(attention)(block_in) self.mid.block_2 = resolve_str_to_obj(mid_resnet)( in_channels=block_in, out_channels=block_in, dropout=dropout, ) # ---- Out ---- self.norm_out = Normalize(block_in) self.conv_out = resolve_str_to_obj(conv_out)( block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1, ) def forward(self, x): hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1]) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if hasattr(self.down[i_level], "downsample"): hs.append(self.down[i_level].downsample(hs[-1])) if hasattr(self.down[i_level], "time_downsample"): hs_down = self.down[i_level].time_downsample(hs[-1]) hs.append(hs_down) h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h class Decoder(nn.Module): def __init__( self, z_channels: int, hidden_size: int, hidden_size_mult: Tuple[int] = (1, 2, 4, 4), attn_resolutions: Tuple[int] = (16,), conv_in: str = "Conv2d", conv_out: str = "CasualConv3d", attention: str = "AttnBlock", resnet_blocks: Tuple[str] = ( "ResnetBlock3D", "ResnetBlock3D", "ResnetBlock3D", "ResnetBlock3D", ), spatial_upsample: Tuple[str] = ( "", "SpatialUpsample2x", "SpatialUpsample2x", "SpatialUpsample2x", ), temporal_upsample: Tuple[str] = ("", "", "", "TimeUpsampleRes2x"), mid_resnet: str = "ResnetBlock3D", dropout: float = 0.0, resolution: int = 256, num_res_blocks: int = 2, ): super().__init__() # ---- Config ---- self.num_resolutions = len(hidden_size_mult) self.resolution = resolution self.num_res_blocks = num_res_blocks # ---- In ---- block_in = hidden_size * hidden_size_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.conv_in = resolve_str_to_obj(conv_in)(z_channels, block_in, kernel_size=3, padding=1) # ---- Mid ---- self.mid = nn.Module() self.mid.block_1 = resolve_str_to_obj(mid_resnet)( in_channels=block_in, out_channels=block_in, dropout=dropout, ) self.mid.attn_1 = resolve_str_to_obj(attention)(block_in) self.mid.block_2 = resolve_str_to_obj(mid_resnet)( in_channels=block_in, out_channels=block_in, dropout=dropout, ) # ---- Upsample ---- self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = hidden_size * hidden_size_mult[i_level] for i_block in range(self.num_res_blocks + 1): block.append( resolve_str_to_obj(resnet_blocks[i_level])( in_channels=block_in, out_channels=block_out, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(resolve_str_to_obj(attention)(block_in)) up = nn.Module() up.block = block up.attn = attn if spatial_upsample[i_level]: up.upsample = resolve_str_to_obj(spatial_upsample[i_level])(block_in, block_in) curr_res = curr_res * 2 if temporal_upsample[i_level]: up.time_upsample = resolve_str_to_obj(temporal_upsample[i_level])(block_in, block_in) self.up.insert(0, up) # ---- Out ---- self.norm_out = Normalize(block_in) self.conv_out = resolve_str_to_obj(conv_out)(block_in, 3, kernel_size=3, padding=1) def forward(self, z): h = self.conv_in(z) h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if hasattr(self.up[i_level], "upsample"): h = self.up[i_level].upsample(h) if hasattr(self.up[i_level], "time_upsample"): h = self.up[i_level].time_upsample(h) h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h class CausalVAEModel(VideoBaseAE_PL): @register_to_config def __init__( self, lr: float = 1e-5, hidden_size: int = 128, z_channels: int = 4, hidden_size_mult: Tuple[int] = (1, 2, 4, 4), attn_resolutions: Tuple[int] = [], dropout: float = 0.0, resolution: int = 256, double_z: bool = True, embed_dim: int = 4, num_res_blocks: int = 2, loss_type: str = "opensora.models.ae.videobase.losses.LPIPSWithDiscriminator", loss_params: dict = { "kl_weight": 0.000001, "logvar_init": 0.0, "disc_start": 2001, "disc_weight": 0.5, }, q_conv: str = "CausalConv3d", encoder_conv_in: str = "CausalConv3d", encoder_conv_out: str = "CausalConv3d", encoder_attention: str = "AttnBlock3D", encoder_resnet_blocks: Tuple[str] = ( "ResnetBlock3D", "ResnetBlock3D", "ResnetBlock3D", "ResnetBlock3D", ), encoder_spatial_downsample: Tuple[str] = ( "SpatialDownsample2x", "SpatialDownsample2x", "SpatialDownsample2x", "", ), encoder_temporal_downsample: Tuple[str] = ( "", "TimeDownsample2x", "TimeDownsample2x", "", ), encoder_mid_resnet: str = "ResnetBlock3D", decoder_conv_in: str = "CausalConv3d", decoder_conv_out: str = "CausalConv3d", decoder_attention: str = "AttnBlock3D", decoder_resnet_blocks: Tuple[str] = ( "ResnetBlock3D", "ResnetBlock3D", "ResnetBlock3D", "ResnetBlock3D", ), decoder_spatial_upsample: Tuple[str] = ( "", "SpatialUpsample2x", "SpatialUpsample2x", "SpatialUpsample2x", ), decoder_temporal_upsample: Tuple[str] = ("", "", "TimeUpsample2x", "TimeUpsample2x"), decoder_mid_resnet: str = "ResnetBlock3D", ) -> None: super().__init__() self.tile_sample_min_size = 256 self.tile_sample_min_size_t = 65 self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1))) t_down_ratio = [i for i in encoder_temporal_downsample if len(i) > 0] self.tile_latent_min_size_t = int((self.tile_sample_min_size_t - 1) / (2 ** len(t_down_ratio))) + 1 self.tile_overlap_factor = 0.25 self.use_tiling = False self.learning_rate = lr self.lr_g_factor = 1.0 self.loss = resolve_str_to_obj(loss_type, append=False)(**loss_params) self.encoder = Encoder( z_channels=z_channels, hidden_size=hidden_size, hidden_size_mult=hidden_size_mult, attn_resolutions=attn_resolutions, conv_in=encoder_conv_in, conv_out=encoder_conv_out, attention=encoder_attention, resnet_blocks=encoder_resnet_blocks, spatial_downsample=encoder_spatial_downsample, temporal_downsample=encoder_temporal_downsample, mid_resnet=encoder_mid_resnet, dropout=dropout, resolution=resolution, num_res_blocks=num_res_blocks, double_z=double_z, ) self.decoder = Decoder( z_channels=z_channels, hidden_size=hidden_size, hidden_size_mult=hidden_size_mult, attn_resolutions=attn_resolutions, conv_in=decoder_conv_in, conv_out=decoder_conv_out, attention=decoder_attention, resnet_blocks=decoder_resnet_blocks, spatial_upsample=decoder_spatial_upsample, temporal_upsample=decoder_temporal_upsample, mid_resnet=decoder_mid_resnet, dropout=dropout, resolution=resolution, num_res_blocks=num_res_blocks, ) quant_conv_cls = resolve_str_to_obj(q_conv) self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1) self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1) if hasattr(self.loss, "discriminator"): self.automatic_optimization = False def encode(self, x): if self.use_tiling and ( x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size or x.shape[-3] > self.tile_sample_min_size_t ): return self.tiled_encode(x) h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) return posterior def decode(self, z): if self.use_tiling and ( z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size or z.shape[-3] > self.tile_latent_min_size_t ): return self.tiled_decode(z) z = self.post_quant_conv(z) dec = self.decoder(z) return dec def forward(self, input, sample_posterior=True): posterior = self.encode(input) if sample_posterior: z = posterior.sample() else: z = posterior.mode() dec = self.decode(z) return dec, posterior def get_input(self, batch, k): x = batch[k] if len(x.shape) == 3: x = x[..., None] x = x.to(memory_format=torch.contiguous_format).float() return x def training_step(self, batch, batch_idx): if hasattr(self.loss, "discriminator"): return self._training_step_gan(batch, batch_idx=batch_idx) else: return self._training_step(batch, batch_idx=batch_idx) def _training_step(self, batch, batch_idx): inputs = self.get_input(batch, "video") reconstructions, posterior = self(inputs) aeloss, log_dict_ae = self.loss( inputs, reconstructions, posterior, split="train", ) self.log( "aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, ) self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) return aeloss def _training_step_gan(self, batch, batch_idx): inputs = self.get_input(batch, "video") reconstructions, posterior = self(inputs) opt1, opt2 = self.optimizers() # ---- AE Loss ---- aeloss, log_dict_ae = self.loss( inputs, reconstructions, posterior, 0, self.global_step, last_layer=self.get_last_layer(), split="train", ) self.log( "aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, ) opt1.zero_grad() self.manual_backward(aeloss) self.clip_gradients(opt1, gradient_clip_val=1, gradient_clip_algorithm="norm") opt1.step() # ---- GAN Loss ---- discloss, log_dict_disc = self.loss( inputs, reconstructions, posterior, 1, self.global_step, last_layer=self.get_last_layer(), split="train", ) self.log( "discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, ) opt2.zero_grad() self.manual_backward(discloss) self.clip_gradients(opt2, gradient_clip_val=1, gradient_clip_algorithm="norm") opt2.step() self.log_dict( {**log_dict_ae, **log_dict_disc}, prog_bar=False, logger=True, on_step=True, on_epoch=False, ) def configure_optimizers(self): from itertools import chain lr = self.learning_rate modules_to_train = [ self.encoder.named_parameters(), self.decoder.named_parameters(), self.post_quant_conv.named_parameters(), self.quant_conv.named_parameters(), ] params_with_time = [] params_without_time = [] for name, param in chain(*modules_to_train): if "time" in name: params_with_time.append(param) else: params_without_time.append(param) optimizers = [] opt_ae = torch.optim.Adam( [ {"params": params_with_time, "lr": lr}, {"params": params_without_time, "lr": lr}, ], lr=lr, betas=(0.5, 0.9), ) optimizers.append(opt_ae) if hasattr(self.loss, "discriminator"): opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) optimizers.append(opt_disc) return optimizers, [] def get_last_layer(self): if hasattr(self.decoder.conv_out, "conv"): return self.decoder.conv_out.conv.weight else: return self.decoder.conv_out.weight def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[3], b.shape[3], blend_extent) for y in range(blend_extent): b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( y / blend_extent ) return b def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[4], b.shape[4], blend_extent) for x in range(blend_extent): b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( x / blend_extent ) return b def tiled_encode(self, x): t = x.shape[2] t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t - 1)] if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: t_chunk_start_end = [[0, t]] else: t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)] if t_chunk_start_end[-1][-1] > t: t_chunk_start_end[-1][-1] = t elif t_chunk_start_end[-1][-1] < t: last_start_end = [t_chunk_idx[-1], t] t_chunk_start_end.append(last_start_end) moments = [] for idx, (start, end) in enumerate(t_chunk_start_end): chunk_x = x[:, :, start:end] if idx != 0: moment = self.tiled_encode2d(chunk_x, return_moments=True)[:, :, 1:] else: moment = self.tiled_encode2d(chunk_x, return_moments=True) moments.append(moment) moments = torch.cat(moments, dim=2) posterior = DiagonalGaussianDistribution(moments) return posterior def tiled_decode(self, x): t = x.shape[2] t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t - 1)] if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: t_chunk_start_end = [[0, t]] else: t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)] if t_chunk_start_end[-1][-1] > t: t_chunk_start_end[-1][-1] = t elif t_chunk_start_end[-1][-1] < t: last_start_end = [t_chunk_idx[-1], t] t_chunk_start_end.append(last_start_end) dec_ = [] for idx, (start, end) in enumerate(t_chunk_start_end): chunk_x = x[:, :, start:end] if idx != 0: dec = self.tiled_decode2d(chunk_x)[:, :, 1:] else: dec = self.tiled_decode2d(chunk_x) dec_.append(dec) dec_ = torch.cat(dec_, dim=2) return dec_ def tiled_encode2d(self, x, return_moments=False): overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) row_limit = self.tile_latent_min_size - blend_extent # Split the image into 512x512 tiles and encode them separately. rows = [] for i in range(0, x.shape[3], overlap_size): row = [] for j in range(0, x.shape[4], overlap_size): tile = x[ :, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size, ] tile = self.encoder(tile) tile = self.quant_conv(tile) row.append(tile) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append(tile[:, :, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=4)) moments = torch.cat(result_rows, dim=3) posterior = DiagonalGaussianDistribution(moments) if return_moments: return moments return posterior def tiled_decode2d(self, z): overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) row_limit = self.tile_sample_min_size - blend_extent # Split z into overlapping 64x64 tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. rows = [] for i in range(0, z.shape[3], overlap_size): row = [] for j in range(0, z.shape[4], overlap_size): tile = z[ :, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size, ] tile = self.post_quant_conv(tile) decoded = self.decoder(tile) row.append(decoded) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append(tile[:, :, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=4)) dec = torch.cat(result_rows, dim=3) return dec def enable_tiling(self, use_tiling: bool = True): self.use_tiling = use_tiling def disable_tiling(self): self.enable_tiling(False) def init_from_ckpt(self, path, ignore_keys=list(), remove_loss=False): sd = torch.load(path, map_location="cpu") print("init from " + path) if "state_dict" in sd: sd = sd["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] self.load_state_dict(sd, strict=False) def validation_step(self, batch, batch_idx): inputs = self.get_input(batch, "video") latents = self.encode(inputs).sample() video_recon = self.decode(latents) for idx in range(len(video_recon)): self.logger.log_video(f"recon {batch_idx} {idx}", [tensor_to_video(video_recon[idx])], fps=[10]) class CausalVAEModelWrapper(nn.Module): def __init__(self, model_path, subfolder=None, cache_dir=None, **kwargs): super(CausalVAEModelWrapper, self).__init__() # if os.path.exists(ckpt): # self.vae = CausalVAEModel.load_from_checkpoint(ckpt) self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs) def encode(self, x): # b c t h w # x = self.vae.encode(x).sample() x = self.vae.encode(x).sample().mul_(0.18215) return x def decode(self, x): # x = self.vae.decode(x) x = self.vae.decode(x / 0.18215) x = rearrange(x, "b c t h w -> b t c h w").contiguous() return x def dtype(self): return self.vae.dtype # # def device(self): # return self.vae.device videobase_ae_stride = { "CausalVAEModel_4x8x8": [4, 8, 8], } videobase_ae_channel = { "CausalVAEModel_4x8x8": 4, } videobase_ae = { "CausalVAEModel_4x8x8": CausalVAEModelWrapper, } ae_stride_config = {} ae_stride_config.update(videobase_ae_stride) ae_channel_config = {} ae_channel_config.update(videobase_ae_channel) def getae_wrapper(ae): """deprecation""" ae = videobase_ae.get(ae, None) assert ae is not None return ae