Open-Sora-Plan-v1-0-0 / examples /get_latents_std.py
fffiloni's picture
Upload 244 files
b3f324b verified
raw
history blame contribute delete
No virus
1.16 kB
import torch
from torch.utils.data import DataLoader, Subset
import sys
sys.path.append(".")
from opensora.models.ae.videobase import CausalVAEModel, CausalVAEDataset
num_workers = 4
batch_size = 12
torch.manual_seed(0)
torch.set_grad_enabled(False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pretrained_model_name_or_path = 'results/causalvae/checkpoint-26000'
data_path = '/remote-home1/dataset/UCF-101'
video_num_frames = 17
resolution = 128
sample_rate = 10
vae = CausalVAEModel.load_from_checkpoint(pretrained_model_name_or_path)
vae.to(device)
dataset = CausalVAEDataset(data_path, sequence_length=video_num_frames, resolution=resolution, sample_rate=sample_rate)
subset_indices = list(range(1000))
subset_dataset = Subset(dataset, subset_indices)
loader = DataLoader(subset_dataset, batch_size=8, pin_memory=True)
all_latents = []
for video_data in loader:
video_data = video_data['video'].to(device)
latents = vae.encode(video_data).sample()
all_latents.append(video_data.cpu())
all_latents_tensor = torch.cat(all_latents)
std = all_latents_tensor.std().item()
normalizer = 1 / std
print(f'{normalizer = }')