File size: 1,781 Bytes
910e2ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Pyramid Flow's VAE Training Guide

This is the training guide for a [MAGVIT-v2](https://arxiv.org/abs/2310.05737) like continuous 3D VAE, which should be quite flexible. Feel free to build your own video generative model on this part of VAE training code. Please refer to [another document](https://github.com/jy0205/Pyramid-Flow/blob/main/docs/DiT) for DiT finetuning.

## Hardware Requirements

+ VAE training: At least 8 A100 GPUs.


## Prepare the Dataset

The training of our causal video vae uses both image and video data. Both of them should be arranged into a json file, with `video` or `image` field. The final training annotation json file should look like the following format:

```

# For Video

{"video": video_path}



# For Image

{"image": image_path}

```

## Run Training

The causal video vae undergoes a two-stage training. 
+ Stage-1: image and video mixed training
+ Stage-2: pure video training, using context parallel to load video with more video frames

The VAE training script is `scripts/train_causal_video_vae.sh`, run it as follows:

```bash

sh scripts/train_causal_video_vae.sh

```

We also provide a VAE demo `causal_video_vae_demo.ipynb` for image and video reconstruction.


## Tips

+ For stage-1, we use a mixed image and video training. Add the param `--use_image_video_mixed_training` to support the mixed training. We set the image ratio to 0.1 by default. 
+ Set the `resolution` to 256 is enough for VAE training.
+ For stage-1, the `max_frames` is set to 17. It means we use 17 sampled video frames for training.
+ For stage-2, we open the param `use_context_parallel` to distribute long video frames to multiple GPUs. Make sure to set `GPUS % CONTEXT_SIZE == 0` and `NUM_FRAMES=17 * CONTEXT_SIZE + 1`