Open-Sora / docs /structure.md
kadirnar's picture
Upload 98 files
e7d5680 verified
|
raw
history blame
7.71 kB

Repo & Config Structure

Repo Structure

Open-Sora
β”œβ”€β”€ README.md
β”œβ”€β”€ docs
β”‚   β”œβ”€β”€ acceleration.md            -> Acceleration & Speed benchmark
β”‚   β”œβ”€β”€ command.md                 -> Commands for training & inference
β”‚   β”œβ”€β”€ datasets.md                -> Datasets used in this project
β”‚   β”œβ”€β”€ structure.md               -> This file
β”‚   └── report_v1.md               -> Report for Open-Sora v1
β”œβ”€β”€ scripts
β”‚   β”œβ”€β”€ train.py                   -> diffusion training script
β”‚   └── inference.py               -> Report for Open-Sora v1
β”œβ”€β”€ configs                        -> Configs for training & inference
β”œβ”€β”€ opensora
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ registry.py                -> Registry helper
β”‚   β”œβ”€β”€ acceleration               -> Acceleration related code
β”‚   β”œβ”€β”€ dataset                    -> Dataset related code
β”‚   β”œβ”€β”€ models
β”‚   β”‚   β”œβ”€β”€ layers                 -> Common layers
β”‚   β”‚   β”œβ”€β”€ vae                    -> VAE as image encoder
β”‚   β”‚   β”œβ”€β”€ text_encoder           -> Text encoder
β”‚   β”‚   β”‚   β”œβ”€β”€ classes.py         -> Class id encoder (inference only)
β”‚   β”‚   β”‚   β”œβ”€β”€ clip.py            -> CLIP encoder
β”‚   β”‚   β”‚   └── t5.py              -> T5 encoder
β”‚   β”‚   β”œβ”€β”€ dit
β”‚   β”‚   β”œβ”€β”€ latte
β”‚   β”‚   β”œβ”€β”€ pixart
β”‚   β”‚   └── stdit                  -> Our STDiT related code
β”‚   β”œβ”€β”€ schedulers                 -> Diffusion shedulers
β”‚   β”‚   β”œβ”€β”€ iddpm                  -> IDDPM for training and inference
β”‚   β”‚   └── dpms                   -> DPM-Solver for fast inference
β”‚   └── utils
└── tools                          -> Tools for data processing and more

Configs

Our config files follows MMEgine. MMEngine will reads the config file (a .py file) and parse it into a dictionary-like object.

Open-Sora
└── configs                        -> Configs for training & inference
    β”œβ”€β”€ opensora                   -> STDiT related configs
    β”‚   β”œβ”€β”€ inference
    β”‚   β”‚   β”œβ”€β”€ 16x256x256.py      -> Sample videos 16 frames 256x256
    β”‚   β”‚   β”œβ”€β”€ 16x512x512.py      -> Sample videos 16 frames 512x512
    β”‚   β”‚   └── 64x512x512.py      -> Sample videos 64 frames 512x512
    β”‚   └── train
    β”‚       β”œβ”€β”€ 16x256x256.py      -> Train on videos 16 frames 256x256
    β”‚       β”œβ”€β”€ 16x256x256.py      -> Train on videos 16 frames 256x256
    β”‚       └── 64x512x512.py      -> Train on videos 64 frames 512x512
    β”œβ”€β”€ dit                        -> DiT related configs
    β”‚   β”œβ”€β”€ inference
    β”‚   β”‚   β”œβ”€β”€ 1x256x256-class.py -> Sample images with ckpts from DiT
    β”‚   β”‚   β”œβ”€β”€ 1x256x256.py       -> Sample images with clip condition
    β”‚   β”‚   └── 16x256x256.py      -> Sample videos
    β”‚   └── train
    β”‚       β”œβ”€β”€ 1x256x256.py       -> Train on images with clip condition
    β”‚       └── 16x256x256.py      -> Train on videos
    β”œβ”€β”€ latte                      -> Latte related configs
    └── pixart                     -> PixArt related configs

Inference config demos

To change the inference settings, you can directly modify the corresponding config file. Or you can pass arguments to overwrite the config file (config_utils.py). To change sampling prompts, you should modify the .txt file passed to the --prompt_path argument.

--prompt_path ./assets/texts/t2v_samples.txt  -> prompt_path
--ckpt-path ./path/to/your/ckpt.pth           -> model["from_pretrained"]

The explanation of each field is provided below.

# Define sampling size
num_frames = 64               # number of frames
fps = 24 // 2                 # frames per second (divided by 2 for frame_interval=2)
image_size = (512, 512)       # image size (height, width)

# Define model
model = dict(
    type="STDiT-XL/2",        # Select model type (STDiT-XL/2, DiT-XL/2, etc.)
    space_scale=1.0,          # (Optional) Space positional encoding scale (new height / old height)
    time_scale=2 / 3,         # (Optional) Time positional encoding scale (new frame_interval / old frame_interval)
    enable_flashattn=True,    # (Optional) Speed up training and inference with flash attention
    enable_layernorm_kernel=True, # (Optional) Speed up training and inference with fused kernel
    from_pretrained="PRETRAINED_MODEL",  # (Optional) Load from pretrained model
    no_temporal_pos_emb=True,  # (Optional) Disable temporal positional encoding (for image)
)
vae = dict(
    type="VideoAutoencoderKL", # Select VAE type
    from_pretrained="stabilityai/sd-vae-ft-ema", # Load from pretrained VAE
    micro_batch_size=128,      # VAE with micro batch size to save memory
)
text_encoder = dict(
    type="t5",                 # Select text encoder type (t5, clip)
    from_pretrained="./pretrained_models/t5_ckpts", # Load from pretrained text encoder
    model_max_length=120,      # Maximum length of input text
)
scheduler = dict(
    type="iddpm",              # Select scheduler type (iddpm, dpm-solver)
    num_sampling_steps=100,    # Number of sampling steps
    cfg_scale=7.0,             # hyper-parameter for classifier-free diffusion
)
dtype = "fp16"                 # Computation type (fp16, fp32, bf16)

# Other settings
batch_size = 1                 # batch size
seed = 42                      # random seed
prompt_path = "./assets/texts/t2v_samples.txt"  # path to prompt file
save_dir = "./samples"         # path to save samples

Training config demos

# Define sampling size
num_frames = 64
frame_interval = 2             # sample every 2 frames
image_size = (512, 512)

# Define dataset
root = None                    # root path to the dataset
data_path = "CSV_PATH"         # path to the csv file
use_image_transform = False    # True if training on images
num_workers = 4                # number of workers for dataloader

# Define acceleration
dtype = "bf16"                 # Computation type (fp16, bf16)
grad_checkpoint = True         # Use gradient checkpointing
plugin = "zero2"               # Plugin for distributed training (zero2, zero2-seq)
sp_size = 1                    # Sequence parallelism size (1 for no sequence parallelism)

# Define model
model = dict(
    type="STDiT-XL/2",
    space_scale=1.0,
    time_scale=2 / 3,
    from_pretrained="YOUR_PRETRAINED_MODEL",
    enable_flashattn=True,        # Enable flash attention
    enable_layernorm_kernel=True, # Enable layernorm kernel
)
vae = dict(
    type="VideoAutoencoderKL",
    from_pretrained="stabilityai/sd-vae-ft-ema",
    micro_batch_size=128,
)
text_encoder = dict(
    type="t5",
    from_pretrained="./pretrained_models/t5_ckpts",
    model_max_length=120,
    shardformer=True,           # Enable shardformer for T5 acceleration
)
scheduler = dict(
    type="iddpm",
    timestep_respacing="",      # Default 1000 timesteps
)

# Others
seed = 42
outputs = "outputs"             # path to save checkpoints
wandb = False                   # Use wandb for logging

epochs = 1000                   # number of epochs (just large enough, kill when satisfied)
log_every = 10
ckpt_every = 250
load = None                     # path to resume training

batch_size = 4
lr = 2e-5
grad_clip = 1.0                 # gradient clipping