sironagasuyagi's picture
Upload folder using huggingface_hub
910e2ad verified

Pyramid Flow's VAE Training Guide

This is the training guide for a MAGVIT-v2 like continuous 3D VAE, which should be quite flexible. Feel free to build your own video generative model on this part of VAE training code. Please refer to another document for DiT finetuning.

Hardware Requirements

  • VAE training: At least 8 A100 GPUs.

Prepare the Dataset

The training of our causal video vae uses both image and video data. Both of them should be arranged into a json file, with video or image field. The final training annotation json file should look like the following format:

# For Video
{"video": video_path}

# For Image
{"image": image_path}

Run Training

The causal video vae undergoes a two-stage training.

  • Stage-1: image and video mixed training
  • Stage-2: pure video training, using context parallel to load video with more video frames

The VAE training script is scripts/train_causal_video_vae.sh, run it as follows:

sh scripts/train_causal_video_vae.sh

We also provide a VAE demo causal_video_vae_demo.ipynb for image and video reconstruction.

Tips

  • For stage-1, we use a mixed image and video training. Add the param --use_image_video_mixed_training to support the mixed training. We set the image ratio to 0.1 by default.
  • Set the resolution to 256 is enough for VAE training.
  • For stage-1, the max_frames is set to 17. It means we use 17 sampled video frames for training.
  • For stage-2, we open the param use_context_parallel to distribute long video frames to multiple GPUs. Make sure to set GPUS % CONTEXT_SIZE == 0 and NUM_FRAMES=17 * CONTEXT_SIZE + 1