Spaces:
Runtime error
Runtime error
File size: 7,709 Bytes
e7d5680 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
# Repo & Config Structure
## Repo Structure
```plaintext
Open-Sora
βββ README.md
βββ docs
β βββ acceleration.md -> Acceleration & Speed benchmark
β βββ command.md -> Commands for training & inference
β βββ datasets.md -> Datasets used in this project
β βββ structure.md -> This file
β βββ report_v1.md -> Report for Open-Sora v1
βββ scripts
β βββ train.py -> diffusion training script
β βββ inference.py -> Report for Open-Sora v1
βββ configs -> Configs for training & inference
βββ opensora
β βββ __init__.py
β βββ registry.py -> Registry helper
βΒ Β βββ acceleration -> Acceleration related code
βΒ Β βββ dataset -> Dataset related code
βΒ Β βββ models
βΒ Β βΒ Β βββ layers -> Common layers
βΒ Β βΒ Β βββ vae -> VAE as image encoder
βΒ Β βΒ Β βββ text_encoder -> Text encoder
βΒ Β βΒ Β βΒ Β βββ classes.py -> Class id encoder (inference only)
βΒ Β βΒ Β βΒ Β βββ clip.py -> CLIP encoder
βΒ Β βΒ Β βΒ Β βββ t5.py -> T5 encoder
βΒ Β βΒ Β βββ dit
βΒ Β βΒ Β βββ latte
βΒ Β βΒ Β βββ pixart
βΒ Β βΒ Β βββ stdit -> Our STDiT related code
βΒ Β βββ schedulers -> Diffusion shedulers
βΒ Β βΒ Β βββ iddpm -> IDDPM for training and inference
βΒ Β β βββ dpms -> DPM-Solver for fast inference
β βββ utils
βββ tools -> Tools for data processing and more
```
## Configs
Our config files follows [MMEgine](https://github.com/open-mmlab/mmengine). MMEngine will reads the config file (a `.py` file) and parse it into a dictionary-like object.
```plaintext
Open-Sora
βββ configs -> Configs for training & inference
βββ opensora -> STDiT related configs
β βββ inference
β β βββ 16x256x256.py -> Sample videos 16 frames 256x256
β β βββ 16x512x512.py -> Sample videos 16 frames 512x512
β β βββ 64x512x512.py -> Sample videos 64 frames 512x512
β βββ train
β βββ 16x256x256.py -> Train on videos 16 frames 256x256
β βββ 16x256x256.py -> Train on videos 16 frames 256x256
β βββ 64x512x512.py -> Train on videos 64 frames 512x512
βββ dit -> DiT related configs
Β Β βΒ Β βββ inference
Β Β βΒ Β βΒ Β βββ 1x256x256-class.py -> Sample images with ckpts from DiT
Β Β βΒ Β βΒ Β βββ 1x256x256.py -> Sample images with clip condition
Β Β βΒ Β βΒ Β βββ 16x256x256.py -> Sample videos
Β Β βΒ Β βββ train
Β Β βΒ Β Β βββ 1x256x256.py -> Train on images with clip condition
Β Β βΒ Β Β Β βββ 16x256x256.py -> Train on videos
βββ latte -> Latte related configs
βββ pixart -> PixArt related configs
```
## Inference config demos
To change the inference settings, you can directly modify the corresponding config file. Or you can pass arguments to overwrite the config file ([config_utils.py](/opensora/utils/config_utils.py)). To change sampling prompts, you should modify the `.txt` file passed to the `--prompt_path` argument.
```plaintext
--prompt_path ./assets/texts/t2v_samples.txt -> prompt_path
--ckpt-path ./path/to/your/ckpt.pth -> model["from_pretrained"]
```
The explanation of each field is provided below.
```python
# Define sampling size
num_frames = 64 # number of frames
fps = 24 // 2 # frames per second (divided by 2 for frame_interval=2)
image_size = (512, 512) # image size (height, width)
# Define model
model = dict(
type="STDiT-XL/2", # Select model type (STDiT-XL/2, DiT-XL/2, etc.)
space_scale=1.0, # (Optional) Space positional encoding scale (new height / old height)
time_scale=2 / 3, # (Optional) Time positional encoding scale (new frame_interval / old frame_interval)
enable_flashattn=True, # (Optional) Speed up training and inference with flash attention
enable_layernorm_kernel=True, # (Optional) Speed up training and inference with fused kernel
from_pretrained="PRETRAINED_MODEL", # (Optional) Load from pretrained model
no_temporal_pos_emb=True, # (Optional) Disable temporal positional encoding (for image)
)
vae = dict(
type="VideoAutoencoderKL", # Select VAE type
from_pretrained="stabilityai/sd-vae-ft-ema", # Load from pretrained VAE
micro_batch_size=128, # VAE with micro batch size to save memory
)
text_encoder = dict(
type="t5", # Select text encoder type (t5, clip)
from_pretrained="./pretrained_models/t5_ckpts", # Load from pretrained text encoder
model_max_length=120, # Maximum length of input text
)
scheduler = dict(
type="iddpm", # Select scheduler type (iddpm, dpm-solver)
num_sampling_steps=100, # Number of sampling steps
cfg_scale=7.0, # hyper-parameter for classifier-free diffusion
)
dtype = "fp16" # Computation type (fp16, fp32, bf16)
# Other settings
batch_size = 1 # batch size
seed = 42 # random seed
prompt_path = "./assets/texts/t2v_samples.txt" # path to prompt file
save_dir = "./samples" # path to save samples
```
## Training config demos
```python
# Define sampling size
num_frames = 64
frame_interval = 2 # sample every 2 frames
image_size = (512, 512)
# Define dataset
root = None # root path to the dataset
data_path = "CSV_PATH" # path to the csv file
use_image_transform = False # True if training on images
num_workers = 4 # number of workers for dataloader
# Define acceleration
dtype = "bf16" # Computation type (fp16, bf16)
grad_checkpoint = True # Use gradient checkpointing
plugin = "zero2" # Plugin for distributed training (zero2, zero2-seq)
sp_size = 1 # Sequence parallelism size (1 for no sequence parallelism)
# Define model
model = dict(
type="STDiT-XL/2",
space_scale=1.0,
time_scale=2 / 3,
from_pretrained="YOUR_PRETRAINED_MODEL",
enable_flashattn=True, # Enable flash attention
enable_layernorm_kernel=True, # Enable layernorm kernel
)
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
micro_batch_size=128,
)
text_encoder = dict(
type="t5",
from_pretrained="./pretrained_models/t5_ckpts",
model_max_length=120,
shardformer=True, # Enable shardformer for T5 acceleration
)
scheduler = dict(
type="iddpm",
timestep_respacing="", # Default 1000 timesteps
)
# Others
seed = 42
outputs = "outputs" # path to save checkpoints
wandb = False # Use wandb for logging
epochs = 1000 # number of epochs (just large enough, kill when satisfied)
log_every = 10
ckpt_every = 250
load = None # path to resume training
batch_size = 4
lr = 2e-5
grad_clip = 1.0 # gradient clipping
```
|