hpoghos commited on
Commit
f949b3f
1 Parent(s): 4e189bc
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .gitignore +17 -0
  3. README.md +103 -2
  4. __assets__/github/teaser/teaser_final.png +3 -0
  5. app.py +7 -0
  6. examples/underwater.png +0 -0
  7. pyproject.toml +39 -0
  8. requirements.txt +37 -0
  9. t2v_enhanced/__init__.py +4 -0
  10. t2v_enhanced/checkpoints/streaming_t2v.ckpt +3 -0
  11. t2v_enhanced/configs/inference/inference_long_video.yaml +37 -0
  12. t2v_enhanced/configs/text_to_video/config.yaml +227 -0
  13. t2v_enhanced/gradio_demo.py +189 -0
  14. t2v_enhanced/inference.py +82 -0
  15. t2v_enhanced/inference_utils.py +101 -0
  16. t2v_enhanced/model/__init__.py +0 -0
  17. t2v_enhanced/model/callbacks.py +102 -0
  18. t2v_enhanced/model/datasets/prompt_reader.py +80 -0
  19. t2v_enhanced/model/datasets/video_dataset.py +57 -0
  20. t2v_enhanced/model/diffusers_conditional/__init__.py +0 -0
  21. t2v_enhanced/model/diffusers_conditional/models/__init__.py +0 -0
  22. t2v_enhanced/model/diffusers_conditional/models/controlnet/__init__.py +0 -0
  23. t2v_enhanced/model/diffusers_conditional/models/controlnet/attention.py +291 -0
  24. t2v_enhanced/model/diffusers_conditional/models/controlnet/attention_processor.py +444 -0
  25. t2v_enhanced/model/diffusers_conditional/models/controlnet/conditioning.py +100 -0
  26. t2v_enhanced/model/diffusers_conditional/models/controlnet/controlnet.py +865 -0
  27. t2v_enhanced/model/diffusers_conditional/models/controlnet/cross_attention.py +30 -0
  28. t2v_enhanced/model/diffusers_conditional/models/controlnet/image_embedder.py +211 -0
  29. t2v_enhanced/model/diffusers_conditional/models/controlnet/mask_generator.py +27 -0
  30. t2v_enhanced/model/diffusers_conditional/models/controlnet/pipeline_text_to_video_w_controlnet_synth.py +925 -0
  31. t2v_enhanced/model/diffusers_conditional/models/controlnet/processor.py +240 -0
  32. t2v_enhanced/model/diffusers_conditional/models/controlnet/transformer_2d.py +333 -0
  33. t2v_enhanced/model/diffusers_conditional/models/controlnet/transformer_temporal.py +190 -0
  34. t2v_enhanced/model/diffusers_conditional/models/controlnet/transformer_temporal_crossattention.py +182 -0
  35. t2v_enhanced/model/diffusers_conditional/models/controlnet/unet_3d_blocks.py +930 -0
  36. t2v_enhanced/model/diffusers_conditional/models/controlnet/unet_3d_condition.py +635 -0
  37. t2v_enhanced/model/flags.py +1 -0
  38. t2v_enhanced/model/layers/conv_channel_extension.py +143 -0
  39. t2v_enhanced/model/pl_module_extension.py +297 -0
  40. t2v_enhanced/model/pl_module_params_controlnet.py +356 -0
  41. t2v_enhanced/model/requires_grad_setter.py +36 -0
  42. t2v_enhanced/model/video_ldm.py +327 -0
  43. t2v_enhanced/model/video_noise_generator.py +225 -0
  44. t2v_enhanced/model_func.py +117 -0
  45. t2v_enhanced/model_init.py +112 -0
  46. t2v_enhanced/utils/conversions.py +48 -0
  47. t2v_enhanced/utils/iimage.py +517 -0
  48. t2v_enhanced/utils/image_converter.py +45 -0
  49. t2v_enhanced/utils/object_loader.py +26 -0
  50. t2v_enhanced/utils/video_utils.py +376 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ t2v_enhanced/checkpoints/streaming_t2v.ckpt filter=lfs diff=lfs merge=lfs -text
37
+ __assets__/github/teaser/teaser_final.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+
4
+ .mlflow/
5
+ /logs
6
+ /experiments
7
+ /t2v_enhanced/.mlflow
8
+ /t2v_enhanced/logs
9
+ /t2v_enhanced/slurm_logs
10
+ /t2v_enhanced/results
11
+
12
+ t2v_enhanced/.mlflow
13
+ t2v_enhanced/logs
14
+ t2v_enhanced/slurm_logs
15
+ t2v_enhanced/lightning_logs
16
+ t2v_enhanced/results
17
+ t2v_enhanced/gradio_output
README.md CHANGED
@@ -7,7 +7,108 @@ sdk: gradio
7
  sdk_version: 4.25.0
8
  app_file: app.py
9
  pinned: false
10
- short_description: 'StreamingT2V: Consistent, Dynamic, and Extendable Long Video'
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  sdk_version: 4.25.0
8
  app_file: app.py
9
  pinned: false
10
+ short_description: Consistent, Dynamic, and Extendable Long Video Generation fr
11
  ---
12
 
13
+
14
+
15
+ # StreamingT2V
16
+
17
+ This repository is the official implementation of [StreamingT2V](https://streamingt2v.github.io/).
18
+
19
+
20
+ **[StreamingT2V: Consistent, Dynamic, and Extendable Long Video Generation from Text](https://arxiv.org/abs/2403.14773)**
21
+ </br>
22
+ Roberto Henschel,
23
+ Levon Khachatryan,
24
+ Daniil Hayrapetyan,
25
+ Hayk Poghosyan,
26
+ Vahram Tadevosyan,
27
+ Zhangyang Wang, Shant Navasardyan, Humphrey Shi
28
+ </br>
29
+
30
+ [arXiv preprint](https://arxiv.org/abs/2403.14773) | [Video](https://twitter.com/i/status/1770909673463390414) | [Project page](https://streamingt2v.github.io/)
31
+
32
+
33
+ <p align="center">
34
+ <img src="__assets__/github/teaser/teaser_final.png" width="800px"/>
35
+ <br>
36
+ <br>
37
+ <em>StreamingT2V is an advanced autoregressive technique that enables the creation of long videos featuring rich motion dynamics without any stagnation. It ensures temporal consistency throughout the video, aligns closely with the descriptive text, and maintains high frame-level image quality. Our demonstrations include successful examples of videos up to 1200 frames, spanning 2 minutes, and can be extended for even longer durations. Importantly, the effectiveness of StreamingT2V is not limited by the specific Text2Video model used, indicating that improvements in base models could yield even higher-quality videos.</em>
38
+ </p>
39
+
40
+ ## News
41
+
42
+ * [03/21/2024] Paper [StreamingT2V](https://arxiv.org/abs/2403.14773) released!
43
+ * [04/03/2024] Code and [model](https://huggingface.co/PAIR/StreamingT2V) released!
44
+
45
+
46
+ ## Setup
47
+
48
+
49
+
50
+ 1. Clone this repository and enter:
51
+
52
+ ``` shell
53
+ git clone https://github.com/Picsart-AI-Research/StreamingT2V.git
54
+ cd StreamingT2V/
55
+ ```
56
+ 2. Install requirements using Python 3.10 and CUDA >= 11.6
57
+ ``` shell
58
+ conda create -n st2v python=3.10
59
+ conda activate st2v
60
+ pip install -r requirements.txt
61
+ ```
62
+ 3. (Optional) Install FFmpeg if it's missing on your system
63
+ ``` shell
64
+ conda install conda-forge::ffmpeg
65
+ ```
66
+ 4. Download the weights from [HF](https://huggingface.co/PAIR/StreamingT2V) and put them into the `t2v_enhanced/checkpoints` directory.
67
+
68
+ ---
69
+
70
+
71
+ ## Inference
72
+
73
+
74
+
75
+ ### For Text-to-Video
76
+
77
+ ``` shell
78
+ cd StreamingT2V/
79
+ python inference.py --prompt="A cat running on the street"
80
+ ```
81
+ To use other base models add the `--base_model=AnimateDiff` argument. Use `python inference.py --help` for more options.
82
+
83
+ ### For Image-to-Video
84
+
85
+ ``` shell
86
+ cd StreamingT2V/
87
+ python inference.py --image=../examples/underwater.png --base_model=SVD
88
+ ```
89
+
90
+
91
+
92
+ ## Results
93
+ Detailed results can be found in the [Project page](https://streamingt2v.github.io/).
94
+
95
+ ## License
96
+ Our code is published under the CreativeML Open RAIL-M license.
97
+
98
+ We include [ModelscopeT2V](https://github.com/modelscope/modelscope), [AnimateDiff](https://github.com/guoyww/AnimateDiff), [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter) in the demo for research purposes and to demonstrate the flexibility of the StreamingT2V framework to include different T2V/I2V models. For commercial usage of such components, please refer to their original license.
99
+
100
+
101
+
102
+
103
+ ## BibTeX
104
+ If you use our work in your research, please cite our publication:
105
+ ```
106
+ @article{henschel2024streamingt2v,
107
+ title={StreamingT2V: Consistent, Dynamic, and Extendable Long Video Generation from Text},
108
+ author={Henschel, Roberto and Khachatryan, Levon and Hayrapetyan, Daniil and Poghosyan, Hayk and Tadevosyan, Vahram and Wang, Zhangyang and Navasardyan, Shant and Shi, Humphrey},
109
+ journal={arXiv preprint arXiv:2403.14773},
110
+ year={2024}
111
+ }
112
+ ```
113
+
114
+
__assets__/github/teaser/teaser_final.png ADDED

Git LFS Details

  • SHA256: fd4343418202d8aad2f08a65096482eb17527b784562a4e116da432aa22a30c5
  • Pointer size: 132 Bytes
  • Size of remote file: 2.66 MB
app.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def greet(name):
4
+ return "Hello " + name + "!!"
5
+
6
+ iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+ iface.launch()
examples/underwater.png ADDED
pyproject.toml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "t2v-enhanced"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Your Name <[email protected]>"]
6
+ readme = "README.md"
7
+ packages = [{include = "t2v_enhanced"}]
8
+
9
+ [tool.poetry.dependencies]
10
+ python = "^3.9"
11
+ torch = "^2.0.0"
12
+ omegaconf = "^2.3.0"
13
+ hydra-core = "^1.3.2"
14
+ pytorch-lightning = {extras = ["extra"], version = "^2.0.9"}
15
+ transformers = "^4.28.1"
16
+ torchmetrics = {extras = ["image"], version = "^0.11.4"}
17
+ mlflow = {extras = ["extras"], version = "^2.3.0"}
18
+ torchvision = "^0.15.1"
19
+ av = "^10.0.0"
20
+ rich = "^13.3.4"
21
+ albumentations = "^1.3.0"
22
+ datasets = "^2.12.0"
23
+ xformers = "^0.0.19"
24
+ kornia = "^0.7.0"
25
+ decord = "^0.6.0"
26
+ gdown = "^4.7.1"
27
+ pygifsicle = "^1.0.7"
28
+ ftfy = "^6.1.1"
29
+ regex = "^2023.6.3"
30
+ clip = {git = "https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33"}
31
+
32
+
33
+ [tool.poetry.group.dev.dependencies]
34
+ yapf = "^0.33.0"
35
+ autopep8 = "^2.0.2"
36
+
37
+ [build-system]
38
+ requires = ["poetry-core"]
39
+ build-backend = "poetry.core.masonry.api"
requirements.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.28.0
2
+ bitsandbytes==0.43.0
3
+ transformers==4.39.3
4
+ diffusers==0.27.2
5
+ albumentations==1.3.0
6
+ av==10.0.0
7
+ boto3==1.26.115
8
+ clip @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33
9
+ decord==0.6.0
10
+ einops==0.7.0
11
+ fastapi==0.103.2
12
+ Flask==2.3.2
13
+ gdown==4.7.1
14
+ gradio==4.25.0
15
+ gradio_client==0.15.0
16
+ huggingface-hub==0.21.4
17
+ jupyterlab==3.6.3
18
+ omegaconf==2.3.0
19
+ pandas==2.0.0
20
+ pytorch-lightning==2.0.9
21
+ scikit-image==0.20.0
22
+ scikit-learn==1.2.2
23
+ scipy==1.9.1
24
+ seaborn==0.12.2
25
+ -e .
26
+ torch==2.0.0
27
+ torchdata==0.6.0
28
+ torchvision==0.15.1
29
+ tqdm==4.65.0
30
+ xformers==0.0.19
31
+ open-clip-torch==2.24.0
32
+ jsonargparse==4.20.1
33
+ fairscale==0.4.13
34
+ rotary-embedding-torch==0.5.3
35
+ easydict==1.13
36
+ torchsde==0.2.6
37
+ imageio[ffmpeg]==2.25.0
t2v_enhanced/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+
4
+ WORK_DIR = Path(__file__).resolve().parent
t2v_enhanced/checkpoints/streaming_t2v.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:821f94e00bb9e25b0b03ab5a37ac01d31a24df4573f2b7809c34c54c9712aa5c
3
+ size 25568849525
t2v_enhanced/configs/inference/inference_long_video.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trainer:
2
+ devices: '1'
3
+ num_nodes: 1
4
+ model:
5
+ inference_params:
6
+ class_path: t2v_enhanced.model.pl_module_params_controlnet.InferenceParams
7
+ init_args:
8
+ num_inference_steps: 50 # number of inference steps
9
+ frame_rate: 3
10
+ eta: 1.0 # eta used for DDIM sampler
11
+ guidance_scale: 7.5 # classifier free guidance scale
12
+ conditioning_type: fixed
13
+ start_from_real_input: false
14
+ n_autoregressive_generations: 6 # how many autoregressive generations
15
+ scheduler_cls: '' # we can load other models
16
+ unet_params:
17
+ class_path: t2v_enhanced.model.pl_module_params_controlnet.UNetParams
18
+ init_args:
19
+ use_standard_attention_processor: False
20
+ opt_params:
21
+ class_path: t2v_enhanced.model.pl_module_params_controlnet.OptimizerParams
22
+ init_args:
23
+ noise_generator:
24
+ class_path: t2v_enhanced.model.video_noise_generator.NoiseGenerator
25
+ init_args:
26
+ mode: vanilla # can be 'vanilla','mixed_noise', 'consistI2V' or 'mixed_noise_consistI2V'
27
+ alpha: 1.0
28
+ shared_noise_across_chunks: True # if true, shared noise between all chunks of a video
29
+ forward_steps: 850 # number of DDPM forward steps
30
+ radius: [2,2,2] # radius for time, width and height
31
+ n_predictions: 300
32
+ data:
33
+ class_path: t2v_enhanced.model.datasets.prompt_reader.PromptReader
34
+ init_args:
35
+ prompt_cfg:
36
+ type: file
37
+ content: /home/roberto.henschel/T2V-Enhanced/repo/training_code/t2v_enhanced/evaluation_prompts/prompts_long_eval.txt
t2v_enhanced/configs/text_to_video/config.yaml ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_lightning==2.0.9
2
+ seed_everything: 33
3
+ trainer:
4
+ accelerator: auto
5
+ strategy: auto
6
+ devices: '8'
7
+ num_nodes: 1
8
+ precision: 16-mixed
9
+ logger: null
10
+ callbacks:
11
+ - class_path: pytorch_lightning.callbacks.RichModelSummary
12
+ init_args:
13
+ max_depth: 1
14
+ - class_path: pytorch_lightning.callbacks.RichProgressBar
15
+ init_args:
16
+ refresh_rate: 1
17
+ leave: false
18
+ theme:
19
+ description: white
20
+ progress_bar: '#6206E0'
21
+ progress_bar_finished: '#6206E0'
22
+ progress_bar_pulse: '#6206E0'
23
+ batch_progress: white
24
+ time: grey54
25
+ processing_speed: grey70
26
+ metrics: white
27
+ console_kwargs: null
28
+ fast_dev_run: false
29
+ max_epochs: 5000
30
+ min_epochs: null
31
+ max_steps: 2020000
32
+ min_steps: null
33
+ max_time: null
34
+ limit_train_batches: null
35
+ limit_val_batches: 512
36
+ limit_test_batches: null
37
+ limit_predict_batches: null
38
+ overfit_batches: 0.0
39
+ val_check_interval: 8000
40
+ check_val_every_n_epoch: 1
41
+ num_sanity_val_steps: null
42
+ log_every_n_steps: 10
43
+ enable_checkpointing: null
44
+ enable_progress_bar: null
45
+ enable_model_summary: null
46
+ accumulate_grad_batches: 8
47
+ gradient_clip_val: 1
48
+ gradient_clip_algorithm: norm
49
+ deterministic: null
50
+ benchmark: null
51
+ inference_mode: true
52
+ use_distributed_sampler: true
53
+ profiler: null
54
+ detect_anomaly: false
55
+ barebones: false
56
+ plugins: null
57
+ sync_batchnorm: false
58
+ reload_dataloaders_every_n_epochs: 0
59
+ default_root_dir: null
60
+ model:
61
+ inference_params:
62
+ class_path: t2v_enhanced.model.pl_module_params_controlnet.InferenceParams
63
+ init_args:
64
+ width: 256
65
+ height: 256
66
+ video_length: 16
67
+ guidance_scale: 7.5
68
+ use_dec_scaling: true
69
+ frame_rate: 8
70
+ num_inference_steps: 50
71
+ eta: 1.0
72
+ n_autoregressive_generations: 1
73
+ mode: long_video
74
+ start_from_real_input: true
75
+ eval_loss_metrics: false
76
+ scheduler_cls: ''
77
+ negative_prompt: ''
78
+ conditioning_from_all_past: false
79
+ validation_samples: 80
80
+ conditioning_type: last_chunk
81
+ result_formats:
82
+ - eval_gif
83
+ - gif
84
+ - mp4
85
+ concat_video: true
86
+ opt_params:
87
+ class_path: t2v_enhanced.model.pl_module_params_controlnet.OptimizerParams
88
+ init_args:
89
+ learning_rate: 5.0e-05
90
+ layers_config:
91
+ class_path: t2v_enhanced.model.requires_grad_setter.LayerConfig
92
+ init_args:
93
+ gradient_setup:
94
+ - - false
95
+ - - vae
96
+ - - false
97
+ - - text_encoder
98
+ - - false
99
+ - - image_encoder
100
+ - - true
101
+ - - resampler
102
+ - - true
103
+ - - unet
104
+ - - true
105
+ - - base_model
106
+ - - false
107
+ - - base_model
108
+ - transformer_in
109
+ - - false
110
+ - - base_model
111
+ - temp_attentions
112
+ - - false
113
+ - - base_model
114
+ - temp_convs
115
+ layers_config_base: null
116
+ use_warmup: false
117
+ warmup_steps: 10000
118
+ warmup_start_factor: 1.0e-05
119
+ learning_rate_spatial: 0.0
120
+ use_8_bit_adam: false
121
+ noise_generator: null
122
+ noise_decomposition: null
123
+ perceptual_loss: false
124
+ noise_offset: 0.0
125
+ split_opt_by_node: false
126
+ reset_prediction_type_to_eps: false
127
+ train_val_sampler_may_differ: true
128
+ measure_similarity: false
129
+ similarity_loss: false
130
+ similarity_loss_weight: 1.0
131
+ loss_conditional_weight: 0.0
132
+ loss_conditional_weight_convex: false
133
+ loss_conditional_change_after_step: 0
134
+ mask_conditional_frames: false
135
+ sample_from_noise: true
136
+ mask_alternating: false
137
+ uncondition_freq: -1
138
+ no_text_condition_control: false
139
+ inject_image_into_input: false
140
+ inject_at_T: false
141
+ resampling_steps: 1
142
+ control_freq_in_resample: 1
143
+ resample_to_T: false
144
+ adaptive_loss_reweight: false
145
+ load_resampler_from_ckpt: ''
146
+ skip_controlnet_branch: false
147
+ use_fps_conditioning: false
148
+ num_frame_embeddings_range: 16
149
+ start_frame_training: 16
150
+ start_frame_ctrl: 16
151
+ load_trained_base_model_and_resampler_from_ckpt: ''
152
+ load_trained_controlnet_from_ckpt: ''
153
+ unet_params:
154
+ class_path: t2v_enhanced.model.pl_module_params_controlnet.UNetParams
155
+ init_args:
156
+ conditioning_embedding_out_channels:
157
+ - 32
158
+ - 96
159
+ - 256
160
+ - 512
161
+ ckpt_spatial_layers: ''
162
+ pipeline_repo: damo-vilab/text-to-video-ms-1.7b
163
+ unet_from_diffusers: true
164
+ spatial_latent_input: false
165
+ num_frame_conditioning: 1
166
+ pipeline_class: t2v_enhanced.model.model.controlnet.pipeline_text_to_video_w_controlnet_synth.TextToVideoSDPipeline
167
+ frame_expansion: none
168
+ downsample_controlnet_cond: true
169
+ num_frames: 16
170
+ pre_transformer_in_cond: false
171
+ num_tranformers: 1
172
+ zero_conv_3d: false
173
+ merging_mode: addition
174
+ compute_only_conditioned_frames: false
175
+ condition_encoder: ''
176
+ zero_conv_mode: Identity
177
+ clean_model: true
178
+ merging_mode_base: attention_cross_attention
179
+ attention_mask_params: null
180
+ attention_mask_params_base: null
181
+ modelscope_input_format: true
182
+ temporal_self_attention_only_on_conditioning: false
183
+ temporal_self_attention_mask_included_itself: false
184
+ use_post_merger_zero_conv: false
185
+ weight_control_sample: 1.0
186
+ use_controlnet_mask: false
187
+ random_mask_shift: false
188
+ random_mask: false
189
+ use_resampler: true
190
+ unet_from_pipe: false
191
+ unet_operates_on_2d: false
192
+ image_encoder: CLIP
193
+ use_standard_attention_processor: false
194
+ num_frames_before_chunk: 0
195
+ resampler_type: single_frame
196
+ resampler_cls: t2v_enhanced.model.diffusers_conditional.models.controlnet.image_embedder.ImgEmbContextResampler
197
+ resampler_merging_layers: 4
198
+ image_encoder_obj:
199
+ class_path: t2v_enhanced.model.diffusers_conditional.models.controlnet.image_embedder.FrozenOpenCLIPImageEmbedder
200
+ init_args:
201
+ arch: ViT-H-14
202
+ version: laion2b_s32b_b79k
203
+ device: cuda
204
+ max_length: 77
205
+ freeze: true
206
+ antialias: true
207
+ ucg_rate: 0.0
208
+ unsqueeze_dim: false
209
+ repeat_to_max_len: false
210
+ num_image_crops: 0
211
+ output_tokens: false
212
+ cfg_text_image: false
213
+ aggregation: last_out
214
+ resampler_random_shift: true
215
+ img_cond_alpha_per_frame: false
216
+ num_control_input_frames: 8
217
+ use_image_encoder_normalization: false
218
+ use_of: false
219
+ ema_param: -1.0
220
+ concat: false
221
+ use_image_tokens_main: true
222
+ use_image_tokens_ctrl: false
223
+ result_fol: results
224
+ exp_name: my_exp_name
225
+ run_name: my_run_name
226
+ scale_lr: false
227
+ matmul_precision: high
t2v_enhanced/gradio_demo.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # General
2
+ import os
3
+ from os.path import join as opj
4
+ import argparse
5
+ import datetime
6
+ from pathlib import Path
7
+ import torch
8
+ import gradio as gr
9
+ import tempfile
10
+ import yaml
11
+ from t2v_enhanced.model.video_ldm import VideoLDM
12
+
13
+ # Utilities
14
+ from inference_utils import *
15
+ from model_init import *
16
+ from model_func import *
17
+
18
+
19
+ on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument('--public_access', action='store_true', default=True)
22
+ parser.add_argument('--where_to_log', type=str, default="gradio_output")
23
+ parser.add_argument('--device', type=str, default="cuda")
24
+ args = parser.parse_args()
25
+
26
+
27
+ Path(args.where_to_log).mkdir(parents=True, exist_ok=True)
28
+ result_fol = Path(args.where_to_log).absolute()
29
+ device = args.device
30
+
31
+
32
+ # --------------------------
33
+ # ----- Configurations -----
34
+ # --------------------------
35
+ ckpt_file_streaming_t2v = Path("checkpoints/streaming_t2v.ckpt").absolute()
36
+ cfg_v2v = {'downscale': 1, 'upscale_size': (1280, 720), 'model_id': 'damo/Video-to-Video', 'pad': True}
37
+
38
+
39
+ # --------------------------
40
+ # ----- Initialization -----
41
+ # --------------------------
42
+ ms_model = init_modelscope(device)
43
+ # zs_model = init_zeroscope(device)
44
+ stream_cli, stream_model = init_streamingt2v_model(ckpt_file_streaming_t2v, result_fol)
45
+ msxl_model = init_v2v_model(cfg_v2v)
46
+
47
+ inference_generator = torch.Generator(device="cuda")
48
+
49
+
50
+ # -------------------------
51
+ # ----- Functionality -----
52
+ # -------------------------
53
+ def generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, n_prompt, seed, t, image_guidance, where_to_log=result_fol):
54
+ now = datetime.datetime.now()
55
+ name = prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_")
56
+
57
+ if num_frames == [] or num_frames is None:
58
+ num_frames = 56
59
+ else:
60
+ num_frames = int(num_frames.split(" ")[0])
61
+
62
+ n_autoreg_gen = num_frames/8-8
63
+
64
+ inference_generator.manual_seed(seed)
65
+ short_video = ms_short_gen(prompt, ms_model, inference_generator, t, device)
66
+ stream_long_gen(prompt, short_video, n_autoreg_gen, n_prompt, seed, t, image_guidance, name, stream_cli, stream_model)
67
+ video_path = opj(where_to_log, name+".mp4")
68
+ return video_path
69
+
70
+ def enhance(prompt, input_to_enhance):
71
+ encoded_video = video2video(prompt, input_to_enhance, result_fol, cfg_v2v, msxl_model)
72
+ return encoded_video
73
+
74
+
75
+ # --------------------------
76
+ # ----- Gradio-Demo UI -----
77
+ # --------------------------
78
+ with gr.Blocks() as demo:
79
+ gr.HTML(
80
+ """
81
+ <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
82
+ <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
83
+ <a href="https://github.com/Picsart-AI-Research/StreamingT2V" style="color:blue;">StreamingT2V</a>
84
+ </h1>
85
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
86
+ Roberto Henschel<sup>1*</sup>, Levon Khachatryan<sup>1*</sup>, Daniil Hayrapetyan<sup>1*</sup>, Hayk Poghosyan<sup>1</sup>, Vahram Tadevosyan<sup>1</sup>, Zhangyang Wang<sup>1,2</sup>, Shant Navasardyan<sup>1</sup>, Humphrey Shi<sup>1,3</sup>
87
+ </h2>
88
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
89
+ <sup>1</sup>Picsart AI Resarch (PAIR), <sup>2</sup>UT Austin, <sup>3</sup>SHI Labs @ Georgia Tech, Oregon & UIUC
90
+ </h2>
91
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
92
+ *Equal Contribution
93
+ </h2>
94
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
95
+ [<a href="https://arxiv.org/abs/2403.14773" style="color:blue;">arXiv</a>]
96
+ [<a href="https://github.com/Picsart-AI-Research/StreamingT2V" style="color:blue;">GitHub</a>]
97
+ </h2>
98
+ <h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
99
+ <b>StreamingT2V</b> is an advanced autoregressive technique that enables the creation of long videos featuring rich motion dynamics without any stagnation.
100
+ It ensures temporal consistency throughout the video, aligns closely with the descriptive text, and maintains high frame-level image quality.
101
+ Our demonstrations include successful examples of videos up to <b>1200 frames, spanning 2 minutes</b>, and can be extended for even longer durations.
102
+ Importantly, the effectiveness of StreamingT2V is not limited by the specific Text2Video model used, indicating that improvements in base models could yield even higher-quality videos.
103
+ </h2>
104
+ </div>
105
+ """)
106
+
107
+ if on_huggingspace:
108
+ gr.HTML("""
109
+ <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
110
+ <br/>
111
+ <a href="https://huggingface.co/spaces/PAIR/StreamingT2V?duplicate=true">
112
+ <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
113
+ </p>""")
114
+
115
+ with gr.Row():
116
+ with gr.Column():
117
+ with gr.Row():
118
+ with gr.Column():
119
+ with gr.Row():
120
+ num_frames = gr.Dropdown(["24", "32", "40", "48", "56", "80 - only on local", "240 - only on local", "600 - only on local", "1200 - only on local", "10000 - only on local"], label="Number of Video Frames: Default is 56", info="For >80 frames use local workstation!")
121
+ with gr.Row():
122
+ prompt_stage1 = gr.Textbox(label='Textual Prompt', placeholder="Ex: Dog running on the street.")
123
+ with gr.Row():
124
+ image_stage1 = gr.Image(label='Image Prompt (only required for I2V base models)', show_label=True, scale=1, show_download_button=True)
125
+ with gr.Column():
126
+ video_stage1 = gr.Video(label='Long Video Preview', show_label=True, interactive=False, scale=2, show_download_button=True)
127
+ with gr.Row():
128
+ run_button_stage1 = gr.Button("Long Video Preview Generation")
129
+
130
+ with gr.Row():
131
+ with gr.Column():
132
+ with gr.Accordion('Advanced options', open=False):
133
+ model_name_stage1 = gr.Dropdown(
134
+ choices=["T2V: ModelScope", "T2V: ZeroScope", "I2V: AnimateDiff"],
135
+ label="Base Model. Default is ModelScope",
136
+ info="Currently supports only ModelScope. We will add more options later!",
137
+ )
138
+ model_name_stage2 = gr.Dropdown(
139
+ choices=["ModelScope-XL", "Another", "Another"],
140
+ label="Enhancement Model. Default is ModelScope-XL",
141
+ info="Currently supports only ModelScope-XL. We will add more options later!",
142
+ )
143
+ n_prompt = gr.Textbox(label="Optional Negative Prompt", value='')
144
+ seed = gr.Slider(label='Seed', minimum=0, maximum=65536, value=33,step=1,)
145
+
146
+ t = gr.Slider(label="Timesteps", minimum=0, maximum=100, value=50, step=1,)
147
+ image_guidance = gr.Slider(label='Image guidance scale', minimum=1, maximum=10, value=9.0, step=1.0)
148
+
149
+ with gr.Column():
150
+ with gr.Row():
151
+ video_stage2 = gr.Video(label='Enhanced Long Video', show_label=True, interactive=False, height=473, show_download_button=True)
152
+ with gr.Row():
153
+ run_button_stage2 = gr.Button("Long Video Enhancement")
154
+ '''
155
+ '''
156
+ gr.HTML(
157
+ """
158
+ <div style="text-align: justify; max-width: 1200px; margin: 20px auto;">
159
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
160
+ <b>Version: v1.0</b>
161
+ </h3>
162
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
163
+ <b>Caution</b>:
164
+ We would like the raise the awareness of users of this demo of its potential issues and concerns.
165
+ Like previous large foundation models, StreamingT2V could be problematic in some cases, partially we use pretrained ModelScope, therefore StreamingT2V can Inherit Its Imperfections.
166
+ So far, we keep all features available for research testing both to show the great potential of the StreamingT2V framework and to collect important feedback to improve the model in the future.
167
+ We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors.
168
+ </h3>
169
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
170
+ <b>Biases and content acknowledgement</b>:
171
+ Beware that StreamingT2V may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence.
172
+ StreamingT2V in this demo is meant only for research purposes.
173
+ </h3>
174
+ </div>
175
+ """)
176
+
177
+ inputs_t2v = [prompt_stage1, num_frames, image_stage1, model_name_stage1, model_name_stage2, n_prompt, seed, t, image_guidance]
178
+ run_button_stage1.click(fn=generate, inputs=inputs_t2v, outputs=video_stage1,)
179
+
180
+ inputs_v2v = [prompt_stage1, video_stage1]
181
+ run_button_stage2.click(fn=enhance, inputs=inputs_v2v, outputs=video_stage2,)
182
+
183
+
184
+ if on_huggingspace:
185
+ demo.queue(max_size=20)
186
+ demo.launch(debug=True)
187
+ else:
188
+ _, _, link = demo.queue(api_open=False).launch(share=args.public_access)
189
+ print(link)
t2v_enhanced/inference.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # General
2
+ import os
3
+ from os.path import join as opj
4
+ import argparse
5
+ import datetime
6
+ from pathlib import Path
7
+ import torch
8
+ import gradio as gr
9
+ import tempfile
10
+ import yaml
11
+ from t2v_enhanced.model.video_ldm import VideoLDM
12
+
13
+ # Utilities
14
+ from inference_utils import *
15
+ from model_init import *
16
+ from model_func import *
17
+
18
+
19
+ if __name__ == "__main__":
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument('--prompt', type=str, default="A cat running on the street", help="The prompt to guide video generation.")
22
+ parser.add_argument('--image', type=str, default="", help="Path to image conditioning.")
23
+ # parser.add_argument('--video', type=str, default="", help="Path to video conditioning.")
24
+ parser.add_argument('--base_model', type=str, default="ModelscopeT2V", help="Base model to generate first chunk from", choices=["ModelscopeT2V", "AnimateDiff", "SVD"])
25
+ parser.add_argument('--num_frames', type=int, default=24, help="The number of video frames to generate.")
26
+ parser.add_argument('--negative_prompt', type=str, default="", help="The prompt to guide what to not include in video generation.")
27
+ parser.add_argument('--num_steps', type=int, default=50, help="The number of denoising steps.")
28
+ parser.add_argument('--image_guidance', type=float, default=9.0, help="The guidance scale.")
29
+
30
+ parser.add_argument('--output_dir', type=str, default="results", help="Path where to save the generated videos.")
31
+ parser.add_argument('--device', type=str, default="cuda")
32
+ parser.add_argument('--seed', type=int, default=33, help="Random seed")
33
+ args = parser.parse_args()
34
+
35
+
36
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
37
+ result_fol = Path(args.output_dir).absolute()
38
+ device = args.device
39
+
40
+
41
+ # --------------------------
42
+ # ----- Configurations -----
43
+ # --------------------------
44
+ ckpt_file_streaming_t2v = Path("checkpoints/streaming_t2v.ckpt").absolute()
45
+ cfg_v2v = {'downscale': 1, 'upscale_size': (1280, 720), 'model_id': 'damo/Video-to-Video', 'pad': True}
46
+
47
+
48
+ # --------------------------
49
+ # ----- Initialization -----
50
+ # --------------------------
51
+ stream_cli, stream_model = init_streamingt2v_model(ckpt_file_streaming_t2v, result_fol)
52
+ if args.base_model == "ModelscopeT2V":
53
+ model = init_modelscope(device)
54
+ elif args.base_model == "AnimateDiff":
55
+ model = init_animatediff(device)
56
+ elif args.base_model == "SVD":
57
+ model = init_svd(device)
58
+ sdxl_model = init_sdxl(device)
59
+
60
+
61
+ inference_generator = torch.Generator(device="cuda")
62
+
63
+
64
+ # ------------------
65
+ # ----- Inputs -----
66
+ # ------------------
67
+ now = datetime.datetime.now()
68
+ name = args.prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_")
69
+
70
+ inference_generator = torch.Generator(device="cuda")
71
+ inference_generator.manual_seed(args.seed)
72
+
73
+ if args.base_model == "ModelscopeT2V":
74
+ short_video = ms_short_gen(args.prompt, model, inference_generator)
75
+ elif args.base_model == "AnimateDiff":
76
+ short_video = ad_short_gen(args.prompt, model, inference_generator)
77
+ elif args.base_model == "SVD":
78
+ short_video = svd_short_gen(args.image, args.prompt, model, sdxl_model, inference_generator)
79
+
80
+ n_autoreg_gen = args.num_frames // 8 - 8
81
+ stream_long_gen(args.prompt, short_video, n_autoreg_gen, args.negative_prompt, args.seed, args.num_steps, args.image_guidance, name, stream_cli, stream_model)
82
+ video2video(args.prompt, opj(result_fol, name+".mp4"), result_fol, cfg_v2v, msxl_model)
t2v_enhanced/inference_utils.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import argparse
2
+ import sys
3
+ from pathlib import Path
4
+ from pytorch_lightning.cli import LightningCLI
5
+ from PIL import Image
6
+
7
+ # For streaming
8
+ import yaml
9
+ from copy import deepcopy
10
+ from typing import List, Optional
11
+ from jsonargparse.typing import restricted_string_type
12
+
13
+
14
+ # --------------------------------------
15
+ # ----------- For Streaming ------------
16
+ # --------------------------------------
17
+ class CustomCLI(LightningCLI):
18
+ def add_arguments_to_parser(self, parser):
19
+ parser.add_argument("--result_fol", type=Path,
20
+ help="Set the path to the result folder", default="results")
21
+ parser.add_argument("--exp_name", type=str, help="Experiment name")
22
+ parser.add_argument("--run_name", type=str,
23
+ help="Current run name")
24
+ parser.add_argument("--prompts", type=Optional[List[str]])
25
+ parser.add_argument("--scale_lr", type=bool,
26
+ help="Scale lr", default=False)
27
+ CodeType = restricted_string_type(
28
+ 'CodeType', '(medium)|(high)|(highest)')
29
+ parser.add_argument("--matmul_precision", type=CodeType)
30
+ parser.add_argument("--ckpt", type=Path,)
31
+ parser.add_argument("--n_predictions", type=int)
32
+ return parser
33
+
34
+ def remove_value(dictionary, x):
35
+ for key, value in list(dictionary.items()):
36
+ if key == x:
37
+ del dictionary[key]
38
+ elif isinstance(value, dict):
39
+ remove_value(value, x)
40
+ return dictionary
41
+
42
+ def legacy_transformation(cfg: yaml):
43
+ cfg = deepcopy(cfg)
44
+ cfg["trainer"]["devices"] = "1"
45
+ cfg["trainer"]['num_nodes'] = 1
46
+
47
+ if not "class_path" in cfg["model"]["inference_params"]:
48
+ cfg["model"]["inference_params"] = {
49
+ "class_path": "t2v_enhanced.model.pl_module_params.InferenceParams", "init_args": cfg["model"]["inference_params"]}
50
+ return cfg
51
+
52
+
53
+ # ---------------------------------------------
54
+ # ----------- For enhancement -----------
55
+ # ---------------------------------------------
56
+ def add_margin(pil_img, top, right, bottom, left, color):
57
+ width, height = pil_img.size
58
+ new_width = width + right + left
59
+ new_height = height + top + bottom
60
+ result = Image.new(pil_img.mode, (new_width, new_height), color)
61
+ result.paste(pil_img, (left, top))
62
+ return result
63
+
64
+ def resize_to_fit(image, size):
65
+ W, H = size
66
+ w, h = image.size
67
+ if H / h > W / w:
68
+ H_ = int(h * W / w)
69
+ W_ = W
70
+ else:
71
+ W_ = int(w * H / h)
72
+ H_ = H
73
+ return image.resize((W_, H_))
74
+
75
+ def pad_to_fit(image, size):
76
+ W, H = size
77
+ w, h = image.size
78
+ pad_h = (H - h) // 2
79
+ pad_w = (W - w) // 2
80
+ return add_margin(image, pad_h, pad_w, pad_h, pad_w, (0, 0, 0))
81
+
82
+ def resize_and_keep(pil_img):
83
+ myheight = 576
84
+ hpercent = (myheight/float(pil_img.size[1]))
85
+ wsize = int((float(pil_img.size[0])*float(hpercent)))
86
+ pil_img = pil_img.resize((wsize, myheight))
87
+ return pil_img
88
+
89
+ def center_crop(pil_img):
90
+ width, height = pil_img.size
91
+ new_width = 576
92
+ new_height = 576
93
+
94
+ left = (width - new_width)/2
95
+ top = (height - new_height)/2
96
+ right = (width + new_width)/2
97
+ bottom = (height + new_height)/2
98
+
99
+ # Crop the center of the image
100
+ pil_img = pil_img.crop((left, top, right, bottom))
101
+ return pil_img
t2v_enhanced/model/__init__.py ADDED
File without changes
t2v_enhanced/model/callbacks.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+ from pytorch_lightning import Callback
4
+ import os
5
+ import torch
6
+ from lightning_fabric.utilities.cloud_io import get_filesystem
7
+ from pytorch_lightning.cli import LightningArgumentParser
8
+ from pytorch_lightning import LightningModule, Trainer
9
+ from lightning_utilities.core.imports import RequirementCache
10
+ from omegaconf import OmegaConf
11
+
12
+ _JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache(
13
+ "jsonargparse[signatures]>=4.17.0")
14
+
15
+ if _JSONARGPARSE_SIGNATURES_AVAILABLE:
16
+ import docstring_parser
17
+ from jsonargparse import (
18
+ ActionConfigFile,
19
+ ArgumentParser,
20
+ class_from_function,
21
+ Namespace,
22
+ register_unresolvable_import_paths,
23
+ set_config_read_mode,
24
+ )
25
+
26
+ # Required until fix https://github.com/pytorch/pytorch/issues/74483
27
+ register_unresolvable_import_paths(torch)
28
+ set_config_read_mode(fsspec_enabled=True)
29
+ else:
30
+ locals()["ArgumentParser"] = object
31
+ locals()["Namespace"] = object
32
+
33
+
34
+ class SaveConfigCallback(Callback):
35
+ """Saves a LightningCLI config to the log_dir when training starts.
36
+
37
+ Args:
38
+ parser: The parser object used to parse the configuration.
39
+ config: The parsed configuration that will be saved.
40
+ config_filename: Filename for the config file.
41
+ overwrite: Whether to overwrite an existing config file.
42
+ multifile: When input is multiple config files, saved config preserves this structure.
43
+
44
+ Raises:
45
+ RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ parser: LightningArgumentParser,
51
+ config: Namespace,
52
+ log_dir: str,
53
+ config_filename: str = "config.yaml",
54
+ overwrite: bool = False,
55
+ multifile: bool = False,
56
+
57
+ ) -> None:
58
+ self.parser = parser
59
+ self.config = config
60
+ self.config_filename = config_filename
61
+ self.overwrite = overwrite
62
+ self.multifile = multifile
63
+ self.already_saved = False
64
+ self.log_dir = log_dir
65
+
66
+ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
67
+ if self.already_saved:
68
+ return
69
+
70
+ log_dir = self.log_dir
71
+ assert log_dir is not None
72
+ config_path = os.path.join(log_dir, self.config_filename)
73
+ fs = get_filesystem(log_dir)
74
+
75
+ if not self.overwrite:
76
+ # check if the file exists on rank 0
77
+ file_exists = fs.isfile(
78
+ config_path) if trainer.is_global_zero else False
79
+ # broadcast whether to fail to all ranks
80
+ file_exists = trainer.strategy.broadcast(file_exists)
81
+ if file_exists:
82
+ raise RuntimeError(
83
+ f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting"
84
+ " results of a previous run. You can delete the previous config file,"
85
+ " set `LightningCLI(save_config_callback=None)` to disable config saving,"
86
+ ' or set `LightningCLI(save_config_kwargs={"overwrite": True})` to overwrite the config file.'
87
+ )
88
+
89
+ # save the file on rank 0
90
+ if trainer.is_global_zero:
91
+ # save only on rank zero to avoid race conditions.
92
+ # the `log_dir` needs to be created as we rely on the logger to do it usually
93
+ # but it hasn't logged anything at this point
94
+ fs.makedirs(log_dir, exist_ok=True)
95
+ self.parser.save(
96
+ self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
97
+ )
98
+ self.already_saved = True
99
+ trainer.logger.log_hyperparams(OmegaConf.load(config_path))
100
+
101
+ # broadcast so that all ranks are in sync on future calls to .setup()
102
+ self.already_saved = trainer.strategy.broadcast(self.already_saved)
t2v_enhanced/model/datasets/prompt_reader.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Dict, List, Optional
3
+
4
+ import numpy as np
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ from pytorch_lightning.utilities.types import EVAL_DATALOADERS
8
+
9
+ from t2v_enhanced.model.datasets.video_dataset import Annotations
10
+ import json
11
+
12
+
13
+ class ConcatDataset(torch.utils.data.Dataset):
14
+ def __init__(self, datasets):
15
+ self.datasets = datasets
16
+ self.model_id = datasets["reconstruction_dataset"].model_id
17
+
18
+ def __getitem__(self, idx):
19
+ sample = {ds: self.datasets[ds].__getitem__(
20
+ idx) for ds in self.datasets}
21
+ return sample
22
+
23
+ def __len__(self):
24
+ return min(len(self.datasets[d]) for d in self.datasets)
25
+
26
+
27
+ class CustomPromptsDataset(torch.utils.data.Dataset):
28
+
29
+ def __init__(self, prompt_cfg: Dict[str, str]):
30
+ super().__init__()
31
+
32
+ if prompt_cfg["type"] == "prompt":
33
+ self.prompts = [prompt_cfg["content"]]
34
+ elif prompt_cfg["type"] == "file":
35
+ file = Path(prompt_cfg["content"])
36
+ if file.suffix == ".npy":
37
+ self.prompts = np.load(file.as_posix())
38
+ elif file.suffix == ".txt":
39
+ with open(prompt_cfg["content"]) as f:
40
+ lines = [line.rstrip() for line in f]
41
+ self.prompts = lines
42
+ elif file.suffix == ".json":
43
+ with open(prompt_cfg["content"],"r") as file:
44
+ metadata = json.load(file)
45
+ if "videos_root" in prompt_cfg:
46
+ videos_root = Path(prompt_cfg["videos_root"])
47
+ video_path = [str(videos_root / sample["page_dir"] /
48
+ f"{sample['videoid']}.mp4") for sample in metadata]
49
+ else:
50
+ video_path = [str(sample["page_dir"] /
51
+ f"{sample['videoid']}.mp4") for sample in metadata]
52
+ self.prompts = [sample["prompt"] for sample in metadata]
53
+ self.video_path = video_path
54
+
55
+
56
+
57
+
58
+ transformed_prompts = []
59
+ for prompt in self.prompts:
60
+ transformed_prompts.append(
61
+ Annotations.clean_prompt(prompt))
62
+ self.prompts = transformed_prompts
63
+
64
+ def __len__(self):
65
+ return len(self.prompts)
66
+
67
+ def __getitem__(self, index):
68
+ output = {"prompt": self.prompts[index]}
69
+ if hasattr(self,"video_path"):
70
+ output["video"] = self.video_path[index]
71
+ return output
72
+
73
+
74
+ class PromptReader(pl.LightningDataModule):
75
+ def __init__(self, prompt_cfg: Dict[str, str]):
76
+ super().__init__()
77
+ self.predict_dataset = CustomPromptsDataset(prompt_cfg)
78
+
79
+ def predict_dataloader(self) -> EVAL_DATALOADERS:
80
+ return torch.utils.data.DataLoader(self.predict_dataset, batch_size=1, pin_memory=False, shuffle=False, drop_last=False)
t2v_enhanced/model/datasets/video_dataset.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from einops import repeat
3
+ from diffusers import DiffusionPipeline
4
+ from decord import VideoReader, cpu
5
+ import torchvision
6
+ import torch
7
+ import numpy as np
8
+ import decord
9
+ import albumentations as album
10
+ import math
11
+ import random
12
+ from abc import abstractmethod
13
+ from copy import deepcopy
14
+ from pathlib import Path
15
+ from typing import Any, Dict, List, Union
16
+ from PIL import Image
17
+ import json
18
+ Image.MAX_IMAGE_PIXELS = None
19
+
20
+ decord.bridge.set_bridge("torch")
21
+
22
+ class Annotations():
23
+
24
+ def __init__(self,
25
+ annotation_cfg: Dict) -> None:
26
+ self.annotation_cfg = annotation_cfg
27
+
28
+ # TODO find all special characters
29
+
30
+ @staticmethod
31
+ def process_string(string):
32
+ for special_char in [".", ",", ":"]:
33
+ result = ""
34
+ i = 0
35
+ while i < len(string):
36
+ if string[i] == special_char:
37
+ if i > 0 and i < len(string) - 1 and string[i-1].isalpha() and string[i+1].isalpha():
38
+ result += special_char+" "
39
+ else:
40
+ result += special_char
41
+ else:
42
+ result += string[i]
43
+ i += 1
44
+ string = result
45
+ string = result
46
+ return result
47
+
48
+ @staticmethod
49
+ def clean_prompt(prompt):
50
+ prompt = " ".join(prompt.split())
51
+ prompt = prompt.replace(" , ", ", ")
52
+ prompt = prompt.replace(" . ", ". ")
53
+ prompt = prompt.replace(" : ", ": ")
54
+ prompt = Annotations.process_string(prompt)
55
+ return prompt
56
+ # return " ".join(prompt.split())
57
+
t2v_enhanced/model/diffusers_conditional/__init__.py ADDED
File without changes
t2v_enhanced/model/diffusers_conditional/models/__init__.py ADDED
File without changes
t2v_enhanced/model/diffusers_conditional/models/controlnet/__init__.py ADDED
File without changes
t2v_enhanced/model/diffusers_conditional/models/controlnet/attention.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import Callable, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils.import_utils import is_xformers_available
22
+ # from diffusers.models.attention_processor import Attention
23
+ # from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention import Attention
24
+ from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention
25
+ from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
26
+ # from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention
27
+
28
+
29
+ if is_xformers_available():
30
+ import xformers
31
+ import xformers.ops
32
+ else:
33
+ xformers = None
34
+
35
+
36
+
37
+ class BasicTransformerBlock(nn.Module):
38
+ r"""
39
+ A basic Transformer block.
40
+
41
+ Parameters:
42
+ dim (`int`): The number of channels in the input and output.
43
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
44
+ attention_head_dim (`int`): The number of channels in each head.
45
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
46
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
47
+ only_cross_attention (`bool`, *optional*):
48
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
49
+ double_self_attention (`bool`, *optional*):
50
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
51
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
52
+ num_embeds_ada_norm (:
53
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
54
+ attention_bias (:
55
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ dim: int,
61
+ num_attention_heads: int,
62
+ attention_head_dim: int,
63
+ is_spatial_attention: bool = False,
64
+ dropout=0.0,
65
+ cross_attention_dim: Optional[int] = None,
66
+ activation_fn: str = "geglu",
67
+ num_embeds_ada_norm: Optional[int] = None,
68
+ attention_bias: bool = False,
69
+ only_cross_attention: bool = False,
70
+ double_self_attention: bool = False,
71
+ upcast_attention: bool = False,
72
+ norm_elementwise_affine: bool = True,
73
+ norm_type: str = "layer_norm",
74
+ final_dropout: bool = False,
75
+ use_image_embedding: bool = False,
76
+ unet_params=None,
77
+ ):
78
+ super().__init__()
79
+
80
+ self.only_cross_attention = only_cross_attention
81
+
82
+ self.use_ada_layer_norm_zero = (
83
+ num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
84
+ self.use_ada_layer_norm = (
85
+ num_embeds_ada_norm is not None) and norm_type == "ada_norm"
86
+
87
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
88
+ raise ValueError(
89
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
90
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
91
+ )
92
+
93
+ # Define 3 blocks. Each block has its own normalization layer.
94
+ # 1. Self-Attn
95
+ if self.use_ada_layer_norm:
96
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
97
+ elif self.use_ada_layer_norm_zero:
98
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
99
+ else:
100
+ self.norm1 = nn.LayerNorm(
101
+ dim, elementwise_affine=norm_elementwise_affine)
102
+
103
+ self.attn1 = Attention(
104
+ query_dim=dim,
105
+ heads=num_attention_heads,
106
+ dim_head=attention_head_dim,
107
+ dropout=dropout,
108
+ bias=attention_bias,
109
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
110
+ upcast_attention=upcast_attention,
111
+ is_spatial_attention=is_spatial_attention,
112
+ use_image_embedding=use_image_embedding,
113
+ )
114
+
115
+ # 2. Cross-Attn
116
+ if cross_attention_dim is not None or double_self_attention:
117
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
118
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
119
+ # the second cross attention block.
120
+ self.norm2 = (
121
+ AdaLayerNorm(dim, num_embeds_ada_norm)
122
+ if self.use_ada_layer_norm
123
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
124
+ )
125
+ self.attn2 = Attention(
126
+ query_dim=dim,
127
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
128
+ heads=num_attention_heads,
129
+ dim_head=attention_head_dim,
130
+ dropout=dropout,
131
+ bias=attention_bias,
132
+ upcast_attention=upcast_attention,
133
+ is_spatial_attention=is_spatial_attention,
134
+ use_image_embedding=use_image_embedding,
135
+ unet_params=unet_params,
136
+ ) # is self-attn if encoder_hidden_states is none
137
+ else:
138
+ self.norm2 = None
139
+ self.attn2 = None
140
+
141
+ # 3. Feed-forward
142
+ self.norm3 = nn.LayerNorm(
143
+ dim, elementwise_affine=norm_elementwise_affine)
144
+ self.ff = FeedForward(
145
+ dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
146
+
147
+ def forward(
148
+ self,
149
+ hidden_states,
150
+ attention_mask=None,
151
+ encoder_hidden_states=None,
152
+ encoder_attention_mask=None,
153
+ timestep=None,
154
+ cross_attention_kwargs=None,
155
+ class_labels=None,
156
+ ):
157
+ # Notice that normalization is always applied before the real computation in the following blocks.
158
+ # 1. Self-Attention
159
+
160
+ if self.use_ada_layer_norm:
161
+ norm_hidden_states = self.norm1(hidden_states, timestep)
162
+ elif self.use_ada_layer_norm_zero:
163
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
164
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
165
+ )
166
+ else:
167
+ norm_hidden_states = self.norm1(hidden_states)
168
+
169
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
170
+ attn_output = self.attn1(
171
+ norm_hidden_states,
172
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
173
+ attention_mask=attention_mask,
174
+ **cross_attention_kwargs,
175
+ )
176
+ if self.use_ada_layer_norm_zero:
177
+ attn_output = gate_msa.unsqueeze(1) * attn_output
178
+ hidden_states = attn_output + hidden_states
179
+
180
+ # 2. Cross-Attention
181
+ if self.attn2 is not None:
182
+ norm_hidden_states = (
183
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(
184
+ hidden_states)
185
+ )
186
+ # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
187
+ # prepare attention mask here
188
+
189
+ attn_output = self.attn2(
190
+ norm_hidden_states,
191
+ encoder_hidden_states=encoder_hidden_states,
192
+ attention_mask=encoder_attention_mask,
193
+ **cross_attention_kwargs,
194
+ )
195
+ hidden_states = attn_output + hidden_states
196
+
197
+ # 3. Feed-forward
198
+ norm_hidden_states = self.norm3(hidden_states)
199
+
200
+ if self.use_ada_layer_norm_zero:
201
+ norm_hidden_states = norm_hidden_states * \
202
+ (1 + scale_mlp[:, None]) + shift_mlp[:, None]
203
+
204
+ ff_output = self.ff(norm_hidden_states)
205
+
206
+ if self.use_ada_layer_norm_zero:
207
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
208
+
209
+ hidden_states = ff_output + hidden_states
210
+
211
+ return hidden_states
212
+
213
+
214
+ class FeedForward(nn.Module):
215
+ r"""
216
+ A feed-forward layer.
217
+
218
+ Parameters:
219
+ dim (`int`): The number of channels in the input.
220
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
221
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
222
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
223
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
224
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ dim: int,
230
+ dim_out: Optional[int] = None,
231
+ mult: int = 4,
232
+ dropout: float = 0.0,
233
+ activation_fn: str = "geglu",
234
+ final_dropout: bool = False,
235
+ ):
236
+ super().__init__()
237
+ inner_dim = int(dim * mult)
238
+ dim_out = dim_out if dim_out is not None else dim
239
+
240
+ if activation_fn == "gelu":
241
+ act_fn = GELU(dim, inner_dim)
242
+ if activation_fn == "gelu-approximate":
243
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
244
+ elif activation_fn == "geglu":
245
+ act_fn = GEGLU(dim, inner_dim)
246
+ elif activation_fn == "geglu-approximate":
247
+ act_fn = ApproximateGELU(dim, inner_dim)
248
+
249
+ self.net = nn.ModuleList([])
250
+ # project in
251
+ self.net.append(act_fn)
252
+ # project dropout
253
+ self.net.append(nn.Dropout(dropout))
254
+ # project out
255
+ self.net.append(nn.Linear(inner_dim, dim_out))
256
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
257
+ if final_dropout:
258
+ self.net.append(nn.Dropout(dropout))
259
+
260
+ def forward(self, hidden_states):
261
+ for module in self.net:
262
+ hidden_states = module(hidden_states)
263
+ return hidden_states
264
+
265
+
266
+ class GEGLU(nn.Module):
267
+ r"""
268
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
269
+
270
+ Parameters:
271
+ dim_in (`int`): The number of channels in the input.
272
+ dim_out (`int`): The number of channels in the output.
273
+ """
274
+
275
+ def __init__(self, dim_in: int, dim_out: int):
276
+ super().__init__()
277
+ self.proj = nn.Linear(dim_in, dim_out * 2)
278
+
279
+ def gelu(self, gate):
280
+ if gate.device.type != "mps":
281
+ return F.gelu(gate)
282
+ # mps: gelu is not implemented for float16
283
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
284
+
285
+ def forward(self, hidden_states):
286
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
287
+ return hidden_states * self.gelu(gate)
288
+
289
+
290
+
291
+
t2v_enhanced/model/diffusers_conditional/models/controlnet/attention_processor.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from einops import repeat
15
+ from typing import Callable, Optional, Union
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils import deprecate, logging
22
+ from diffusers.utils.import_utils import is_xformers_available
23
+
24
+
25
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
26
+
27
+
28
+ if is_xformers_available():
29
+ import xformers
30
+ import xformers.ops
31
+ else:
32
+ xformers = None
33
+
34
+
35
+ class Attention(nn.Module):
36
+ r"""
37
+ A cross attention layer.
38
+
39
+ Parameters:
40
+ query_dim (`int`): The number of channels in the query.
41
+ cross_attention_dim (`int`, *optional*):
42
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
43
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
44
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
45
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
46
+ bias (`bool`, *optional*, defaults to False):
47
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ query_dim: int,
53
+ is_spatial_attention: bool,
54
+ cross_attention_dim: Optional[int] = None,
55
+ heads: int = 8,
56
+ dim_head: int = 64,
57
+ dropout: float = 0.0,
58
+ bias=False,
59
+ upcast_attention: bool = False,
60
+ upcast_softmax: bool = False,
61
+ cross_attention_norm: Optional[str] = None,
62
+ cross_attention_norm_num_groups: int = 32,
63
+ added_kv_proj_dim: Optional[int] = None,
64
+ norm_num_groups: Optional[int] = None,
65
+ out_bias: bool = True,
66
+ scale_qk: bool = True,
67
+ only_cross_attention: bool = False,
68
+ processor: Optional["AttnProcessor"] = None,
69
+ use_image_embedding: bool = False,
70
+ unet_params=None,
71
+ ):
72
+ super().__init__()
73
+ inner_dim = dim_head * heads
74
+ self.cross_attention_mode = cross_attention_dim is not None
75
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
76
+ self.is_spatial_attention = is_spatial_attention
77
+ self.upcast_attention = upcast_attention
78
+ self.upcast_softmax = upcast_softmax
79
+ self.train_image_cond_weight = use_image_embedding
80
+ self.use_image_embedding = use_image_embedding
81
+
82
+ self.scale = dim_head**-0.5 if scale_qk else 1.0
83
+
84
+ self.heads = heads
85
+ # for slice_size > 0 the attention score computation
86
+ # is split across the batch axis to save memory
87
+ # You can set slice_size with `set_attention_slice`
88
+ self.sliceable_head_dim = heads
89
+
90
+ self.added_kv_proj_dim = added_kv_proj_dim
91
+ self.only_cross_attention = only_cross_attention
92
+
93
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
94
+ raise ValueError(
95
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
96
+ )
97
+
98
+ if norm_num_groups is not None:
99
+ self.group_norm = nn.GroupNorm(
100
+ num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
101
+ else:
102
+ self.group_norm = None
103
+
104
+ if cross_attention_norm is None:
105
+ self.norm_cross = None
106
+ elif cross_attention_norm == "layer_norm":
107
+ self.norm_cross = nn.LayerNorm(cross_attention_dim)
108
+ elif cross_attention_norm == "group_norm":
109
+ if self.added_kv_proj_dim is not None:
110
+ # The given `encoder_hidden_states` are initially of shape
111
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
112
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
113
+ # before the projection, so we need to use `added_kv_proj_dim` as
114
+ # the number of channels for the group norm.
115
+ norm_cross_num_channels = added_kv_proj_dim
116
+ else:
117
+ norm_cross_num_channels = cross_attention_dim
118
+
119
+ self.norm_cross = nn.GroupNorm(
120
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
121
+ )
122
+ else:
123
+ raise ValueError(
124
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
125
+ )
126
+
127
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
128
+
129
+ if not self.only_cross_attention:
130
+ # only relevant for the `AddedKVProcessor` classes
131
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
132
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
133
+ else:
134
+ self.to_k = None
135
+ self.to_v = None
136
+
137
+ if self.added_kv_proj_dim is not None:
138
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
139
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
140
+
141
+ self.to_out = nn.ModuleList([])
142
+ self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
143
+ self.to_out.append(nn.Dropout(dropout))
144
+
145
+ embed_dim = 93
146
+ if self.cross_attention_mode and self.is_spatial_attention and self.use_image_embedding:
147
+ self.conv = torch.nn.Conv1d(embed_dim, 77, kernel_size=3, padding="same")
148
+ self.conv_ln = nn.LayerNorm(1024)
149
+ self.register_parameter("alpha", nn.Parameter(torch.tensor(0.)))
150
+
151
+ # set attention processor
152
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
153
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
154
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
155
+ if processor is None:
156
+ processor = (
157
+ AttnProcessor2_0() if hasattr(
158
+ F, "scaled_dot_product_attention") and scale_qk else AttnProcessor()
159
+ )
160
+ self.set_processor(processor)
161
+
162
+ def set_use_memory_efficient_attention_xformers(
163
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
164
+ ):
165
+ is_lora = hasattr(self, "processor") and isinstance(
166
+ self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor)
167
+ )
168
+
169
+ if use_memory_efficient_attention_xformers:
170
+ if self.added_kv_proj_dim is not None:
171
+ # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
172
+ # which uses this type of cross attention ONLY because the attention mask of format
173
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
174
+ raise NotImplementedError(
175
+ "Memory efficient attention with `xformers` is currently not supported when"
176
+ " `self.added_kv_proj_dim` is defined."
177
+ )
178
+ elif not is_xformers_available():
179
+ raise ModuleNotFoundError(
180
+ (
181
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
182
+ " xformers"
183
+ ),
184
+ name="xformers",
185
+ )
186
+ elif not torch.cuda.is_available():
187
+ raise ValueError(
188
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
189
+ " only available for GPU "
190
+ )
191
+ else:
192
+ try:
193
+ # Make sure we can run the memory efficient attention
194
+ _ = xformers.ops.memory_efficient_attention(
195
+ torch.randn((1, 2, 40), device="cuda"),
196
+ torch.randn((1, 2, 40), device="cuda"),
197
+ torch.randn((1, 2, 40), device="cuda"),
198
+ )
199
+ except Exception as e:
200
+ raise e
201
+
202
+ if is_lora:
203
+ processor = LoRAXFormersAttnProcessor(
204
+ hidden_size=self.processor.hidden_size,
205
+ cross_attention_dim=self.processor.cross_attention_dim,
206
+ rank=self.processor.rank,
207
+ attention_op=attention_op,
208
+ )
209
+ processor.load_state_dict(self.processor.state_dict())
210
+ processor.to(self.processor.to_q_lora.up.weight.device)
211
+ else:
212
+ processor = XFormersAttnProcessor(attention_op=attention_op)
213
+ else:
214
+ if is_lora:
215
+ processor = LoRAAttnProcessor(
216
+ hidden_size=self.processor.hidden_size,
217
+ cross_attention_dim=self.processor.cross_attention_dim,
218
+ rank=self.processor.rank,
219
+ )
220
+ processor.load_state_dict(self.processor.state_dict())
221
+ processor.to(self.processor.to_q_lora.up.weight.device)
222
+ else:
223
+ processor = AttnProcessor()
224
+
225
+ self.set_processor(processor)
226
+
227
+ def set_attention_slice(self, slice_size):
228
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
229
+ raise ValueError(
230
+ f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
231
+
232
+ if slice_size is not None and self.added_kv_proj_dim is not None:
233
+ processor = SlicedAttnAddedKVProcessor(slice_size)
234
+ elif slice_size is not None:
235
+ processor = SlicedAttnProcessor(slice_size)
236
+ elif self.added_kv_proj_dim is not None:
237
+ processor = AttnAddedKVProcessor()
238
+ else:
239
+ processor = AttnProcessor()
240
+
241
+ self.set_processor(processor)
242
+
243
+ def set_processor(self, processor: "AttnProcessor"):
244
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
245
+ # pop `processor` from `self._modules`
246
+ if (
247
+ hasattr(self, "processor")
248
+ and isinstance(self.processor, torch.nn.Module)
249
+ and not isinstance(processor, torch.nn.Module)
250
+ ):
251
+ logger.info(
252
+ f"You are removing possibly trained weights of {self.processor} with {processor}")
253
+ self._modules.pop("processor")
254
+
255
+ self.processor = processor
256
+
257
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
258
+ # The `Attention` class can call different attention processors / attention functions
259
+ # here we simply pass along all tensors to the selected processor class
260
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
261
+ return self.processor(
262
+ self,
263
+ hidden_states,
264
+ encoder_hidden_states=encoder_hidden_states,
265
+ attention_mask=attention_mask,
266
+ **cross_attention_kwargs,
267
+ )
268
+
269
+ def batch_to_head_dim(self, tensor):
270
+ head_size = self.heads
271
+ batch_size, seq_len, dim = tensor.shape
272
+ tensor = tensor.reshape(batch_size // head_size,
273
+ head_size, seq_len, dim)
274
+ tensor = tensor.permute(0, 2, 1, 3).reshape(
275
+ batch_size // head_size, seq_len, dim * head_size)
276
+ return tensor
277
+
278
+ def head_to_batch_dim(self, tensor, out_dim=3):
279
+ head_size = self.heads
280
+ batch_size, seq_len, dim = tensor.shape
281
+ tensor = tensor.reshape(batch_size, seq_len,
282
+ head_size, dim // head_size)
283
+ tensor = tensor.permute(0, 2, 1, 3)
284
+
285
+ if out_dim == 3:
286
+ tensor = tensor.reshape(
287
+ batch_size * head_size, seq_len, dim // head_size)
288
+
289
+ return tensor
290
+
291
+ def get_attention_scores(self, query, key, attention_mask=None):
292
+ dtype = query.dtype
293
+ if self.upcast_attention:
294
+ query = query.float()
295
+ key = key.float()
296
+
297
+ if attention_mask is None:
298
+ baddbmm_input = torch.empty(
299
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
300
+ )
301
+ beta = 0
302
+ else:
303
+ baddbmm_input = attention_mask
304
+ beta = 1
305
+
306
+ attention_scores = torch.baddbmm(
307
+ baddbmm_input,
308
+ query,
309
+ key.transpose(-1, -2),
310
+ beta=beta,
311
+ alpha=self.scale,
312
+ )
313
+
314
+ if self.upcast_softmax:
315
+ attention_scores = attention_scores.float()
316
+
317
+ attention_probs = attention_scores.softmax(dim=-1)
318
+ attention_probs = attention_probs.to(dtype)
319
+
320
+ return attention_probs
321
+
322
+ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
323
+ if batch_size is None:
324
+ deprecate(
325
+ "batch_size=None",
326
+ "0.0.15",
327
+ (
328
+ "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
329
+ " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
330
+ " `prepare_attention_mask` when preparing the attention_mask."
331
+ ),
332
+ )
333
+ batch_size = 1
334
+
335
+ head_size = self.heads
336
+ if attention_mask is None:
337
+ return attention_mask
338
+
339
+ if attention_mask.shape[-1] != target_length:
340
+ if attention_mask.device.type == "mps":
341
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
342
+ # Instead, we can manually construct the padding tensor.
343
+ padding_shape = (
344
+ attention_mask.shape[0], attention_mask.shape[1], target_length)
345
+ padding = torch.zeros(
346
+ padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
347
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
348
+ else:
349
+ attention_mask = F.pad(
350
+ attention_mask, (0, target_length), value=0.0)
351
+
352
+ if out_dim == 3:
353
+ if attention_mask.shape[0] < batch_size * head_size:
354
+ attention_mask = attention_mask.repeat_interleave(
355
+ head_size, dim=0)
356
+ elif out_dim == 4:
357
+ attention_mask = attention_mask.unsqueeze(1)
358
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
359
+
360
+ return attention_mask
361
+
362
+ def norm_encoder_hidden_states(self, encoder_hidden_states):
363
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
364
+
365
+ if isinstance(self.norm_cross, nn.LayerNorm):
366
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
367
+ elif isinstance(self.norm_cross, nn.GroupNorm):
368
+ # Group norm norms along the channels dimension and expects
369
+ # input to be in the shape of (N, C, *). In this case, we want
370
+ # to norm along the hidden dimension, so we need to move
371
+ # (batch_size, sequence_length, hidden_size) ->
372
+ # (batch_size, hidden_size, sequence_length)
373
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
374
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
375
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
376
+ else:
377
+ assert False
378
+
379
+ return encoder_hidden_states
380
+
381
+
382
+
383
+
384
+ class AttnProcessor2_0:
385
+ def __init__(self):
386
+ if not hasattr(F, "scaled_dot_product_attention"):
387
+ raise ImportError(
388
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
389
+
390
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
391
+
392
+ batch_size, sequence_length, _ = (
393
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
394
+ )
395
+ inner_dim = hidden_states.shape[-1]
396
+
397
+ if attention_mask is not None:
398
+ attention_mask = attn.prepare_attention_mask(
399
+ attention_mask, sequence_length, batch_size)
400
+ # scaled_dot_product_attention expects attention_mask shape to be
401
+ # (batch, heads, source_length, target_length)
402
+ attention_mask = attention_mask.view(
403
+ batch_size, attn.heads, -1, attention_mask.shape[-1])
404
+
405
+ query = attn.to_q(hidden_states)
406
+
407
+ if encoder_hidden_states is None:
408
+ encoder_hidden_states = hidden_states
409
+ elif attn.norm_cross:
410
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
411
+ encoder_hidden_states)
412
+
413
+ key = attn.to_k(encoder_hidden_states)
414
+ value = attn.to_v(encoder_hidden_states)
415
+
416
+ head_dim = inner_dim // attn.heads
417
+ query = query.view(batch_size, -1, attn.heads,
418
+ head_dim).transpose(1, 2)
419
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
420
+ value = value.view(batch_size, -1, attn.heads,
421
+ head_dim).transpose(1, 2)
422
+
423
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
424
+ # TODO: add support for attn.scale when we move to Torch 2.1
425
+
426
+ hidden_states = F.scaled_dot_product_attention(
427
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
428
+ )
429
+
430
+ hidden_states = hidden_states.transpose(1, 2).reshape(
431
+ batch_size, -1, attn.heads * head_dim)
432
+ hidden_states = hidden_states.to(query.dtype)
433
+
434
+ # linear proj
435
+ hidden_states = attn.to_out[0](hidden_states)
436
+ # dropout
437
+ hidden_states = attn.to_out[1](hidden_states)
438
+ return hidden_states
439
+
440
+
441
+
442
+ AttentionProcessor = Union[
443
+ AttnProcessor2_0,
444
+ ]
t2v_enhanced/model/diffusers_conditional/models/controlnet/conditioning.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import diffusers
2
+ from diffusers.models.transformer_temporal import TransformerTemporalModel, TransformerTemporalModelOutput
3
+ import torch.nn as nn
4
+ from einops import rearrange
5
+ from diffusers.models.attention_processor import Attention
6
+ # from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention
7
+ from t2v_enhanced.model.diffusers_conditional.models.controlnet.transformer_temporal_crossattention import TransformerTemporalModel as TransformerTemporalModelCrossAttn
8
+ import torch
9
+
10
+
11
+ class CrossAttention(nn.Module):
12
+
13
+ def __init__(self, input_channels, attention_head_dim, norm_num_groups=32):
14
+ super().__init__()
15
+ self.attention = Attention(
16
+ query_dim=input_channels, cross_attention_dim=input_channels, heads=input_channels//attention_head_dim, dim_head=attention_head_dim, bias=False, upcast_attention=False)
17
+ self.norm = torch.nn.GroupNorm(
18
+ num_groups=norm_num_groups, num_channels=input_channels, eps=1e-6, affine=True)
19
+ self.proj_in = nn.Linear(input_channels, input_channels)
20
+ self.proj_out = nn.Linear(input_channels, input_channels)
21
+
22
+ def forward(self, hidden_state, encoder_hidden_states, num_frames):
23
+ h, w = hidden_state.shape[2], hidden_state.shape[3]
24
+ hidden_state_norm = rearrange(
25
+ hidden_state, "(B F) C H W -> B C F H W", F=num_frames)
26
+ hidden_state_norm = self.norm(hidden_state_norm)
27
+ hidden_state_norm = rearrange(
28
+ hidden_state_norm, "B C F H W -> (B H W) F C")
29
+ hidden_state_norm = self.proj_in(hidden_state_norm)
30
+ attn = self.attention(hidden_state_norm,
31
+ encoder_hidden_states=encoder_hidden_states,
32
+ attention_mask=None,
33
+ )
34
+ # proj_out
35
+
36
+ residual = self.proj_out(attn)
37
+
38
+ residual = rearrange(
39
+ residual, "(B H W) F C -> (B F) C H W", H=h, W=w)
40
+ output = hidden_state + residual
41
+ return TransformerTemporalModelOutput(sample=output)
42
+
43
+
44
+ class ConditionalModel(nn.Module):
45
+
46
+ def __init__(self, input_channels, conditional_model: str, attention_head_dim=64):
47
+ super().__init__()
48
+ num_layers = 1
49
+ if "_layers_" in conditional_model:
50
+ config = conditional_model.split("_layers_")
51
+ conditional_model = config[0]
52
+ num_layers = int(config[1])
53
+
54
+ if conditional_model == "self_cross_transformer":
55
+ self.temporal_transformer = TransformerTemporalModel(num_attention_heads=input_channels//attention_head_dim, attention_head_dim=attention_head_dim, in_channels=input_channels,
56
+ double_self_attention=False, cross_attention_dim=input_channels)
57
+ elif conditional_model == "cross_transformer":
58
+ self.temporal_transformer = TransformerTemporalModelCrossAttn(num_attention_heads=input_channels//attention_head_dim, attention_head_dim=attention_head_dim, in_channels=input_channels,
59
+ double_self_attention=False, cross_attention_dim=input_channels, num_layers=num_layers)
60
+ elif conditional_model == "cross_attention":
61
+ self.temporal_transformer = CrossAttention(
62
+ input_channels=input_channels, attention_head_dim=attention_head_dim)
63
+ elif conditional_model == "test_conv":
64
+ self.temporal_transformer = nn.Conv2d(
65
+ input_channels, input_channels, kernel_size=1)
66
+ else:
67
+ raise NotImplementedError(
68
+ f"mode {conditional_model} not implemented")
69
+ if conditional_model != "test_conv":
70
+ nn.init.zeros_(self.temporal_transformer.proj_out.weight)
71
+ nn.init.zeros_(self.temporal_transformer.proj_out.bias)
72
+ else:
73
+ nn.init.zeros_(self.temporal_transformer.weight)
74
+ nn.init.zeros_(self.temporal_transformer.bias)
75
+ self.conditional_model = conditional_model
76
+
77
+ def forward(self, sample, conditioning, num_frames=None):
78
+
79
+ assert conditioning.ndim == 5
80
+ assert sample.ndim == 5
81
+ if self.conditional_model != "test_conv":
82
+ conditioning = rearrange(conditioning, "B F C H W -> (B H W) F C")
83
+
84
+ num_frames = sample.shape[1]
85
+
86
+ sample = rearrange(sample, "B F C H W -> (B F) C H W")
87
+
88
+ sample = self.temporal_transformer(
89
+ sample, encoder_hidden_states=conditioning, num_frames=num_frames).sample
90
+
91
+ sample = rearrange(
92
+ sample, "(B F) C H W -> B F C H W", F=num_frames)
93
+ else:
94
+
95
+ conditioning = rearrange(conditioning, "B F C H W -> (B F) C H W")
96
+ f = sample.shape[1]
97
+ sample = rearrange(sample, "B F C H W -> (B F) C H W")
98
+ sample = sample + self.temporal_transformer(conditioning)
99
+ sample = rearrange(sample, "(B F) C H W -> B F C H W", F=f)
100
+ return sample
t2v_enhanced/model/diffusers_conditional/models/controlnet/controlnet.py ADDED
@@ -0,0 +1,865 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+ from einops import rearrange, repeat
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
25
+ # from diffusers.models.transformer_temporal import TransformerTemporalModel
26
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
27
+ from diffusers.models.modeling_utils import ModelMixin
28
+ from t2v_enhanced.model.diffusers_conditional.models.controlnet.unet_3d_blocks import (
29
+ CrossAttnDownBlock3D,
30
+ CrossAttnUpBlock3D,
31
+ DownBlock3D,
32
+ UNetMidBlock3DCrossAttn,
33
+ UpBlock3D,
34
+ get_down_block,
35
+ get_up_block,
36
+ transformer_g_c
37
+ )
38
+ # from diffusers.models.unet_3d_condition import UNet3DConditionModel
39
+ from t2v_enhanced.model.diffusers_conditional.models.controlnet.unet_3d_condition import UNet3DConditionModel
40
+ from t2v_enhanced.model.diffusers_conditional.models.controlnet.transformer_temporal import TransformerTemporalModel
41
+ from t2v_enhanced.model.layers.conv_channel_extension import Conv2D_SubChannels
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+
45
+ @dataclass
46
+ class ControlNetOutput(BaseOutput):
47
+ down_block_res_samples: Tuple[torch.Tensor]
48
+ mid_block_res_sample: torch.Tensor
49
+
50
+
51
+ class Merger(nn.Module):
52
+ def __init__(self, n_frames_condition: int = 8, n_frames_sample: int = 16, merge_mode: str = "addition", input_channels=0, frame_expansion="last_frame") -> None:
53
+ super().__init__()
54
+ self.merge_mode = merge_mode
55
+ self.n_frames_condition = n_frames_condition
56
+ self.n_frames_sample = n_frames_sample
57
+ self.frame_expansion = frame_expansion
58
+
59
+ if merge_mode.startswith("attention"):
60
+ self.attention = ConditionalModel(input_channels=input_channels,
61
+ conditional_model=merge_mode.split("attention_")[1])
62
+
63
+ def forward(self, x, condition_signal):
64
+ x = rearrange(x, "(B F) C H W -> B F C H W", F=self.n_frames_sample)
65
+
66
+ condition_signal = rearrange(
67
+ condition_signal, "(B F) C H W -> B F C H W", B=x.shape[0])
68
+
69
+ if x.shape[1] - condition_signal.shape[1] > 0:
70
+ if self.frame_expansion == "last_frame":
71
+ fillup_latent = repeat(
72
+ condition_signal[:, -1], "B C H W -> B F C H W", F=x.shape[1] - condition_signal.shape[1])
73
+ elif self.frame_expansion == "zero":
74
+ fillup_latent = torch.zeros(
75
+ (x.shape[0], self.n_frames_sample-self.n_frames_condition, *x.shape[2:]), device=x.device, dtype=x.dtype)
76
+
77
+ if self.frame_expansion != "none":
78
+ condition_signal = torch.cat(
79
+ [condition_signal, fillup_latent], dim=1)
80
+
81
+ if self.merge_mode == "addition":
82
+ out = x + condition_signal
83
+ elif self.merge_mode.startswith("attention"):
84
+ out = self.attention(x, condition_signal)
85
+ out = rearrange(out, "B F C H W -> (B F) C H W")
86
+ return out
87
+
88
+
89
+ class ZeroConv(nn.Module):
90
+ def __init__(self, channels: int, mode: str = "2d", num_frames: int = 8, zero_init=True):
91
+ super().__init__()
92
+ mode_parts = mode.split("_")
93
+ if len(mode_parts) > 1 and mode_parts[1] == "noinit":
94
+ zero_init = False
95
+
96
+ if mode.startswith("2d"):
97
+ model = nn.Conv2d(
98
+ channels, channels, kernel_size=1)
99
+ model = zero_module(model, reset=zero_init)
100
+ elif mode.startswith("3d"):
101
+ model = ZeroConv3D(num_frames=num_frames,
102
+ channels=channels, zero_init=zero_init)
103
+ elif mode == "Identity":
104
+ model = nn.Identity()
105
+ self.model = model
106
+
107
+ def forward(self, x):
108
+ return self.model(x)
109
+
110
+
111
+
112
+
113
+
114
+ class ControlNetConditioningEmbedding(nn.Module):
115
+ """
116
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
117
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
118
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
119
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
120
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
121
+ model) to encode image-space conditions ... into feature maps ..."
122
+ """
123
+ # TODO why not GAUSSIAN used?
124
+ # TODO why not 4x4 kernel?
125
+ # TODO why not 2 x2 stride?
126
+
127
+ def __init__(
128
+ self,
129
+ conditioning_embedding_channels: int,
130
+ conditioning_channels: int = 3,
131
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
132
+ downsample: bool = True,
133
+ final_3d_conv: bool = False,
134
+ num_frame_conditioning: int = 8,
135
+ num_frames: int = 16,
136
+ zero_init: bool = True,
137
+ use_controlnet_mask: bool = False,
138
+ use_normalization: bool = False,
139
+ ):
140
+ super().__init__()
141
+ self.num_frame_conditioning = num_frame_conditioning
142
+ self.num_frames = num_frames
143
+ self.final_3d_conv = final_3d_conv
144
+ self.conv_in = nn.Conv2d(
145
+ conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
146
+ if final_3d_conv:
147
+ print("USING 3D CONV in ControlNET")
148
+
149
+ self.blocks = nn.ModuleList([])
150
+ if use_normalization:
151
+ self.norms = nn.ModuleList([])
152
+ self.use_normalization = use_normalization
153
+
154
+ stride = 2 if downsample else 1
155
+ if use_normalization:
156
+ res = 256 # HARD-CODED Resolution!
157
+
158
+ for i in range(len(block_out_channels) - 1):
159
+ channel_in = block_out_channels[i]
160
+ channel_out = block_out_channels[i + 1]
161
+ self.blocks.append(
162
+ nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
163
+ if use_normalization:
164
+ self.norms.append(nn.LayerNorm((channel_in, res, res)))
165
+ self.blocks.append(
166
+ nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=stride))
167
+ if use_normalization:
168
+ res = res // 2
169
+ self.norms.append(nn.LayerNorm((channel_out, res, res)))
170
+
171
+ if not final_3d_conv:
172
+ self.conv_out = zero_module(
173
+ nn.Conv2d(
174
+ block_out_channels[-1]+int(use_controlnet_mask), conditioning_embedding_channels, kernel_size=3, padding=1), reset=zero_init
175
+ )
176
+ else:
177
+ self.conv_temp = zero_module(TemporalConvLayer_Custom(
178
+ num_frame_conditioning, num_frames, dropout=0.0), reset=zero_init)
179
+ self.conv_out = nn.Conv2d(
180
+ block_out_channels[-1]+int(use_controlnet_mask), conditioning_embedding_channels, kernel_size=3, padding=1)
181
+ # self.conv_temp = zero_module(nn.Conv3d(
182
+ # num_frame_conditioning, num_frames, kernel_size=3, padding=1)
183
+ # )
184
+
185
+ def forward(self, conditioning, vq_gan=None, controlnet_mask=None):
186
+ embedding = self.conv_in(conditioning)
187
+ embedding = F.silu(embedding)
188
+
189
+ if self.use_normalization:
190
+ for block, norm in zip(self.blocks, self.norms):
191
+ embedding = block(embedding)
192
+ embedding = norm(embedding)
193
+ embedding = F.silu(embedding)
194
+ else:
195
+ for block in self.blocks:
196
+ embedding = block(embedding)
197
+ embedding = F.silu(embedding)
198
+
199
+ if controlnet_mask is not None:
200
+ embedding = rearrange(
201
+ embedding, "(B F) C H W -> F B C H W", F=self.num_frames)
202
+ controlnet_mask_expanded = controlnet_mask[:, :, None, None, None]
203
+ controlnet_mask_expanded = rearrange(
204
+ controlnet_mask_expanded, "B F C W H -> F B C W H")
205
+ masked_embedding = controlnet_mask_expanded * embedding
206
+ embedding = rearrange(masked_embedding, "F B C H W -> (B F) C H W")
207
+ controlnet_mask_expanded = rearrange(
208
+ controlnet_mask_expanded, "F B C H W -> (B F) C H W")
209
+ # controlnet_mask_expanded = repeat(controlnet_mask_expanded,"B C W H -> B (C x) W H",x=embedding.shape[1])
210
+ controlnet_mask_expanded = repeat(
211
+ controlnet_mask_expanded, "B C W H -> B C (W y) H", y=embedding.shape[2])
212
+ controlnet_mask_expanded = repeat(
213
+ controlnet_mask_expanded, "B C W H -> B C W (H z)", z=embedding.shape[3])
214
+
215
+ embedding = torch.cat([embedding, controlnet_mask_expanded], dim=1)
216
+
217
+ embedding = self.conv_out(embedding)
218
+ if self.final_3d_conv:
219
+ # embedding = F.silu(embedding)
220
+ embedding = rearrange(
221
+ embedding, "(b f) c h w -> b f c h w", f=self.num_frame_conditioning)
222
+ embedding = self.conv_temp(embedding)
223
+ embedding = rearrange(embedding, "b f c h w -> (b f) c h w")
224
+
225
+ return embedding
226
+
227
+ class ControlNetModel(ModelMixin, ConfigMixin):
228
+ _supports_gradient_checkpointing = False
229
+
230
+ @register_to_config
231
+ def __init__(
232
+ self,
233
+ in_channels: int = 4,
234
+ flip_sin_to_cos: bool = True,
235
+ freq_shift: int = 0,
236
+ down_block_types: Tuple[str] = (
237
+ "CrossAttnDownBlock3D",
238
+ "CrossAttnDownBlock3D",
239
+ "CrossAttnDownBlock3D",
240
+ "DownBlock3D",
241
+ ),
242
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
243
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
244
+ layers_per_block: int = 2,
245
+ downsample_padding: int = 1,
246
+ mid_block_scale_factor: float = 1,
247
+ act_fn: str = "silu",
248
+ norm_num_groups: Optional[int] = 32,
249
+ norm_eps: float = 1e-5,
250
+ cross_attention_dim: int = 1280,
251
+ attention_head_dim: Union[int, Tuple[int]] = 8,
252
+ use_linear_projection: bool = False,
253
+ class_embed_type: Optional[str] = None,
254
+ num_class_embeds: Optional[int] = None,
255
+ upcast_attention: bool = False,
256
+ resnet_time_scale_shift: str = "default",
257
+ projection_class_embeddings_input_dim: Optional[int] = None,
258
+ controlnet_conditioning_channel_order: str = "rgb",
259
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (
260
+ 16, 32, 96, 256),
261
+ global_pool_conditions: bool = False,
262
+ downsample_controlnet_cond: bool = True,
263
+ frame_expansion: str = "zero",
264
+ condition_encoder: str = "",
265
+ num_frames: int = 16,
266
+ num_frame_conditioning: int = 8,
267
+ num_tranformers: int = 1,
268
+ vae=None,
269
+ merging_mode: str = "addition",
270
+ zero_conv_mode: str = "2d",
271
+ use_controlnet_mask: bool = False,
272
+ use_image_embedding: bool = False,
273
+ use_image_encoder_normalization: bool = False,
274
+ unet_params=None,
275
+ ):
276
+ super().__init__()
277
+ self.gradient_checkpointing = False
278
+ # Check inputs
279
+ if len(block_out_channels) != len(down_block_types):
280
+ raise ValueError(
281
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
282
+ )
283
+
284
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
285
+ raise ValueError(
286
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
287
+ )
288
+
289
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
290
+ raise ValueError(
291
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
292
+ )
293
+ self.use_image_tokens = unet_params.use_image_tokens_ctrl
294
+ self.image_encoder_name = type(unet_params.image_encoder).__name__
295
+
296
+ # input
297
+ conv_in_kernel = 3
298
+ conv_in_padding = (conv_in_kernel - 1) // 2
299
+ '''Conv2D_SubChannels
300
+ self.conv_in = nn.Conv2d(
301
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
302
+ )
303
+ '''
304
+ self.conv_in = Conv2D_SubChannels(
305
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
306
+ )
307
+ # time
308
+ time_embed_dim = block_out_channels[0] * 4
309
+
310
+ self.time_proj = Timesteps(
311
+ block_out_channels[0], flip_sin_to_cos, freq_shift)
312
+ timestep_input_dim = block_out_channels[0]
313
+
314
+ self.time_embedding = TimestepEmbedding(
315
+ timestep_input_dim,
316
+ time_embed_dim,
317
+ act_fn=act_fn,
318
+ )
319
+
320
+ self.transformer_in = TransformerTemporalModel(
321
+ num_attention_heads=8,
322
+ attention_head_dim=attention_head_dim,
323
+ in_channels=block_out_channels[0],
324
+ num_layers=1,
325
+ )
326
+
327
+ # class embedding
328
+ if class_embed_type is None and num_class_embeds is not None:
329
+ self.class_embedding = nn.Embedding(
330
+ num_class_embeds, time_embed_dim)
331
+ elif class_embed_type == "timestep":
332
+ self.class_embedding = TimestepEmbedding(
333
+ timestep_input_dim, time_embed_dim)
334
+ elif class_embed_type == "identity":
335
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
336
+ elif class_embed_type == "projection":
337
+ if projection_class_embeddings_input_dim is None:
338
+ raise ValueError(
339
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
340
+ )
341
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
342
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
343
+ # 2. it projects from an arbitrary input dimension.
344
+ #
345
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
346
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
347
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
348
+ self.class_embedding = TimestepEmbedding(
349
+ projection_class_embeddings_input_dim, time_embed_dim)
350
+ else:
351
+ self.class_embedding = None
352
+ conditioning_channels = 3 if downsample_controlnet_cond else 4
353
+ # control net conditioning embedding
354
+
355
+ if condition_encoder == "temp_conv_vq":
356
+ controlnet_cond_embedding = ControlNetConditioningEmbeddingVQ(
357
+ conditioning_embedding_channels=block_out_channels[0],
358
+ conditioning_channels=4,
359
+ block_out_channels=conditioning_embedding_out_channels,
360
+ downsample=False,
361
+
362
+ num_frame_conditioning=num_frame_conditioning,
363
+ num_frames=num_frames,
364
+ num_tranformers=num_tranformers,
365
+ # zero_init=not merging_mode.startswith("attention"),
366
+ )
367
+ elif condition_encoder == "vq":
368
+ controlnet_cond_embedding = ControlNetConditioningOptVQ(vq=vae,
369
+ conditioning_embedding_channels=block_out_channels[
370
+ 0],
371
+ conditioning_channels=4,
372
+ block_out_channels=conditioning_embedding_out_channels,
373
+ num_frame_conditioning=num_frame_conditioning,
374
+ num_frames=num_frames,
375
+ )
376
+
377
+ else:
378
+ controlnet_cond_embedding = ControlNetConditioningEmbedding(
379
+ conditioning_embedding_channels=block_out_channels[0],
380
+ conditioning_channels=conditioning_channels,
381
+ block_out_channels=conditioning_embedding_out_channels,
382
+ downsample=downsample_controlnet_cond,
383
+ final_3d_conv=condition_encoder.endswith("3DConv"),
384
+ num_frame_conditioning=num_frame_conditioning,
385
+ num_frames=num_frames,
386
+ # zero_init=not merging_mode.startswith("attention")
387
+ use_controlnet_mask=use_controlnet_mask,
388
+ use_normalization=use_image_encoder_normalization,
389
+ )
390
+ self.use_controlnet_mask = use_controlnet_mask
391
+ self.down_blocks = nn.ModuleList([])
392
+ self.controlnet_down_blocks = nn.ModuleList([])
393
+
394
+ # conv_in
395
+ self.merger = Merger(n_frames_sample=num_frames, n_frames_condition=num_frame_conditioning,
396
+ merge_mode=merging_mode, input_channels=block_out_channels[0], frame_expansion=frame_expansion)
397
+
398
+ if isinstance(only_cross_attention, bool):
399
+ only_cross_attention = [
400
+ only_cross_attention] * len(down_block_types)
401
+
402
+ if isinstance(attention_head_dim, int):
403
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
404
+
405
+ # down
406
+ output_channel = block_out_channels[0]
407
+ self.controlnet_down_blocks.append(
408
+ ZeroConv(channels=output_channel, mode=zero_conv_mode, num_frames=num_frames))
409
+ for i, down_block_type in enumerate(down_block_types):
410
+ input_channel = output_channel
411
+ output_channel = block_out_channels[i]
412
+ is_final_block = i == len(block_out_channels) - 1
413
+
414
+ down_block = get_down_block(
415
+ down_block_type,
416
+ num_layers=layers_per_block,
417
+ in_channels=input_channel,
418
+ out_channels=output_channel,
419
+ temb_channels=time_embed_dim,
420
+ add_downsample=not is_final_block,
421
+ resnet_eps=norm_eps,
422
+ resnet_act_fn=act_fn,
423
+ resnet_groups=norm_num_groups,
424
+ cross_attention_dim=cross_attention_dim,
425
+ attn_num_head_channels=attention_head_dim[i],
426
+ downsample_padding=downsample_padding,
427
+ dual_cross_attention=False,
428
+ use_image_embedding=use_image_embedding,
429
+ unet_params=unet_params,
430
+ )
431
+ self.down_blocks.append(down_block)
432
+
433
+ for _ in range(layers_per_block):
434
+ self.controlnet_down_blocks.append(
435
+ ZeroConv(channels=output_channel, mode=zero_conv_mode, num_frames=num_frames))
436
+
437
+ if not is_final_block:
438
+ self.controlnet_down_blocks.append(
439
+ ZeroConv(channels=output_channel, mode=zero_conv_mode, num_frames=num_frames))
440
+
441
+ # mid
442
+ mid_block_channel = block_out_channels[-1]
443
+
444
+ self.controlnet_mid_block = ZeroConv(
445
+ channels=mid_block_channel, mode=zero_conv_mode, num_frames=num_frames)
446
+
447
+ self.mid_block = UNetMidBlock3DCrossAttn(
448
+ in_channels=block_out_channels[-1],
449
+ temb_channels=time_embed_dim,
450
+ resnet_eps=norm_eps,
451
+ resnet_act_fn=act_fn,
452
+ output_scale_factor=mid_block_scale_factor,
453
+ cross_attention_dim=cross_attention_dim,
454
+ attn_num_head_channels=attention_head_dim[-1],
455
+ resnet_groups=norm_num_groups,
456
+ dual_cross_attention=False,
457
+ use_image_embedding=use_image_embedding,
458
+ unet_params=unet_params,
459
+ )
460
+ self.controlnet_cond_embedding = controlnet_cond_embedding
461
+ self.num_frames = num_frames
462
+ self.num_frame_conditioning = num_frame_conditioning
463
+
464
+ @classmethod
465
+ def from_unet(
466
+ cls,
467
+ unet: UNet3DConditionModel,
468
+ controlnet_conditioning_channel_order: str = "rgb",
469
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (
470
+ 16, 32, 96, 256),
471
+ load_weights_from_unet: bool = True,
472
+ downsample_controlnet_cond: bool = True,
473
+ num_frames: int = 16,
474
+ num_frame_conditioning: int = 8,
475
+ frame_expansion: str = "zero",
476
+ num_tranformers: int = 1,
477
+ vae=None,
478
+ zero_conv_mode: str = "2d",
479
+ merging_mode: str = "addition",
480
+ # [spatial,spatial_3DConv,temp_conv_vq]
481
+ condition_encoder: str = "spatial_3DConv",
482
+ use_controlnet_mask: bool = False,
483
+ use_image_embedding: bool = False,
484
+ use_image_encoder_normalization: bool = False,
485
+ unet_params=None,
486
+ ** kwargs,
487
+ ):
488
+ r"""
489
+ Instantiate Controlnet class from UNet3DConditionModel.
490
+
491
+ Parameters:
492
+ unet (`UNet3DConditionModel`):
493
+ UNet model which weights are copied to the ControlNet. Note that all configuration options are also
494
+ copied where applicable.
495
+ """
496
+ controlnet = cls(
497
+ in_channels=unet.config.in_channels,
498
+ down_block_types=unet.config.down_block_types,
499
+ block_out_channels=unet.config.block_out_channels,
500
+ layers_per_block=unet.config.layers_per_block,
501
+ act_fn=unet.config.act_fn,
502
+ norm_num_groups=unet.config.norm_num_groups,
503
+ norm_eps=unet.config.norm_eps,
504
+ cross_attention_dim=unet.config.cross_attention_dim,
505
+ attention_head_dim=unet.config.attention_head_dim,
506
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
507
+ downsample_controlnet_cond=downsample_controlnet_cond,
508
+ num_frame_conditioning=num_frame_conditioning,
509
+ num_frames=num_frames,
510
+ frame_expansion=frame_expansion,
511
+ num_tranformers=num_tranformers,
512
+ vae=vae,
513
+ zero_conv_mode=zero_conv_mode,
514
+ merging_mode=merging_mode,
515
+ condition_encoder=condition_encoder,
516
+ use_controlnet_mask=use_controlnet_mask,
517
+ use_image_embedding=use_image_embedding,
518
+ use_image_encoder_normalization=use_image_encoder_normalization,
519
+ unet_params=unet_params,
520
+
521
+ )
522
+
523
+ if load_weights_from_unet:
524
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
525
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
526
+ controlnet.transformer_in.load_state_dict(
527
+ unet.transformer_in.state_dict())
528
+ controlnet.time_embedding.load_state_dict(
529
+ unet.time_embedding.state_dict())
530
+
531
+ if controlnet.class_embedding:
532
+ controlnet.class_embedding.load_state_dict(
533
+ unet.class_embedding.state_dict())
534
+
535
+ controlnet.down_blocks.load_state_dict(
536
+ unet.down_blocks.state_dict(), strict=False) # can be that the controlnet model does not use image clip encoding
537
+ controlnet.mid_block.load_state_dict(
538
+ unet.mid_block.state_dict(), strict=False)
539
+
540
+ return controlnet
541
+
542
+ @property
543
+ # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.attn_processors
544
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
545
+ r"""
546
+ Returns:
547
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
548
+ indexed by its weight name.
549
+ """
550
+ # set recursively
551
+ processors = {}
552
+
553
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
554
+ if hasattr(module, "set_processor"):
555
+ processors[f"{name}.processor"] = module.processor
556
+
557
+ for sub_name, child in module.named_children():
558
+ fn_recursive_add_processors(
559
+ f"{name}.{sub_name}", child, processors)
560
+
561
+ return processors
562
+
563
+ for name, module in self.named_children():
564
+ fn_recursive_add_processors(name, module, processors)
565
+
566
+ return processors
567
+
568
+ # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.set_attn_processor
569
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
570
+ r"""
571
+ Parameters:
572
+ `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
573
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
574
+ of **all** `Attention` layers.
575
+ In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
576
+
577
+ """
578
+ count = len(self.attn_processors.keys())
579
+
580
+ if isinstance(processor, dict) and len(processor) != count:
581
+ raise ValueError(
582
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
583
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
584
+ )
585
+
586
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
587
+ if hasattr(module, "set_processor"):
588
+ if not isinstance(processor, dict):
589
+ module.set_processor(processor)
590
+ else:
591
+ module.set_processor(processor.pop(f"{name}.processor"))
592
+
593
+ for sub_name, child in module.named_children():
594
+ fn_recursive_attn_processor(
595
+ f"{name}.{sub_name}", child, processor)
596
+
597
+ for name, module in self.named_children():
598
+ fn_recursive_attn_processor(name, module, processor)
599
+
600
+ # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.set_default_attn_processor
601
+ def set_default_attn_processor(self):
602
+ """
603
+ Disables custom attention processors and sets the default attention implementation.
604
+ """
605
+ self.set_attn_processor(AttnProcessor())
606
+
607
+ # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.set_attention_slice
608
+ def set_attention_slice(self, slice_size):
609
+ r"""
610
+ Enable sliced attention computation.
611
+
612
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
613
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
614
+
615
+ Args:
616
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
617
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
618
+ `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
619
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
620
+ must be a multiple of `slice_size`.
621
+ """
622
+ sliceable_head_dims = []
623
+
624
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
625
+ if hasattr(module, "set_attention_slice"):
626
+ sliceable_head_dims.append(module.sliceable_head_dim)
627
+
628
+ for child in module.children():
629
+ fn_recursive_retrieve_sliceable_dims(child)
630
+
631
+ # retrieve number of attention layers
632
+ for module in self.children():
633
+ fn_recursive_retrieve_sliceable_dims(module)
634
+
635
+ num_sliceable_layers = len(sliceable_head_dims)
636
+
637
+ if slice_size == "auto":
638
+ # half the attention head size is usually a good trade-off between
639
+ # speed and memory
640
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
641
+ elif slice_size == "max":
642
+ # make smallest slice possible
643
+ slice_size = num_sliceable_layers * [1]
644
+
645
+ slice_size = num_sliceable_layers * \
646
+ [slice_size] if not isinstance(slice_size, list) else slice_size
647
+
648
+ if len(slice_size) != len(sliceable_head_dims):
649
+ raise ValueError(
650
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
651
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
652
+ )
653
+
654
+ for i in range(len(slice_size)):
655
+ size = slice_size[i]
656
+ dim = sliceable_head_dims[i]
657
+ if size is not None and size > dim:
658
+ raise ValueError(
659
+ f"size {size} has to be smaller or equal to {dim}.")
660
+
661
+ # Recursively walk through all the children.
662
+ # Any children which exposes the set_attention_slice method
663
+ # gets the message
664
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
665
+ if hasattr(module, "set_attention_slice"):
666
+ module.set_attention_slice(slice_size.pop())
667
+
668
+ for child in module.children():
669
+ fn_recursive_set_attention_slice(child, slice_size)
670
+
671
+ reversed_slice_size = list(reversed(slice_size))
672
+ for module in self.children():
673
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
674
+
675
+ def _set_gradient_checkpointing(self, module, value=False):
676
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D)):
677
+ module.gradient_checkpointing = value
678
+
679
+ # TODO ADD WEIGHT CONTROL
680
+ def forward(
681
+ self,
682
+ sample: torch.FloatTensor,
683
+ timestep: Union[torch.Tensor, float, int],
684
+ encoder_hidden_states: torch.Tensor,
685
+ controlnet_cond: torch.FloatTensor,
686
+ conditioning_scale: float = 1.0,
687
+ class_labels: Optional[torch.Tensor] = None,
688
+ timestep_cond: Optional[torch.Tensor] = None,
689
+ attention_mask: Optional[torch.Tensor] = None,
690
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
691
+ guess_mode: bool = False,
692
+ return_dict: bool = True,
693
+ weight_control: float = 1.0,
694
+ weight_control_sample: float = 1.0,
695
+ controlnet_mask: Optional[torch.Tensor] = None,
696
+ vq_gan=None,
697
+ ) -> Union[ControlNetOutput, Tuple]:
698
+ # check channel order
699
+ # TODO SET ATTENTION MASK And WEIGHT CONTROL as in CONTROLNET.PY
700
+ '''
701
+ # prepare attention_mask
702
+ if attention_mask is not None:
703
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
704
+ attention_mask = attention_mask.unsqueeze(1)
705
+ '''
706
+ # assert controlnet_mask is None, "Controlnet Mask not implemented yet for clean model"
707
+ # 1. time
708
+
709
+ timesteps = timestep
710
+ if not torch.is_tensor(timesteps):
711
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
712
+ # This would be a good case for the `match` statement (Python 3.10+)
713
+ is_mps = sample.device.type == "mps"
714
+ if isinstance(timestep, float):
715
+ dtype = torch.float32 if is_mps else torch.float64
716
+ else:
717
+ dtype = torch.int32 if is_mps else torch.int64
718
+ timesteps = torch.tensor(
719
+ [timesteps], dtype=dtype, device=sample.device)
720
+ elif len(timesteps.shape) == 0:
721
+ timesteps = timesteps[None].to(sample.device)
722
+
723
+ sample = sample[:, :, :self.num_frames]
724
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
725
+ num_frames = sample.shape[2]
726
+ timesteps = timesteps.expand(sample.shape[0])
727
+
728
+ t_emb = self.time_proj(timesteps)
729
+
730
+ # timesteps does not contain any weights and will always return f32 tensors
731
+ # but time_embedding might actually be running in fp16. so we need to cast here.
732
+ # there might be better ways to encapsulate this.
733
+ t_emb = t_emb.to(dtype=self.dtype)
734
+
735
+ emb = self.time_embedding(t_emb, timestep_cond)
736
+ emb = emb.repeat_interleave(repeats=num_frames, dim=0)
737
+
738
+ if not self.use_image_tokens and encoder_hidden_states.shape[1] > 77:
739
+ encoder_hidden_states = encoder_hidden_states[:, :77]
740
+
741
+ if encoder_hidden_states.shape[1] > 77:
742
+ # assert (
743
+ # encoder_hidden_states.shape[1]-77) % num_frames == 0, f"Encoder shape {encoder_hidden_states.shape}. Num frames = {num_frames}"
744
+ context_text, context_img = encoder_hidden_states[:,
745
+ :77, :], encoder_hidden_states[:, 77:, :]
746
+ context_text = context_text.repeat_interleave(
747
+ repeats=num_frames, dim=0)
748
+
749
+ if self.image_encoder_name == "FrozenOpenCLIPImageEmbedder":
750
+ context_img = context_img.repeat_interleave(
751
+ repeats=num_frames, dim=0)
752
+ else:
753
+ context_img = rearrange(
754
+ context_img, 'b (t l) c -> (b t) l c', t=num_frames)
755
+
756
+ encoder_hidden_states = torch.cat(
757
+ [context_text, context_img], dim=1)
758
+ else:
759
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(
760
+ repeats=num_frames, dim=0)
761
+
762
+ # print(f"ctrl with tokens = {encoder_hidden_states.shape[1]}")
763
+ '''
764
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(
765
+ repeats=num_frames, dim=0)
766
+ '''
767
+
768
+ # 2. pre-process
769
+ sample = sample.permute(0, 2, 1, 3, 4).reshape(
770
+ (sample.shape[0] * num_frames, -1) + sample.shape[3:])
771
+ sample = self.conv_in(sample)
772
+
773
+ controlnet_cond = self.controlnet_cond_embedding(
774
+ controlnet_cond, vq_gan=vq_gan, controlnet_mask=controlnet_mask)
775
+
776
+ if num_frames > 1:
777
+ if self.gradient_checkpointing:
778
+ sample = transformer_g_c(
779
+ self.transformer_in, sample, num_frames)
780
+ else:
781
+ sample = self.transformer_in(
782
+ sample, num_frames=num_frames, attention_mask=attention_mask).sample
783
+
784
+ sample = self.merger(sample * weight_control_sample,
785
+ weight_control * controlnet_cond)
786
+
787
+ # 3. down
788
+ down_block_res_samples = (sample,)
789
+ for downsample_block in self.down_blocks:
790
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
791
+ sample, res_samples = downsample_block(
792
+ hidden_states=sample,
793
+ temb=emb,
794
+ encoder_hidden_states=encoder_hidden_states,
795
+ attention_mask=attention_mask,
796
+ num_frames=num_frames,
797
+ cross_attention_kwargs=cross_attention_kwargs,
798
+ )
799
+ else:
800
+ sample, res_samples = downsample_block(
801
+ hidden_states=sample, temb=emb, num_frames=num_frames)
802
+
803
+ down_block_res_samples += res_samples
804
+
805
+ # 4. mid
806
+ if self.mid_block is not None:
807
+ sample = self.mid_block(
808
+ sample,
809
+ emb,
810
+ encoder_hidden_states=encoder_hidden_states,
811
+ attention_mask=attention_mask,
812
+ num_frames=num_frames,
813
+ cross_attention_kwargs=cross_attention_kwargs,
814
+ )
815
+
816
+ # 5. Control net blocks
817
+
818
+ controlnet_down_block_res_samples = ()
819
+
820
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
821
+ down_block_res_sample = controlnet_block(down_block_res_sample)
822
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + \
823
+ (down_block_res_sample,)
824
+
825
+ down_block_res_samples = controlnet_down_block_res_samples
826
+
827
+ mid_block_res_sample = self.controlnet_mid_block(sample)
828
+
829
+ # 6. scaling
830
+ if guess_mode and not self.config.global_pool_conditions:
831
+ # 0.1 to 1.0
832
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) +
833
+ 1, device=sample.device)
834
+
835
+ scales = scales * conditioning_scale
836
+ down_block_res_samples = [
837
+ sample * scale for sample, scale in zip(down_block_res_samples, scales)]
838
+ mid_block_res_sample = mid_block_res_sample * \
839
+ scales[-1] # last one
840
+ else:
841
+ down_block_res_samples = [
842
+ sample * conditioning_scale for sample in down_block_res_samples]
843
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
844
+
845
+ if self.config.global_pool_conditions:
846
+ down_block_res_samples = [
847
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
848
+ ]
849
+ mid_block_res_sample = torch.mean(
850
+ mid_block_res_sample, dim=(2, 3), keepdim=True)
851
+
852
+ if not return_dict:
853
+ return (down_block_res_samples, mid_block_res_sample)
854
+
855
+ return ControlNetOutput(
856
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
857
+ )
858
+
859
+
860
+
861
+ def zero_module(module, reset=True):
862
+ if reset:
863
+ for p in module.parameters():
864
+ nn.init.zeros_(p)
865
+ return module
t2v_enhanced/model/diffusers_conditional/models/controlnet/cross_attention.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import Callable, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils.import_utils import is_xformers_available
22
+ # from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention
23
+
24
+
25
+ if is_xformers_available():
26
+ import xformers
27
+ import xformers.ops
28
+ else:
29
+ xformers = None
30
+
t2v_enhanced/model/diffusers_conditional/models/controlnet/image_embedder.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Mapping
3
+ import torch
4
+ import torch.nn as nn
5
+ import kornia
6
+ import open_clip
7
+ from transformers import AutoImageProcessor, AutoModel
8
+ from transformers.models.bit.image_processing_bit import BitImageProcessor
9
+ from einops import rearrange, repeat
10
+ # FFN
11
+ # from mamba_ssm import Mamba
12
+
13
+
14
+
15
+ class ImgEmbContextResampler(nn.Module):
16
+
17
+ def __init__(
18
+ self,
19
+ inner_dim=1280,
20
+ cross_attention_dim=1024,
21
+ expansion_factor=16,
22
+ **kwargs,
23
+ ):
24
+ super().__init__()
25
+ self.context_embedding = nn.Sequential(
26
+ nn.Linear(cross_attention_dim, inner_dim),
27
+ nn.SiLU(),
28
+ nn.Linear(inner_dim, cross_attention_dim * expansion_factor),
29
+ )
30
+ self.expansion_factor = expansion_factor
31
+ self.cross_attention_dim = cross_attention_dim
32
+
33
+ def forward(self, x, batch_size=0):
34
+ if x.ndim == 2:
35
+ x = rearrange(x, "(B F) C -> B F C", B=batch_size)
36
+ assert x.ndim == 3
37
+ x = torch.mean(x, dim=1, keepdim=True)
38
+ x = self.context_embedding(x)
39
+ x = x.view(-1, self.expansion_factor, self.cross_attention_dim)
40
+ return x
41
+
42
+
43
+
44
+ class AbstractEncoder(nn.Module):
45
+ def __init__(self):
46
+ super().__init__()
47
+ self.embedding_dim = -1
48
+ self.num_tokens = -1
49
+
50
+ def encode(self, *args, **kwargs):
51
+ raise NotImplementedError
52
+
53
+
54
+
55
+ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
56
+ """
57
+ Uses the OpenCLIP vision transformer encoder for images
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ arch="ViT-H-14",
63
+ version="laion2b_s32b_b79k",
64
+ device="cuda",
65
+ max_length=77,
66
+ freeze=True,
67
+ antialias=True,
68
+ ucg_rate=0.0,
69
+ unsqueeze_dim=False,
70
+ repeat_to_max_len=False,
71
+ num_image_crops=0,
72
+ output_tokens=False,
73
+ ):
74
+ super().__init__()
75
+ model, _, _ = open_clip.create_model_and_transforms(
76
+ arch,
77
+ device=torch.device("cpu"),
78
+ pretrained=version,
79
+ )
80
+ del model.transformer
81
+ self.model = model
82
+ self.max_crops = num_image_crops
83
+ self.pad_to_max_len = self.max_crops > 0
84
+ self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
85
+ self.device = device
86
+ self.max_length = max_length
87
+ if freeze:
88
+ self.freeze()
89
+
90
+ self.antialias = antialias
91
+
92
+ self.register_buffer(
93
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
94
+ )
95
+ self.register_buffer(
96
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
97
+ )
98
+ self.ucg_rate = ucg_rate
99
+ self.unsqueeze_dim = unsqueeze_dim
100
+ self.stored_batch = None
101
+ self.model.visual.output_tokens = output_tokens
102
+ self.output_tokens = output_tokens
103
+
104
+ def preprocess(self, x):
105
+ # normalize to [0,1]
106
+ x = kornia.geometry.resize(
107
+ x,
108
+ (224, 224),
109
+ interpolation="bicubic",
110
+ align_corners=True,
111
+ antialias=self.antialias,
112
+ )
113
+ x = (x + 1.0) / 2.0
114
+ # renormalize according to clip
115
+ x = kornia.enhance.normalize(x, self.mean, self.std)
116
+ return x
117
+
118
+ def freeze(self):
119
+ self.model = self.model.eval()
120
+ for param in self.parameters():
121
+ param.requires_grad = False
122
+
123
+ def forward(self, image, no_dropout=False):
124
+ z = self.encode_with_vision_transformer(image)
125
+ tokens = None
126
+ if self.output_tokens:
127
+ z, tokens = z[0], z[1]
128
+ z = z.to(image.dtype)
129
+ if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
130
+ z = (
131
+ torch.bernoulli(
132
+ (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
133
+ )[:, None]
134
+ * z
135
+ )
136
+ if tokens is not None:
137
+ tokens = (
138
+ expand_dims_like(
139
+ torch.bernoulli(
140
+ (1.0 - self.ucg_rate)
141
+ * torch.ones(tokens.shape[0], device=tokens.device)
142
+ ),
143
+ tokens,
144
+ )
145
+ * tokens
146
+ )
147
+ if self.unsqueeze_dim:
148
+ z = z[:, None, :]
149
+ if self.output_tokens:
150
+ assert not self.repeat_to_max_len
151
+ assert not self.pad_to_max_len
152
+ return tokens, z
153
+ if self.repeat_to_max_len:
154
+ if z.dim() == 2:
155
+ z_ = z[:, None, :]
156
+ else:
157
+ z_ = z
158
+ return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
159
+ elif self.pad_to_max_len:
160
+ assert z.dim() == 3
161
+ z_pad = torch.cat(
162
+ (
163
+ z,
164
+ torch.zeros(
165
+ z.shape[0],
166
+ self.max_length - z.shape[1],
167
+ z.shape[2],
168
+ device=z.device,
169
+ ),
170
+ ),
171
+ 1,
172
+ )
173
+ return z_pad, z_pad[:, 0, ...]
174
+ return z
175
+
176
+ def encode_with_vision_transformer(self, img):
177
+ # if self.max_crops > 0:
178
+ # img = self.preprocess_by_cropping(img)
179
+ if img.dim() == 5:
180
+ assert self.max_crops == img.shape[1]
181
+ img = rearrange(img, "b n c h w -> (b n) c h w")
182
+ img = self.preprocess(img)
183
+ if not self.output_tokens:
184
+ assert not self.model.visual.output_tokens
185
+ x = self.model.visual(img)
186
+ tokens = None
187
+ else:
188
+ assert self.model.visual.output_tokens
189
+ x, tokens = self.model.visual(img)
190
+ if self.max_crops > 0:
191
+ x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
192
+ # drop out between 0 and all along the sequence axis
193
+ x = (
194
+ torch.bernoulli(
195
+ (1.0 - self.ucg_rate)
196
+ * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
197
+ )
198
+ * x
199
+ )
200
+ if tokens is not None:
201
+ tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
202
+ print(
203
+ f"You are running very experimental token-concat in {self.__class__.__name__}. "
204
+ f"Check what you are doing, and then remove this message."
205
+ )
206
+ if self.output_tokens:
207
+ return x, tokens
208
+ return x
209
+
210
+ def encode(self, text):
211
+ return self(text)
t2v_enhanced/model/diffusers_conditional/models/controlnet/mask_generator.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from t2v_enhanced.model.pl_module_params_controlnet import AttentionMaskParams
2
+ import torch
3
+
4
+
5
+ class MaskGenerator():
6
+
7
+ def __init__(self, params: AttentionMaskParams, num_frame_conditioning, num_frames):
8
+ self.params = params
9
+ self.num_frame_conditioning = num_frame_conditioning
10
+ self.num_frames = num_frames
11
+ def get_mask(self, precision, device):
12
+
13
+ params = self.params
14
+ if params.temporal_self_attention_only_on_conditioning:
15
+ with torch.no_grad():
16
+ attention_mask = torch.zeros((1, self.num_frames, self.num_frames), dtype=torch.float16 if precision.startswith(
17
+ "16") else torch.float32, device=device)
18
+ for frame in range(self.num_frame_conditioning, self.num_frames):
19
+ attention_mask[:, frame,
20
+ self.num_frame_conditioning:] = float("-inf")
21
+ if params.temporal_self_attention_mask_included_itself:
22
+ attention_mask[:, frame, frame] = 0
23
+ if params.temp_attend_on_uncond_include_past:
24
+ attention_mask[:, frame, :frame] = 0
25
+ else:
26
+ attention_mask = None
27
+ return attention_mask
t2v_enhanced/model/diffusers_conditional/models/controlnet/pipeline_text_to_video_w_controlnet_synth.py ADDED
@@ -0,0 +1,925 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import PIL.Image
19
+ import numpy as np
20
+ import torch
21
+ from transformers import CLIPTextModel, CLIPTokenizer
22
+
23
+ from diffusers.loaders import TextualInversionLoaderMixin
24
+ from diffusers.models import AutoencoderKL, UNet3DConditionModel
25
+ from diffusers.schedulers import KarrasDiffusionSchedulers
26
+ from diffusers.utils import (
27
+ PIL_INTERPOLATION,
28
+ is_accelerate_available,
29
+ is_accelerate_version,
30
+ logging,
31
+ replace_example_docstring,
32
+ )
33
+ from diffusers.utils.torch_utils import randn_tensor
34
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
35
+ from diffusers.pipelines.text_to_video_synthesis import TextToVideoSDPipelineOutput
36
+ from einops import rearrange
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+ EXAMPLE_DOC_STRING = """
41
+ Examples:
42
+ ```py
43
+ >>> import torch
44
+ >>> from diffusers import TextToVideoSDPipeline
45
+ >>> from diffusers.utils import export_to_video
46
+
47
+ >>> pipe = TextToVideoSDPipeline.from_pretrained(
48
+ ... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"
49
+ ... )
50
+ >>> pipe.enable_model_cpu_offload()
51
+
52
+ >>> prompt = "Spiderman is surfing"
53
+ >>> video_frames = pipe(prompt).frames
54
+ >>> video_path = export_to_video(video_frames)
55
+ >>> video_path
56
+ ```
57
+ """
58
+
59
+
60
+ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], output_type="list") -> List[np.ndarray]:
61
+ # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
62
+ # reshape to ncfhw
63
+ mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
64
+ std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
65
+ # unnormalize back to [0,1]
66
+ video = video.mul_(std).add_(mean)
67
+ video.clamp_(0, 1)
68
+ # prepare the final outputs
69
+ i, c, f, h, w = video.shape
70
+ images = video.permute(2, 3, 0, 4, 1).reshape(
71
+ f, h, i * w, c
72
+ ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
73
+ if output_type == "list":
74
+ # prepare a list of indvidual (consecutive frames)
75
+ images = images.unbind(dim=0)
76
+ images = [(image.cpu().numpy() * 255).astype("uint8")
77
+ for image in images] # f h w c
78
+ elif output_type == "pt":
79
+ pass
80
+ return images
81
+
82
+
83
+ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
84
+ r"""
85
+ Pipeline for text-to-video generation.
86
+
87
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
88
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
89
+
90
+ Args:
91
+ vae ([`AutoencoderKL`]):
92
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
93
+ text_encoder ([`CLIPTextModel`]):
94
+ Frozen text-encoder. Same as Stable Diffusion 2.
95
+ tokenizer (`CLIPTokenizer`):
96
+ Tokenizer of class
97
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
98
+ unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents.
99
+ scheduler ([`SchedulerMixin`]):
100
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
101
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ vae: AutoencoderKL,
107
+ text_encoder: CLIPTextModel,
108
+ tokenizer: CLIPTokenizer,
109
+ unet: UNet3DConditionModel,
110
+ controlnet,
111
+ scheduler: KarrasDiffusionSchedulers,
112
+ ):
113
+ super().__init__()
114
+
115
+ self.register_modules(
116
+ vae=vae,
117
+ text_encoder=text_encoder,
118
+ tokenizer=tokenizer,
119
+ unet=unet,
120
+ controlnet=controlnet,
121
+ scheduler=scheduler,
122
+ )
123
+ self.vae_scale_factor = 2 ** (
124
+ len(self.vae.config.block_out_channels) - 1)
125
+
126
+ def prepare_image(
127
+ self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance, cfg_text_image=False,
128
+ ):
129
+ if not isinstance(image, torch.Tensor):
130
+ if isinstance(image, PIL.Image.Image):
131
+ image = [image]
132
+
133
+ if isinstance(image[0], PIL.Image.Image):
134
+ images = []
135
+
136
+ for image_ in image:
137
+ image_ = image_.convert("RGB")
138
+ image_ = image_.resize(
139
+ (width, height), resample=PIL_INTERPOLATION["lanczos"])
140
+ image_ = np.array(image_)
141
+ image_ = image_[None, :]
142
+ images.append(image_)
143
+
144
+ image = images
145
+
146
+ image = np.concatenate(image, axis=0)
147
+ image = np.array(image).astype(np.float32) / 255.0
148
+ image = image.transpose(0, 3, 1, 2)
149
+ image = torch.from_numpy(image)
150
+ elif isinstance(image[0], torch.Tensor):
151
+ image = torch.cat(image, dim=0)
152
+
153
+ image_batch_size = image.shape[0]
154
+
155
+ if image_batch_size == 1:
156
+ repeat_by = batch_size
157
+ else:
158
+ # image batch size is the same as prompt batch size
159
+ repeat_by = num_images_per_prompt
160
+
161
+ image = image.repeat_interleave(repeat_by, dim=0)
162
+
163
+ image = image.to(device=device, dtype=dtype)
164
+
165
+ image_vq_enc = self.vae.encode(rearrange(
166
+ image, "B F C W H -> (B F) C W H")).latent_dist.sample() * self.vae.config.scaling_factor
167
+ image_vq_enc = rearrange(
168
+ image_vq_enc, "(B F) C W H -> B F C W H", B=image_batch_size)
169
+ if do_classifier_free_guidance:
170
+ if cfg_text_image:
171
+ image = torch.cat([torch.zeros_like(image), image], dim=0)
172
+ else:
173
+ image = torch.cat([image] * 2)
174
+ # image_vq_enc = torch.cat([image_vq_enc] * 2)
175
+
176
+ return image, image_vq_enc
177
+
178
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
179
+
180
+ def enable_vae_slicing(self):
181
+ r"""
182
+ Enable sliced VAE decoding.
183
+
184
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
185
+ steps. This is useful to save some memory and allow larger batch sizes.
186
+ """
187
+ self.vae.enable_slicing()
188
+
189
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
190
+ def disable_vae_slicing(self):
191
+ r"""
192
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
193
+ computing decoding in one step.
194
+ """
195
+ self.vae.disable_slicing()
196
+
197
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
198
+ def enable_vae_tiling(self):
199
+ r"""
200
+ Enable tiled VAE decoding.
201
+
202
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
203
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
204
+ """
205
+ self.vae.enable_tiling()
206
+
207
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
208
+ def disable_vae_tiling(self):
209
+ r"""
210
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
211
+ computing decoding in one step.
212
+ """
213
+ self.vae.disable_tiling()
214
+
215
+ def enable_sequential_cpu_offload(self, gpu_id=0):
216
+ r"""
217
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
218
+ text_encoder, vae have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded
219
+ to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a
220
+ submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower.
221
+ """
222
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
223
+ from accelerate import cpu_offload
224
+ else:
225
+ raise ImportError(
226
+ "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
227
+
228
+ device = torch.device(f"cuda:{gpu_id}")
229
+
230
+ if self.device.type != "cpu":
231
+ self.to("cpu", silence_dtype_warnings=True)
232
+ # otherwise we don't see the memory savings (but they probably exist)
233
+ torch.cuda.empty_cache()
234
+
235
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
236
+ cpu_offload(cpu_offloaded_model, device)
237
+
238
+ def enable_model_cpu_offload(self, gpu_id=0):
239
+ r"""
240
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
241
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
242
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
243
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
244
+ """
245
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
246
+ from accelerate import cpu_offload_with_hook
247
+ else:
248
+ raise ImportError(
249
+ "`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
250
+
251
+ device = torch.device(f"cuda:{gpu_id}")
252
+
253
+ if self.device.type != "cpu":
254
+ self.to("cpu", silence_dtype_warnings=True)
255
+ # otherwise we don't see the memory savings (but they probably exist)
256
+ torch.cuda.empty_cache()
257
+
258
+ hook = None
259
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
260
+ _, hook = cpu_offload_with_hook(
261
+ cpu_offloaded_model, device, prev_module_hook=hook)
262
+
263
+ # We'll offload the last model manually.
264
+ self.final_offload_hook = hook
265
+
266
+ @property
267
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
268
+ def _execution_device(self):
269
+ r"""
270
+ Returns the device on which the pipeline's models will be executed. After calling
271
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
272
+ hooks.
273
+ """
274
+ if not hasattr(self.unet, "_hf_hook"):
275
+ return self.device
276
+ for module in self.unet.modules():
277
+ if (
278
+ hasattr(module, "_hf_hook")
279
+ and hasattr(module._hf_hook, "execution_device")
280
+ and module._hf_hook.execution_device is not None
281
+ ):
282
+ return torch.device(module._hf_hook.execution_device)
283
+ return self.device
284
+
285
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
286
+ def _encode_prompt(
287
+ self,
288
+ prompt,
289
+ device,
290
+ num_images_per_prompt,
291
+ do_classifier_free_guidance,
292
+ negative_prompt=None,
293
+ prompt_embeds: Optional[torch.FloatTensor] = None,
294
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
295
+ img_cond: Optional[torch.FloatTensor] = None,
296
+ img_cond_unc: Optional[torch.FloatTensor] = None,
297
+ ):
298
+ r"""
299
+ Encodes the prompt into text encoder hidden states.
300
+
301
+ Args:
302
+ prompt (`str` or `List[str]`, *optional*):
303
+ prompt to be encoded
304
+ device: (`torch.device`):
305
+ torch device
306
+ num_images_per_prompt (`int`):
307
+ number of images that should be generated per prompt
308
+ do_classifier_free_guidance (`bool`):
309
+ whether to use classifier free guidance or not
310
+ negative_prompt (`str` or `List[str]`, *optional*):
311
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
312
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
313
+ less than `1`).
314
+ prompt_embeds (`torch.FloatTensor`, *optional*):
315
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
316
+ provided, text embeddings will be generated from `prompt` input argument.
317
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
318
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
319
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
320
+ argument.
321
+ """
322
+ if prompt is not None and isinstance(prompt, str):
323
+ batch_size = 1
324
+ elif prompt is not None and isinstance(prompt, list):
325
+ batch_size = len(prompt)
326
+ else:
327
+ batch_size = prompt_embeds.shape[0]
328
+
329
+ if prompt_embeds is None:
330
+ # textual inversion: procecss multi-vector tokens if necessary
331
+ if isinstance(self, TextualInversionLoaderMixin):
332
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
333
+
334
+ text_inputs = self.tokenizer(
335
+ prompt,
336
+ padding="max_length",
337
+ max_length=self.tokenizer.model_max_length,
338
+ truncation=True,
339
+ return_tensors="pt",
340
+ )
341
+ text_input_ids = text_inputs.input_ids
342
+ untruncated_ids = self.tokenizer(
343
+ prompt, padding="longest", return_tensors="pt").input_ids
344
+
345
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
346
+ text_input_ids, untruncated_ids
347
+ ):
348
+ removed_text = self.tokenizer.batch_decode(
349
+ untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
350
+ )
351
+ logger.warning(
352
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
353
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
354
+ )
355
+
356
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
357
+ attention_mask = text_inputs.attention_mask.to(device)
358
+ else:
359
+ attention_mask = None
360
+
361
+ prompt_embeds = self.text_encoder(
362
+ text_input_ids.to(device),
363
+ attention_mask=attention_mask,
364
+ )
365
+ prompt_embeds = prompt_embeds[0]
366
+
367
+ prompt_embeds = prompt_embeds.to(
368
+ dtype=self.text_encoder.dtype, device=device)
369
+
370
+ bs_embed, seq_len, _ = prompt_embeds.shape
371
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
372
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
373
+ prompt_embeds = prompt_embeds.view(
374
+ bs_embed * num_images_per_prompt, seq_len, -1)
375
+ max_length = prompt_embeds.shape[1]
376
+ if img_cond is not None:
377
+ if img_cond.ndim == 2:
378
+ img_cond = img_cond.unsqueeze(1)
379
+ prompt_embeds = torch.cat([prompt_embeds, img_cond], dim=1)
380
+
381
+ # get unconditional embeddings for classifier free guidance
382
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
383
+ uncond_tokens: List[str]
384
+ if negative_prompt is None:
385
+ uncond_tokens = [""] * batch_size
386
+ elif type(prompt) is not type(negative_prompt):
387
+ raise TypeError(
388
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
389
+ f" {type(prompt)}."
390
+ )
391
+ elif isinstance(negative_prompt, str):
392
+ uncond_tokens = [negative_prompt]
393
+ elif batch_size != len(negative_prompt):
394
+ raise ValueError(
395
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
396
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
397
+ " the batch size of `prompt`."
398
+ )
399
+ else:
400
+ uncond_tokens = negative_prompt
401
+
402
+ # textual inversion: procecss multi-vector tokens if necessary
403
+ if isinstance(self, TextualInversionLoaderMixin):
404
+ uncond_tokens = self.maybe_convert_prompt(
405
+ uncond_tokens, self.tokenizer)
406
+
407
+ # max_length = prompt_embeds.shape[1]
408
+ uncond_input = self.tokenizer(
409
+ uncond_tokens,
410
+ padding="max_length",
411
+ max_length=max_length,
412
+ truncation=True,
413
+ return_tensors="pt",
414
+ )
415
+
416
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
417
+ attention_mask = uncond_input.attention_mask.to(device)
418
+ else:
419
+ attention_mask = None
420
+
421
+ negative_prompt_embeds = self.text_encoder(
422
+ uncond_input.input_ids.to(device),
423
+ attention_mask=attention_mask,
424
+ )
425
+ negative_prompt_embeds = negative_prompt_embeds[0]
426
+
427
+ if do_classifier_free_guidance:
428
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
429
+ seq_len = negative_prompt_embeds.shape[1]
430
+
431
+ negative_prompt_embeds = negative_prompt_embeds.to(
432
+ dtype=self.text_encoder.dtype, device=device)
433
+
434
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
435
+ 1, num_images_per_prompt, 1)
436
+ negative_prompt_embeds = negative_prompt_embeds.view(
437
+ batch_size * num_images_per_prompt, seq_len, -1)
438
+
439
+ if img_cond_unc is not None:
440
+ if img_cond_unc.ndim == 2:
441
+ img_cond_unc = img_cond_unc.unsqueeze(1)
442
+ negative_prompt_embeds = torch.cat(
443
+ [negative_prompt_embeds, img_cond_unc], dim=1)
444
+
445
+ # For classifier free guidance, we need to do two forward passes.
446
+ # Here we concatenate the unconditional and text embeddings into a single batch
447
+ # to avoid doing two forward passes
448
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
449
+
450
+ return prompt_embeds
451
+
452
+ def decode_latents(self, latents):
453
+ latents = 1 / self.vae.config.scaling_factor * latents
454
+
455
+ batch_size, channels, num_frames, height, width = latents.shape
456
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(
457
+ batch_size * num_frames, channels, height, width)
458
+
459
+ image = self.vae.decode(latents).sample
460
+ video = (
461
+ image[None, :]
462
+ .reshape(
463
+ (
464
+ batch_size,
465
+ num_frames,
466
+ -1,
467
+ )
468
+ + image.shape[2:]
469
+ )
470
+ .permute(0, 2, 1, 3, 4)
471
+ )
472
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
473
+ video = video.float()
474
+ return video
475
+
476
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
477
+ def prepare_extra_step_kwargs(self, generator, eta):
478
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
479
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
480
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
481
+ # and should be between [0, 1]
482
+
483
+ accepts_eta = "eta" in set(inspect.signature(
484
+ self.scheduler.step).parameters.keys())
485
+ extra_step_kwargs = {}
486
+ if accepts_eta:
487
+ extra_step_kwargs["eta"] = eta
488
+
489
+ # check if the scheduler accepts generator
490
+ accepts_generator = "generator" in set(
491
+ inspect.signature(self.scheduler.step).parameters.keys())
492
+ if accepts_generator:
493
+ extra_step_kwargs["generator"] = generator
494
+ return extra_step_kwargs
495
+
496
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
497
+ def check_inputs(
498
+ self,
499
+ prompt,
500
+ height,
501
+ width,
502
+ callback_steps,
503
+ negative_prompt=None,
504
+ prompt_embeds=None,
505
+ negative_prompt_embeds=None,
506
+ ):
507
+ if height % 8 != 0 or width % 8 != 0:
508
+ raise ValueError(
509
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
510
+
511
+ if (callback_steps is None) or (
512
+ callback_steps is not None and (not isinstance(
513
+ callback_steps, int) or callback_steps <= 0)
514
+ ):
515
+ raise ValueError(
516
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
517
+ f" {type(callback_steps)}."
518
+ )
519
+
520
+ if prompt is not None and prompt_embeds is not None:
521
+ raise ValueError(
522
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
523
+ " only forward one of the two."
524
+ )
525
+ elif prompt is None and prompt_embeds is None:
526
+ raise ValueError(
527
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
528
+ )
529
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
530
+ raise ValueError(
531
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
532
+
533
+ if negative_prompt is not None and negative_prompt_embeds is not None:
534
+ raise ValueError(
535
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
536
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
537
+ )
538
+
539
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
540
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
541
+ raise ValueError(
542
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
543
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
544
+ f" {negative_prompt_embeds.shape}."
545
+ )
546
+
547
+ def prepare_latents(
548
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
549
+ ):
550
+ shape = (
551
+ batch_size,
552
+ num_channels_latents,
553
+ num_frames,
554
+ height // self.vae_scale_factor,
555
+ width // self.vae_scale_factor,
556
+ )
557
+ if isinstance(generator, list) and len(generator) != batch_size:
558
+ raise ValueError(
559
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
560
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
561
+ )
562
+ if hasattr(self, "noise_generator"):
563
+ latents = self.noise_generator.sample_noise(
564
+ shape=shape, generator=generator, device=device, dtype=dtype)
565
+ elif latents is None:
566
+ latents = randn_tensor(
567
+ shape, generator=generator, device=device, dtype=dtype)
568
+ else:
569
+ latents = latents.to(device)
570
+
571
+ # scale the initial noise by the standard deviation required by the scheduler
572
+ latents = latents * self.scheduler.init_noise_sigma
573
+ return latents
574
+
575
+ def set_noise_generator(self, noise_generator):
576
+ if noise_generator is not None and noise_generator.mode != "vanilla":
577
+ self.noise_generator = noise_generator
578
+
579
+ def reset_noise_generator_state(self):
580
+ if hasattr(self, "noise_generator") and hasattr(self.noise_generator, "reset_noise"):
581
+ self.noise_generator.reset_noise_generator_state()
582
+
583
+ @torch.no_grad()
584
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
585
+ def __call__(
586
+ self,
587
+ prompt: Union[str, List[str]] = None,
588
+ # the image input for the controlnet branch
589
+ image: Optional[torch.FloatTensor] = None,
590
+ height: Optional[int] = None,
591
+ width: Optional[int] = None,
592
+ num_frames: int = 16,
593
+ num_inference_steps: int = 50,
594
+ guidance_scale: float = 9.0,
595
+ negative_prompt: Optional[Union[str, List[str]]] = None,
596
+ eta: float = 0.0,
597
+ generator: Optional[Union[torch.Generator,
598
+ List[torch.Generator]]] = None,
599
+ latents: Optional[torch.FloatTensor] = None,
600
+ prompt_embeds: Optional[torch.FloatTensor] = None,
601
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
602
+ output_type: Optional[str] = "np",
603
+ return_dict: bool = True,
604
+ callback: Optional[Callable[[
605
+ int, int, torch.FloatTensor], None]] = None,
606
+ callback_steps: int = 1,
607
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
608
+ precision: str = "16",
609
+ mask_generator=None,
610
+ no_text_condition_control: bool = False,
611
+ weight_control_sample: float = 1.0,
612
+ use_controlnet_mask: bool = False,
613
+ skip_controlnet_branch: bool = False,
614
+ img_cond_resampler=None,
615
+ img_cond_encoder=None,
616
+ input_frames_conditioning=None,
617
+ cfg_text_image: bool = False,
618
+ use_of: bool = False,
619
+ ** kwargs,
620
+ ):
621
+ r"""
622
+ Function invoked when calling the pipeline for generation.
623
+
624
+ Args:
625
+ prompt (`str` or `List[str]`, *optional*):
626
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
627
+ instead.
628
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
629
+ The height in pixels of the generated video.
630
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
631
+ The width in pixels of the generated video.
632
+ num_frames (`int`, *optional*, defaults to 16):
633
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
634
+ amounts to 2 seconds of video.
635
+ num_inference_steps (`int`, *optional*, defaults to 50):
636
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
637
+ expense of slower inference.
638
+ guidance_scale (`float`, *optional*, defaults to 7.5):
639
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
640
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
641
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
642
+ 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`,
643
+ usually at the expense of lower video quality.
644
+ negative_prompt (`str` or `List[str]`, *optional*):
645
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
646
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
647
+ less than `1`).
648
+ eta (`float`, *optional*, defaults to 0.0):
649
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
650
+ [`schedulers.DDIMScheduler`], will be ignored for others.
651
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
652
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
653
+ to make generation deterministic.
654
+ latents (`torch.FloatTensor`, *optional*):
655
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
656
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
657
+ tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape
658
+ `(batch_size, num_channel, num_frames, height, width)`.
659
+ prompt_embeds (`torch.FloatTensor`, *optional*):
660
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
661
+ provided, text embeddings will be generated from `prompt` input argument.
662
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
663
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
664
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
665
+ argument.
666
+ output_type (`str`, *optional*, defaults to `"np"`):
667
+ The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`.
668
+ return_dict (`bool`, *optional*, defaults to `True`):
669
+ Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a
670
+ plain tuple.
671
+ callback (`Callable`, *optional*):
672
+ A function that will be called every `callback_steps` steps during inference. The function will be
673
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
674
+ callback_steps (`int`, *optional*, defaults to 1):
675
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
676
+ called at every step.
677
+ cross_attention_kwargs (`dict`, *optional*):
678
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
679
+ `self.processor` in
680
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
681
+
682
+ Examples:
683
+
684
+ Returns:
685
+ [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`:
686
+ [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
687
+ When returning a tuple, the first element is a list with the generated frames.
688
+ """
689
+ # 0. Default height and width to unet
690
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
691
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
692
+
693
+ num_images_per_prompt = 1
694
+ controlnet_mask = None
695
+
696
+ # 1. Check inputs. Raise error if not correct
697
+ self.check_inputs(
698
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
699
+ )
700
+ # import pdb
701
+ # pdb.set_trace()
702
+
703
+ if img_cond_resampler is not None and image is not None:
704
+ bsz = image.shape[0]
705
+ image_for_conditioniong = rearrange(
706
+ input_frames_conditioning, "B F C W H -> (B F) C W H")
707
+ image_enc = img_cond_encoder(image_for_conditioniong)
708
+ img_cond = img_cond_resampler(image_enc, batch_size=bsz)
709
+ image_enc_unc = img_cond_encoder(
710
+ torch.zeros_like(image_for_conditioniong))
711
+ img_cond_unc = img_cond_resampler(image_enc_unc, batch_size=bsz)
712
+ else:
713
+ img_cond = None
714
+ img_cond_unc = None
715
+
716
+ # 2. Define call parameters
717
+ if prompt is not None and isinstance(prompt, str):
718
+ batch_size = 1
719
+ elif prompt is not None and isinstance(prompt, list):
720
+ batch_size = len(prompt)
721
+ else:
722
+ batch_size = prompt_embeds.shape[0]
723
+
724
+ device = self._execution_device
725
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
726
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
727
+ # corresponds to doing no classifier free guidance.
728
+ do_classifier_free_guidance = guidance_scale > 1.0
729
+
730
+ # 3. Encode input prompt
731
+ prompt_embeds = self._encode_prompt(
732
+ prompt,
733
+ device,
734
+ num_images_per_prompt,
735
+ do_classifier_free_guidance,
736
+ negative_prompt,
737
+ prompt_embeds=prompt_embeds,
738
+ negative_prompt_embeds=negative_prompt_embeds,
739
+ img_cond=img_cond,
740
+ img_cond_unc=img_cond_unc
741
+ )
742
+ skip_conditioning = image is None or skip_controlnet_branch
743
+ # import pdb
744
+ # pdb.set_trace()
745
+ if not skip_conditioning:
746
+ num_condition_frames = image.shape[1]
747
+ image, image_vq_enc = self.prepare_image(
748
+ image=image,
749
+ width=width,
750
+ height=height,
751
+ batch_size=batch_size * num_images_per_prompt,
752
+ num_images_per_prompt=num_images_per_prompt,
753
+ device=device,
754
+ dtype=self.controlnet.dtype,
755
+ do_classifier_free_guidance=do_classifier_free_guidance,
756
+ cfg_text_image=cfg_text_image,
757
+ )
758
+ if len(image.shape) == 5:
759
+ image = rearrange(image, "B F C H W -> (B F) C H W")
760
+ if use_controlnet_mask:
761
+ # num_condition_frames = all possible frames, e.g. 16
762
+ assert num_condition_frames == num_frames
763
+ image = rearrange(
764
+ image, "(B F) C H W -> B F C H W", F=num_condition_frames)
765
+ # image = torch.cat([image, image], dim=1)
766
+ controlnet_mask = torch.zeros(
767
+ (image.shape[0], num_frames), device=image.device, dtype=image.dtype)
768
+ # TODO HARDCODED number of frames!
769
+ controlnet_mask[:, :8] = 1.0
770
+ image = rearrange(image, "B F C H W -> (B F) C H W")
771
+
772
+ # 4. Prepare timesteps
773
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
774
+ timesteps = self.scheduler.timesteps
775
+
776
+ # 5. Prepare latent variables
777
+ of_channels = 2 if use_of else 0
778
+ num_channels_ctrl = self.unet.config.in_channels
779
+ num_channels_latents = num_channels_ctrl + of_channels
780
+ if not skip_conditioning:
781
+ image_vq_enc = rearrange(
782
+ image_vq_enc, "B F C H W -> B C F H W ", F=num_condition_frames)
783
+
784
+ latents = self.prepare_latents(
785
+ batch_size * num_images_per_prompt,
786
+ num_channels_latents,
787
+ num_frames,
788
+ height,
789
+ width,
790
+ prompt_embeds.dtype,
791
+ device,
792
+ generator,
793
+ latents,
794
+ )
795
+
796
+ if self.unet.concat:
797
+ image_latents = self.vae.encode(rearrange(
798
+ image, "B F C W H -> (B F) C W H")).latent_dist.sample() * self.vae.config.scaling_factor
799
+ image_latents = rearrange(
800
+ image_latents, "(B F) C W H -> B C F W H", B=latents.shape[0])
801
+ image_shape = image_latents.shape
802
+ image_shape = [ax_dim for ax_dim in image_shape]
803
+ image_shape[2] = 16-image_shape[2]
804
+ image_latents = torch.cat([image_latents, torch.zeros(
805
+ image_shape, dtype=image_latents.dtype, device=image_latents.device)], dim=2)
806
+ controlnet_mask = torch.zeros(
807
+ image_latents.shape, device=image_latents.device, dtype=image_latents.dtype)
808
+ controlnet_mask[:, :, :8] = 1.0
809
+ image_latents = image_latents * controlnet_mask
810
+ # torch.cat([latents, image_latents, controlnet_mask[:, :1]], dim=1)
811
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
812
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
813
+
814
+ # 7. Denoising loop
815
+ num_warmup_steps = len(timesteps) - \
816
+ num_inference_steps * self.scheduler.order
817
+
818
+ if mask_generator is not None:
819
+ attention_mask = mask_generator.get_mask(
820
+ device=latents.device, precision=precision)
821
+ else:
822
+ attention_mask = None
823
+
824
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
825
+ for i, t in enumerate(timesteps):
826
+ # expand the latents if we are doing classifier free guidance
827
+ latent_model_input = torch.cat(
828
+ [latents] * 2) if do_classifier_free_guidance else latents
829
+ latent_model_input = self.scheduler.scale_model_input(
830
+ latent_model_input, t)
831
+
832
+ if self.unet.concat:
833
+ latent_model_input = torch.cat([latent_model_input, image_latents.repeat(
834
+ 2, 1, 1, 1, 1), controlnet_mask[:, :1].repeat(2, 1, 1, 1, 1)], dim=1)
835
+ if not skip_conditioning:
836
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
837
+ latent_model_input[:, :num_channels_ctrl],
838
+ t,
839
+ encoder_hidden_states=prompt_embeds if (not no_text_condition_control) else torch.stack([
840
+ prompt_embeds[0], prompt_embeds[0]]),
841
+ controlnet_cond=image,
842
+ attention_mask=attention_mask,
843
+ vq_gan=self.vae,
844
+ weight_control_sample=weight_control_sample,
845
+ return_dict=False,
846
+ controlnet_mask=controlnet_mask,
847
+ )
848
+ else:
849
+ down_block_res_samples = None
850
+ mid_block_res_sample = None
851
+
852
+ # predict the noise residual
853
+ noise_pred = self.unet(
854
+ latent_model_input,
855
+ t,
856
+ encoder_hidden_states=prompt_embeds,
857
+ cross_attention_kwargs=cross_attention_kwargs,
858
+ attention_mask=attention_mask,
859
+ down_block_additional_residuals=[
860
+ sample.to(dtype=latent_model_input.dtype) for sample in down_block_res_samples
861
+ ] if down_block_res_samples is not None else None,
862
+ mid_block_additional_residual=mid_block_res_sample.to(
863
+ dtype=latent_model_input.dtype) if mid_block_res_sample is not None else None,
864
+ fps=None,
865
+
866
+ ).sample
867
+ # perform guidance
868
+ if do_classifier_free_guidance:
869
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(
870
+ 2)
871
+ noise_pred = noise_pred_uncond + guidance_scale * \
872
+ (noise_pred_text - noise_pred_uncond)
873
+
874
+ # reshape latents
875
+ bsz, channel, frames, width, height = latents.shape
876
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(
877
+ bsz * frames, channel, width, height)
878
+ noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(
879
+ bsz * frames, channel, width, height)
880
+
881
+ # compute the previous noisy sample x_t -> x_t-1
882
+ scheduler_step = self.scheduler.step(
883
+ noise_pred, t, latents, **extra_step_kwargs)
884
+ latents = scheduler_step.prev_sample
885
+
886
+ # reshape latents back
887
+ latents = latents[None, :].reshape(
888
+ bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4)
889
+
890
+ # call the callback, if provided
891
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
892
+ progress_bar.update()
893
+ if callback is not None and i % callback_steps == 0:
894
+ callback(i, t, latents)
895
+
896
+ latents_video = latents[:, :num_channels_ctrl]
897
+ if of_channels > 0:
898
+ latents_of = latents[:, num_channels_ctrl:]
899
+ latents_of = rearrange(latents_of, "B C F W H -> (B F) C W H")
900
+ video_tensor = self.decode_latents(latents_video)
901
+
902
+ if output_type == "pt":
903
+ video = video_tensor
904
+ elif output_type == "pt_t2v":
905
+ video = tensor2vid(video_tensor, output_type="pt")
906
+ video = rearrange(video, "f h w c -> f c h w")
907
+ elif output_type == "concat_image":
908
+ image_video = image.unsqueeze(2)[0:1].repeat([1, 1, 24, 1, 1])
909
+ video_tensor_concat = torch.concat(
910
+ [image_video, video_tensor], dim=4)
911
+ video = tensor2vid(video_tensor_concat)
912
+ else:
913
+ video = tensor2vid(video_tensor)
914
+
915
+ # Offload last model to CPU
916
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
917
+ self.final_offload_hook.offload()
918
+
919
+ if not return_dict:
920
+ if of_channels == 0:
921
+ return video
922
+ else:
923
+ return video, latents_of
924
+
925
+ return TextToVideoSDPipelineOutput(frames=video)
t2v_enhanced/model/diffusers_conditional/models/controlnet/processor.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import repeat, rearrange
2
+ from typing import Callable, Optional, Union
3
+ from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention
4
+ # from t2v_enhanced.model.diffusers_conditional.controldiffusers.models.attention import Attention
5
+ from diffusers.utils.import_utils import is_xformers_available
6
+ from t2v_enhanced.model.pl_module_params_controlnet import AttentionMaskParams
7
+ import torch
8
+ import torch.nn.functional as F
9
+ if is_xformers_available():
10
+ import xformers
11
+ import xformers.ops
12
+ else:
13
+ xformers = None
14
+
15
+
16
+ def set_use_memory_efficient_attention_xformers(
17
+ model, num_frame_conditioning: int, num_frames: int, attention_mask_params: AttentionMaskParams, valid: bool = True, attention_op: Optional[Callable] = None
18
+ ) -> None:
19
+ # Recursively walk through all the children.
20
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
21
+ # gets the message
22
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
23
+ if hasattr(module, "set_processor"):
24
+
25
+ module.set_processor(XFormersAttnProcessor(attention_op=attention_op,
26
+ num_frame_conditioning=num_frame_conditioning,
27
+ num_frames=num_frames,
28
+ attention_mask_params=attention_mask_params,)
29
+ )
30
+
31
+ for child in module.children():
32
+ fn_recursive_set_mem_eff(child)
33
+
34
+ for module in model.children():
35
+ if isinstance(module, torch.nn.Module):
36
+ fn_recursive_set_mem_eff(module)
37
+
38
+
39
+ class XFormersAttnProcessor:
40
+ def __init__(self,
41
+ attention_mask_params: AttentionMaskParams,
42
+ attention_op: Optional[Callable] = None,
43
+ num_frame_conditioning: int = None,
44
+ num_frames: int = None,
45
+ use_image_embedding: bool = False,
46
+ ):
47
+ self.attention_op = attention_op
48
+ self.num_frame_conditioning = num_frame_conditioning
49
+ self.num_frames = num_frames
50
+ self.temp_attend_on_neighborhood_of_condition_frames = attention_mask_params.temp_attend_on_neighborhood_of_condition_frames
51
+ self.spatial_attend_on_condition_frames = attention_mask_params.spatial_attend_on_condition_frames
52
+ self.use_image_embedding = use_image_embedding
53
+
54
+ def __call__(self, attn: Attention, hidden_states, hidden_state_height=None, hidden_state_width=None, encoder_hidden_states=None, attention_mask=None):
55
+ batch_size, sequence_length, _ = (
56
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
57
+ )
58
+
59
+ key_img = None
60
+ value_img = None
61
+ hidden_states_img = None
62
+ if attention_mask is not None:
63
+ attention_mask = repeat(
64
+ attention_mask, "1 F D -> B F D", B=batch_size)
65
+
66
+ attention_mask = attn.prepare_attention_mask(
67
+ attention_mask, sequence_length, batch_size)
68
+
69
+ query = attn.to_q(hidden_states)
70
+
71
+ is_cross_attention = encoder_hidden_states is not None
72
+
73
+ if encoder_hidden_states is None:
74
+ encoder_hidden_states = hidden_states
75
+ elif attn.norm_cross:
76
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
77
+ encoder_hidden_states)
78
+ default_attention = not hasattr(attn, "is_spatial_attention")
79
+ if default_attention:
80
+ assert not self.temp_attend_on_neighborhood_of_condition_frames, "special attention must be implemented with new interface"
81
+ assert not self.spatial_attend_on_condition_frames, "special attention must be implemented with new interface"
82
+ is_spatial_attention = attn.is_spatial_attention if hasattr(
83
+ attn, "is_spatial_attention") else False
84
+ use_image_embedding = attn.use_image_embedding if hasattr(
85
+ attn, "use_image_embedding") else False
86
+
87
+ if is_spatial_attention and use_image_embedding and attn.cross_attention_mode:
88
+ assert not self.spatial_attend_on_condition_frames, "Not implemented together with image embedding"
89
+
90
+ alpha = attn.alpha
91
+ encoder_hidden_states_txt = encoder_hidden_states[:, :77, :]
92
+
93
+ encoder_hidden_states_mixed = attn.conv(encoder_hidden_states)
94
+ encoder_hidden_states_mixed = attn.conv_ln(encoder_hidden_states_mixed)
95
+ encoder_hidden_states = encoder_hidden_states_txt + encoder_hidden_states_mixed * F.silu(alpha)
96
+
97
+ key = attn.to_k(encoder_hidden_states)
98
+ value = attn.to_v(encoder_hidden_states)
99
+ else:
100
+ key = attn.to_k(encoder_hidden_states)
101
+ value = attn.to_v(encoder_hidden_states)
102
+
103
+
104
+
105
+
106
+ if not default_attention and not is_spatial_attention and self.temp_attend_on_neighborhood_of_condition_frames and not attn.cross_attention_mode:
107
+ # normal attention
108
+ query_condition = query[:, :self.num_frame_conditioning]
109
+ query_condition = attn.head_to_batch_dim(
110
+ query_condition).contiguous()
111
+ key_condition = key
112
+ value_condition = value
113
+ key_condition = attn.head_to_batch_dim(key_condition).contiguous()
114
+ value_condition = attn.head_to_batch_dim(
115
+ value_condition).contiguous()
116
+ hidden_states_condition = xformers.ops.memory_efficient_attention(
117
+ query_condition, key_condition, value_condition, attn_bias=None, op=self.attention_op, scale=attn.scale
118
+ )
119
+ hidden_states_condition = hidden_states_condition.to(query.dtype)
120
+ hidden_states_condition = attn.batch_to_head_dim(
121
+ hidden_states_condition)
122
+ #
123
+ query_uncondition = query[:, self.num_frame_conditioning:]
124
+
125
+ key = key[:, :self.num_frame_conditioning]
126
+ value = value[:, :self.num_frame_conditioning]
127
+ key = rearrange(key, "(B W H) F C -> B W H F C",
128
+ H=hidden_state_height, W=hidden_state_width)
129
+ value = rearrange(value, "(B W H) F C -> B W H F C",
130
+ H=hidden_state_height, W=hidden_state_width)
131
+
132
+ keys = []
133
+ values = []
134
+ for shifts_width in [-1, 0, 1]:
135
+ for shifts_height in [-1, 0, 1]:
136
+ keys.append(torch.roll(key, shifts=(
137
+ shifts_width, shifts_height), dims=(1, 2)))
138
+ values.append(torch.roll(value, shifts=(
139
+ shifts_width, shifts_height), dims=(1, 2)))
140
+ key = rearrange(torch.cat(keys, dim=3), "B W H F C -> (B W H) F C")
141
+ value = rearrange(torch.cat(values, dim=3),
142
+ 'B W H F C -> (B W H) F C')
143
+
144
+ query = attn.head_to_batch_dim(query_uncondition).contiguous()
145
+ key = attn.head_to_batch_dim(key).contiguous()
146
+ value = attn.head_to_batch_dim(value).contiguous()
147
+
148
+ hidden_states = xformers.ops.memory_efficient_attention(
149
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
150
+ )
151
+ hidden_states = hidden_states.to(query.dtype)
152
+ hidden_states = attn.batch_to_head_dim(hidden_states)
153
+ hidden_states = torch.cat(
154
+ [hidden_states_condition, hidden_states], dim=1)
155
+ elif not default_attention and is_spatial_attention and self.spatial_attend_on_condition_frames and not attn.cross_attention_mode:
156
+ # (B F) W H C -> B F W H C
157
+ query_condition = rearrange(
158
+ query, "(B F) S C -> B F S C", F=self.num_frames)
159
+ query_condition = query_condition[:, :self.num_frame_conditioning]
160
+ query_condition = rearrange(
161
+ query_condition, "B F S C -> (B F) S C")
162
+ query_condition = attn.head_to_batch_dim(
163
+ query_condition).contiguous()
164
+
165
+ key_condition = rearrange(
166
+ key, "(B F) S C -> B F S C", F=self.num_frames)
167
+ key_condition = key_condition[:, :self.num_frame_conditioning]
168
+ key_condition = rearrange(key_condition, "B F S C -> (B F) S C")
169
+
170
+ value_condition = rearrange(
171
+ value, "(B F) S C -> B F S C", F=self.num_frames)
172
+ value_condition = value_condition[:, :self.num_frame_conditioning]
173
+ value_condition = rearrange(
174
+ value_condition, "B F S C -> (B F) S C")
175
+
176
+ key_condition = attn.head_to_batch_dim(key_condition).contiguous()
177
+ value_condition = attn.head_to_batch_dim(
178
+ value_condition).contiguous()
179
+ hidden_states_condition = xformers.ops.memory_efficient_attention(
180
+ query_condition, key_condition, value_condition, attn_bias=None, op=self.attention_op, scale=attn.scale
181
+ )
182
+ hidden_states_condition = hidden_states_condition.to(query.dtype)
183
+ hidden_states_condition = attn.batch_to_head_dim(
184
+ hidden_states_condition)
185
+
186
+ query_uncondition = rearrange(
187
+ query, "(B F) S C -> B F S C", F=self.num_frames)
188
+ query_uncondition = query_uncondition[:,
189
+ self.num_frame_conditioning:]
190
+ key_uncondition = rearrange(
191
+ key, "(B F) S C -> B F S C", F=self.num_frames)
192
+ value_uncondition = rearrange(
193
+ value, "(B F) S C -> B F S C", F=self.num_frames)
194
+ key_uncondition = key_uncondition[:,
195
+ self.num_frame_conditioning-1, None]
196
+ value_uncondition = value_uncondition[:,
197
+ self.num_frame_conditioning-1, None]
198
+ # if self.trainer.training:
199
+ # import pdb
200
+ # pdb.set_trace()
201
+ # print("now")
202
+ query_uncondition = rearrange(
203
+ query_uncondition, "B F S C -> (B F) S C")
204
+ key_uncondition = repeat(rearrange(
205
+ key_uncondition, "B F S C -> B (F S) C"), "B T C -> (B F) T C", F=self.num_frames-self.num_frame_conditioning)
206
+ value_uncondition = repeat(rearrange(
207
+ value_uncondition, "B F S C -> B (F S) C"), "B T C -> (B F) T C", F=self.num_frames-self.num_frame_conditioning)
208
+ query_uncondition = attn.head_to_batch_dim(
209
+ query_uncondition).contiguous()
210
+ key_uncondition = attn.head_to_batch_dim(
211
+ key_uncondition).contiguous()
212
+ value_uncondition = attn.head_to_batch_dim(
213
+ value_uncondition).contiguous()
214
+ hidden_states_uncondition = xformers.ops.memory_efficient_attention(
215
+ query_uncondition, key_uncondition, value_uncondition, attn_bias=None, op=self.attention_op, scale=attn.scale
216
+ )
217
+ hidden_states_uncondition = hidden_states_uncondition.to(
218
+ query.dtype)
219
+ hidden_states_uncondition = attn.batch_to_head_dim(
220
+ hidden_states_uncondition)
221
+ hidden_states = torch.cat([rearrange(hidden_states_condition, "(B F) S C -> B F S C", F=self.num_frame_conditioning), rearrange(
222
+ hidden_states_uncondition, "(B F) S C -> B F S C", F=self.num_frames-self.num_frame_conditioning)], dim=1)
223
+ hidden_states = rearrange(hidden_states, "B F S C -> (B F) S C")
224
+ else:
225
+ query = attn.head_to_batch_dim(query).contiguous()
226
+ key = attn.head_to_batch_dim(key).contiguous()
227
+ value = attn.head_to_batch_dim(value).contiguous()
228
+
229
+ hidden_states = xformers.ops.memory_efficient_attention(
230
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
231
+ )
232
+
233
+ hidden_states = hidden_states.to(query.dtype)
234
+ hidden_states = attn.batch_to_head_dim(hidden_states)
235
+
236
+ # linear proj
237
+ hidden_states = attn.to_out[0](hidden_states)
238
+ # dropout
239
+ hidden_states = attn.to_out[1](hidden_states)
240
+ return hidden_states
t2v_enhanced/model/diffusers_conditional/models/controlnet/transformer_2d.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate
24
+ from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention import BasicTransformerBlock
25
+ from diffusers.models.embeddings import PatchEmbed
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+
28
+
29
+ @dataclass
30
+ class Transformer2DModelOutput(BaseOutput):
31
+ """
32
+ Args:
33
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
34
+ Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
35
+ for the unnoised latent pixels.
36
+ """
37
+
38
+ sample: torch.FloatTensor
39
+
40
+
41
+ class Transformer2DModel(ModelMixin, ConfigMixin):
42
+ """
43
+ Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
44
+ embeddings) inputs.
45
+
46
+ When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
47
+ transformer action. Finally, reshape to image.
48
+
49
+ When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
50
+ embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
51
+ classes of unnoised image.
52
+
53
+ Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
54
+ image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
55
+
56
+ Parameters:
57
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
58
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
59
+ in_channels (`int`, *optional*):
60
+ Pass if the input is continuous. The number of channels in the input and output.
61
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
62
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
63
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
64
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
65
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
66
+ `ImagePositionalEmbeddings`.
67
+ num_vector_embeds (`int`, *optional*):
68
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
69
+ Includes the class for the masked latent pixel.
70
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
71
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
72
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
73
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
74
+ up to but not more than steps than `num_embeds_ada_norm`.
75
+ attention_bias (`bool`, *optional*):
76
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
77
+ """
78
+
79
+ @register_to_config
80
+ def __init__(
81
+ self,
82
+ num_attention_heads: int = 16,
83
+ attention_head_dim: int = 88,
84
+ in_channels: Optional[int] = None,
85
+ out_channels: Optional[int] = None,
86
+ num_layers: int = 1,
87
+ dropout: float = 0.0,
88
+ norm_num_groups: int = 32,
89
+ cross_attention_dim: Optional[int] = None,
90
+ attention_bias: bool = False,
91
+ sample_size: Optional[int] = None,
92
+ num_vector_embeds: Optional[int] = None,
93
+ patch_size: Optional[int] = None,
94
+ activation_fn: str = "geglu",
95
+ num_embeds_ada_norm: Optional[int] = None,
96
+ use_linear_projection: bool = False,
97
+ only_cross_attention: bool = False,
98
+ upcast_attention: bool = False,
99
+ norm_type: str = "layer_norm",
100
+ norm_elementwise_affine: bool = True,
101
+ use_image_embedding: bool = False,
102
+ unet_params=None,
103
+ ):
104
+ super().__init__()
105
+ self.use_linear_projection = use_linear_projection
106
+ self.num_attention_heads = num_attention_heads
107
+ self.attention_head_dim = attention_head_dim
108
+ inner_dim = num_attention_heads * attention_head_dim
109
+
110
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
111
+ # Define whether input is continuous or discrete depending on configuration
112
+ self.is_input_continuous = (
113
+ in_channels is not None) and (patch_size is None)
114
+ self.is_input_vectorized = num_vector_embeds is not None
115
+ self.is_input_patches = in_channels is not None and patch_size is not None
116
+
117
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
118
+ deprecation_message = (
119
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
120
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
121
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
122
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
123
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
124
+ )
125
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0",
126
+ deprecation_message, standard_warn=False)
127
+ norm_type = "ada_norm"
128
+
129
+ if self.is_input_continuous and self.is_input_vectorized:
130
+ raise ValueError(
131
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
132
+ " sure that either `in_channels` or `num_vector_embeds` is None."
133
+ )
134
+ elif self.is_input_vectorized and self.is_input_patches:
135
+ raise ValueError(
136
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
137
+ " sure that either `num_vector_embeds` or `num_patches` is None."
138
+ )
139
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
140
+ raise ValueError(
141
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
142
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
143
+ )
144
+
145
+ # 2. Define input layers
146
+ if self.is_input_continuous:
147
+ self.in_channels = in_channels
148
+
149
+ self.norm = torch.nn.GroupNorm(
150
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
151
+ if use_linear_projection:
152
+ self.proj_in = nn.Linear(in_channels, inner_dim)
153
+ else:
154
+ self.proj_in = nn.Conv2d(
155
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
156
+ elif self.is_input_vectorized:
157
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
158
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
159
+
160
+ self.height = sample_size
161
+ self.width = sample_size
162
+ self.num_vector_embeds = num_vector_embeds
163
+ self.num_latent_pixels = self.height * self.width
164
+
165
+ self.latent_image_embedding = ImagePositionalEmbeddings(
166
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
167
+ )
168
+ elif self.is_input_patches:
169
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
170
+
171
+ self.height = sample_size
172
+ self.width = sample_size
173
+
174
+ self.patch_size = patch_size
175
+ self.pos_embed = PatchEmbed(
176
+ height=sample_size,
177
+ width=sample_size,
178
+ patch_size=patch_size,
179
+ in_channels=in_channels,
180
+ embed_dim=inner_dim,
181
+ )
182
+
183
+ # 3. Define transformers blocks
184
+ self.transformer_blocks = nn.ModuleList(
185
+ [
186
+ BasicTransformerBlock(
187
+ inner_dim,
188
+ num_attention_heads,
189
+ attention_head_dim,
190
+ dropout=dropout,
191
+ cross_attention_dim=cross_attention_dim,
192
+ activation_fn=activation_fn,
193
+ num_embeds_ada_norm=num_embeds_ada_norm,
194
+ attention_bias=attention_bias,
195
+ only_cross_attention=only_cross_attention,
196
+ upcast_attention=upcast_attention,
197
+ norm_type=norm_type,
198
+ norm_elementwise_affine=norm_elementwise_affine,
199
+ is_spatial_attention=True,
200
+ use_image_embedding=use_image_embedding,
201
+ unet_params=unet_params,
202
+ )
203
+ for d in range(num_layers)
204
+ ]
205
+ )
206
+
207
+ # 4. Define output layers
208
+ self.out_channels = in_channels if out_channels is None else out_channels
209
+ if self.is_input_continuous:
210
+ # TODO: should use out_channels for continuous projections
211
+ if use_linear_projection:
212
+ self.proj_out = nn.Linear(inner_dim, in_channels)
213
+ else:
214
+ self.proj_out = nn.Conv2d(
215
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
216
+ elif self.is_input_vectorized:
217
+ self.norm_out = nn.LayerNorm(inner_dim)
218
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
219
+ elif self.is_input_patches:
220
+ self.norm_out = nn.LayerNorm(
221
+ inner_dim, elementwise_affine=False, eps=1e-6)
222
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
223
+ self.proj_out_2 = nn.Linear(
224
+ inner_dim, patch_size * patch_size * self.out_channels)
225
+
226
+ def forward(
227
+ self,
228
+ hidden_states,
229
+ encoder_hidden_states=None,
230
+ timestep=None,
231
+ class_labels=None,
232
+ cross_attention_kwargs=None,
233
+ return_dict: bool = True,
234
+ ):
235
+ """
236
+ Args:
237
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
238
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
239
+ hidden_states
240
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
241
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
242
+ self-attention.
243
+ timestep ( `torch.long`, *optional*):
244
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
245
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
246
+ Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
247
+ conditioning.
248
+ return_dict (`bool`, *optional*, defaults to `True`):
249
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
250
+
251
+ Returns:
252
+ [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
253
+ [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
254
+ returning a tuple, the first element is the sample tensor.
255
+ """
256
+ # 1. Input
257
+ if self.is_input_continuous:
258
+ batch, _, height, width = hidden_states.shape
259
+ residual = hidden_states
260
+
261
+ hidden_states = self.norm(hidden_states)
262
+ if not self.use_linear_projection:
263
+ hidden_states = self.proj_in(hidden_states)
264
+ inner_dim = hidden_states.shape[1]
265
+ hidden_states = hidden_states.permute(
266
+ 0, 2, 3, 1).reshape(batch, height * width, inner_dim)
267
+ else:
268
+ inner_dim = hidden_states.shape[1]
269
+ hidden_states = hidden_states.permute(
270
+ 0, 2, 3, 1).reshape(batch, height * width, inner_dim)
271
+ hidden_states = self.proj_in(hidden_states)
272
+ elif self.is_input_vectorized:
273
+ hidden_states = self.latent_image_embedding(hidden_states)
274
+ elif self.is_input_patches:
275
+ hidden_states = self.pos_embed(hidden_states)
276
+
277
+ # 2. Blocks
278
+ for block in self.transformer_blocks:
279
+ hidden_states = block(
280
+ hidden_states,
281
+ encoder_hidden_states=encoder_hidden_states,
282
+ timestep=timestep,
283
+ cross_attention_kwargs=cross_attention_kwargs,
284
+ class_labels=class_labels,
285
+ )
286
+
287
+ # 3. Output
288
+ if self.is_input_continuous:
289
+ if not self.use_linear_projection:
290
+ hidden_states = hidden_states.reshape(
291
+ batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
292
+ hidden_states = self.proj_out(hidden_states)
293
+ else:
294
+ hidden_states = self.proj_out(hidden_states)
295
+ hidden_states = hidden_states.reshape(
296
+ batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
297
+
298
+ output = hidden_states + residual
299
+ elif self.is_input_vectorized:
300
+ hidden_states = self.norm_out(hidden_states)
301
+ logits = self.out(hidden_states)
302
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
303
+ logits = logits.permute(0, 2, 1)
304
+
305
+ # log(p(x_0))
306
+ output = F.log_softmax(logits.double(), dim=1).float()
307
+ elif self.is_input_patches:
308
+ # TODO: cleanup!
309
+ conditioning = self.transformer_blocks[0].norm1.emb(
310
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
311
+ )
312
+ shift, scale = self.proj_out_1(
313
+ F.silu(conditioning)).chunk(2, dim=1)
314
+ hidden_states = self.norm_out(
315
+ hidden_states) * (1 + scale[:, None]) + shift[:, None]
316
+ hidden_states = self.proj_out_2(hidden_states)
317
+
318
+ # unpatchify
319
+ height = width = int(hidden_states.shape[1] ** 0.5)
320
+ hidden_states = hidden_states.reshape(
321
+ shape=(-1, height, width, self.patch_size,
322
+ self.patch_size, self.out_channels)
323
+ )
324
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
325
+ output = hidden_states.reshape(
326
+ shape=(-1, self.out_channels, height *
327
+ self.patch_size, width * self.patch_size)
328
+ )
329
+
330
+ if not return_dict:
331
+ return (output,)
332
+
333
+ return Transformer2DModelOutput(sample=output)
t2v_enhanced/model/diffusers_conditional/models/controlnet/transformer_temporal.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.utils import BaseOutput
22
+ # from diffusers.models.attention import BasicTransformerBlock
23
+ # from t2v_enhanced.model.diffusers_conditional.models.attention import BasicTransformerBlock
24
+ from diffusers.models.modeling_utils import ModelMixin
25
+ from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention import BasicTransformerBlock
26
+
27
+
28
+ @dataclass
29
+ class TransformerTemporalModelOutput(BaseOutput):
30
+ """
31
+ Args:
32
+ sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`)
33
+ Hidden states conditioned on `encoder_hidden_states` input.
34
+ """
35
+
36
+ sample: torch.FloatTensor
37
+
38
+
39
+ class TransformerTemporalModel(ModelMixin, ConfigMixin):
40
+ """
41
+ Transformer model for video-like data.
42
+
43
+ Parameters:
44
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
45
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
46
+ in_channels (`int`, *optional*):
47
+ Pass if the input is continuous. The number of channels in the input and output.
48
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
49
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
50
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
51
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
52
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
53
+ `ImagePositionalEmbeddings`.
54
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
55
+ attention_bias (`bool`, *optional*):
56
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
57
+ double_self_attention (`bool`, *optional*):
58
+ Configure if each TransformerBlock should contain two self-attention layers
59
+ """
60
+
61
+ @register_to_config
62
+ def __init__(
63
+ self,
64
+ num_attention_heads: int = 16,
65
+ attention_head_dim: int = 88,
66
+ in_channels: Optional[int] = None,
67
+ out_channels: Optional[int] = None,
68
+ num_layers: int = 1,
69
+ dropout: float = 0.0,
70
+ norm_num_groups: int = 32,
71
+ cross_attention_dim: Optional[int] = None,
72
+ attention_bias: bool = False,
73
+ sample_size: Optional[int] = None,
74
+ activation_fn: str = "geglu",
75
+ norm_elementwise_affine: bool = True,
76
+ double_self_attention: bool = True,
77
+ ):
78
+ super().__init__()
79
+ self.num_attention_heads = num_attention_heads
80
+ self.attention_head_dim = attention_head_dim
81
+ inner_dim = num_attention_heads * attention_head_dim
82
+
83
+ self.in_channels = in_channels
84
+
85
+ self.norm = torch.nn.GroupNorm(
86
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
87
+ self.proj_in = nn.Linear(in_channels, inner_dim)
88
+
89
+ # 3. Define transformers blocks
90
+ self.transformer_blocks = nn.ModuleList(
91
+ [
92
+ BasicTransformerBlock(
93
+ inner_dim,
94
+ num_attention_heads,
95
+ attention_head_dim,
96
+ dropout=dropout,
97
+ cross_attention_dim=cross_attention_dim,
98
+ activation_fn=activation_fn,
99
+ attention_bias=attention_bias,
100
+ double_self_attention=double_self_attention,
101
+ norm_elementwise_affine=norm_elementwise_affine,
102
+ is_spatial_attention=False,
103
+ )
104
+ for d in range(num_layers)
105
+ ]
106
+ )
107
+
108
+ self.proj_out = nn.Linear(inner_dim, in_channels)
109
+
110
+ def forward(
111
+ self,
112
+ hidden_states,
113
+ encoder_hidden_states=None,
114
+ timestep=None,
115
+ class_labels=None,
116
+ num_frames=1,
117
+ cross_attention_kwargs=None,
118
+ return_dict: bool = True,
119
+ attention_mask=None,
120
+ ):
121
+ """
122
+ Args:
123
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
124
+ When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
125
+ hidden_states
126
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
127
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
128
+ self-attention.
129
+ timestep ( `torch.long`, *optional*):
130
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
131
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
132
+ Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
133
+ conditioning.
134
+ return_dict (`bool`, *optional*, defaults to `True`):
135
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
136
+
137
+ Returns:
138
+ [`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`:
139
+ [`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`.
140
+ When returning a tuple, the first element is the sample tensor.
141
+ """
142
+ # 1. Input
143
+ batch_frames, channel, height, width = hidden_states.shape
144
+ batch_size = batch_frames // num_frames
145
+
146
+ residual = hidden_states
147
+
148
+ hidden_states = hidden_states[None, :].reshape(
149
+ batch_size, num_frames, channel, height, width)
150
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
151
+
152
+ hidden_states = self.norm(hidden_states)
153
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(
154
+ batch_size * height * width, num_frames, channel)
155
+
156
+ hidden_states = self.proj_in(hidden_states)
157
+ if cross_attention_kwargs is None:
158
+ cross_attention_kwargs = {}
159
+ cross_attention_kwargs["hidden_state_height"] = height
160
+ cross_attention_kwargs["hidden_state_width"] = width
161
+
162
+ # 2. Blocks
163
+ for block in self.transformer_blocks:
164
+ hidden_states = block(
165
+ hidden_states,
166
+ encoder_hidden_states=encoder_hidden_states,
167
+ timestep=timestep,
168
+ cross_attention_kwargs=cross_attention_kwargs,
169
+ class_labels=class_labels,
170
+ attention_mask=attention_mask,
171
+ encoder_attention_mask=attention_mask,
172
+ )
173
+
174
+ # 3. Output
175
+ hidden_states = self.proj_out(hidden_states)
176
+ hidden_states = (
177
+ hidden_states[None, None, :]
178
+ .reshape(batch_size, height, width, channel, num_frames)
179
+ .permute(0, 3, 4, 1, 2)
180
+ .contiguous()
181
+ )
182
+ hidden_states = hidden_states.reshape(
183
+ batch_frames, channel, height, width)
184
+
185
+ output = hidden_states + residual
186
+
187
+ if not return_dict:
188
+ return (output,)
189
+
190
+ return TransformerTemporalModelOutput(sample=output)
t2v_enhanced/model/diffusers_conditional/models/controlnet/transformer_temporal_crossattention.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.utils import BaseOutput
22
+
23
+ from diffusers.models.modeling_utils import ModelMixin
24
+
25
+
26
+ @dataclass
27
+ class TransformerTemporalModelOutput(BaseOutput):
28
+ """
29
+ Args:
30
+ sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`)
31
+ Hidden states conditioned on `encoder_hidden_states` input.
32
+ """
33
+
34
+ sample: torch.FloatTensor
35
+
36
+
37
+ class TransformerTemporalModel(ModelMixin, ConfigMixin):
38
+ """
39
+ Transformer model for video-like data.
40
+
41
+ Parameters:
42
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
43
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
44
+ in_channels (`int`, *optional*):
45
+ Pass if the input is continuous. The number of channels in the input and output.
46
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
47
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
48
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
49
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
50
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
51
+ `ImagePositionalEmbeddings`.
52
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
53
+ attention_bias (`bool`, *optional*):
54
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
55
+ double_self_attention (`bool`, *optional*):
56
+ Configure if each TransformerBlock should contain two self-attention layers
57
+ """
58
+
59
+ @register_to_config
60
+ def __init__(
61
+ self,
62
+ num_attention_heads: int = 16,
63
+ attention_head_dim: int = 88,
64
+ in_channels: Optional[int] = None,
65
+ out_channels: Optional[int] = None,
66
+ num_layers: int = 1,
67
+ dropout: float = 0.0,
68
+ norm_num_groups: int = 32,
69
+ cross_attention_dim: Optional[int] = None,
70
+ attention_bias: bool = False,
71
+ sample_size: Optional[int] = None,
72
+ activation_fn: str = "geglu",
73
+ norm_elementwise_affine: bool = True,
74
+ double_self_attention: bool = True,
75
+ ):
76
+ super().__init__()
77
+
78
+ self.num_attention_heads = num_attention_heads
79
+ self.attention_head_dim = attention_head_dim
80
+ inner_dim = num_attention_heads * attention_head_dim
81
+
82
+ self.in_channels = in_channels
83
+
84
+ self.norm = torch.nn.GroupNorm(
85
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
86
+ self.proj_in = nn.Linear(in_channels, inner_dim)
87
+
88
+ # 3. Define transformers blocks
89
+ self.transformer_blocks = nn.ModuleList(
90
+ [
91
+ BasicTransformerBlock(
92
+ inner_dim,
93
+ num_attention_heads,
94
+ attention_head_dim,
95
+ dropout=dropout,
96
+ cross_attention_dim=cross_attention_dim,
97
+ activation_fn=activation_fn,
98
+ attention_bias=attention_bias,
99
+ double_self_attention=double_self_attention,
100
+ norm_elementwise_affine=norm_elementwise_affine,
101
+ only_cross_attention=True,
102
+ )
103
+ for d in range(num_layers)
104
+ ]
105
+ )
106
+
107
+ self.proj_out = nn.Linear(inner_dim, in_channels)
108
+
109
+ def forward(
110
+ self,
111
+ hidden_states,
112
+ encoder_hidden_states=None,
113
+ timestep=None,
114
+ class_labels=None,
115
+ num_frames=1,
116
+ cross_attention_kwargs=None,
117
+ return_dict: bool = True,
118
+ ):
119
+ """
120
+ Args:
121
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
122
+ When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
123
+ hidden_states
124
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
125
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
126
+ self-attention.
127
+ timestep ( `torch.long`, *optional*):
128
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
129
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
130
+ Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
131
+ conditioning.
132
+ return_dict (`bool`, *optional*, defaults to `True`):
133
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
134
+
135
+ Returns:
136
+ [`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`:
137
+ [`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`.
138
+ When returning a tuple, the first element is the sample tensor.
139
+ """
140
+ # 1. Input
141
+ batch_frames, channel, height, width = hidden_states.shape
142
+ batch_size = batch_frames // num_frames
143
+
144
+ residual = hidden_states
145
+
146
+ hidden_states = hidden_states[None, :].reshape(
147
+ batch_size, num_frames, channel, height, width)
148
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
149
+
150
+ hidden_states = self.norm(hidden_states)
151
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(
152
+ batch_size * height * width, num_frames, channel)
153
+
154
+ hidden_states = self.proj_in(hidden_states)
155
+
156
+ # 2. Blocks
157
+ for block in self.transformer_blocks:
158
+ hidden_states = block(
159
+ hidden_states,
160
+ encoder_hidden_states=encoder_hidden_states,
161
+ timestep=timestep,
162
+ cross_attention_kwargs=cross_attention_kwargs,
163
+ class_labels=class_labels,
164
+ )
165
+
166
+ # 3. Output
167
+ hidden_states = self.proj_out(hidden_states)
168
+ hidden_states = (
169
+ hidden_states[None, None, :]
170
+ .reshape(batch_size, height, width, channel, num_frames)
171
+ .permute(0, 3, 4, 1, 2)
172
+ .contiguous()
173
+ )
174
+ hidden_states = hidden_states.reshape(
175
+ batch_frames, channel, height, width)
176
+
177
+ output = hidden_states + residual
178
+
179
+ if not return_dict:
180
+ return (output,)
181
+
182
+ return TransformerTemporalModelOutput(sample=output)
t2v_enhanced/model/diffusers_conditional/models/controlnet/unet_3d_blocks.py ADDED
@@ -0,0 +1,930 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.utils.checkpoint as checkpoint
17
+ from torch import nn
18
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
19
+ # from diffusers.models.transformer_2d import Transformer2DModel
20
+ from t2v_enhanced.model.diffusers_conditional.models.controlnet.transformer_2d import Transformer2DModel
21
+ # from diffusers.models.transformer_temporal import TransformerTemporalModel
22
+ from t2v_enhanced.model.diffusers_conditional.models.controlnet.transformer_temporal import TransformerTemporalModel
23
+
24
+
25
+ # Assign gradient checkpoint function to simple variable for readability.
26
+ g_c = checkpoint.checkpoint
27
+
28
+
29
+ def is_video(num_frames, only_video=True):
30
+ if num_frames == 1 and not only_video:
31
+ return False
32
+ return num_frames > 1
33
+
34
+
35
+ def custom_checkpoint(module, mode=None):
36
+ if mode == None:
37
+ raise ValueError('Mode for gradient checkpointing cannot be none.')
38
+
39
+ custom_forward = None
40
+
41
+ if mode == 'resnet':
42
+ def custom_forward(hidden_states, temb):
43
+ inputs = module(hidden_states, temb)
44
+ return inputs
45
+
46
+ if mode == 'attn':
47
+ def custom_forward(
48
+ hidden_states,
49
+ encoder_hidden_states=None,
50
+ cross_attention_kwargs=None,
51
+ attention_mask=None,
52
+ ):
53
+ inputs = module(
54
+ hidden_states,
55
+ encoder_hidden_states,
56
+ cross_attention_kwargs,
57
+ attention_mask
58
+ )
59
+ return inputs.sample
60
+
61
+ if mode == 'temp':
62
+ # If inputs are not None, we can assume that this was a single image.
63
+ # Otherwise, do temporal convolutions / attention.
64
+ def custom_forward(hidden_states, num_frames=None):
65
+ if not is_video(num_frames):
66
+ return hidden_states
67
+ else:
68
+ inputs = module(
69
+ hidden_states,
70
+ num_frames=num_frames
71
+ )
72
+ if isinstance(module, TransformerTemporalModel):
73
+ return inputs.sample
74
+ else:
75
+ return inputs
76
+
77
+ return custom_forward
78
+
79
+
80
+ def transformer_g_c(transformer, sample, num_frames):
81
+ sample = g_c(custom_checkpoint(transformer, mode='temp'),
82
+ sample, num_frames, use_reentrant=False,
83
+ )
84
+ return sample
85
+
86
+
87
+ def cross_attn_g_c(
88
+ attn,
89
+ temp_attn,
90
+ resnet,
91
+ temp_conv,
92
+ hidden_states,
93
+ encoder_hidden_states,
94
+ cross_attention_kwargs,
95
+ temb,
96
+ num_frames,
97
+ inverse_temp=False,
98
+ attention_mask=None,
99
+ ):
100
+
101
+ def ordered_g_c(idx):
102
+
103
+ # Self and CrossAttention
104
+ if idx == 0:
105
+ return g_c(custom_checkpoint(attn, mode='attn'),
106
+ hidden_states,
107
+ encoder_hidden_states,
108
+ cross_attention_kwargs,
109
+ attention_mask,
110
+ use_reentrant=False
111
+ )
112
+
113
+ # Temporal Self and CrossAttention
114
+ if idx == 1:
115
+ return g_c(custom_checkpoint(temp_attn, mode='temp'),
116
+ hidden_states,
117
+ num_frames,
118
+ use_reentrant=False
119
+ )
120
+
121
+ # Resnets
122
+ if idx == 2:
123
+ return g_c(custom_checkpoint(resnet, mode='resnet'),
124
+ hidden_states,
125
+ temb,
126
+ use_reentrant=False
127
+ )
128
+
129
+ # Temporal Convolutions
130
+ if idx == 3:
131
+ return g_c(custom_checkpoint(temp_conv, mode='temp'),
132
+ hidden_states,
133
+ num_frames,
134
+ use_reentrant=False
135
+ )
136
+
137
+ # Here we call the function depending on the order in which they are called.
138
+ # For some layers, the orders are different, so we access the appropriate one by index.
139
+
140
+ if not inverse_temp:
141
+ for idx in [0, 1, 2, 3]:
142
+ hidden_states = ordered_g_c(idx)
143
+ else:
144
+ for idx in [2, 3, 0, 1]:
145
+ hidden_states = ordered_g_c(idx)
146
+
147
+ return hidden_states
148
+
149
+
150
+ def up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames):
151
+ hidden_states = g_c(custom_checkpoint(resnet, mode='resnet'),
152
+ hidden_states,
153
+ temb,
154
+ use_reentrant=False
155
+ )
156
+ hidden_states = g_c(custom_checkpoint(temp_conv, mode='temp'),
157
+ hidden_states,
158
+ num_frames,
159
+ use_reentrant=False
160
+ )
161
+ return hidden_states
162
+
163
+
164
+ def get_down_block(
165
+ down_block_type,
166
+ num_layers,
167
+ in_channels,
168
+ out_channels,
169
+ temb_channels,
170
+ add_downsample,
171
+ resnet_eps,
172
+ resnet_act_fn,
173
+ attn_num_head_channels,
174
+ resnet_groups=None,
175
+ cross_attention_dim=None,
176
+ downsample_padding=None,
177
+ dual_cross_attention=False,
178
+ use_linear_projection=True,
179
+ only_cross_attention=False,
180
+ upcast_attention=False,
181
+ resnet_time_scale_shift="default",
182
+ use_image_embedding=False,
183
+ unet_params=None,
184
+ ):
185
+ if down_block_type == "DownBlock3D":
186
+ return DownBlock3D(
187
+ num_layers=num_layers,
188
+ in_channels=in_channels,
189
+ out_channels=out_channels,
190
+ temb_channels=temb_channels,
191
+ add_downsample=add_downsample,
192
+ resnet_eps=resnet_eps,
193
+ resnet_act_fn=resnet_act_fn,
194
+ resnet_groups=resnet_groups,
195
+ downsample_padding=downsample_padding,
196
+ resnet_time_scale_shift=resnet_time_scale_shift,
197
+ )
198
+ elif down_block_type == "CrossAttnDownBlock3D":
199
+ if cross_attention_dim is None:
200
+ raise ValueError(
201
+ "cross_attention_dim must be specified for CrossAttnDownBlock3D")
202
+ return CrossAttnDownBlock3D(
203
+ num_layers=num_layers,
204
+ in_channels=in_channels,
205
+ out_channels=out_channels,
206
+ temb_channels=temb_channels,
207
+ add_downsample=add_downsample,
208
+ resnet_eps=resnet_eps,
209
+ resnet_act_fn=resnet_act_fn,
210
+ resnet_groups=resnet_groups,
211
+ downsample_padding=downsample_padding,
212
+ cross_attention_dim=cross_attention_dim,
213
+ attn_num_head_channels=attn_num_head_channels,
214
+ dual_cross_attention=dual_cross_attention,
215
+ use_linear_projection=use_linear_projection,
216
+ only_cross_attention=only_cross_attention,
217
+ upcast_attention=upcast_attention,
218
+ resnet_time_scale_shift=resnet_time_scale_shift,
219
+ use_image_embedding=use_image_embedding,
220
+ unet_params=unet_params,
221
+ )
222
+ raise ValueError(f"{down_block_type} does not exist.")
223
+
224
+
225
+ def get_up_block(
226
+ up_block_type,
227
+ num_layers,
228
+ in_channels,
229
+ out_channels,
230
+ prev_output_channel,
231
+ temb_channels,
232
+ add_upsample,
233
+ resnet_eps,
234
+ resnet_act_fn,
235
+ attn_num_head_channels,
236
+ resnet_groups=None,
237
+ cross_attention_dim=None,
238
+ dual_cross_attention=False,
239
+ use_linear_projection=True,
240
+ only_cross_attention=False,
241
+ upcast_attention=False,
242
+ resnet_time_scale_shift="default",
243
+ use_image_embedding=False,
244
+ unet_params=None,
245
+ ):
246
+ if up_block_type == "UpBlock3D":
247
+ return UpBlock3D(
248
+ num_layers=num_layers,
249
+ in_channels=in_channels,
250
+ out_channels=out_channels,
251
+ prev_output_channel=prev_output_channel,
252
+ temb_channels=temb_channels,
253
+ add_upsample=add_upsample,
254
+ resnet_eps=resnet_eps,
255
+ resnet_act_fn=resnet_act_fn,
256
+ resnet_groups=resnet_groups,
257
+ resnet_time_scale_shift=resnet_time_scale_shift,
258
+ )
259
+ elif up_block_type == "CrossAttnUpBlock3D":
260
+ if cross_attention_dim is None:
261
+ raise ValueError(
262
+ "cross_attention_dim must be specified for CrossAttnUpBlock3D")
263
+ return CrossAttnUpBlock3D(
264
+ num_layers=num_layers,
265
+ in_channels=in_channels,
266
+ out_channels=out_channels,
267
+ prev_output_channel=prev_output_channel,
268
+ temb_channels=temb_channels,
269
+ add_upsample=add_upsample,
270
+ resnet_eps=resnet_eps,
271
+ resnet_act_fn=resnet_act_fn,
272
+ resnet_groups=resnet_groups,
273
+ cross_attention_dim=cross_attention_dim,
274
+ attn_num_head_channels=attn_num_head_channels,
275
+ dual_cross_attention=dual_cross_attention,
276
+ use_linear_projection=use_linear_projection,
277
+ only_cross_attention=only_cross_attention,
278
+ upcast_attention=upcast_attention,
279
+ resnet_time_scale_shift=resnet_time_scale_shift,
280
+ use_image_embedding=use_image_embedding,
281
+ unet_params=unet_params,
282
+ )
283
+ raise ValueError(f"{up_block_type} does not exist.")
284
+
285
+
286
+ class UNetMidBlock3DCrossAttn(nn.Module):
287
+ def __init__(
288
+ self,
289
+ in_channels: int,
290
+ temb_channels: int,
291
+ dropout: float = 0.0,
292
+ num_layers: int = 1,
293
+ resnet_eps: float = 1e-6,
294
+ resnet_time_scale_shift: str = "default",
295
+ resnet_act_fn: str = "swish",
296
+ resnet_groups: int = 32,
297
+ resnet_pre_norm: bool = True,
298
+ attn_num_head_channels=1,
299
+ output_scale_factor=1.0,
300
+ cross_attention_dim=1280,
301
+ dual_cross_attention=False,
302
+ use_linear_projection=True,
303
+ upcast_attention=False,
304
+ use_image_embedding=False,
305
+ unet_params=None,
306
+ ):
307
+ super().__init__()
308
+ self.gradient_checkpointing = False
309
+ self.has_cross_attention = True
310
+ self.attn_num_head_channels = attn_num_head_channels
311
+ resnet_groups = resnet_groups if resnet_groups is not None else min(
312
+ in_channels // 4, 32)
313
+
314
+ # there is always at least one resnet
315
+ resnets = [
316
+ ResnetBlock2D(
317
+ in_channels=in_channels,
318
+ out_channels=in_channels,
319
+ temb_channels=temb_channels,
320
+ eps=resnet_eps,
321
+ groups=resnet_groups,
322
+ dropout=dropout,
323
+ time_embedding_norm=resnet_time_scale_shift,
324
+ non_linearity=resnet_act_fn,
325
+ output_scale_factor=output_scale_factor,
326
+ pre_norm=resnet_pre_norm,
327
+ )
328
+ ]
329
+ temp_convs = [
330
+ TemporalConvLayer(
331
+ in_channels,
332
+ in_channels,
333
+ dropout=0.1
334
+ )
335
+ ]
336
+ attentions = []
337
+ temp_attentions = []
338
+
339
+ for _ in range(num_layers):
340
+ attentions.append(
341
+ Transformer2DModel(
342
+ in_channels // attn_num_head_channels,
343
+ attn_num_head_channels,
344
+ in_channels=in_channels,
345
+ num_layers=1,
346
+ cross_attention_dim=cross_attention_dim,
347
+ norm_num_groups=resnet_groups,
348
+ use_linear_projection=use_linear_projection,
349
+ upcast_attention=upcast_attention,
350
+ use_image_embedding=use_image_embedding,
351
+ unet_params=unet_params,
352
+ )
353
+ )
354
+ temp_attentions.append(
355
+ TransformerTemporalModel(
356
+ in_channels // attn_num_head_channels,
357
+ attn_num_head_channels,
358
+ in_channels=in_channels,
359
+ num_layers=1,
360
+ cross_attention_dim=cross_attention_dim,
361
+ norm_num_groups=resnet_groups,
362
+ )
363
+ )
364
+ resnets.append(
365
+ ResnetBlock2D(
366
+ in_channels=in_channels,
367
+ out_channels=in_channels,
368
+ temb_channels=temb_channels,
369
+ eps=resnet_eps,
370
+ groups=resnet_groups,
371
+ dropout=dropout,
372
+ time_embedding_norm=resnet_time_scale_shift,
373
+ non_linearity=resnet_act_fn,
374
+ output_scale_factor=output_scale_factor,
375
+ pre_norm=resnet_pre_norm,
376
+ )
377
+ )
378
+ temp_convs.append(
379
+ TemporalConvLayer(
380
+ in_channels,
381
+ in_channels,
382
+ dropout=0.1
383
+ )
384
+ )
385
+
386
+ self.resnets = nn.ModuleList(resnets)
387
+ self.temp_convs = nn.ModuleList(temp_convs)
388
+ self.attentions = nn.ModuleList(attentions)
389
+ self.temp_attentions = nn.ModuleList(temp_attentions)
390
+
391
+ def forward(
392
+ self,
393
+ hidden_states,
394
+ temb=None,
395
+ encoder_hidden_states=None,
396
+ attention_mask=None,
397
+ num_frames=1,
398
+ cross_attention_kwargs=None,
399
+ ):
400
+ if self.gradient_checkpointing:
401
+ hidden_states = up_down_g_c(
402
+ self.resnets[0],
403
+ self.temp_convs[0],
404
+ hidden_states,
405
+ temb,
406
+ num_frames
407
+ )
408
+ else:
409
+ hidden_states = self.resnets[0](hidden_states, temb)
410
+ hidden_states = self.temp_convs[0](
411
+ hidden_states, num_frames=num_frames)
412
+
413
+ for attn, temp_attn, resnet, temp_conv in zip(
414
+ self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
415
+ ):
416
+ if self.gradient_checkpointing:
417
+ hidden_states = cross_attn_g_c(
418
+ attn,
419
+ temp_attn,
420
+ resnet,
421
+ temp_conv,
422
+ hidden_states,
423
+ encoder_hidden_states,
424
+ cross_attention_kwargs,
425
+ temb,
426
+ num_frames
427
+ )
428
+ else:
429
+ hidden_states = attn(
430
+ hidden_states,
431
+ encoder_hidden_states=encoder_hidden_states,
432
+ cross_attention_kwargs=cross_attention_kwargs,
433
+ ).sample
434
+
435
+ if num_frames > 1:
436
+ hidden_states = temp_attn(
437
+ hidden_states, num_frames=num_frames, attention_mask=attention_mask,
438
+
439
+ ).sample
440
+
441
+ hidden_states = resnet(hidden_states, temb)
442
+
443
+ if num_frames > 1:
444
+ hidden_states = temp_conv(
445
+ hidden_states, num_frames=num_frames)
446
+
447
+ return hidden_states
448
+
449
+
450
+ class CrossAttnDownBlock3D(nn.Module):
451
+ def __init__(
452
+ self,
453
+ in_channels: int,
454
+ out_channels: int,
455
+ temb_channels: int,
456
+ dropout: float = 0.0,
457
+ num_layers: int = 1,
458
+ resnet_eps: float = 1e-6,
459
+ resnet_time_scale_shift: str = "default",
460
+ resnet_act_fn: str = "swish",
461
+ resnet_groups: int = 32,
462
+ resnet_pre_norm: bool = True,
463
+ attn_num_head_channels=1,
464
+ cross_attention_dim=1280,
465
+ output_scale_factor=1.0,
466
+ downsample_padding=1,
467
+ add_downsample=True,
468
+ dual_cross_attention=False,
469
+ use_linear_projection=False,
470
+ only_cross_attention=False,
471
+ upcast_attention=False,
472
+ use_image_embedding=False,
473
+ unet_params=None,
474
+ ):
475
+ super().__init__()
476
+ resnets = []
477
+ attentions = []
478
+ temp_attentions = []
479
+ temp_convs = []
480
+
481
+ self.gradient_checkpointing = False
482
+ self.has_cross_attention = True
483
+ self.attn_num_head_channels = attn_num_head_channels
484
+
485
+ for i in range(num_layers):
486
+ in_channels = in_channels if i == 0 else out_channels
487
+ resnets.append(
488
+ ResnetBlock2D(
489
+ in_channels=in_channels,
490
+ out_channels=out_channels,
491
+ temb_channels=temb_channels,
492
+ eps=resnet_eps,
493
+ groups=resnet_groups,
494
+ dropout=dropout,
495
+ time_embedding_norm=resnet_time_scale_shift,
496
+ non_linearity=resnet_act_fn,
497
+ output_scale_factor=output_scale_factor,
498
+ pre_norm=resnet_pre_norm,
499
+ )
500
+ )
501
+ temp_convs.append(
502
+ TemporalConvLayer(
503
+ out_channels,
504
+ out_channels,
505
+ dropout=0.1
506
+ )
507
+ )
508
+ attentions.append(
509
+ Transformer2DModel(
510
+ out_channels // attn_num_head_channels,
511
+ attn_num_head_channels,
512
+ in_channels=out_channels,
513
+ num_layers=1,
514
+ cross_attention_dim=cross_attention_dim,
515
+ norm_num_groups=resnet_groups,
516
+ use_linear_projection=use_linear_projection,
517
+ only_cross_attention=only_cross_attention,
518
+ upcast_attention=upcast_attention,
519
+ use_image_embedding=use_image_embedding,
520
+ unet_params=unet_params,
521
+ )
522
+ )
523
+ temp_attentions.append(
524
+ TransformerTemporalModel(
525
+ out_channels // attn_num_head_channels,
526
+ attn_num_head_channels,
527
+ in_channels=out_channels,
528
+ num_layers=1,
529
+ cross_attention_dim=cross_attention_dim,
530
+ norm_num_groups=resnet_groups,
531
+ )
532
+ )
533
+ self.resnets = nn.ModuleList(resnets)
534
+ self.temp_convs = nn.ModuleList(temp_convs)
535
+ self.attentions = nn.ModuleList(attentions)
536
+ self.temp_attentions = nn.ModuleList(temp_attentions)
537
+
538
+ if add_downsample:
539
+ self.downsamplers = nn.ModuleList(
540
+ [
541
+ Downsample2D(
542
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
543
+ )
544
+ ]
545
+ )
546
+ else:
547
+ self.downsamplers = None
548
+
549
+ def forward(
550
+ self,
551
+ hidden_states,
552
+ temb=None,
553
+ encoder_hidden_states=None,
554
+ attention_mask=None,
555
+ num_frames=1,
556
+ cross_attention_kwargs=None,
557
+ ):
558
+ # TODO(Patrick, William) - attention mask is not used
559
+ output_states = ()
560
+ layer_idx = 0
561
+
562
+ for resnet, temp_conv, attn, temp_attn in zip(
563
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
564
+ ):
565
+ if self.gradient_checkpointing:
566
+ hidden_states = cross_attn_g_c(
567
+ attn,
568
+ temp_attn,
569
+ resnet,
570
+ temp_conv,
571
+ hidden_states,
572
+ encoder_hidden_states,
573
+ cross_attention_kwargs,
574
+ temb,
575
+ num_frames,
576
+ inverse_temp=True
577
+ )
578
+ else:
579
+ hidden_states = resnet(hidden_states, temb)
580
+ if num_frames > 1:
581
+ hidden_states = temp_conv(
582
+ hidden_states, num_frames=num_frames)
583
+
584
+ hidden_states = attn(
585
+ hidden_states,
586
+ encoder_hidden_states=encoder_hidden_states,
587
+ cross_attention_kwargs=cross_attention_kwargs,
588
+ ).sample
589
+ if num_frames > 1:
590
+ hidden_states = temp_attn(
591
+ hidden_states, num_frames=num_frames, attention_mask=attention_mask,
592
+ ).sample
593
+ layer_idx += 1
594
+ output_states += (hidden_states,)
595
+
596
+ if self.downsamplers is not None:
597
+ for downsampler in self.downsamplers:
598
+ hidden_states = downsampler(hidden_states)
599
+
600
+ output_states += (hidden_states,)
601
+
602
+ return hidden_states, output_states
603
+
604
+
605
+ class DownBlock3D(nn.Module):
606
+ def __init__(
607
+ self,
608
+ in_channels: int,
609
+ out_channels: int,
610
+ temb_channels: int,
611
+ dropout: float = 0.0,
612
+ num_layers: int = 1,
613
+ resnet_eps: float = 1e-6,
614
+ resnet_time_scale_shift: str = "default",
615
+ resnet_act_fn: str = "swish",
616
+ resnet_groups: int = 32,
617
+ resnet_pre_norm: bool = True,
618
+ output_scale_factor=1.0,
619
+ add_downsample=True,
620
+ downsample_padding=1,
621
+ ):
622
+ super().__init__()
623
+ resnets = []
624
+ temp_convs = []
625
+
626
+ self.gradient_checkpointing = False
627
+ for i in range(num_layers):
628
+ in_channels = in_channels if i == 0 else out_channels
629
+ resnets.append(
630
+ ResnetBlock2D(
631
+ in_channels=in_channels,
632
+ out_channels=out_channels,
633
+ temb_channels=temb_channels,
634
+ eps=resnet_eps,
635
+ groups=resnet_groups,
636
+ dropout=dropout,
637
+ time_embedding_norm=resnet_time_scale_shift,
638
+ non_linearity=resnet_act_fn,
639
+ output_scale_factor=output_scale_factor,
640
+ pre_norm=resnet_pre_norm,
641
+ )
642
+ )
643
+ temp_convs.append(
644
+ TemporalConvLayer(
645
+ out_channels,
646
+ out_channels,
647
+ dropout=0.1
648
+ )
649
+ )
650
+
651
+ self.resnets = nn.ModuleList(resnets)
652
+ self.temp_convs = nn.ModuleList(temp_convs)
653
+
654
+ if add_downsample:
655
+ self.downsamplers = nn.ModuleList(
656
+ [
657
+ Downsample2D(
658
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
659
+ )
660
+ ]
661
+ )
662
+ else:
663
+ self.downsamplers = None
664
+
665
+ def forward(self, hidden_states, temb=None, num_frames=1):
666
+ output_states = ()
667
+
668
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
669
+ if self.gradient_checkpointing:
670
+ hidden_states = up_down_g_c(
671
+ resnet, temp_conv, hidden_states, temb, num_frames)
672
+ else:
673
+ hidden_states = resnet(hidden_states, temb)
674
+
675
+ if num_frames > 1:
676
+ hidden_states = temp_conv(
677
+ hidden_states, num_frames=num_frames)
678
+
679
+ output_states += (hidden_states,)
680
+
681
+ if self.downsamplers is not None:
682
+ for downsampler in self.downsamplers:
683
+ hidden_states = downsampler(hidden_states)
684
+
685
+ output_states += (hidden_states,)
686
+
687
+ return hidden_states, output_states
688
+
689
+
690
+ class CrossAttnUpBlock3D(nn.Module):
691
+ def __init__(
692
+ self,
693
+ in_channels: int,
694
+ out_channels: int,
695
+ prev_output_channel: int,
696
+ temb_channels: int,
697
+ dropout: float = 0.0,
698
+ num_layers: int = 1,
699
+ resnet_eps: float = 1e-6,
700
+ resnet_time_scale_shift: str = "default",
701
+ resnet_act_fn: str = "swish",
702
+ resnet_groups: int = 32,
703
+ resnet_pre_norm: bool = True,
704
+ attn_num_head_channels=1,
705
+ cross_attention_dim=1280,
706
+ output_scale_factor=1.0,
707
+ add_upsample=True,
708
+ dual_cross_attention=False,
709
+ use_linear_projection=False,
710
+ only_cross_attention=False,
711
+ upcast_attention=False,
712
+ use_image_embedding=False,
713
+ unet_params=None,
714
+ ):
715
+ super().__init__()
716
+ resnets = []
717
+ temp_convs = []
718
+ attentions = []
719
+ temp_attentions = []
720
+
721
+ self.gradient_checkpointing = False
722
+ self.has_cross_attention = True
723
+ self.attn_num_head_channels = attn_num_head_channels
724
+
725
+ for i in range(num_layers):
726
+ res_skip_channels = in_channels if (
727
+ i == num_layers - 1) else out_channels
728
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
729
+
730
+ resnets.append(
731
+ ResnetBlock2D(
732
+ in_channels=resnet_in_channels + res_skip_channels,
733
+ out_channels=out_channels,
734
+ temb_channels=temb_channels,
735
+ eps=resnet_eps,
736
+ groups=resnet_groups,
737
+ dropout=dropout,
738
+ time_embedding_norm=resnet_time_scale_shift,
739
+ non_linearity=resnet_act_fn,
740
+ output_scale_factor=output_scale_factor,
741
+ pre_norm=resnet_pre_norm,
742
+ )
743
+ )
744
+ temp_convs.append(
745
+ TemporalConvLayer(
746
+ out_channels,
747
+ out_channels,
748
+ dropout=0.1
749
+ )
750
+ )
751
+ attentions.append(
752
+ Transformer2DModel(
753
+ out_channels // attn_num_head_channels,
754
+ attn_num_head_channels,
755
+ in_channels=out_channels,
756
+ num_layers=1,
757
+ cross_attention_dim=cross_attention_dim,
758
+ norm_num_groups=resnet_groups,
759
+ use_linear_projection=use_linear_projection,
760
+ only_cross_attention=only_cross_attention,
761
+ upcast_attention=upcast_attention,
762
+ use_image_embedding=use_image_embedding,
763
+ unet_params=unet_params,
764
+ )
765
+ )
766
+ temp_attentions.append(
767
+ TransformerTemporalModel(
768
+ out_channels // attn_num_head_channels,
769
+ attn_num_head_channels,
770
+ in_channels=out_channels,
771
+ num_layers=1,
772
+ cross_attention_dim=cross_attention_dim,
773
+ norm_num_groups=resnet_groups,
774
+ )
775
+ )
776
+ self.resnets = nn.ModuleList(resnets)
777
+ self.temp_convs = nn.ModuleList(temp_convs)
778
+ self.attentions = nn.ModuleList(attentions)
779
+ self.temp_attentions = nn.ModuleList(temp_attentions)
780
+
781
+ if add_upsample:
782
+ self.upsamplers = nn.ModuleList(
783
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
784
+ else:
785
+ self.upsamplers = None
786
+
787
+ def forward(
788
+ self,
789
+ hidden_states,
790
+ res_hidden_states_tuple,
791
+ temb=None,
792
+ encoder_hidden_states=None,
793
+ upsample_size=None,
794
+ attention_mask=None,
795
+ num_frames=1,
796
+ cross_attention_kwargs=None,
797
+ ):
798
+ # TODO(Patrick, William) - attention mask is not used
799
+ output_states = ()
800
+ for resnet, temp_conv, attn, temp_attn in zip(
801
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
802
+ ):
803
+ # pop res hidden states
804
+ res_hidden_states = res_hidden_states_tuple[-1]
805
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
806
+ hidden_states = torch.cat(
807
+ [hidden_states, res_hidden_states], dim=1)
808
+
809
+ if self.gradient_checkpointing:
810
+ hidden_states = cross_attn_g_c(
811
+ attn,
812
+ temp_attn,
813
+ resnet,
814
+ temp_conv,
815
+ hidden_states,
816
+ encoder_hidden_states,
817
+ cross_attention_kwargs,
818
+ temb,
819
+ num_frames,
820
+ inverse_temp=True
821
+ )
822
+ else:
823
+ hidden_states = resnet(hidden_states, temb)
824
+
825
+ if num_frames > 1:
826
+ hidden_states = temp_conv(
827
+ hidden_states, num_frames=num_frames)
828
+
829
+ hidden_states = attn(
830
+ hidden_states,
831
+ encoder_hidden_states=encoder_hidden_states,
832
+ cross_attention_kwargs=cross_attention_kwargs,
833
+ ).sample
834
+
835
+ if num_frames > 1:
836
+ hidden_states = temp_attn(
837
+ hidden_states, num_frames=num_frames,
838
+ attention_mask=attention_mask,
839
+ ).sample
840
+ output_states += (hidden_states,)
841
+ if self.upsamplers is not None:
842
+ for upsampler in self.upsamplers:
843
+ hidden_states = upsampler(hidden_states, upsample_size)
844
+ output_states += (hidden_states,)
845
+
846
+ return hidden_states, output_states
847
+
848
+
849
+ class UpBlock3D(nn.Module):
850
+ def __init__(
851
+ self,
852
+ in_channels: int,
853
+ prev_output_channel: int,
854
+ out_channels: int,
855
+ temb_channels: int,
856
+ dropout: float = 0.0,
857
+ num_layers: int = 1,
858
+ resnet_eps: float = 1e-6,
859
+ resnet_time_scale_shift: str = "default",
860
+ resnet_act_fn: str = "swish",
861
+ resnet_groups: int = 32,
862
+ resnet_pre_norm: bool = True,
863
+ output_scale_factor=1.0,
864
+ add_upsample=True,
865
+ ):
866
+ super().__init__()
867
+ resnets = []
868
+ temp_convs = []
869
+ self.gradient_checkpointing = False
870
+ for i in range(num_layers):
871
+ res_skip_channels = in_channels if (
872
+ i == num_layers - 1) else out_channels
873
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
874
+
875
+ resnets.append(
876
+ ResnetBlock2D(
877
+ in_channels=resnet_in_channels + res_skip_channels,
878
+ out_channels=out_channels,
879
+ temb_channels=temb_channels,
880
+ eps=resnet_eps,
881
+ groups=resnet_groups,
882
+ dropout=dropout,
883
+ time_embedding_norm=resnet_time_scale_shift,
884
+ non_linearity=resnet_act_fn,
885
+ output_scale_factor=output_scale_factor,
886
+ pre_norm=resnet_pre_norm,
887
+ )
888
+ )
889
+ temp_convs.append(
890
+ TemporalConvLayer(
891
+ out_channels,
892
+ out_channels,
893
+ dropout=0.1
894
+ )
895
+ )
896
+
897
+ self.resnets = nn.ModuleList(resnets)
898
+ self.temp_convs = nn.ModuleList(temp_convs)
899
+
900
+ if add_upsample:
901
+ self.upsamplers = nn.ModuleList(
902
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
903
+ else:
904
+ self.upsamplers = None
905
+
906
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
907
+ output_states = ()
908
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
909
+ # pop res hidden states
910
+ res_hidden_states = res_hidden_states_tuple[-1]
911
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
912
+ hidden_states = torch.cat(
913
+ [hidden_states, res_hidden_states], dim=1)
914
+
915
+ if self.gradient_checkpointing:
916
+ hidden_states = up_down_g_c(
917
+ resnet, temp_conv, hidden_states, temb, num_frames)
918
+ else:
919
+ hidden_states = resnet(hidden_states, temb)
920
+
921
+ if num_frames > 1:
922
+ hidden_states = temp_conv(
923
+ hidden_states, num_frames=num_frames)
924
+ output_states += (hidden_states,)
925
+ if self.upsamplers is not None:
926
+ for upsampler in self.upsamplers:
927
+ hidden_states = upsampler(hidden_states, upsample_size)
928
+ output_states += (hidden_states,)
929
+
930
+ return hidden_states, output_states
t2v_enhanced/model/diffusers_conditional/models/controlnet/unet_3d_condition.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
2
+ # Copyright 2023 The ModelScope Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ # from diffusers.models.transformer_temporal import TransformerTemporalModel
27
+ from t2v_enhanced.model.diffusers_conditional.models.controlnet.transformer_temporal import TransformerTemporalModel
28
+ from .unet_3d_blocks import (
29
+ CrossAttnDownBlock3D,
30
+ CrossAttnUpBlock3D,
31
+ DownBlock3D,
32
+ UNetMidBlock3DCrossAttn,
33
+ UpBlock3D,
34
+ get_down_block,
35
+ get_up_block,
36
+ transformer_g_c
37
+ )
38
+ from t2v_enhanced.model.diffusers_conditional.models.controlnet.conditioning import ConditionalModel
39
+ from einops import rearrange
40
+ from t2v_enhanced.model.layers.conv_channel_extension import Conv2D_ExtendedChannels
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+
44
+ @dataclass
45
+ class UNet3DConditionOutput(BaseOutput):
46
+ """
47
+ Args:
48
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
49
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
50
+ """
51
+
52
+ sample: torch.FloatTensor
53
+
54
+
55
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
56
+ r"""
57
+ UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
58
+ and returns sample shaped output.
59
+
60
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
61
+ implements for all the models (such as downloading or saving, etc.)
62
+
63
+ Parameters:
64
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
65
+ Height and width of input/output sample.
66
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
67
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
68
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
69
+ The tuple of downsample blocks to use.
70
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
71
+ The tuple of upsample blocks to use.
72
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
73
+ The tuple of output channels for each block.
74
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
75
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
76
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
77
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
78
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
79
+ If `None`, it will skip the normalization and activation layers in post-processing
80
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
81
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
82
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
83
+ """
84
+
85
+ _supports_gradient_checkpointing = True
86
+
87
+ @register_to_config
88
+ def __init__(
89
+ self,
90
+ sample_size: Optional[int] = None,
91
+ in_channels: int = 4,
92
+ out_channels: int = 4,
93
+ down_block_types: Tuple[str] = (
94
+ "CrossAttnDownBlock3D",
95
+ "CrossAttnDownBlock3D",
96
+ "CrossAttnDownBlock3D",
97
+ "DownBlock3D",
98
+ ),
99
+ up_block_types: Tuple[str] = (
100
+ "UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
101
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
102
+ layers_per_block: int = 2,
103
+ downsample_padding: int = 1,
104
+ mid_block_scale_factor: float = 1,
105
+ act_fn: str = "silu",
106
+ norm_num_groups: Optional[int] = 32,
107
+ norm_eps: float = 1e-5,
108
+ cross_attention_dim: int = 1024,
109
+ attention_head_dim: Union[int, Tuple[int]] = 64,
110
+ merging_mode: str = "addition",
111
+ use_image_embedding: bool = False,
112
+ use_fps_conditioning: bool = False,
113
+ unet_params=None,
114
+ ):
115
+ super().__init__()
116
+ channel_expansion = unet_params.use_of
117
+ self.concat = unet_params.concat
118
+ self.use_image_tokens = unet_params.use_image_tokens_main
119
+ self.image_encoder_name = type(unet_params.image_encoder).__name__
120
+ self.use_image_embedding = use_image_embedding
121
+ self.sample_size = sample_size
122
+ self.gradient_checkpointing = False
123
+ # Check inputs
124
+ if len(down_block_types) != len(up_block_types):
125
+ raise ValueError(
126
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
127
+ )
128
+
129
+ if len(block_out_channels) != len(down_block_types):
130
+ raise ValueError(
131
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
132
+ )
133
+
134
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
135
+ raise ValueError(
136
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
137
+ )
138
+
139
+ # input
140
+ conv_in_kernel = 3
141
+ conv_out_kernel = 3
142
+ conv_in_padding = (conv_in_kernel - 1) // 2
143
+ '''
144
+ self.conv_in = nn.Conv2d(
145
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
146
+ )
147
+ '''
148
+ self.conv_in = Conv2D_ExtendedChannels(
149
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding, in_channel_extension=5 if self.concat else 0,
150
+ )
151
+
152
+ # time
153
+ time_embed_dim = block_out_channels[0] * 4
154
+ self.time_proj = Timesteps(block_out_channels[0], True, 0)
155
+ timestep_input_dim = block_out_channels[0]
156
+
157
+ self.time_embedding = TimestepEmbedding(
158
+ timestep_input_dim,
159
+ time_embed_dim,
160
+ act_fn=act_fn,
161
+ )
162
+ self.use_fps_conditioning = use_fps_conditioning
163
+ if use_fps_conditioning:
164
+ fps_embed_dim = block_out_channels[0] * 4
165
+ fps_input_dim = block_out_channels[0]
166
+ self.fps_embedding = TimestepEmbedding(
167
+ fps_input_dim, fps_embed_dim, act_fn=act_fn)
168
+ self.fps_proj = Timesteps(block_out_channels[0], True, 0)
169
+
170
+ self.transformer_in = TransformerTemporalModel(
171
+ num_attention_heads=8,
172
+ attention_head_dim=attention_head_dim,
173
+ in_channels=block_out_channels[0],
174
+ num_layers=1,
175
+ )
176
+
177
+ # class embedding
178
+ self.down_blocks = nn.ModuleList([])
179
+ self.up_blocks = nn.ModuleList([])
180
+
181
+ self.merging_mode = merging_mode
182
+ print("self.merging_mode", self.merging_mode)
183
+ if self.merging_mode.startswith("attention"):
184
+ self.cross_attention_merger_down_blocks = nn.ModuleList([])
185
+ self.cross_attention_merger_mid_block = nn.ModuleList([])
186
+ if isinstance(attention_head_dim, int):
187
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
188
+
189
+ # down
190
+ output_channel = block_out_channels[0]
191
+ for i, down_block_type in enumerate(down_block_types):
192
+ input_channel = output_channel
193
+ output_channel = block_out_channels[i]
194
+ is_final_block = i == len(block_out_channels) - 1
195
+
196
+ down_block = get_down_block(
197
+ down_block_type,
198
+ num_layers=layers_per_block,
199
+ in_channels=input_channel,
200
+ out_channels=output_channel,
201
+ temb_channels=time_embed_dim,
202
+ add_downsample=not is_final_block,
203
+ resnet_eps=norm_eps,
204
+ resnet_act_fn=act_fn,
205
+ resnet_groups=norm_num_groups,
206
+ cross_attention_dim=cross_attention_dim,
207
+ attn_num_head_channels=attention_head_dim[i],
208
+ downsample_padding=downsample_padding,
209
+ dual_cross_attention=False,
210
+ use_image_embedding=use_image_embedding,
211
+ unet_params=unet_params,
212
+ )
213
+ self.down_blocks.append(down_block)
214
+
215
+ if self.merging_mode.startswith("attention"):
216
+ for idx in range(3):
217
+ self.cross_attention_merger_down_blocks.append(ConditionalModel(
218
+ input_channels=input_channel if idx == 0 else output_channel, conditional_model=self.merging_mode.split("attention_")[1]))
219
+
220
+ # mid
221
+ self.mid_block = UNetMidBlock3DCrossAttn(
222
+ in_channels=block_out_channels[-1],
223
+ temb_channels=time_embed_dim,
224
+ resnet_eps=norm_eps,
225
+ resnet_act_fn=act_fn,
226
+ output_scale_factor=mid_block_scale_factor,
227
+ cross_attention_dim=cross_attention_dim,
228
+ attn_num_head_channels=attention_head_dim[-1],
229
+ resnet_groups=norm_num_groups,
230
+ dual_cross_attention=False,
231
+ use_image_embedding=use_image_embedding,
232
+ unet_params=unet_params,
233
+ )
234
+ if self.merging_mode.startswith("attention"):
235
+ self.cross_attention_merger_mid_block = ConditionalModel(
236
+ input_channels=block_out_channels[-1], conditional_model=self.merging_mode.split("attention_")[1])
237
+ # count how many layers upsample the images
238
+ self.num_upsamplers = 0
239
+
240
+ # up
241
+ reversed_block_out_channels = list(reversed(block_out_channels))
242
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
243
+
244
+ output_channel = reversed_block_out_channels[0]
245
+ for i, up_block_type in enumerate(up_block_types):
246
+ is_final_block = i == len(block_out_channels) - 1
247
+
248
+ prev_output_channel = output_channel
249
+ output_channel = reversed_block_out_channels[i]
250
+ input_channel = reversed_block_out_channels[min(
251
+ i + 1, len(block_out_channels) - 1)]
252
+
253
+ # add upsample block for all BUT final layer
254
+ if not is_final_block:
255
+ add_upsample = True
256
+ self.num_upsamplers += 1
257
+ else:
258
+ add_upsample = False
259
+
260
+ up_block = get_up_block(
261
+ up_block_type,
262
+ num_layers=layers_per_block + 1,
263
+ in_channels=input_channel,
264
+ out_channels=output_channel,
265
+ prev_output_channel=prev_output_channel,
266
+ temb_channels=time_embed_dim,
267
+ add_upsample=add_upsample,
268
+ resnet_eps=norm_eps,
269
+ resnet_act_fn=act_fn,
270
+ resnet_groups=norm_num_groups,
271
+ cross_attention_dim=cross_attention_dim,
272
+ attn_num_head_channels=reversed_attention_head_dim[i],
273
+ dual_cross_attention=False,
274
+ use_image_embedding=use_image_embedding,
275
+ unet_params=unet_params,
276
+ )
277
+ self.up_blocks.append(up_block)
278
+ prev_output_channel = output_channel
279
+
280
+ # out
281
+ if norm_num_groups is not None:
282
+ self.conv_norm_out = nn.GroupNorm(
283
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
284
+ )
285
+ self.conv_act = nn.SiLU()
286
+ else:
287
+ self.conv_norm_out = None
288
+ self.conv_act = None
289
+
290
+ conv_out_padding = (conv_out_kernel - 1) // 2
291
+ '''
292
+ self.conv_out = nn.Conv2d(
293
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
294
+ )
295
+ '''
296
+ self.conv_out = Conv2D_ExtendedChannels(
297
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding, out_channel_extension=2 if channel_expansion else 0,
298
+ )
299
+
300
+ def set_attention_slice(self, slice_size):
301
+ r"""
302
+ Enable sliced attention computation.
303
+
304
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
305
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
306
+
307
+ Args:
308
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
309
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
310
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
311
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
312
+ must be a multiple of `slice_size`.
313
+ """
314
+ sliceable_head_dims = []
315
+
316
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
317
+ if hasattr(module, "set_attention_slice"):
318
+ sliceable_head_dims.append(module.sliceable_head_dim)
319
+
320
+ for child in module.children():
321
+ fn_recursive_retrieve_slicable_dims(child)
322
+
323
+ # retrieve number of attention layers
324
+ for module in self.children():
325
+ fn_recursive_retrieve_slicable_dims(module)
326
+
327
+ num_slicable_layers = len(sliceable_head_dims)
328
+
329
+ if slice_size == "auto":
330
+ # half the attention head size is usually a good trade-off between
331
+ # speed and memory
332
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
333
+ elif slice_size == "max":
334
+ # make smallest slice possible
335
+ slice_size = num_slicable_layers * [1]
336
+
337
+ slice_size = num_slicable_layers * \
338
+ [slice_size] if not isinstance(slice_size, list) else slice_size
339
+
340
+ if len(slice_size) != len(sliceable_head_dims):
341
+ raise ValueError(
342
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
343
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
344
+ )
345
+
346
+ for i in range(len(slice_size)):
347
+ size = slice_size[i]
348
+ dim = sliceable_head_dims[i]
349
+ if size is not None and size > dim:
350
+ raise ValueError(
351
+ f"size {size} has to be smaller or equal to {dim}.")
352
+
353
+ # Recursively walk through all the children.
354
+ # Any children which exposes the set_attention_slice method
355
+ # gets the message
356
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
357
+ if hasattr(module, "set_attention_slice"):
358
+ module.set_attention_slice(slice_size.pop())
359
+
360
+ for child in module.children():
361
+ fn_recursive_set_attention_slice(child, slice_size)
362
+
363
+ reversed_slice_size = list(reversed(slice_size))
364
+ for module in self.children():
365
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
366
+
367
+ def _set_gradient_checkpointing(self, value=False):
368
+ self.gradient_checkpointing = value
369
+ self.mid_block.gradient_checkpointing = value
370
+ for module in self.down_blocks + self.up_blocks:
371
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
372
+ module.gradient_checkpointing = value
373
+
374
+ def forward(
375
+ self,
376
+ sample: torch.FloatTensor,
377
+ timestep: Union[torch.Tensor, float, int],
378
+ encoder_hidden_states: torch.Tensor,
379
+ fps: Optional[torch.Tensor] = None,
380
+ class_labels: Optional[torch.Tensor] = None,
381
+ timestep_cond: Optional[torch.Tensor] = None,
382
+ attention_mask: Optional[torch.Tensor] = None,
383
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
384
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
385
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
386
+ return_dict: bool = True,
387
+ ) -> Union[UNet3DConditionOutput, Tuple]:
388
+ r"""
389
+ Args:
390
+ sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor
391
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
392
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
393
+ return_dict (`bool`, *optional*, defaults to `True`):
394
+ Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple.
395
+ cross_attention_kwargs (`dict`, *optional*):
396
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
397
+ `self.processor` in
398
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
399
+
400
+ Returns:
401
+ [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`:
402
+ [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
403
+ returning a tuple, the first element is the sample tensor.
404
+ """
405
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
406
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
407
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
408
+ # on the fly if necessary.
409
+ default_overall_up_factor = 2**self.num_upsamplers
410
+
411
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
412
+ forward_upsample_size = False
413
+ upsample_size = None
414
+
415
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
416
+ logger.info(
417
+ "Forward upsample size to force interpolation output size.")
418
+ forward_upsample_size = True
419
+
420
+ # prepare attention_mask
421
+ '''
422
+ if attention_mask is not None:
423
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
424
+ attention_mask = attention_mask.unsqueeze(1)
425
+ '''
426
+ debug = False
427
+ if self.use_fps_conditioning:
428
+
429
+ if torch.is_tensor(fps):
430
+ assert (fps > -1).all(), "FPS not set"
431
+ if len(fps.shape) == 0:
432
+ fps = fps[None].to(sample.device)
433
+ else:
434
+ assert (fps > -1), "FPS not set"
435
+ is_mps = sample.device.type == "mps"
436
+ if isinstance(fps, float):
437
+ dtype = torch.float32 if is_mps else torch.float64
438
+ else:
439
+ dtype = torch.int32 if is_mps else torch.int64
440
+ fps = torch.tensor([fps], dtype=dtype, device=sample.device)
441
+ fps = fps.expand(sample.shape[0])
442
+ fps_proj = self.fps_proj(fps)
443
+ fps_proj = fps_proj.to(dtype=self.dtype)
444
+ fps_emb = self.fps_embedding(fps_proj)
445
+ # 1. time
446
+ timesteps = timestep
447
+ if not torch.is_tensor(timesteps):
448
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
449
+ # This would be a good case for the `match` statement (Python 3.10+)
450
+ is_mps = sample.device.type == "mps"
451
+ if isinstance(timestep, float):
452
+ dtype = torch.float32 if is_mps else torch.float64
453
+ else:
454
+ dtype = torch.int32 if is_mps else torch.int64
455
+ timesteps = torch.tensor(
456
+ [timesteps], dtype=dtype, device=sample.device)
457
+ elif len(timesteps.shape) == 0:
458
+ timesteps = timesteps[None].to(sample.device)
459
+
460
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
461
+ num_frames = sample.shape[2]
462
+ timesteps = timesteps.expand(sample.shape[0])
463
+ batch_size = sample.shape[0]
464
+
465
+ t_emb = self.time_proj(timesteps)
466
+
467
+ # timesteps does not contain any weights and will always return f32 tensors
468
+ # but time_embedding might actually be running in fp16. so we need to cast here.
469
+ # there might be better ways to encapsulate this.
470
+ t_emb = t_emb.to(dtype=self.dtype)
471
+
472
+ emb = self.time_embedding(t_emb, timestep_cond)
473
+
474
+ emb = emb.repeat_interleave(repeats=num_frames, dim=0)
475
+ if self.use_fps_conditioning:
476
+ fps_emb = fps_emb.repeat_interleave(repeats=num_frames, dim=0)
477
+ emb = emb + fps_emb
478
+
479
+ if not self.use_image_tokens and encoder_hidden_states.shape[1] > 77:
480
+ encoder_hidden_states = encoder_hidden_states[:, :77]
481
+ # print(f"MAIN with tokens = {encoder_hidden_states.shape[1]}")
482
+ if encoder_hidden_states.shape[1] > 77:
483
+ # assert (
484
+ # encoder_hidden_states.shape[1]-77) % num_frames == 0, f"Encoder shape {encoder_hidden_states.shape}. Num frames = {num_frames}"
485
+ context_text, context_img = encoder_hidden_states[:,
486
+ :77, :], encoder_hidden_states[:, 77:, :]
487
+ context_text = context_text.repeat_interleave(
488
+ repeats=num_frames, dim=0)
489
+
490
+ if self.image_encoder_name == "FrozenOpenCLIPImageEmbedder":
491
+ context_img = context_img.repeat_interleave(
492
+ repeats=num_frames, dim=0)
493
+ else:
494
+ context_img = rearrange(
495
+ context_img, 'b (t l) c -> (b t) l c', t=num_frames)
496
+
497
+ encoder_hidden_states = torch.cat(
498
+ [context_text, context_img], dim=1)
499
+ else:
500
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(
501
+ repeats=num_frames, dim=0)
502
+
503
+ # 2. pre-process
504
+ sample = sample.permute(0, 2, 1, 3, 4).reshape(
505
+ (sample.shape[0] * num_frames, -1) + sample.shape[3:])
506
+ sample = self.conv_in(sample)
507
+
508
+ if num_frames > 1:
509
+ if self.gradient_checkpointing:
510
+ sample = transformer_g_c(
511
+ self.transformer_in, sample, num_frames)
512
+ else:
513
+ sample = self.transformer_in(
514
+ sample, num_frames=num_frames, attention_mask=attention_mask).sample
515
+
516
+ # 3. down
517
+ down_block_res_samples = (sample,)
518
+ for downsample_block in self.down_blocks:
519
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
520
+ sample, res_samples = downsample_block(
521
+ hidden_states=sample,
522
+ temb=emb,
523
+ encoder_hidden_states=encoder_hidden_states,
524
+ attention_mask=attention_mask,
525
+ num_frames=num_frames,
526
+ cross_attention_kwargs=cross_attention_kwargs,
527
+ )
528
+ else:
529
+ sample, res_samples = downsample_block(
530
+ hidden_states=sample, temb=emb, num_frames=num_frames)
531
+
532
+ down_block_res_samples += res_samples
533
+
534
+ if down_block_additional_residuals is not None:
535
+ new_down_block_res_samples = ()
536
+
537
+ if self.merging_mode == "addition":
538
+ for down_block_res_sample, down_block_additional_residual in zip(
539
+ down_block_res_samples, down_block_additional_residuals
540
+ ):
541
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
542
+ new_down_block_res_samples += (down_block_res_sample,)
543
+ elif self.merging_mode.startswith("attention"):
544
+ for down_block_res_sample, down_block_additional_residual, merger in zip(
545
+ down_block_res_samples, down_block_additional_residuals, self.cross_attention_merger_down_blocks
546
+ ):
547
+
548
+ down_block_res_sample = merger(
549
+ rearrange(down_block_res_sample, "(B F) C H W -> B F C H W", B=batch_size), rearrange(down_block_additional_residual, "(B F) C H W -> B F C H W", B=batch_size))
550
+ down_block_res_sample = rearrange(
551
+ down_block_res_sample, "B F C H W -> (B F) C H W")
552
+ new_down_block_res_samples += (down_block_res_sample,)
553
+ elif self.merging_mode == "overwrite":
554
+ for down_block_res_sample, down_block_additional_residual in zip(
555
+ down_block_res_samples, down_block_additional_residuals
556
+ ):
557
+ down_block_res_sample = down_block_additional_residual
558
+ new_down_block_res_samples += (down_block_res_sample,)
559
+ down_block_res_samples = new_down_block_res_samples
560
+
561
+ # 4. mid
562
+ if self.mid_block is not None:
563
+ sample = self.mid_block(
564
+ sample,
565
+ emb,
566
+ encoder_hidden_states=encoder_hidden_states,
567
+ attention_mask=attention_mask,
568
+ num_frames=num_frames,
569
+ cross_attention_kwargs=cross_attention_kwargs,
570
+ )
571
+
572
+ if mid_block_additional_residual is not None:
573
+ if self.merging_mode == "addition":
574
+ sample = sample + mid_block_additional_residual
575
+ elif self.merging_mode == "overwrite":
576
+ sample = sample + mid_block_additional_residual
577
+ elif self.merging_mode.startswith("attention"):
578
+ sample = self.cross_attention_merger_mid_block(
579
+ rearrange(sample, "(B F) C H W -> B F C H W", B=batch_size), rearrange(mid_block_additional_residual, "(B F) C H W -> B F C H W", B=batch_size))
580
+ sample = rearrange(sample, "B F C H W -> (B F) C H W")
581
+
582
+ if debug:
583
+ upblockout = (sample,)
584
+ # 5. up
585
+ # import pdb
586
+ # pdb.set_trace()
587
+ for i, upsample_block in enumerate(self.up_blocks):
588
+ is_final_block = i == len(self.up_blocks) - 1
589
+
590
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
591
+ down_block_res_samples = down_block_res_samples[: -len(
592
+ upsample_block.resnets)]
593
+
594
+ # if we have not reached the final block and need to forward the
595
+ # upsample size, we do it here
596
+ if not is_final_block and forward_upsample_size:
597
+ upsample_size = down_block_res_samples[-1].shape[2:]
598
+
599
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
600
+ sample, output_states = upsample_block(
601
+ hidden_states=sample,
602
+ temb=emb,
603
+ res_hidden_states_tuple=res_samples,
604
+ encoder_hidden_states=encoder_hidden_states,
605
+ upsample_size=upsample_size,
606
+ attention_mask=attention_mask,
607
+ num_frames=num_frames,
608
+ cross_attention_kwargs=cross_attention_kwargs,
609
+ )
610
+ else:
611
+ sample, output_states = upsample_block(
612
+ hidden_states=sample,
613
+ temb=emb,
614
+ res_hidden_states_tuple=res_samples,
615
+ upsample_size=upsample_size,
616
+ num_frames=num_frames,
617
+ )
618
+ if debug:
619
+ upblockout += output_states
620
+
621
+ # 6. post-process
622
+ if self.conv_norm_out:
623
+ sample = self.conv_norm_out(sample)
624
+ sample = self.conv_act(sample)
625
+
626
+ sample = self.conv_out(sample)
627
+
628
+ # reshape to (batch, channel, framerate, width, height)
629
+ sample = sample[None, :].reshape(
630
+ (-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
631
+
632
+ if not return_dict:
633
+ return (sample,)
634
+
635
+ return UNet3DConditionOutput(sample=sample)
t2v_enhanced/model/flags.py ADDED
@@ -0,0 +1 @@
 
 
1
+ TORCH_DISTRIBUTED_DEBUG = DETAIL
t2v_enhanced/model/layers/conv_channel_extension.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Union
4
+ from torch.nn.common_types import _size_2_t
5
+
6
+
7
+ class Conv2D_SubChannels(nn.Conv2d):
8
+ def __init__(self,
9
+ in_channels: int,
10
+ out_channels: int,
11
+ kernel_size: _size_2_t,
12
+ stride: _size_2_t = 1,
13
+ padding: Union[str, _size_2_t] = 0,
14
+ dilation: _size_2_t = 1,
15
+ groups: int = 1,
16
+ bias: bool = True,
17
+ padding_mode: str = 'zeros',
18
+ device=None,
19
+ dtype=None,
20
+ ) -> None:
21
+ super().__init__(in_channels, out_channels, kernel_size, stride,
22
+ padding, dilation, groups, bias, padding_mode, device, dtype)
23
+
24
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
25
+
26
+ if prefix+"weight" in state_dict and ((state_dict[prefix+"weight"].shape[0] > self.out_channels) or (state_dict[prefix+"weight"].shape[1] > self.in_channels)):
27
+ print(
28
+ f"Model checkpoint has too many channels. Excluding channels of convolution {prefix}.")
29
+ if self.bias is not None:
30
+ bias = state_dict[prefix+"bias"][:self.out_channels]
31
+ state_dict[prefix+"bias"] = bias
32
+ del bias
33
+
34
+ weight = state_dict[prefix+"weight"]
35
+ state_dict[prefix+"weight"] = weight[:self.out_channels,
36
+ :self.in_channels]
37
+ del weight
38
+
39
+ return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
40
+
41
+
42
+ class Conv2D_ExtendedChannels(nn.Conv2d):
43
+
44
+ def __init__(self,
45
+ in_channels: int,
46
+ out_channels: int,
47
+ kernel_size: _size_2_t,
48
+ stride: _size_2_t = 1,
49
+ padding: Union[str, _size_2_t] = 0,
50
+ dilation: _size_2_t = 1,
51
+ groups: int = 1,
52
+ bias: bool = True,
53
+ padding_mode: str = 'zeros',
54
+ device=None,
55
+ dtype=None,
56
+ in_channel_extension: int = 0,
57
+ out_channel_extension: int = 0,
58
+ ) -> None:
59
+ super().__init__(in_channels+in_channel_extension, out_channels+out_channel_extension, kernel_size, stride,
60
+ padding, dilation, groups, bias, padding_mode, device, dtype)
61
+
62
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
63
+ print(f"Call extend channel loader with {prefix}")
64
+ if prefix+"weight" in state_dict and (state_dict[prefix+"weight"].shape[0] < self.out_channels or state_dict[prefix+"weight"].shape[1] < self.in_channels):
65
+ print(
66
+ f"Model checkpoint has insufficient channels. Extending channels of convolution {prefix} by adding zeros.")
67
+ if self.bias is not None:
68
+ bias = state_dict[prefix+"bias"]
69
+ state_dict[prefix+"bias"] = torch.cat(
70
+ [bias, torch.zeros(self.out_channels-len(bias), dtype=bias.dtype, layout=bias.layout, device=bias.device)])
71
+ del bias
72
+
73
+ weight = state_dict[prefix+"weight"]
74
+ extended_weight = torch.zeros(self.out_channels, self.in_channels,
75
+ weight.shape[2], weight.shape[3], device=weight.device, dtype=weight.dtype, layout=weight.layout)
76
+ extended_weight[:weight.shape[0], :weight.shape[1]] = weight
77
+ state_dict[prefix+"weight"] = extended_weight
78
+ del extended_weight
79
+ del weight
80
+
81
+ return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ class MyModel(nn.Module):
86
+
87
+ def __init__(self, conv_type: str, c_in, c_out, in_extension, out_extension) -> None:
88
+ super().__init__()
89
+
90
+ if not conv_type == "normal":
91
+
92
+ self.conv1 = Conv2D_ExtendedChannels(
93
+ c_in, c_out, 3, padding=1, in_channel_extension=in_extension, out_channel_extension=out_extension, bias=True)
94
+
95
+ else:
96
+ self.conv1 = nn.Conv2d(c_in, c_out, 3, padding=1, bias=True)
97
+
98
+ def forward(self, x):
99
+ return self.conv1(x)
100
+
101
+ c_in = 9
102
+ c_out = 12
103
+ c_in_ext = 0
104
+ c_out_ext = 3
105
+ model = MyModel("normal", c_in, c_out, c_in_ext, c_out_ext)
106
+
107
+ input = torch.randn((4, c_in+c_in_ext, 128, 128))
108
+ out_normal = model(input[:, :c_in])
109
+ torch.save(model.state_dict(), "model_dummy.py")
110
+
111
+ model_2 = MyModel("special", c_in, c_out, c_in_ext, c_out_ext)
112
+ model_2.load_state_dict(torch.load("model_dummy.py"))
113
+ out_model_2 = model_2(input)
114
+ out_special = out_model_2[:, :c_out]
115
+
116
+ out_new = out_model_2[:, c_out:]
117
+ model_3 = MyModel("special", c_in, c_out, c_in_ext, c_out_ext)
118
+ model_3.load_state_dict(model_2.state_dict())
119
+ # out_model_2 = model_2(input)
120
+ # out_special = out_model_2[:, :c_out]
121
+
122
+ print(
123
+ f"Difference: Forward pass with extended convolution minus initial convolution: {(out_normal-out_special).abs().max()}")
124
+
125
+ print(f"Compared tensors with shape: ",
126
+ out_normal.shape, out_special.shape)
127
+
128
+ if model_3.conv1.bias is not None:
129
+ criterion = nn.MSELoss()
130
+
131
+ before_opt = model_3.conv1.bias.detach().clone()
132
+ target = torch.ones_like(out_model_2)
133
+ optimizer = torch.optim.SGD(
134
+ model_3.parameters(), lr=0.01, momentum=0.9)
135
+ for iter in range(10):
136
+ optimizer.zero_grad()
137
+ out = model_3(input)
138
+ loss = criterion(out, target)
139
+ loss.backward()
140
+ optimizer.step()
141
+ print(
142
+ f"Weights before and after are the same? {before_opt[c_out:].detach()} | {model_3.conv1.bias[c_out:].detach()} ")
143
+ print(model_3.conv1.bias, model_2.conv1.bias)
t2v_enhanced/model/pl_module_extension.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from copy import deepcopy
3
+ from einops import repeat
4
+ import math
5
+
6
+
7
+ class FrameConditioning():
8
+ def __init__(self,
9
+ add_frame_to_input: bool = False,
10
+ add_frame_to_layers: bool = False,
11
+ fill_zero: bool = False,
12
+ randomize_mask: bool = False,
13
+ concatenate_mask: bool = False,
14
+ injection_probability: float = 0.9,
15
+ ) -> None:
16
+ self.use = None
17
+ self.add_frame_to_input = add_frame_to_input
18
+ self.add_frame_to_layers = add_frame_to_layers
19
+ self.fill_zero = fill_zero
20
+ self.randomize_mask = randomize_mask
21
+ self.concatenate_mask = concatenate_mask
22
+ self.injection_probability = injection_probability
23
+ self.add_frame_to_input or self.add_frame_to_layers
24
+
25
+ assert not add_frame_to_layers or not add_frame_to_input
26
+
27
+ def set_random_mask(self, random_mask: bool):
28
+ frame_conditioning = deepcopy(self)
29
+ frame_conditioning.randomize_mask = random_mask
30
+ return frame_conditioning
31
+
32
+ @property
33
+ def use(self):
34
+ return self.add_frame_to_input or self.add_frame_to_layers
35
+
36
+ @use.setter
37
+ def use(self, value):
38
+ if value is not None:
39
+ raise NotImplementedError("Direct access not allowed")
40
+
41
+ def attach_video_frames(self, pl_module, z_0: torch.Tensor = None, batch: torch.Tensor = None, random_mask: bool = False):
42
+ assert self.fill_zero, "Not filling with zero not implemented yet"
43
+ n_frames_inference = self.inference_params.video_length
44
+ with torch.no_grad():
45
+ if z_0 is None:
46
+ assert batch is not None
47
+ z_0 = pl_module.encode_frame(batch)
48
+ assert n_frames_inference == z_0.shape[1], "For frame injection, the number of frames sampled by the dataloader must match the number of frames used for video generation"
49
+ shape = list(z_0.shape)
50
+
51
+ shape[1] = pl_module.inference_params.video_length
52
+ M = torch.zeros(shape, dtype=z_0.dtype,
53
+ device=pl_module.device) # [B F C W H]
54
+ bsz = z_0.shape[0]
55
+ if random_mask:
56
+ p_inject_frame = self.injection_probability
57
+ use_masks = torch.bernoulli(
58
+ torch.tensor(p_inject_frame).repeat(bsz)).long()
59
+ keep_frame_idx = torch.randint(
60
+ 0, n_frames_inference, (bsz,), device=pl_module.device).long()
61
+ else:
62
+ use_masks = torch.ones((bsz,), device=pl_module.device).long()
63
+ # keep only first frame
64
+ keep_frame_idx = 0 * use_masks
65
+ frame_idx = []
66
+
67
+ for batch_idx, (keep_frame, use_mask) in enumerate(zip(keep_frame_idx, use_masks)):
68
+ M[batch_idx, keep_frame] = use_mask
69
+ frame_idx.append(keep_frame if use_mask == 1 else -1)
70
+
71
+ x0 = z_0*M
72
+ if self.concatenate_mask:
73
+ # flatten mask
74
+ M = M[:, :, 0, None]
75
+ x0 = torch.cat([x0, M], dim=2)
76
+ if getattr(pl_module.opt_params.noise_decomposition, "use", False) and random_mask:
77
+ assert x0.shape[0] == 1, "randomizing frame injection with noise decomposition not implemented for batch size >1"
78
+ return x0, frame_idx
79
+
80
+
81
+ class NoiseDecomposition():
82
+
83
+ def __init__(self,
84
+ use: bool = False,
85
+ random_frame: bool = False,
86
+ lambda_f: float = 0.5,
87
+ use_base_model: bool = True,
88
+ ):
89
+ self.use = use
90
+ self.random_frame = random_frame
91
+ self.lambda_f = lambda_f
92
+ self.use_base_model = use_base_model
93
+
94
+ def get_loss(self, x0, unet_base, unet, noise_scheduler, frame_idx, z_t_base, timesteps, encoder_hidden_states, base_noise, z_t_residual, composed_noise):
95
+ if x0 is not None:
96
+ # x0.shape = [B,F,C,W,H], if extrapolation_params.fill_zero=true, only one frame per batch non-zero
97
+ assert not self.random_frame
98
+
99
+ # TODO add x0 injection
100
+ x0_base = []
101
+ for batch_idx, frame in enumerate(frame_idx):
102
+ x0_base.append(x0[batch_idx, frame, None, None])
103
+
104
+ x0_base = torch.cat(x0_base, dim=0)
105
+ x0_residual = repeat(
106
+ x0[:, 0], "B C W H -> B F C W H", F=x0.shape[1]-1)
107
+ else:
108
+ x0_residual = None
109
+
110
+ if self.use_base_model:
111
+ base_pred = unet_base(z_t_base, timesteps,
112
+ encoder_hidden_states, x0=x0_base).sample
113
+ else:
114
+ base_pred = base_noise
115
+
116
+ timesteps_alphas = [
117
+ noise_scheduler.alphas_cumprod[t.cpu()] for t in timesteps]
118
+ timesteps_alphas = torch.stack(
119
+ timesteps_alphas).to(base_pred.device)
120
+ timesteps_alphas = repeat(timesteps_alphas, "B -> B F C W H",
121
+ F=base_pred.shape[1], C=base_pred.shape[2], W=base_pred.shape[3], H=base_pred.shape[4])
122
+ base_correction = math.sqrt(
123
+ lambda_f) * torch.sqrt(1-timesteps_alphas) * base_pred
124
+
125
+ z_t_residual_dash = z_t_residual - base_correction
126
+
127
+ residual_pred = unet(
128
+ z_t_residual_dash, timesteps, encoder_hidden_states, x0=x0_residual).sample
129
+ composed_pred = math.sqrt(
130
+ lambda_f)*base_pred.detach() + math.sqrt(1-lambda_f) * residual_pred
131
+
132
+ loss_residual = torch.nn.functional.mse_loss(
133
+ composed_noise.float(), composed_pred.float(), reduction=reduction)
134
+ if self.use_base_model:
135
+ loss_base = torch.nn.functional.mse_loss(
136
+ base_noise.float(), base_pred.float(), reduction=reduction)
137
+ loss = loss_residual+loss_base
138
+ else:
139
+ loss = loss_residual
140
+ return loss
141
+
142
+ def add_noise(self, z_base, base_noise, z_residual, composed_noise, noise_scheduler, timesteps):
143
+ z_t_base = noise_scheduler.add_noise(
144
+ z_base, base_noise, timesteps)
145
+ z_t_residual = noise_scheduler.add_noise(
146
+ z_residual, composed_noise, timesteps)
147
+ return z_t_base, z_t_residual
148
+
149
+ def split_latent_into_base_residual(self, z_0, pl_module, noise_generator):
150
+ if self.random_frame:
151
+ raise NotImplementedError("Must be synced with x0 mask!")
152
+ fr_select = torch.randint(
153
+ 0, z_0.shape[1], (bsz,), device=pl_module.device).long()
154
+ z_base = z_0[:, fr_Select, None]
155
+ fr_residual = [fr for fr in range(
156
+ z_0.shape[1]) if fr != fr_select]
157
+ z_residual = z_0[:, fr_residual, None]
158
+ else:
159
+ if not pl_module.unet_params.frame_conditioning.randomize_mask:
160
+ z_base = z_0[:, 0, None]
161
+ z_residual = z_0[:, 1:]
162
+ else:
163
+ z_base = []
164
+ for batch_idx, frame_at_batch in enumerate(frame_idx):
165
+ z_base.append(
166
+ z_0[batch_idx, frame_at_batch, None, None])
167
+ z_base = torch.cat(z_base, dim=0)
168
+ # z_residual = z_0[[:, 1:]
169
+ z_residual = []
170
+
171
+ for batch_idx, frame_idx_batch in enumerate(frame_idx):
172
+ z_residual_batch = []
173
+ for frame in range(z_0.shape[1]):
174
+ if frame_idx_batch != frame:
175
+ z_residual_batch.append(
176
+ z_0[batch_idx, frame, None, None])
177
+ z_residual_batch = torch.cat(
178
+ z_residual_batch, dim=1)
179
+ z_residual.append(z_residual_batch)
180
+ z_residual = torch.cat(z_residual, dim=0)
181
+ base_noise = noise_generator.sample_noise(z_base) # b_t
182
+ residual_noise = noise_generator.sample_noise(z_residual) # r^f_t
183
+ lambda_f = self.lambda_f
184
+ composed_noise = math.sqrt(
185
+ lambda_f) * base_noise + math.sqrt(1-lambda_f) * residual_noise # dimension issue?
186
+
187
+ return z_base, base_noise, z_residual, composed_noise
188
+
189
+
190
+ class NoiseGenerator():
191
+
192
+ def __init__(self, mode="vanilla") -> None:
193
+ self.mode = mode
194
+
195
+ def set_seed(self, seed: int):
196
+ self.seed = seed
197
+
198
+ def reset_seed(self, seed: int):
199
+ pass
200
+
201
+ def sample_noise(self, z_0: torch.tensor = None, shape=None, device=None, dtype=None, generator=None):
202
+
203
+ assert (z_0 is not None) != (
204
+ shape is not None), f"either z_0 must be None, or shape must be None. Both provided."
205
+ kwargs = {}
206
+
207
+ if z_0 is None:
208
+ if device is not None:
209
+ kwargs["device"] = device
210
+ if dtype is not None:
211
+ kwargs["dtype"] = dtype
212
+
213
+ else:
214
+ kwargs["device"] = z_0.device
215
+ kwargs["dtype"] = z_0.dtype
216
+ shape = z_0.shape
217
+
218
+ if generator is not None:
219
+ kwargs["generator"] = generator
220
+
221
+ B, F, C, W, H = shape
222
+
223
+ if self.mode == "vanilla":
224
+ noise = torch.randn(
225
+ shape, **kwargs)
226
+ elif self.mode == "free_noise":
227
+ noise = torch.randn(shape, **kwargs)
228
+ if noise.shape[1] > 4:
229
+ # HARD CODED
230
+ noise = noise[:, :8]
231
+ noise = torch.cat(
232
+ [noise, noise[:, torch.randperm(noise.shape[1])]], dim=1)
233
+ elif noise.shape[2] > 4:
234
+ noise = noise[:, :, :8]
235
+ noise = torch.cat(
236
+ [noise, noise[:, :, torch.randperm(noise.shape[2])]], dim=2)
237
+ else:
238
+ raise NotImplementedError(
239
+ f"Shape of noise vector not as expected {noise.shape}")
240
+ elif self.mode == "equal":
241
+ shape = list(shape)
242
+ shape[1] = 1
243
+ noise_init = torch.randn(
244
+ shape, **kwargs)
245
+ shape[1] = F
246
+ noise = torch.zeros(
247
+ shape, device=noise_init.device, dtype=noise_init.dtype)
248
+ for fr in range(F):
249
+ noise[:, fr] = noise_init[:, 0]
250
+ elif self.mode == "fusion":
251
+ shape = list(shape)
252
+ shape[1] = 1
253
+ noise_init = torch.randn(
254
+ shape, **kwargs)
255
+ noises = []
256
+ noises.append(noise_init)
257
+ for fr in range(F-1):
258
+
259
+ shift = 2*(fr+1)
260
+ local_copy = noise_init
261
+ shifted_noise = torch.cat(
262
+ [local_copy[:, :, :, shift:, :], local_copy[:, :, :, :shift, :]], dim=3)
263
+ noises.append(math.sqrt(0.2)*shifted_noise +
264
+ math.sqrt(1-0.2)*torch.rand(shape, **kwargs))
265
+ noise = torch.cat(noises, dim=1)
266
+
267
+ elif self.mode == "motion_dynamics" or self.mode == "equal_noise_per_sequence":
268
+
269
+ shape = list(shape)
270
+ normal_frames = 1
271
+ shape[1] = normal_frames
272
+ init_noise = torch.randn(
273
+ shape, **kwargs)
274
+ noises = []
275
+ noises.append(init_noise)
276
+ init_noise = init_noise[:, -1, None]
277
+ print(f"UPDATE with noise = {init_noise.shape}")
278
+
279
+ if self.mode == "motion_dynamics":
280
+ for fr in range(F-normal_frames):
281
+
282
+ shift = 2*(fr+1)
283
+ print(fr, shift)
284
+ local_copy = init_noise
285
+ shifted_noise = torch.cat(
286
+ [local_copy[:, :, :, shift:, :], local_copy[:, :, :, :shift, :]], dim=3)
287
+ noises.append(shifted_noise)
288
+ elif self.mode == "equal_noise_per_sequence":
289
+ for fr in range(F-1):
290
+ noises.append(init_noise)
291
+ else:
292
+ raise NotImplementedError()
293
+ # noises[0] = noises[0] * 0
294
+ noise = torch.cat(noises, dim=1)
295
+ print(noise.shape)
296
+
297
+ return noise
t2v_enhanced/model/pl_module_params_controlnet.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Any, Dict, List, Optional, Callable
2
+ from t2v_enhanced.model import pl_module_extension
3
+ from t2v_enhanced.model.diffusers_conditional.models.controlnet.image_embedder import AbstractEncoder
4
+ from t2v_enhanced.model.requires_grad_setter import LayerConfig as LayerConfigNew
5
+ from t2v_enhanced.model import video_noise_generator
6
+
7
+
8
+ def auto_str(cls):
9
+ def __str__(self):
10
+ return '%s(%s)' % (
11
+ type(self).__name__,
12
+ ', '.join('%s=%s' % item for item in vars(self).items())
13
+ )
14
+ cls.__str__ = __str__
15
+ return cls
16
+
17
+
18
+ class LayerConfig():
19
+ def __init__(self,
20
+ update_with_full_lr: Optional[Union[List[str],
21
+ List[List[str]]]] = None,
22
+ exclude: Optional[List[str]] = None,
23
+ deactivate_all_grads: bool = True,
24
+ ) -> None:
25
+ self.deactivate_all_grads = deactivate_all_grads
26
+ if exclude is not None:
27
+ self.exclude = exclude
28
+ if update_with_full_lr is not None:
29
+ self.update_with_full_lr = update_with_full_lr
30
+
31
+ def __str__(self) -> str:
32
+ str = f"Deactivate all gradients first={self.deactivate_all_grads}. "
33
+ if hasattr(self, "update_with_full_lr"):
34
+ str += f"Then activating gradients for: {self.update_with_full_lr}. "
35
+ if hasattr(self, "exclude"):
36
+ str += f"Finally, excluding: {self.exclude}. "
37
+ return str
38
+
39
+
40
+ class OptimizerParams():
41
+ def __init__(self,
42
+ learning_rate: float,
43
+ # Default value due to legacy
44
+ layers_config: Union[LayerConfig, LayerConfigNew] = None,
45
+ layers_config_base: LayerConfig = None, # Default value due to legacy
46
+ use_warmup: bool = False,
47
+ warmup_steps: int = 10000,
48
+ warmup_start_factor: float = 1e-5,
49
+ learning_rate_spatial: float = 0.0,
50
+ use_8_bit_adam: bool = False,
51
+ noise_generator: Union[pl_module_extension.NoiseGenerator,
52
+ video_noise_generator.NoiseGenerator] = None,
53
+ noise_decomposition: pl_module_extension.NoiseDecomposition = None,
54
+ perceptual_loss: bool = False,
55
+ noise_offset: float = 0.0,
56
+ split_opt_by_node: bool = False,
57
+ reset_prediction_type_to_eps: bool = False,
58
+ train_val_sampler_may_differ: bool = False,
59
+ measure_similarity: bool = False,
60
+ similarity_loss: bool = False,
61
+ similarity_loss_weight: float = 1.0,
62
+ loss_conditional_weight: float = 0.0,
63
+ loss_conditional_weight_convex: bool = False,
64
+ loss_conditional_change_after_step: int = 0,
65
+ mask_conditional_frames: bool = False,
66
+ sample_from_noise: bool = True,
67
+ mask_alternating: bool = False,
68
+ uncondition_freq: int = -1,
69
+ no_text_condition_control: bool = False,
70
+ inject_image_into_input: bool = False,
71
+ inject_at_T: bool = False,
72
+ resampling_steps: int = 1,
73
+ control_freq_in_resample: int = 1,
74
+ resample_to_T: bool = False,
75
+ adaptive_loss_reweight: bool = False,
76
+ load_resampler_from_ckpt: str = "",
77
+ skip_controlnet_branch: bool = False,
78
+ use_fps_conditioning: bool = False,
79
+ num_frame_embeddings_range: int = 16,
80
+ start_frame_training: int = 0,
81
+ start_frame_ctrl: int = 0,
82
+ load_trained_base_model_and_resampler_from_ckpt: str = "",
83
+ load_trained_controlnet_from_ckpt: str = "",
84
+ # fill_up_frame_to_video: bool = False,
85
+ ) -> None:
86
+ self.use_warmup = use_warmup
87
+ self.warmup_steps = warmup_steps
88
+ self.warmup_start_factor = warmup_start_factor
89
+ self.learning_rate_spatial = learning_rate_spatial
90
+ self.learning_rate = learning_rate
91
+ self.use_8_bit_adam = use_8_bit_adam
92
+ self.layers_config = layers_config
93
+ self.noise_generator = noise_generator
94
+ self.perceptual_loss = perceptual_loss
95
+ self.noise_decomposition = noise_decomposition
96
+ self.noise_offset = noise_offset
97
+ self.split_opt_by_node = split_opt_by_node
98
+ self.reset_prediction_type_to_eps = reset_prediction_type_to_eps
99
+ self.train_val_sampler_may_differ = train_val_sampler_may_differ
100
+ self.measure_similarity = measure_similarity
101
+ self.similarity_loss = similarity_loss
102
+ self.similarity_loss_weight = similarity_loss_weight
103
+ self.loss_conditional_weight = loss_conditional_weight
104
+ self.loss_conditional_change_after_step = loss_conditional_change_after_step
105
+ self.mask_conditional_frames = mask_conditional_frames
106
+ self.loss_conditional_weight_convex = loss_conditional_weight_convex
107
+ self.sample_from_noise = sample_from_noise
108
+ self.layers_config_base = layers_config_base
109
+ self.mask_alternating = mask_alternating
110
+ self.uncondition_freq = uncondition_freq
111
+ self.no_text_condition_control = no_text_condition_control
112
+ self.inject_image_into_input = inject_image_into_input
113
+ self.inject_at_T = inject_at_T
114
+ self.resampling_steps = resampling_steps
115
+ self.control_freq_in_resample = control_freq_in_resample
116
+ self.resample_to_T = resample_to_T
117
+ self.adaptive_loss_reweight = adaptive_loss_reweight
118
+ self.load_resampler_from_ckpt = load_resampler_from_ckpt
119
+ self.skip_controlnet_branch = skip_controlnet_branch
120
+ self.use_fps_conditioning = use_fps_conditioning
121
+ self.num_frame_embeddings_range = num_frame_embeddings_range
122
+ self.start_frame_training = start_frame_training
123
+ self.load_trained_base_model_and_resampler_from_ckpt = load_trained_base_model_and_resampler_from_ckpt
124
+ self.load_trained_controlnet_from_ckpt = load_trained_controlnet_from_ckpt
125
+ self.start_frame_ctrl = start_frame_ctrl
126
+ if start_frame_ctrl < 0:
127
+ print("new format start frame cannot be negative")
128
+ exit()
129
+
130
+ # self.fill_up_frame_to_video = fill_up_frame_to_video
131
+
132
+ @property
133
+ def learning_rate_spatial(self):
134
+ return self._learning_rate_spatial
135
+
136
+ # legacy code that maps the state None or '-1' to '0.0'
137
+ # so 0.0 indicated no spatial learning rate is selected
138
+ @learning_rate_spatial.setter
139
+ def learning_rate_spatial(self, value):
140
+ if value is None or value == -1:
141
+ value = 0
142
+ self._learning_rate_spatial = value
143
+
144
+
145
+ # Legacy class
146
+ class SchedulerParams():
147
+ def __init__(self,
148
+ use_warmup: bool = False,
149
+ warmup_steps: int = 10000,
150
+ warmup_start_factor: float = 1e-5,
151
+ ) -> None:
152
+ self.use_warmup = use_warmup
153
+ self.warmup_steps = warmup_steps
154
+ self.warmup_start_factor = warmup_start_factor
155
+
156
+
157
+
158
+ class CrossFrameAttentionParams():
159
+
160
+ def __init__(self, attent_on: List[int], masking=False) -> None:
161
+ self.attent_on = attent_on
162
+ self.masking = masking
163
+
164
+
165
+ class InferenceParams():
166
+ def __init__(self,
167
+ width: int,
168
+ height: int,
169
+ video_length: int,
170
+ guidance_scale: float = 7.5,
171
+ use_dec_scaling: bool = True,
172
+ frame_rate: int = 2,
173
+ num_inference_steps: int = 50,
174
+ eta: float = 0.0,
175
+ n_autoregressive_generations: int = 1,
176
+ mode: str = "long_video",
177
+ start_from_real_input: bool = True,
178
+ eval_loss_metrics: bool = False,
179
+ scheduler_cls: str = "",
180
+ negative_prompt: str = "",
181
+ conditioning_from_all_past: bool = False,
182
+ validation_samples: int = 80,
183
+ conditioning_type: str = "last_chunk",
184
+ result_formats: List[str] = ["eval_gif", "gif", "mp4"],
185
+ concat_video: bool = True,
186
+ seed: int = 33,
187
+ ):
188
+ self.width = width
189
+ self.height = height
190
+ self.video_length = video_length if isinstance(
191
+ video_length, int) else int(video_length)
192
+ self.guidance_scale = guidance_scale
193
+ self.use_dec_scaling = use_dec_scaling
194
+ self.frame_rate = frame_rate
195
+ self.num_inference_steps = num_inference_steps
196
+ self.eta = eta
197
+ self.negative_prompt = negative_prompt
198
+ self.n_autoregressive_generations = n_autoregressive_generations
199
+ self.mode = mode
200
+ self.start_from_real_input = start_from_real_input
201
+ self.eval_loss_metrics = eval_loss_metrics
202
+ self.scheduler_cls = scheduler_cls
203
+ self.conditioning_from_all_past = conditioning_from_all_past
204
+ self.validation_samples = validation_samples
205
+ self.conditioning_type = conditioning_type
206
+ self.result_formats = result_formats
207
+ self.concat_video = concat_video
208
+ self.seed = seed
209
+
210
+ def to_dict(self):
211
+
212
+ keys = [entry for entry in dir(self) if not callable(getattr(
213
+ self, entry)) and not entry.startswith("__")]
214
+
215
+ result_dict = {}
216
+ for key in keys:
217
+ result_dict[key] = getattr(self, key)
218
+ return result_dict
219
+
220
+
221
+ @auto_str
222
+ class AttentionMaskParams():
223
+
224
+ def __init__(self,
225
+ temporal_self_attention_only_on_conditioning: bool = False,
226
+ temporal_self_attention_mask_included_itself: bool = False,
227
+ spatial_attend_on_condition_frames: bool = False,
228
+ temp_attend_on_neighborhood_of_condition_frames: bool = False,
229
+ temp_attend_on_uncond_include_past: bool = False,
230
+ ) -> None:
231
+ self.temporal_self_attention_mask_included_itself = temporal_self_attention_mask_included_itself
232
+ self.spatial_attend_on_condition_frames = spatial_attend_on_condition_frames
233
+ self.temp_attend_on_neighborhood_of_condition_frames = temp_attend_on_neighborhood_of_condition_frames
234
+ self.temporal_self_attention_only_on_conditioning = temporal_self_attention_only_on_conditioning
235
+ self.temp_attend_on_uncond_include_past = temp_attend_on_uncond_include_past
236
+
237
+ assert not temp_attend_on_neighborhood_of_condition_frames or not temporal_self_attention_only_on_conditioning
238
+
239
+
240
+ class UNetParams():
241
+
242
+ def __init__(self,
243
+ conditioning_embedding_out_channels: List[int],
244
+ ckpt_spatial_layers: str = "",
245
+ pipeline_repo: str = "",
246
+ unet_from_diffusers: bool = True,
247
+ spatial_latent_input: bool = False,
248
+ num_frame_conditioning: int = 1,
249
+ pipeline_class: str = "t2v_enhanced.model.model.controlnet.pipeline_text_to_video_w_controlnet_synth.TextToVideoSDPipeline",
250
+ frame_expansion: str = "last_frame",
251
+ downsample_controlnet_cond: bool = True,
252
+ num_frames: int = 1,
253
+ pre_transformer_in_cond: bool = False,
254
+ num_tranformers: int = 1,
255
+ zero_conv_3d: bool = False,
256
+ merging_mode: str = "addition",
257
+ compute_only_conditioned_frames: bool = False,
258
+ condition_encoder: str = "",
259
+ zero_conv_mode: str = "2d",
260
+ clean_model: bool = False,
261
+ merging_mode_base: str = "addition",
262
+ attention_mask_params: AttentionMaskParams = None,
263
+ attention_mask_params_base: AttentionMaskParams = None,
264
+ modelscope_input_format: bool = True,
265
+ temporal_self_attention_only_on_conditioning: bool = False,
266
+ temporal_self_attention_mask_included_itself: bool = False,
267
+ use_post_merger_zero_conv: bool = False,
268
+ weight_control_sample: float = 1.0,
269
+ use_controlnet_mask: bool = False,
270
+ random_mask_shift: bool = False,
271
+ random_mask: bool = False,
272
+ use_resampler: bool = False,
273
+ unet_from_pipe: bool = False,
274
+ unet_operates_on_2d: bool = False,
275
+ image_encoder: str = "CLIP",
276
+ use_standard_attention_processor: bool = True,
277
+ num_frames_before_chunk: int = 0,
278
+ resampler_type: str = "single_frame",
279
+ resampler_cls: str = "",
280
+ resampler_merging_layers: int = 1,
281
+ image_encoder_obj: AbstractEncoder = None,
282
+ cfg_text_image: bool = False,
283
+ aggregation: str = "last_out",
284
+ resampler_random_shift: bool = False,
285
+ img_cond_alpha_per_frame: bool = False,
286
+ num_control_input_frames: int = -1,
287
+ use_image_encoder_normalization: bool = False,
288
+ use_of: bool = False,
289
+ ema_param: float = -1.0,
290
+ concat: bool = False,
291
+ use_image_tokens_main: bool = True,
292
+ use_image_tokens_ctrl: bool = False,
293
+ ):
294
+
295
+ self.ckpt_spatial_layers = ckpt_spatial_layers
296
+ self.pipeline_repo = pipeline_repo
297
+ self.unet_from_diffusers = unet_from_diffusers
298
+ self.spatial_latent_input = spatial_latent_input
299
+ self.pipeline_class = pipeline_class
300
+ self.num_frame_conditioning = num_frame_conditioning
301
+ if num_control_input_frames == -1:
302
+ self.num_control_input_frames = num_frame_conditioning
303
+ else:
304
+ self.num_control_input_frames = num_control_input_frames
305
+
306
+ self.conditioning_embedding_out_channels = conditioning_embedding_out_channels
307
+ self.frame_expansion = frame_expansion
308
+ self.downsample_controlnet_cond = downsample_controlnet_cond
309
+ self.num_frames = num_frames
310
+ self.pre_transformer_in_cond = pre_transformer_in_cond
311
+ self.num_tranformers = num_tranformers
312
+ self.zero_conv_3d = zero_conv_3d
313
+ self.merging_mode = merging_mode
314
+ self.compute_only_conditioned_frames = compute_only_conditioned_frames
315
+ self.clean_model = clean_model
316
+ self.condition_encoder = condition_encoder
317
+ self.zero_conv_mode = zero_conv_mode
318
+ self.merging_mode_base = merging_mode_base
319
+ self.modelscope_input_format = modelscope_input_format
320
+ assert not temporal_self_attention_only_on_conditioning, "This parameter is only here for backward compatibility. Set AttentionMaskParams instead."
321
+ assert not temporal_self_attention_mask_included_itself, "This parameter is only here for backward compatibility. Set AttentionMaskParams instead."
322
+ if attention_mask_params is not None and attention_mask_params_base is None:
323
+ attention_mask_params_base = attention_mask_params
324
+ if attention_mask_params is None:
325
+ attention_mask_params = AttentionMaskParams()
326
+ if attention_mask_params_base is None:
327
+ attention_mask_params_base = AttentionMaskParams()
328
+ self.attention_mask_params = attention_mask_params
329
+ self.attention_mask_params_base = attention_mask_params_base
330
+ self.weight_control_sample = weight_control_sample
331
+ self.use_controlnet_mask = use_controlnet_mask
332
+ self.random_mask_shift = random_mask_shift
333
+ self.random_mask = random_mask
334
+ self.use_resampler = use_resampler
335
+ self.unet_from_pipe = unet_from_pipe
336
+ self.unet_operates_on_2d = unet_operates_on_2d
337
+ self.image_encoder = image_encoder_obj
338
+ self.use_standard_attention_processor = use_standard_attention_processor
339
+ self.num_frames_before_chunk = num_frames_before_chunk
340
+ self.resampler_type = resampler_type
341
+ self.resampler_cls = resampler_cls
342
+ self.resampler_merging_layers = resampler_merging_layers
343
+ self.cfg_text_image = cfg_text_image
344
+ self.aggregation = aggregation
345
+ self.resampler_random_shift = resampler_random_shift
346
+ self.img_cond_alpha_per_frame = img_cond_alpha_per_frame
347
+ self.use_image_encoder_normalization = use_image_encoder_normalization
348
+ self.use_of = use_of
349
+ self.ema_param = ema_param
350
+ self.concat = concat
351
+ self.use_image_tokens_main = use_image_tokens_main
352
+ self.use_image_tokens_ctrl = use_image_tokens_ctrl
353
+ assert not use_post_merger_zero_conv
354
+
355
+ if spatial_latent_input:
356
+ assert unet_from_diffusers, "Spatial latent input only implemented by original diffusers model. Set 'model.unet_params.unet_from_diffusers=True'."
t2v_enhanced/model/requires_grad_setter.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Any, Dict, List, Optional, Tuple
2
+ import pytorch_lightning as pl
3
+
4
+
5
+ class LayerConfig():
6
+ def __init__(self,
7
+ gradient_setup: List[Tuple[bool, List[str]]] = None,
8
+ ) -> None:
9
+
10
+ if gradient_setup is not None:
11
+ self.gradient_setup = gradient_setup
12
+ self.new_config = True
13
+ # TODO add option to specify quantization per layer
14
+
15
+ def set_requires_grad(self, pl_module: pl.LightningModule):
16
+ # [["True","unet.a.b","c"],["True,[]"]]
17
+
18
+ for selected_module_setup in self.gradient_setup:
19
+ for model_name, p in pl_module.named_parameters():
20
+ grad_mode = selected_module_setup[0] == True
21
+ selected_module_path = selected_module_setup[1]
22
+ path_is_matching = True
23
+ model_name_selection = model_name
24
+ for selected_module in selected_module_path:
25
+ position = model_name_selection.find(selected_module)
26
+ if position == -1:
27
+ path_is_matching = False
28
+ continue
29
+ else:
30
+ shift = len(selected_module)
31
+ model_name_selection = model_name_selection[position+shift:]
32
+ if path_is_matching:
33
+ # if grad_mode:
34
+ # print(
35
+ # f"Setting gradient for {model_name} to {grad_mode}")
36
+ p.requires_grad = grad_mode
t2v_enhanced/model/video_ldm.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Any, Optional, Union, Callable
3
+
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ from diffusers import DDPMScheduler, DiffusionPipeline, AutoencoderKL, DDIMScheduler
7
+ from diffusers.utils.import_utils import is_xformers_available
8
+ from einops import rearrange, repeat
9
+
10
+ from transformers import CLIPTextModel, CLIPTokenizer
11
+ from utils.video_utils import ResultProcessor, save_videos_grid, video_naming
12
+
13
+ from t2v_enhanced.model import pl_module_params_controlnet
14
+
15
+ from t2v_enhanced.model.diffusers_conditional.models.controlnet.controlnet import ControlNetModel
16
+ from t2v_enhanced.model.diffusers_conditional.models.controlnet.unet_3d_condition import UNet3DConditionModel
17
+ from t2v_enhanced.model.diffusers_conditional.models.controlnet.pipeline_text_to_video_w_controlnet_synth import TextToVideoSDPipeline
18
+
19
+ from t2v_enhanced.model.diffusers_conditional.models.controlnet.processor import set_use_memory_efficient_attention_xformers
20
+ from t2v_enhanced.model.diffusers_conditional.models.controlnet.mask_generator import MaskGenerator
21
+
22
+ import warnings
23
+ # from warnings import warn
24
+ from t2v_enhanced.utils.iimage import IImage
25
+ from t2v_enhanced.utils.object_loader import instantiate_object
26
+ from t2v_enhanced.utils.object_loader import get_class
27
+
28
+
29
+ class VideoLDM(pl.LightningModule):
30
+
31
+ def __init__(self,
32
+ inference_params: pl_module_params_controlnet.InferenceParams,
33
+ opt_params: pl_module_params_controlnet.OptimizerParams = None,
34
+ unet_params: pl_module_params_controlnet.UNetParams = None,
35
+ ):
36
+ super().__init__()
37
+
38
+ self.inference_generator = torch.Generator(device=self.device)
39
+
40
+ self.opt_params = opt_params
41
+ self.unet_params = unet_params
42
+
43
+ print(f"Base pipeline from: {unet_params.pipeline_repo}")
44
+ print(f"Pipeline class {unet_params.pipeline_class}")
45
+ # load entire pipeline (unet, vq, text encoder,..)
46
+ state_dict_control_model = None
47
+ state_dict_fusion = None
48
+ state_dict_base_model = None
49
+
50
+ if len(opt_params.load_trained_controlnet_from_ckpt) > 0:
51
+ state_dict_ckpt = torch.load(opt_params.load_trained_controlnet_from_ckpt, map_location=torch.device("cpu"))
52
+ state_dict_ckpt = state_dict_ckpt["state_dict"]
53
+ state_dict_control_model = dict(filter(lambda x: x[0].startswith("unet"), state_dict_ckpt.items()))
54
+ state_dict_control_model = {k.split("unet.")[1]: v for (k, v) in state_dict_control_model.items()}
55
+
56
+ state_dict_fusion = dict(filter(lambda x: "cross_attention_merger" in x[0], state_dict_ckpt.items()))
57
+ state_dict_fusion = {k.split("base_model.")[1]: v for (k, v) in state_dict_fusion.items()}
58
+ del state_dict_ckpt
59
+
60
+ state_dict_proj = None
61
+ state_dict_ckpt = None
62
+
63
+ if hasattr(unet_params, "use_resampler") and unet_params.use_resampler:
64
+ num_queries = unet_params.num_frames if unet_params.num_frames > 1 else None
65
+ if unet_params.use_image_tokens_ctrl:
66
+ num_queries = unet_params.num_control_input_frames
67
+ assert unet_params.frame_expansion == "none"
68
+ image_encoder = self.unet_params.image_encoder
69
+ embedding_dim = image_encoder.embedding_dim
70
+
71
+ resampler = instantiate_object(self.unet_params.resampler_cls, video_length=num_queries, embedding_dim=embedding_dim, input_tokens=image_encoder.num_tokens, num_layers=self.unet_params.resampler_merging_layers, aggregation=self.unet_params.aggregation)
72
+
73
+ state_dict_proj = None
74
+
75
+ self.resampler = resampler
76
+ self.image_encoder = image_encoder
77
+
78
+
79
+ noise_scheduler = DDPMScheduler.from_pretrained(self.unet_params.pipeline_repo, subfolder="scheduler")
80
+ tokenizer = CLIPTokenizer.from_pretrained(self.unet_params.pipeline_repo, subfolder="tokenizer")
81
+ text_encoder = CLIPTextModel.from_pretrained(self.unet_params.pipeline_repo, subfolder="text_encoder")
82
+ vae = AutoencoderKL.from_pretrained(self.unet_params.pipeline_repo, subfolder="vae")
83
+ base_model = UNet3DConditionModel.from_pretrained(self.unet_params.pipeline_repo, subfolder="unet", low_cpu_mem_usage=False, device_map=None, merging_mode=self.unet_params.merging_mode_base, use_image_embedding=unet_params.use_resampler and unet_params.use_image_tokens_main, use_fps_conditioning=self.opt_params.use_fps_conditioning, unet_params=unet_params)
84
+
85
+ if state_dict_base_model is not None:
86
+ miss, unex = base_model.load_state_dict(state_dict_base_model, strict=False)
87
+ assert len(unex) == 0
88
+ if len(miss) > 0:
89
+ warnings.warn(f"Missing keys when loading base_mode:{miss}")
90
+ del state_dict_base_model
91
+ if state_dict_fusion is not None:
92
+ miss, unex = base_model.load_state_dict(state_dict_fusion, strict=False)
93
+ assert len(unex) == 0
94
+ del state_dict_fusion
95
+
96
+ print("PIPE LOADING DONE")
97
+ self.noise_scheduler = noise_scheduler
98
+ self.tokenizer = tokenizer
99
+ self.text_encoder = text_encoder
100
+ self.vae = vae
101
+
102
+ self.unet = ControlNetModel.from_unet(
103
+ unet=base_model,
104
+ conditioning_embedding_out_channels=unet_params.conditioning_embedding_out_channels,
105
+ downsample_controlnet_cond=unet_params.downsample_controlnet_cond,
106
+ num_frames=unet_params.num_frames if (unet_params.frame_expansion != "none" or self.unet_params.use_controlnet_mask) else unet_params.num_control_input_frames,
107
+ num_frame_conditioning=unet_params.num_control_input_frames,
108
+ frame_expansion=unet_params.frame_expansion,
109
+ pre_transformer_in_cond=unet_params.pre_transformer_in_cond,
110
+ num_tranformers=unet_params.num_tranformers,
111
+ vae=AutoencoderKL.from_pretrained(self.unet_params.pipeline_repo, subfolder="vae"),
112
+ zero_conv_mode=unet_params.zero_conv_mode,
113
+ merging_mode=unet_params.merging_mode,
114
+ condition_encoder=unet_params.condition_encoder,
115
+ use_controlnet_mask=unet_params.use_controlnet_mask,
116
+ use_image_embedding=unet_params.use_resampler and unet_params.use_image_tokens_ctrl,
117
+ unet_params=unet_params,
118
+ use_image_encoder_normalization=unet_params.use_image_encoder_normalization,
119
+ )
120
+ if state_dict_control_model is not None:
121
+ miss, unex = self.unet.load_state_dict(
122
+ state_dict_control_model, strict=False)
123
+ if len(miss) > 0:
124
+ print("WARNING: Loading checkpoint for controlnet misses states")
125
+ print(miss)
126
+
127
+ if unet_params.frame_expansion == "none":
128
+ attention_params = self.unet_params.attention_mask_params
129
+ assert not attention_params.temporal_self_attention_only_on_conditioning and not attention_params.spatial_attend_on_condition_frames and not attention_params.temp_attend_on_neighborhood_of_condition_frames
130
+
131
+ self.mask_generator = MaskGenerator(
132
+ self.unet_params.attention_mask_params, num_frame_conditioning=self.unet_params.num_control_input_frames, num_frames=self.unet_params.num_frames)
133
+ self.mask_generator_base = MaskGenerator(
134
+ self.unet_params.attention_mask_params_base, num_frame_conditioning=self.unet_params.num_control_input_frames, num_frames=self.unet_params.num_frames)
135
+
136
+ if state_dict_proj is not None and unet_params.use_image_tokens_main:
137
+ if unet_params.use_image_tokens_main:
138
+ missing, unexpected = base_model.load_state_dict(
139
+ state_dict_proj, strict=False)
140
+ elif unet_params.use_image_tokens_ctrl:
141
+ missing, unexpected = unet.load_state_dict(
142
+ state_dict_proj, strict=False)
143
+ assert len(unexpected) == 0, f"Unexpected entries {unexpected}"
144
+ print(f"Missing keys state proj = {missing}")
145
+ del state_dict_proj
146
+
147
+ base_model.requires_grad_(False)
148
+ self.base_model = base_model
149
+ self.unet.requires_grad_(False)
150
+ self.text_encoder.requires_grad_(False)
151
+ self.vae.requires_grad_(False)
152
+
153
+ layers_config = opt_params.layers_config
154
+ layers_config.set_requires_grad(self)
155
+
156
+ print("CUSTOM XFORMERS ATTENTION USED.")
157
+ if is_xformers_available():
158
+ set_use_memory_efficient_attention_xformers(self.unet, num_frame_conditioning=self.unet_params.num_control_input_frames,
159
+ num_frames=self.unet_params.num_frames,
160
+ attention_mask_params=self.unet_params.attention_mask_params
161
+ )
162
+ set_use_memory_efficient_attention_xformers(self.base_model, num_frame_conditioning=self.unet_params.num_control_input_frames,
163
+ num_frames=self.unet_params.num_frames,
164
+ attention_mask_params=self.unet_params.attention_mask_params_base)
165
+
166
+ if len(inference_params.scheduler_cls) > 0:
167
+ inf_scheduler_class = get_class(inference_params.scheduler_cls)
168
+ else:
169
+ inf_scheduler_class = DDIMScheduler
170
+
171
+ inf_scheduler = inf_scheduler_class.from_pretrained(
172
+ self.unet_params.pipeline_repo, subfolder="scheduler")
173
+ inference_pipeline = TextToVideoSDPipeline(vae=self.vae,
174
+ text_encoder=self.text_encoder,
175
+ tokenizer=self.tokenizer,
176
+ unet=self.base_model,
177
+ controlnet=self.unet,
178
+ scheduler=inf_scheduler
179
+ )
180
+
181
+ inference_pipeline.set_noise_generator(self.opt_params.noise_generator)
182
+ inference_pipeline.enable_vae_slicing()
183
+
184
+ inference_pipeline.set_progress_bar_config(disable=True)
185
+
186
+ self.inference_params = inference_params
187
+ self.inference_pipeline = inference_pipeline
188
+
189
+ self.result_processor = ResultProcessor(fps=self.inference_params.frame_rate, n_frames=self.inference_params.video_length)
190
+
191
+ def on_start(self):
192
+ datamodule = self.trainer._data_connector._datahook_selector.datamodule
193
+ pipe_id_model = self.unet_params.pipeline_repo
194
+ for dataset_key in ["video_dataset", "image_dataset", "predict_dataset"]:
195
+ dataset = getattr(datamodule, dataset_key, None)
196
+ if dataset is not None and hasattr(dataset, "model_id"):
197
+ pipe_id_data = dataset.model_id
198
+ assert pipe_id_model == pipe_id_data, f"Model and Dataloader need the same pipeline path. Found '{pipe_id_model}' and '{dataset_key}.model_id={pipe_id_data}'. Consider setting '--data.{dataset_key}.model_id={pipe_id_data}'"
199
+ self.result_processor.set_logger(self.logger)
200
+
201
+ def on_predict_start(self) -> None:
202
+ self.on_start()
203
+ # pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
204
+ # pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
205
+ # pipe.set_progress_bar_config(disable=True)
206
+ # self.first_stage = pipe.to(self.device)
207
+
208
+ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
209
+ cfg = self.trainer.predict_cfg
210
+
211
+ result_file_stem = cfg["result_file_stem"]
212
+ storage_fol = Path(cfg['predict_dir'])
213
+ prompts = [cfg["prompt"]]
214
+
215
+ inference_params: pl_module_params_controlnet.InferenceParams = self.inference_params
216
+ conditioning_type = inference_params.conditioning_type
217
+ n_autoregressive_generations = inference_params.n_autoregressive_generations
218
+ mode = inference_params.mode
219
+ start_from_real_input = inference_params.start_from_real_input
220
+ assert isinstance(prompts, list)
221
+
222
+ prompts = n_autoregressive_generations * prompts
223
+
224
+ self.inference_generator.manual_seed(self.inference_params.seed)
225
+
226
+ assert self.unet_params.num_control_input_frames == self.inference_params.video_length//2, f"currently we assume to have an equal size for and second half of the frame interval, e.g. 16 frames, and we condition on 8. Current setup: {self.unet_params.num_frame_conditioning} and {self.inference_params.video_length}"
227
+
228
+ chunks_conditional = []
229
+ batch_size = 1
230
+ shape = (batch_size, self.inference_pipeline.unet.config.in_channels, self.inference_params.video_length,
231
+ self.inference_pipeline.unet.config.sample_size, self.inference_pipeline.unet.config.sample_size)
232
+ for idx, prompt in enumerate(prompts):
233
+ if idx > 0:
234
+ content = sample*2-1
235
+ content_latent = self.vae.encode(content).latent_dist.sample() * self.vae.config.scaling_factor
236
+ content_latent = rearrange(content_latent, "F C W H -> 1 C F W H")
237
+ content_latent = content_latent[:, :, self.unet_params.num_control_input_frames:].detach().clone()
238
+
239
+ if hasattr(self.inference_pipeline, "noise_generator"):
240
+ latents = self.inference_pipeline.noise_generator.sample_noise(shape=shape, device=self.device, dtype=self.dtype, generator=self.inference_generator, content=content_latent if idx > 0 else None)
241
+ else:
242
+ latents = None
243
+ if idx == 0:
244
+ sample = cfg["video"]
245
+ else:
246
+ if inference_params.conditioning_type == "fixed":
247
+ context = chunks_conditional[0][:self.unet_params.num_frame_conditioning]
248
+ context = [context]
249
+ context = [2*sample-1 for sample in context]
250
+
251
+ input_frames_conditioning = torch.cat(context).detach().clone()
252
+ input_frames_conditioning = rearrange(input_frames_conditioning, "F C W H -> 1 F C W H")
253
+ elif inference_params.conditioning_type == "last_chunk":
254
+ input_frames_conditioning = condition_input[:, -self.unet_params.num_frame_conditioning:].detach().clone()
255
+ elif inference_params.conditioning_type == "past":
256
+ context = [sample[:self.unet_params.num_control_input_frames] for sample in chunks_conditional]
257
+ context = [2*sample-1 for sample in context]
258
+
259
+ input_frames_conditioning = torch.cat(context).detach().clone()
260
+ input_frames_conditioning = rearrange(input_frames_conditioning, "F C W H -> 1 F C W H")
261
+ else:
262
+ raise NotImplementedError()
263
+
264
+ input_frames = condition_input[:, self.unet_params.num_control_input_frames:].detach().clone()
265
+
266
+ sample = self(prompt, input_frames=input_frames, input_frames_conditioning=input_frames_conditioning, latents=latents)
267
+
268
+ if hasattr(self.inference_pipeline, "reset_noise_generator_state"):
269
+ self.inference_pipeline.reset_noise_generator_state()
270
+
271
+ condition_input = rearrange(sample, "F C W H -> 1 F C W H")
272
+ condition_input = (2*condition_input)-1 # range: [-1,1]
273
+
274
+ # store first 16 frames, then always last 8 of a chunk
275
+ chunks_conditional.append(sample)
276
+
277
+ result_formats = self.inference_params.result_formats
278
+ # result_formats = [gif", "mp4"]
279
+ concat_video = self.inference_params.concat_video
280
+
281
+ def IImage_normalized(x): return IImage(x, vmin=0, vmax=1)
282
+ for result_format in result_formats:
283
+ save_format = result_format.replace("eval_", "")
284
+
285
+ merged_video = None
286
+ for chunk_idx, (prompt, video) in enumerate(zip(prompts, chunks_conditional)):
287
+ if chunk_idx == 0:
288
+ current_video = IImage_normalized(video)
289
+ else:
290
+ current_video = IImage_normalized(video[self.unet_params.num_control_input_frames:])
291
+
292
+ if merged_video is None:
293
+ merged_video = current_video
294
+ else:
295
+ merged_video &= current_video
296
+
297
+ if concat_video:
298
+ filename = video_naming(prompts[0], save_format, batch_idx, 0)
299
+ result_file_video = (storage_fol / filename).absolute().as_posix()
300
+ result_file_video = (Path(result_file_video).parent / (result_file_stem+Path(result_file_video).suffix)).as_posix()
301
+ self.result_processor.save_to_file(video=merged_video.torch(vmin=0, vmax=1), prompt=prompts[0], video_filename=result_file_video, prompt_on_vid=False)
302
+
303
+ def forward(self, prompt, input_frames=None, input_frames_conditioning=None, latents=None):
304
+ call_params = self.inference_params.to_dict()
305
+ print(f"INFERENCE PARAMS = {call_params}")
306
+ call_params["prompt"] = prompt
307
+
308
+ call_params["image"] = input_frames
309
+ call_params["num_frames"] = self.inference_params.video_length
310
+ call_params["return_dict"] = False
311
+ call_params["output_type"] = "pt_t2v"
312
+ call_params["mask_generator"] = self.mask_generator
313
+ call_params["precision"] = "16" if self.trainer.precision.startswith("16") else "32"
314
+ call_params["no_text_condition_control"] = self.opt_params.no_text_condition_control
315
+ call_params["weight_control_sample"] = self.unet_params.weight_control_sample
316
+ call_params["use_controlnet_mask"] = self.unet_params.use_controlnet_mask
317
+ call_params["skip_controlnet_branch"] = self.opt_params.skip_controlnet_branch
318
+ call_params["img_cond_resampler"] = self.resampler if self.unet_params.use_resampler else None
319
+ call_params["img_cond_encoder"] = self.image_encoder if self.unet_params.use_resampler else None
320
+ call_params["input_frames_conditioning"] = input_frames_conditioning
321
+ call_params["cfg_text_image"] = self.unet_params.cfg_text_image
322
+ call_params["use_of"] = self.unet_params.use_of
323
+ if latents is not None:
324
+ call_params["latents"] = latents
325
+
326
+ sample = self.inference_pipeline(generator=self.inference_generator, **call_params)
327
+ return sample
t2v_enhanced/model/video_noise_generator.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.fft as fft
3
+ from torch import nn
4
+ from torch.nn import functional
5
+ from math import sqrt
6
+ from einops import rearrange
7
+ import math
8
+ import numbers
9
+ from typing import List
10
+
11
+ # adapted from https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/10
12
+ # and https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/19
13
+
14
+
15
+ def gaussian_smoothing_kernel(shape, kernel_size, sigma, dim=2):
16
+ """
17
+ Apply gaussian smoothing on a
18
+ 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
19
+ in the input using a depthwise convolution.
20
+ Arguments:
21
+ channels (int, sequence): Number of channels of the input tensors. Output will
22
+ have this number of channels as well.
23
+ kernel_size (int, sequence): Size of the gaussian kernel.
24
+ sigma (float, sequence): Standard deviation of the gaussian kernel.
25
+ dim (int, optional): The number of dimensions of the data.
26
+ Default value is 2 (spatial).
27
+ """
28
+ if isinstance(kernel_size, numbers.Number):
29
+ kernel_size = [kernel_size] * dim
30
+ if isinstance(sigma, numbers.Number):
31
+ sigma = [sigma] * dim
32
+
33
+ # The gaussian kernel is the product of the
34
+ # gaussian function of each dimension.
35
+ kernel = 1
36
+ meshgrids = torch.meshgrid(
37
+ [
38
+ torch.arange(size, dtype=torch.float32)
39
+ for size in kernel_size
40
+ ]
41
+ )
42
+
43
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
44
+ mean = (size - 1) / 2
45
+
46
+ kernel *= torch.exp(-((mgrid - mean) / std) ** 2 / 2)
47
+ # kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
48
+ # torch.exp(-((mgrid - mean) / std) ** 2 / 2)
49
+
50
+ # Make sure sum of values in gaussian kernel equals 1.
51
+ kernel = kernel / torch.sum(kernel)
52
+
53
+ pad_length = (math.floor(
54
+ (shape[-1]-kernel_size[-1])/2), math.floor((shape[-1]-kernel_size[-1])/2), math.floor((shape[-2]-kernel_size[-2])/2), math.floor((shape[-2]-kernel_size[-2])/2), math.floor((shape[-3]-kernel_size[-3])/2), math.floor((shape[-3]-kernel_size[-3])/2))
55
+
56
+ kernel = functional.pad(kernel, pad_length)
57
+ assert kernel.shape == shape[-3:]
58
+ return kernel
59
+
60
+ '''
61
+ # Reshape to depthwise convolutional weight
62
+ kernel = kernel.view(1, 1, *kernel.size())
63
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
64
+
65
+
66
+ self.register_buffer('weight', kernel)
67
+ self.groups = channels
68
+
69
+ if dim == 1:
70
+ self.conv = functional.conv1d
71
+ elif dim == 2:
72
+ self.conv = functional.conv2d
73
+ elif dim == 3:
74
+ self.conv = functional.conv3d
75
+ else:
76
+ raise RuntimeError(
77
+ 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(
78
+ dim)
79
+ )
80
+ '''
81
+
82
+
83
+ class NoiseGenerator():
84
+
85
+ def __init__(self, alpha: float = 0.0, shared_noise_across_chunks: bool = False, mode="vanilla", forward_steps: int = 850, radius: List[float] = None) -> None:
86
+ self.mode = mode
87
+ self.alpha = alpha
88
+ self.shared_noise_across_chunks = shared_noise_across_chunks
89
+ self.forward_steps = forward_steps
90
+ self.radius = radius
91
+
92
+ def set_seed(self, seed: int):
93
+ self.seed = seed
94
+
95
+ def reset_seed(self, seed: int):
96
+ pass
97
+
98
+ def reset_noise_generator_state(self):
99
+ if hasattr(self, "e_shared"):
100
+ del self.e_shared
101
+
102
+ def sample_noise(self, z_0: torch.tensor = None, shape=None, device=None, dtype=None, generator=None, content=None):
103
+ assert (z_0 is not None) != (
104
+ shape is not None), f"either z_0 must be None, or shape must be None. Both provided."
105
+ kwargs = {}
106
+ noise = torch.randn(shape, **kwargs)
107
+
108
+ if z_0 is None:
109
+ if device is not None:
110
+ kwargs["device"] = device
111
+ if dtype is not None:
112
+ kwargs["dtype"] = dtype
113
+
114
+ else:
115
+ kwargs["device"] = z_0.device
116
+ kwargs["dtype"] = z_0.dtype
117
+ shape = z_0.shape
118
+
119
+ if generator is not None:
120
+ kwargs["generator"] = generator
121
+
122
+ B, F, C, W, H = shape
123
+ if F == 4 and C > 4:
124
+ frame_idx = 2
125
+ F, C = C, F
126
+ else:
127
+ frame_idx = 1
128
+
129
+ if "mixed_noise" in self.mode:
130
+
131
+ shape_per_frame = [dim for dim in shape]
132
+ shape_per_frame[frame_idx] = 1
133
+ zero_mean = torch.zeros(
134
+ shape_per_frame, device=kwargs["device"], dtype=kwargs["dtype"])
135
+ std = torch.ones(
136
+ shape_per_frame, device=kwargs["device"], dtype=kwargs["dtype"])
137
+ alpha = self.alpha
138
+ std_coeff_shared = (alpha**2) / (1 + alpha**2)
139
+ if self.shared_noise_across_chunks and hasattr(self, "e_shared"):
140
+ e_shared = self.e_shared
141
+ else:
142
+ e_shared = torch.normal(mean=zero_mean, std=sqrt(
143
+ std_coeff_shared)*std, generator=kwargs["generator"] if "generator" in kwargs else None)
144
+ if self.shared_noise_across_chunks:
145
+ self.e_shared = e_shared
146
+
147
+ e_inds = []
148
+ for frame in range(shape[frame_idx]):
149
+ std_coeff_ind = 1 / (1 + alpha**2)
150
+ e_ind = torch.normal(
151
+ mean=zero_mean, std=sqrt(std_coeff_ind)*std, generator=kwargs["generator"] if "generator" in kwargs else None)
152
+ e_inds.append(e_ind)
153
+ noise = torch.cat(
154
+ [e_shared + e_ind for e_ind in e_inds], dim=frame_idx)
155
+
156
+ if "consistI2V" in self.mode and content is not None:
157
+ # if self.mode == "mixed_noise_consistI2V", we will use 'noise' from 'mixed_noise'. Otherwise, it is randn noise.
158
+
159
+ if frame_idx == 1:
160
+ assert content.shape[0] == noise.shape[0] and content.shape[2:] == noise.shape[2:]
161
+ content = torch.concat([content, content[:, -1:].repeat(
162
+ 1, noise.shape[1]-content.shape[1], 1, 1, 1)], dim=1)
163
+ noise = rearrange(noise, "B F C W H -> (B C) F W H")
164
+ content = rearrange(content, "B F C W H -> (B C) F W H")
165
+
166
+ else:
167
+ assert content.shape[:2] == noise.shape[:
168
+ 2] and content.shape[3:] == noise.shape[3:]
169
+ content = torch.concat(
170
+ [content, content[:, :, -1:].repeat(1, 1, noise.shape[2]-content.shape[2], 1, 1)], dim=2)
171
+ noise = rearrange(noise, "B C F W H -> (B C) F W H")
172
+ content = rearrange(content, "B C F W H -> (B C) F W H")
173
+
174
+ # TODO implement DDPM_forward using diffusers framework
175
+ '''
176
+ content_noisy = ddpm_forward(
177
+ content, noise, self.forward_steps)
178
+ '''
179
+
180
+ # A 2D low pass filter was given in the blog:
181
+ # see https://pytorch.org/blog/the-torch.fft-module-accelerated-fast-fourier-transforms-with-autograd-in-pyTorch/
182
+
183
+ # alternative
184
+ # do we have to specify more (s,dim,norm?)
185
+ noise_fft = fft.fftn(noise)
186
+ content_noisy_fft = fft.fftn(content_noisy)
187
+
188
+ # shift low frequency parts to center
189
+ noise_fft_shifted = fft.fftshift(noise_fft)
190
+ content_noisy_fft_shifted = fft.fftshift(content_noisy_fft)
191
+
192
+ # create gaussian low pass filter 'gaussian_low_pass_filter' (specify std!)
193
+ # mask out high frequencies using 'cutoff_frequence', something like gaussian_low_pass_filter[freq > cut_off_frequency] = 0.0
194
+ # TODO define 'gaussian_low_pass_filter', apply frequency cutoff filter using self.cutoff_frequency. We need to apply fft.fftshift too probably.
195
+ # TODO what exactly is the "normalized space-time stop frequency" used for the cutoff?
196
+
197
+ gaussian_3d = gaussian_smoothing_kernel(noise_fft.shape, kernel_size=(
198
+ noise_fft.shape[-3], noise_fft.shape[-2], noise_fft.shape[-1]), sigma=1, dim=3).to(noise.device)
199
+
200
+ # define cutoff frequency around the kernel center
201
+ # TODO define center and cut off radius, e.g. somethink like gaussian_3d[...,:c_x-r_x,:c_y-r_y:,:c_z-r_z] = 0.0 and gaussian_3d[...,c_x+r_x:,c_y+r_y:,c_z+r_z:] = 0.0
202
+ # as we have 16 x 32 x 32, center should be (7.5,15.5,15.5)
203
+ radius = self.radius
204
+
205
+ # TODO we need to use rounding (ceil?)
206
+
207
+ gaussian_3d[:center[0]-radius[0], :center[1] -
208
+ radius[1], :center[2]-radius[2]] = 0.0
209
+ gaussian_3d[center[0]+radius[0]:,
210
+ center[1]+radius[1]:, center[2]+radius[2]:] = 0.0
211
+
212
+ noise_fft_shifted_hp = noise_fft_shifted * (1 - gaussian_3d)
213
+ content_noisy_fft_shifted_lp = content_noisy_fft_shifted * gaussian_3d
214
+
215
+ noise = fft.ifftn(fft.ifftshift(
216
+ noise_fft_shifted_hp+content_noisy_fft_shifted_lp))
217
+ if frame_idx == 1:
218
+ noise = rearrange(
219
+ noise, "(B C) F W H -> B F C W H", B=B)
220
+ else:
221
+ noise = rearrange(
222
+ noise, "(B C) F W H -> B C F W H", B=B)
223
+
224
+ assert noise.shape == shape
225
+ return noise
t2v_enhanced/model_func.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # General
2
+ import os
3
+ from os.path import join as opj
4
+ import datetime
5
+ import torch
6
+ from einops import rearrange, repeat
7
+
8
+ # Utilities
9
+ from inference_utils import *
10
+
11
+ from modelscope.outputs import OutputKeys
12
+ import imageio
13
+ from PIL import Image
14
+ import numpy as np
15
+
16
+ import torch.nn.functional as F
17
+ import torchvision.transforms as transforms
18
+ from diffusers.utils import load_image
19
+ transform = transforms.Compose([
20
+ transforms.PILToTensor()
21
+ ])
22
+
23
+
24
+ def ms_short_gen(prompt, ms_model, inference_generator, t=50, device="cuda"):
25
+ frames = ms_model(prompt,
26
+ num_inference_steps=t,
27
+ generator=inference_generator,
28
+ eta=1.0,
29
+ height=256,
30
+ width=256,
31
+ latents=None).frames
32
+ frames = torch.stack([torch.from_numpy(frame) for frame in frames])
33
+ frames = frames.to(device).to(torch.float32)
34
+ return rearrange(frames[0], "F W H C -> F C W H")
35
+
36
+ def ad_short_gen(prompt, ad_model, inference_generator, t=25, device="cuda"):
37
+ frames = ad_model(prompt,
38
+ negative_prompt="bad quality, worse quality",
39
+ num_frames=16,
40
+ num_inference_steps=t,
41
+ generator=inference_generator,
42
+ guidance_scale=7.5).frames[0]
43
+ frames = torch.stack([transform(frame) for frame in frames])
44
+ frames = frames.to(device).to(torch.float32)
45
+ frames = F.interpolate(frames, size=256)
46
+ frames = frames/255.0
47
+ return frames
48
+
49
+ def sdxl_image_gen(prompt, sdxl_model):
50
+ image = sdxl_model(prompt=prompt).images[0]
51
+ return image
52
+
53
+ def svd_short_gen(image, prompt, svd_model, sdxl_model, inference_generator, t=25, device="cuda"):
54
+ if image is None or image == "":
55
+ image = sdxl_image_gen(prompt, sdxl_model)
56
+ image = image.resize((576, 576))
57
+ image = add_margin(image, 0, 224, 0, 224, (0, 0, 0))
58
+ else:
59
+ image = load_image(image)
60
+ image = resize_and_keep(image)
61
+ image = center_crop(image)
62
+ image = add_margin(image, 0, 224, 0, 224, (0, 0, 0))
63
+
64
+ frames = svd_model(image, decode_chunk_size=8, generator=inference_generator).frames[0]
65
+ frames = torch.stack([transform(frame) for frame in frames])
66
+ frames = frames.to(device).to(torch.float32)
67
+ frames = frames[:16,:,:,224:-224]
68
+ frames = F.interpolate(frames, size=256)
69
+ frames = frames/255.0
70
+ return frames
71
+
72
+
73
+ def stream_long_gen(prompt, short_video, n_autoreg_gen, n_prompt, seed, t, image_guidance, result_file_stem, stream_cli, stream_model):
74
+ trainer = stream_cli.trainer
75
+ trainer.limit_predict_batches = 1
76
+ trainer.predict_cfg = {
77
+ "predict_dir": stream_cli.config["result_fol"].as_posix(),
78
+ "result_file_stem": result_file_stem,
79
+ "prompt": prompt,
80
+ "video": short_video,
81
+ "seed": seed,
82
+ "num_inference_steps": t,
83
+ "guidance_scale": image_guidance,
84
+ 'n_autoregressive_generations': n_autoreg_gen,
85
+ }
86
+
87
+ trainer.predict(model=stream_model, datamodule=stream_cli.datamodule)
88
+
89
+
90
+ def video2video(prompt, video, where_to_log, cfg_v2v, model_v2v, square=True):
91
+ downscale = cfg_v2v['downscale']
92
+ upscale_size = cfg_v2v['upscale_size']
93
+ pad = cfg_v2v['pad']
94
+
95
+ now = datetime.datetime.now()
96
+ name = prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_")
97
+ enhanced_video_mp4 = opj(where_to_log, name+"_enhanced.mp4")
98
+
99
+ video_frames = imageio.mimread(video)
100
+ h, w, _ = video_frames[0].shape
101
+
102
+ # Downscale video, then resize to fit the upscale size
103
+ video = [Image.fromarray(frame).resize((w//downscale, h//downscale)) for frame in video_frames]
104
+ video = [resize_to_fit(frame, upscale_size) for frame in video]
105
+
106
+ if pad:
107
+ video = [pad_to_fit(frame, upscale_size) for frame in video]
108
+ # video = [np.array(frame) for frame in video]
109
+
110
+ imageio.mimsave(opj(where_to_log, 'temp.mp4'), video, fps=8)
111
+
112
+ p_input = {
113
+ 'video_path': opj(where_to_log, 'temp.mp4'),
114
+ 'text': prompt
115
+ }
116
+ output_video_path = model_v2v(p_input, output_video=enhanced_video_mp4)[OutputKeys.OUTPUT_VIDEO]
117
+ return enhanced_video_mp4
t2v_enhanced/model_init.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # General
2
+ import sys
3
+ from pathlib import Path
4
+ import torch
5
+ from pytorch_lightning import LightningDataModule
6
+
7
+ # For Stage-1
8
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
9
+ from diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter
10
+ from diffusers import StableVideoDiffusionPipeline, AutoPipelineForText2Image
11
+
12
+ # For Stage-2
13
+ import tempfile
14
+ import yaml
15
+ from t2v_enhanced.model.video_ldm import VideoLDM
16
+ from model.callbacks import SaveConfigCallback
17
+ from inference_utils import legacy_transformation, remove_value, CustomCLI
18
+
19
+ # For Stage-3
20
+ from modelscope.pipelines import pipeline
21
+
22
+
23
+ # Initialize Stage-1 model1.
24
+ def init_modelscope(device="cuda"):
25
+ pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
26
+ # pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
27
+ # pipe.set_progress_bar_config(disable=True)
28
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
29
+ pipe.enable_model_cpu_offload()
30
+ pipe.enable_vae_slicing()
31
+ pipe.set_progress_bar_config(disable=True)
32
+ return pipe.to(device)
33
+
34
+ def init_zeroscope(device="cuda"):
35
+ pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
36
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
37
+ pipe.enable_model_cpu_offload()
38
+ return pipe.to(device)
39
+
40
+ def init_animatediff(device="cuda"):
41
+ adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
42
+ model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
43
+ pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)
44
+ scheduler = DDIMScheduler.from_pretrained(
45
+ model_id,
46
+ subfolder="scheduler",
47
+ clip_sample=False,
48
+ timestep_spacing="linspace",
49
+ beta_schedule="linear",
50
+ steps_offset=1,
51
+ )
52
+ pipe.scheduler = scheduler
53
+ pipe.enable_vae_slicing()
54
+ pipe.enable_model_cpu_offload()
55
+ return pipe.to(device)
56
+
57
+ def init_sdxl(device="cuda"):
58
+ pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
59
+ # pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
60
+ return pipe.to(device)
61
+
62
+ def init_svd(device="cuda"):
63
+ pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
64
+ pipe.enable_model_cpu_offload()
65
+ return pipe.to(device)
66
+
67
+
68
+ # Initialize StreamingT2V model.
69
+ def init_streamingt2v_model(ckpt_file, result_fol):
70
+ config_file = "configs/text_to_video/config.yaml"
71
+ sys.argv = sys.argv[:1]
72
+ with tempfile.TemporaryDirectory() as tmpdirname:
73
+ storage_fol = Path(tmpdirname)
74
+ with open(config_file, "r") as yaml_handle:
75
+ yaml_obj = yaml.safe_load(yaml_handle)
76
+
77
+ yaml_obj_orig_data_cfg = legacy_transformation(yaml_obj)
78
+ yaml_obj_orig_data_cfg = remove_value(yaml_obj_orig_data_cfg, "video_dataset")
79
+
80
+ with open(storage_fol / 'config.yaml', 'w') as outfile:
81
+ yaml.dump(yaml_obj_orig_data_cfg, outfile, default_flow_style=False)
82
+ sys.argv.append("--config")
83
+ sys.argv.append((storage_fol / 'config.yaml').as_posix())
84
+ sys.argv.append("--ckpt")
85
+ sys.argv.append(ckpt_file.as_posix())
86
+ sys.argv.append("--result_fol")
87
+ sys.argv.append(result_fol.as_posix())
88
+ sys.argv.append("--config")
89
+ sys.argv.append("configs/inference/inference_long_video.yaml")
90
+ sys.argv.append("--data.prompt_cfg.type=prompt")
91
+ sys.argv.append(f"--data.prompt_cfg.content='test prompt for initialization'")
92
+ sys.argv.append("--trainer.devices=1")
93
+ sys.argv.append("--trainer.num_nodes=1")
94
+ sys.argv.append(f"--model.inference_params.num_inference_steps=50")
95
+ sys.argv.append(f"--model.inference_params.n_autoregressive_generations=4")
96
+ sys.argv.append("--model.inference_params.concat_video=True")
97
+ sys.argv.append("--model.inference_params.result_formats=[eval_mp4]")
98
+
99
+ cli = CustomCLI(VideoLDM, LightningDataModule, run=False, subclass_mode_data=True,
100
+ auto_configure_optimizers=False, parser_kwargs={"parser_mode": "omegaconf"}, save_config_callback=SaveConfigCallback, save_config_kwargs={"log_dir": result_fol, "overwrite": True})
101
+
102
+ model = cli.model
103
+ model.load_state_dict(torch.load(
104
+ cli.config["ckpt"].as_posix())["state_dict"])
105
+ return cli, model
106
+
107
+
108
+ # Initialize Stage-3 model.
109
+ def init_v2v_model(cfg):
110
+ model_id = cfg['model_id']
111
+ pipe_enhance = pipeline(task="video-to-video", model=model_id, model_revision='v1.1.0', device='cuda')
112
+ return pipe_enhance
t2v_enhanced/utils/conversions.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import PIL
3
+ from PIL import Image
4
+ import numpy as np
5
+ from dataclasses import dataclass
6
+
7
+ # TODO add register new converter so that it is accessible via converters.to_x
8
+
9
+ def ensure_class(func, params):
10
+ def func_wrapper(function):
11
+ def wrapper(self=None, *args, **kwargs):
12
+ for key in kwargs:
13
+ if key in params:
14
+ kwargs[key] = func(kwargs[key])
15
+ if self is not None:
16
+ return function(self, *args, **kwargs)
17
+ else:
18
+ return function(*args, **kwargs)
19
+
20
+ return wrapper
21
+
22
+ return func_wrapper
23
+
24
+
25
+ def as_PIL(img):
26
+ if not isinstance(img, PIL.Image.Image):
27
+ if isinstance(img, Path):
28
+ img = img.as_posix()
29
+ if isinstance(img, str):
30
+ img = Image.open(img)
31
+ elif isinstance(img, np.ndarray):
32
+ img = Image.fromarray(img)
33
+
34
+ else:
35
+ raise NotImplementedError
36
+ return img
37
+
38
+
39
+ def to_ndarray(input):
40
+ if not isinstance(input, np.ndarray):
41
+ input = np.array(input)
42
+ return input
43
+
44
+
45
+ def to_Path(input):
46
+ if not isinstance(input, Path):
47
+ input = Path(input)
48
+ return input
t2v_enhanced/utils/iimage.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import math
3
+ import os
4
+ import PIL.Image
5
+ import numpy as np
6
+ import imageio.v3 as iio
7
+ import warnings
8
+
9
+
10
+ import torch
11
+ import torchvision.transforms.functional as TF
12
+ from scipy.ndimage import binary_dilation, binary_erosion
13
+ import cv2
14
+
15
+ import re
16
+
17
+ import matplotlib.pyplot as plt
18
+ from matplotlib import animation
19
+ from IPython.display import HTML, Image, display
20
+
21
+
22
+ IMG_THUMBSIZE = None
23
+
24
+ def torch2np(x, vmin=-1, vmax=1):
25
+ if x.ndim != 4:
26
+ # raise Exception("Please only use (B,C,H,W) torch tensors!")
27
+ warnings.warn(
28
+ "Warning! Shape of the image was not provided in (B,C,H,W) format, the shape was inferred automatically!")
29
+ if x.ndim == 3:
30
+ x = x[None]
31
+ if x.ndim == 2:
32
+ x = x[None, None]
33
+ x = x.detach().cpu().float()
34
+ if x.dtype == torch.uint8:
35
+ return x.numpy().astype(np.uint8)
36
+ elif vmin is not None and vmax is not None:
37
+ x = (255 * (x.clip(vmin, vmax) - vmin) / (vmax - vmin))
38
+ x = x.permute(0, 2, 3, 1).to(torch.uint8)
39
+ return x.numpy()
40
+ else:
41
+ raise NotImplementedError()
42
+
43
+
44
+ class IImage:
45
+ '''
46
+ Generic media storage. Can store both images and videos.
47
+ Stores data as a numpy array by default.
48
+ Can be viewed in a jupyter notebook.
49
+ '''
50
+ @staticmethod
51
+ def open(path):
52
+
53
+ iio_obj = iio.imopen(path, 'r')
54
+ data = iio_obj.read()
55
+ try:
56
+ # .properties() does not work for images but for gif files
57
+ if not iio_obj.properties().is_batch:
58
+ data = data[None]
59
+ except AttributeError as e:
60
+ # this one works for gif files
61
+ if not "duration" in iio_obj.metadata():
62
+ data = data[None]
63
+ if data.ndim == 3:
64
+ data = data[..., None]
65
+ image = IImage(data)
66
+ image.link = os.path.abspath(path)
67
+ return image
68
+
69
+ @staticmethod
70
+ def normalized(x, dims=[-1, -2]):
71
+ x = (x - x.amin(dims, True)) / \
72
+ (x.amax(dims, True) - x.amin(dims, True))
73
+ return IImage(x, 0)
74
+
75
+ def numpy(self): return self.data
76
+
77
+ def torch(self, vmin=-1, vmax=1):
78
+ if self.data.ndim == 3:
79
+ data = self.data.transpose(2, 0, 1) / 255.
80
+ else:
81
+ data = self.data.transpose(0, 3, 1, 2) / 255.
82
+ return vmin + torch.from_numpy(data).float().to(self.device) * (vmax - vmin)
83
+
84
+ def cuda(self):
85
+ self.device = 'cuda'
86
+ return self
87
+
88
+ def cpu(self):
89
+ self.device = 'cpu'
90
+ return self
91
+
92
+ def pil(self):
93
+ ans = []
94
+ for x in self.data:
95
+ if x.shape[-1] == 1:
96
+ x = x[..., 0]
97
+
98
+ ans.append(PIL.Image.fromarray(x))
99
+ if len(ans) == 1:
100
+ return ans[0]
101
+ return ans
102
+
103
+ def is_iimage(self):
104
+ return True
105
+
106
+ @property
107
+ def shape(self): return self.data.shape
108
+ @property
109
+ def size(self): return (self.data.shape[-2], self.data.shape[-3])
110
+
111
+ def setFps(self, fps):
112
+ self.fps = fps
113
+ self.generate_display()
114
+ return self
115
+
116
+ def __init__(self, x, vmin=-1, vmax=1, fps=None):
117
+ if isinstance(x, PIL.Image.Image):
118
+ self.data = np.array(x)
119
+ if self.data.ndim == 2:
120
+ self.data = self.data[..., None] # (H,W,C)
121
+ self.data = self.data[None] # (B,H,W,C)
122
+ elif isinstance(x, IImage):
123
+ self.data = x.data.copy() # Simple Copy
124
+ elif isinstance(x, np.ndarray):
125
+ self.data = x.copy().astype(np.uint8)
126
+ if self.data.ndim == 2:
127
+ self.data = self.data[None, ..., None]
128
+ if self.data.ndim == 3:
129
+ warnings.warn(
130
+ "Inferred dimensions for a 3D array as (H,W,C), but could've been (B,H,W)")
131
+ self.data = self.data[None]
132
+ elif isinstance(x, torch.Tensor):
133
+ self.data = torch2np(x, vmin, vmax)
134
+ self.display_str = None
135
+ self.device = 'cpu'
136
+ self.fps = fps if fps is not None else (
137
+ 1 if len(self.data) < 10 else 30)
138
+ self.link = None
139
+
140
+ def generate_display(self):
141
+ if IMG_THUMBSIZE is not None:
142
+ if self.size[1] < self.size[0]:
143
+ thumb = self.resize(
144
+ (self.size[1]*IMG_THUMBSIZE//self.size[0], IMG_THUMBSIZE))
145
+ else:
146
+ thumb = self.resize(
147
+ (IMG_THUMBSIZE, self.size[0]*IMG_THUMBSIZE//self.size[1]))
148
+ else:
149
+ thumb = self
150
+ if self.is_video():
151
+ self.anim = Animation(thumb.data, fps=self.fps)
152
+ self.anim.render()
153
+ self.display_str = self.anim.anim_str
154
+ else:
155
+ b = io.BytesIO()
156
+ data = thumb.data[0]
157
+ if data.shape[-1] == 1:
158
+ data = data[..., 0]
159
+ PIL.Image.fromarray(data).save(b, "PNG")
160
+ self.display_str = b.getvalue()
161
+ return self.display_str
162
+
163
+ def resize(self, size, *args, **kwargs):
164
+ if size is None:
165
+ return self
166
+ use_small_edge_when_int = kwargs.pop('use_small_edge_when_int', False)
167
+
168
+ # Backward compatibility
169
+ resample = kwargs.pop('filter', PIL.Image.BICUBIC)
170
+ resample = kwargs.pop('resample', resample)
171
+
172
+ if isinstance(size, int):
173
+ if use_small_edge_when_int:
174
+ h, w = self.data.shape[1:3]
175
+ aspect_ratio = h / w
176
+ size = (max(size, int(size * aspect_ratio)),
177
+ max(size, int(size / aspect_ratio)))
178
+ else:
179
+ h, w = self.data.shape[1:3]
180
+ aspect_ratio = h / w
181
+ size = (min(size, int(size * aspect_ratio)),
182
+ min(size, int(size / aspect_ratio)))
183
+
184
+ if self.size == size[::-1]:
185
+ return self
186
+ return stack([IImage(x.pil().resize(size[::-1], *args, resample=resample, **kwargs)) for x in self])
187
+
188
+ def pad(self, padding, *args, **kwargs):
189
+ return IImage(TF.pad(self.torch(0), padding=padding, *args, **kwargs), 0)
190
+
191
+ def padx(self, multiplier, *args, **kwargs):
192
+ size = np.array(self.size)
193
+ padding = np.concatenate(
194
+ [[0, 0], np.ceil(size / multiplier).astype(int) * multiplier - size])
195
+ return self.pad(list(padding), *args, **kwargs)
196
+
197
+ def pad2wh(self, w=0, h=0, **kwargs):
198
+ cw, ch = self.size
199
+ return self.pad([0, 0, max(0, w - cw), max(0, h-ch)], **kwargs)
200
+
201
+ def pad2square(self, *args, **kwargs):
202
+ if self.size[0] > self.size[1]:
203
+ dx = self.size[0] - self.size[1]
204
+ return self.pad([0, dx//2, 0, dx-dx//2], *args, **kwargs)
205
+ elif self.size[0] < self.size[1]:
206
+ dx = self.size[1] - self.size[0]
207
+ return self.pad([dx//2, 0, dx-dx//2, 0], *args, **kwargs)
208
+ return self
209
+
210
+ def crop2square(self, *args, **kwargs):
211
+ if self.size[0] > self.size[1]:
212
+ dx = self.size[0] - self.size[1]
213
+ return self.crop([dx//2, 0, self.size[1], self.size[1]], *args, **kwargs)
214
+ elif self.size[0] < self.size[1]:
215
+ dx = self.size[1] - self.size[0]
216
+ return self.crop([0, dx//2, self.size[0], self.size[0]], *args, **kwargs)
217
+ return self
218
+
219
+ def alpha(self):
220
+ return IImage(self.data[..., -1, None], fps=self.fps)
221
+
222
+ def rgb(self):
223
+ return IImage(self.pil().convert('RGB'), fps=self.fps)
224
+
225
+ def png(self):
226
+ return IImage(np.concatenate([self.data, 255 * np.ones_like(self.data)[..., :1]], -1))
227
+
228
+ def grid(self, nrows=None, ncols=None):
229
+ if nrows is not None:
230
+ ncols = math.ceil(self.data.shape[0] / nrows)
231
+ elif ncols is not None:
232
+ nrows = math.ceil(self.data.shape[0] / ncols)
233
+ else:
234
+ warnings.warn(
235
+ "No dimensions specified, creating a grid with 5 columns (default)")
236
+ ncols = 5
237
+ nrows = math.ceil(self.data.shape[0] / ncols)
238
+
239
+ pad = nrows * ncols - self.data.shape[0]
240
+ data = np.pad(self.data, ((0, pad), (0, 0), (0, 0), (0, 0)))
241
+ rows = [np.concatenate(x, 1, dtype=np.uint8)
242
+ for x in np.array_split(data, nrows)]
243
+ return IImage(np.concatenate(rows, 0, dtype=np.uint8)[None])
244
+
245
+ def hstack(self):
246
+ return IImage(np.concatenate(self.data, 1, dtype=np.uint8)[None])
247
+
248
+ def vstack(self):
249
+ return IImage(np.concatenate(self.data, 0, dtype=np.uint8)[None])
250
+
251
+ def vsplit(self, number_of_splits):
252
+ return IImage(np.concatenate(np.split(self.data, number_of_splits, 1)))
253
+
254
+ def hsplit(self, number_of_splits):
255
+ return IImage(np.concatenate(np.split(self.data, number_of_splits, 2)))
256
+
257
+ def heatmap(self, resize=None, cmap=cv2.COLORMAP_JET):
258
+ data = np.stack([cv2.cvtColor(cv2.applyColorMap(
259
+ x, cmap), cv2.COLOR_BGR2RGB) for x in self.data])
260
+ return IImage(data).resize(resize, use_small_edge_when_int=True)
261
+
262
+ def display(self):
263
+ try:
264
+ display(self)
265
+ except:
266
+ print("No display")
267
+ return self
268
+
269
+ def dilate(self, iterations=1, *args, **kwargs):
270
+ if iterations == 0:
271
+ return IImage(self.data)
272
+ return IImage((binary_dilation(self.data, iterations=iterations, *args, *kwargs)*255.).astype(np.uint8))
273
+
274
+ def erode(self, iterations=1, *args, **kwargs):
275
+ return IImage((binary_erosion(self.data, iterations=iterations, *args, *kwargs)*255.).astype(np.uint8))
276
+
277
+ def hull(self):
278
+ convex_hulls = []
279
+ for frame in self.data:
280
+ contours, hierarchy = cv2.findContours(
281
+ frame, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
282
+ contours = [x.astype(np.int32) for x in contours]
283
+ mask_contours = [cv2.convexHull(np.concatenate(contours))]
284
+ canvas = np.zeros(self.data[0].shape, np.uint8)
285
+ convex_hull = cv2.drawContours(
286
+ canvas, mask_contours, -1, (255, 0, 0), -1)
287
+ convex_hulls.append(convex_hull)
288
+ return IImage(np.array(convex_hulls))
289
+
290
+ def is_video(self):
291
+ return self.data.shape[0] > 1
292
+
293
+ def __getitem__(self, idx):
294
+ return IImage(self.data[None, idx], fps=self.fps)
295
+ # if self.is_video(): return IImage(self.data[idx], fps = self.fps)
296
+ # return self
297
+
298
+ def _repr_png_(self):
299
+ if self.is_video():
300
+ return None
301
+ if self.display_str is None:
302
+ self.generate_display()
303
+ return self.display_str
304
+
305
+ def _repr_html_(self):
306
+ if not self.is_video():
307
+ return None
308
+ if self.display_str is None:
309
+ self.generate_display()
310
+ return self.display_str
311
+
312
+ def save(self, path):
313
+ _, ext = os.path.splitext(path)
314
+ if self.is_video():
315
+ # if ext in ['.jpg', '.png']:
316
+ if self.display_str is None:
317
+ self.generate_display()
318
+ if ext == ".apng":
319
+ self.anim.anim_obj.save(path, writer="pillow")
320
+ else:
321
+ self.anim.anim_obj.save(path)
322
+ else:
323
+ data = self.data if self.data.ndim == 3 else self.data[0]
324
+ if data.shape[-1] == 1:
325
+ data = data[:, :, 0]
326
+ PIL.Image.fromarray(data).save(path)
327
+ return self
328
+
329
+ def write(self, text, center=(0, 25), font_scale=0.8, color=(255, 255, 255), thickness=2):
330
+ if not isinstance(text, list):
331
+ text = [text for _ in self.data]
332
+ data = np.stack([cv2.putText(x.copy(), t, center, cv2.FONT_HERSHEY_COMPLEX,
333
+ font_scale, color, thickness) for x, t in zip(self.data, text)])
334
+ return IImage(data)
335
+
336
+ def append_text(self, text, padding, font_scale=0.8, color=(255, 255, 255), thickness=2, scale_factor=0.9, center=(0, 0), fill=0):
337
+
338
+ assert np.count_nonzero(padding) == 1
339
+ axis_padding = np.nonzero(padding)[0][0]
340
+ scale_padding = padding[axis_padding]
341
+
342
+ y_0 = 0
343
+ x_0 = 0
344
+ if axis_padding == 0:
345
+ width = scale_padding
346
+ y_max = self.shape[1]
347
+ elif axis_padding == 1:
348
+ width = self.shape[2]
349
+ y_max = scale_padding
350
+ elif axis_padding == 2:
351
+ x_0 = self.shape[2]
352
+ width = scale_padding
353
+ y_max = self.shape[1]
354
+ elif axis_padding == 3:
355
+ width = self.shape[2]
356
+ y_0 = self.shape[1]
357
+ y_max = self.shape[1]+scale_padding
358
+
359
+ width -= center[0]
360
+ x_0 += center[0]
361
+ y_0 += center[1]
362
+
363
+ self = self.pad(padding, fill=fill)
364
+
365
+ def wrap_text(text, width, _font_scale):
366
+ allowed_seperator = ' |-|_|/|\n'
367
+ words = re.split(allowed_seperator, text)
368
+ # words = text.split()
369
+ lines = []
370
+ current_line = words[0]
371
+ sep_list = []
372
+ start_idx = 0
373
+ for start_word in words[:-1]:
374
+ pos = text.find(start_word, start_idx)
375
+ pos += len(start_word)
376
+ sep_list.append(text[pos])
377
+ start_idx = pos+1
378
+
379
+ for word, separator in zip(words[1:], sep_list):
380
+ if cv2.getTextSize(current_line + separator + word, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width:
381
+ current_line += separator + word
382
+ else:
383
+ if cv2.getTextSize(current_line, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width:
384
+ lines.append(current_line)
385
+ current_line = word
386
+ else:
387
+ return []
388
+
389
+ if cv2.getTextSize(current_line, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width:
390
+ lines.append(current_line)
391
+ else:
392
+ return []
393
+ return lines
394
+
395
+ def wrap_text_and_scale(text, width, _font_scale, y_0, y_max):
396
+ height = y_max+1
397
+ while height > y_max:
398
+ text_lines = wrap_text(text, width, _font_scale)
399
+ if len(text) > 0 and len(text_lines) == 0:
400
+
401
+ height = y_max+1
402
+ else:
403
+ line_height = cv2.getTextSize(
404
+ text_lines[0], cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][1]
405
+ height = line_height * len(text_lines) + y_0
406
+
407
+ # scale font if out of frame
408
+ if height > y_max:
409
+ _font_scale = _font_scale * scale_factor
410
+
411
+ return text_lines, line_height, _font_scale
412
+
413
+ result = []
414
+ if not isinstance(text, list):
415
+ text = [text for _ in self.data]
416
+ else:
417
+ assert len(text) == len(self.data)
418
+
419
+ for x, t in zip(self.data, text):
420
+ x = x.copy()
421
+ text_lines, line_height, _font_scale = wrap_text_and_scale(
422
+ t, width, font_scale, y_0, y_max)
423
+ y = line_height
424
+ for line in text_lines:
425
+ x = cv2.putText(
426
+ x, line, (x_0, y_0+y), cv2.FONT_HERSHEY_COMPLEX, _font_scale, color, thickness)
427
+ y += line_height
428
+ result.append(x)
429
+ data = np.stack(result)
430
+
431
+ return IImage(data)
432
+
433
+ # ========== OPERATORS =============
434
+
435
+ def __or__(self, other):
436
+ # TODO: fix for variable sizes
437
+ return IImage(np.concatenate([self.data, other.data], 2))
438
+
439
+ def __truediv__(self, other):
440
+ # TODO: fix for variable sizes
441
+ return IImage(np.concatenate([self.data, other.data], 1))
442
+
443
+ def __and__(self, other):
444
+ return IImage(np.concatenate([self.data, other.data], 0))
445
+
446
+ def __add__(self, other):
447
+ return IImage(0.5 * self.data + 0.5 * other.data)
448
+
449
+ def __mul__(self, other):
450
+ if isinstance(other, IImage):
451
+ return IImage(self.data / 255. * other.data)
452
+ return IImage(self.data * other / 255.)
453
+
454
+ def __xor__(self, other):
455
+ return IImage(0.5 * self.data + 0.5 * other.data + 0.5 * self.data * (other.data.sum(-1, keepdims=True) == 0))
456
+
457
+ def __invert__(self):
458
+ return IImage(255 - self.data)
459
+ __rmul__ = __mul__
460
+
461
+ def bbox(self):
462
+ return [cv2.boundingRect(x) for x in self.data]
463
+
464
+ def fill_bbox(self, bbox_list, fill=255):
465
+ data = self.data.copy()
466
+ for bbox in bbox_list:
467
+ x, y, w, h = bbox
468
+ data[:, y:y+h, x:x+w, :] = fill
469
+ return IImage(data)
470
+
471
+ def crop(self, bbox):
472
+ assert len(bbox) in [2, 4]
473
+ if len(bbox) == 2:
474
+ x, y = 0, 0
475
+ w, h = bbox
476
+ elif len(bbox) == 4:
477
+ x, y, w, h = bbox
478
+ return IImage(self.data[:, y:y+h, x:x+w, :])
479
+
480
+ def stack(images, axis = 0):
481
+ return IImage(np.concatenate([x.data for x in images], axis))
482
+
483
+ class Animation:
484
+ JS = 0
485
+ HTML = 1
486
+ ANIMATION_MODE = HTML
487
+ def __init__(self, frames, fps = 30):
488
+ """_summary_
489
+
490
+ Args:
491
+ frames (np.ndarray): _description_
492
+ """
493
+ self.frames = frames
494
+ self.fps = fps
495
+ self.anim_obj = None
496
+ self.anim_str = None
497
+ def render(self):
498
+ size = (self.frames.shape[2],self.frames.shape[1])
499
+ self.fig = plt.figure(figsize = size, dpi = 1)
500
+ plt.axis('off')
501
+ img = plt.imshow(self.frames[0], cmap = 'gray')
502
+ self.fig.subplots_adjust(0,0,1,1)
503
+ self.anim_obj = animation.FuncAnimation(
504
+ self.fig,
505
+ lambda i: img.set_data(self.frames[i,:,:,:]),
506
+ frames=self.frames.shape[0],
507
+ interval = 1000 / self.fps
508
+ )
509
+ plt.close()
510
+ if Animation.ANIMATION_MODE == Animation.HTML:
511
+ self.anim_str = self.anim_obj.to_html5_video()
512
+ elif Animation.ANIMATION_MODE == Animation.JS:
513
+ self.anim_str = self.anim_obj.to_jshtml()
514
+ return self.anim_obj
515
+ def _repr_html_(self):
516
+ if self.anim_obj is None: self.render()
517
+ return self.anim_str
t2v_enhanced/utils/image_converter.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from albumentations.augmentations.geometric import functional as F
4
+ from albumentations.core.transforms_interface import DualTransform
5
+
6
+ __all__ = ["ProportionalMinScale"]
7
+
8
+
9
+ class ProportionalMinScale(DualTransform):
10
+
11
+ def __init__(
12
+ self,
13
+ width: int,
14
+ height: int,
15
+ interpolation: int = cv2.INTER_LINEAR,
16
+ always_apply: bool = False,
17
+ p: float = 1,
18
+ ):
19
+ super(ProportionalMinScale, self).__init__(always_apply, p)
20
+ self.width = width
21
+ self.height = height
22
+
23
+ def apply(
24
+ self, img: np.ndarray, width: int = 256, height: int = 256, interpolation: int = cv2.INTER_LINEAR, **params):
25
+ h_img, w_img, _ = img.shape
26
+
27
+ min_side = np.min([h_img, w_img])
28
+
29
+ if (height/h_img)*w_img >= width:
30
+ if h_img == min_side:
31
+ return F.smallest_max_size(img, max_size=height, interpolation=interpolation)
32
+ else:
33
+ return F.longest_max_size(img, max_size=height, interpolation=interpolation)
34
+ if (width/w_img)*h_img >= height:
35
+ if w_img == min_side:
36
+ return F.smallest_max_size(img, max_size=width, interpolation=interpolation)
37
+ else:
38
+ return F.longest_max_size(img, max_size=width, interpolation=interpolation)
39
+ return F.longest_max_size(img, max_size=width, interpolation=interpolation)
40
+
41
+ def get_params(self):
42
+ return {"width": self.width, "height": self.height}
43
+
44
+ def get_transform_init_args_names(self):
45
+ return ("width", "height", "intepolation")
t2v_enhanced/utils/object_loader.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from functools import partialmethod
3
+
4
+
5
+ def instantiate_object(cls_path: str, *args, **kwargs):
6
+ class_ = get_class(cls_path, *args, **kwargs)
7
+ obj = class_()
8
+ return obj
9
+
10
+
11
+ def get_class(cls_path: str, *args, **kwargs):
12
+ module_name = ".".join(cls_path.split(".")[:-1])
13
+ module = importlib.import_module(module_name)
14
+
15
+ class_ = getattr(module, cls_path.split(".")[-1])
16
+ class_.__init__ = partialmethod(class_.__init__, *args, **kwargs)
17
+ return class_
18
+
19
+
20
+ if __name__ == "__main__":
21
+
22
+ class_ = get_class(
23
+ "diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler")
24
+ scheduler = class_.from_config("stabilityai/stable-diffusion-2-1",
25
+ subfolder="scheduler")
26
+ print(scheduler)
t2v_enhanced/utils/video_utils.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import tempfile
4
+ from pathlib import Path
5
+ from typing import Union
6
+ import shutil
7
+
8
+ import cv2
9
+ import imageio
10
+ import numpy as np
11
+ import torch
12
+ import torchvision
13
+ from decord import VideoReader, cpu
14
+ from einops import rearrange, repeat
15
+ from t2v_enhanced.utils.iimage import IImage
16
+ from PIL import Image, ImageDraw, ImageFont
17
+ from torchvision.utils import save_image
18
+
19
+ channel_first = 0
20
+ channel_last = -1
21
+
22
+
23
+ def video_naming(prompt, extension, batch_idx, idx):
24
+ prompt_identifier = prompt.replace(" ", "_")
25
+ prompt_identifier = prompt_identifier.replace("/", "_")
26
+ if len(prompt_identifier) > 40:
27
+ prompt_identifier = prompt_identifier[:40]
28
+ filename = f"{batch_idx:04d}_{idx:04d}_{prompt_identifier}.{extension}"
29
+ return filename
30
+
31
+
32
+ def video_naming_chunk(prompt, extension, batch_idx, idx, chunk_idx):
33
+ prompt_identifier = prompt.replace(" ", "_")
34
+ prompt_identifier = prompt_identifier.replace("/", "_")
35
+ if len(prompt_identifier) > 40:
36
+ prompt_identifier = prompt_identifier[:40]
37
+ filename = f"{batch_idx}_{idx}_{chunk_idx}_{prompt_identifier}.{extension}"
38
+ return filename
39
+
40
+
41
+ class ResultProcessor():
42
+
43
+ def __init__(self, fps: int, n_frames: int, logger=None) -> None:
44
+ self.fps = fps
45
+ self.logger = logger
46
+ self.n_frames = n_frames
47
+
48
+ def set_logger(self, logger):
49
+ self.logger = logger
50
+
51
+ def _create_video(self, video, prompt, filename: Union[str, Path], append_video: torch.FloatTensor = None, input_flow=None):
52
+
53
+ if video.ndim == 5:
54
+ # can be batches if we provide list of filenames
55
+ assert video.shape[0] == 1
56
+ video = video[0]
57
+
58
+ if video.shape[0] == 3 and video.shape[1] == self.n_frames:
59
+ video = rearrange(video, "C F W H -> F C W H")
60
+ assert video.shape[1] == 3, f"Wrong video format. Got {video.shape}"
61
+ if isinstance(filename, Path):
62
+ filename = filename.as_posix()
63
+ # assert video.max() <= 1 and video.min() >= 0
64
+ assert video.max() <=1.1 and video.min() >= -0.1, f"video has unexpected range: [{video.min()}, {video.max()}]"
65
+ vid_obj = IImage(video, vmin=0, vmax=1)
66
+
67
+ if prompt is not None:
68
+ vid_obj = vid_obj.append_text(prompt, padding=(0, 50, 0, 0))
69
+
70
+ if append_video is not None:
71
+ if append_video.ndim == 5:
72
+ assert append_video.shape[0] == 1
73
+ append_video = append_video[0]
74
+ if append_video.shape[0] < video.shape[0]:
75
+ append_video = torch.concat([append_video,
76
+ repeat(append_video[-1, None], "F C W H -> (rep F) C W H", rep=video.shape[0]-append_video.shape[0])], dim=0)
77
+ if append_video.ndim == 3 and video.ndim == 4:
78
+ append_video = repeat(
79
+ append_video, "C W H -> F C W H", F=video.shape[0])
80
+ append_video = IImage(append_video, vmin=-1, vmax=1)
81
+ if prompt is not None:
82
+ append_video = append_video.append_text(
83
+ "input_frame", padding=(0, 50, 0, 0))
84
+ vid_obj = vid_obj | append_video
85
+ vid_obj = vid_obj.setFps(self.fps)
86
+ vid_obj.save(filename)
87
+
88
+ def _create_prompt_file(self, prompt, filename, video_path: str = None):
89
+ filename = Path(filename)
90
+ filename = filename.parent / (filename.stem+".txt")
91
+
92
+ with open(filename.as_posix(), "w") as file_writer:
93
+ file_writer.write(prompt)
94
+ file_writer.write("\n")
95
+ if video_path is not None:
96
+ file_writer.write(video_path)
97
+ else:
98
+ file_writer.write(" no_source")
99
+
100
+ def log_video(self, video: torch.FloatTensor, prompt: str, video_id: str, log_folder: str, input_flow=None, video_path_input: str = None, extension: str = "gif", prompt_on_vid: bool = True, append_video: torch.FloatTensor = None):
101
+
102
+ with tempfile.TemporaryDirectory() as tmpdirname:
103
+ storage_fol = Path(tmpdirname)
104
+ filename = f"{video_id}.{extension}".replace("/", "_")
105
+ vid_filename = storage_fol / filename
106
+ self._create_video(
107
+ video, prompt if prompt_on_vid else None, vid_filename, append_video, input_flow=input_flow)
108
+
109
+ prompt_file = storage_fol / f"{video_id}.txt"
110
+ self._create_prompt_file(prompt, prompt_file, video_path_input)
111
+
112
+ if self.logger.experiment.__class__.__name__ == "_DummyExperiment":
113
+ run_fol = Path(self.logger.save_dir) / \
114
+ self.logger.experiment_id / self.logger.run_id / "artifacts" / log_folder
115
+ if not run_fol.exists():
116
+ run_fol.mkdir(parents=True, exist_ok=True)
117
+ shutil.copy(prompt_file.as_posix(),
118
+ (run_fol / f"{video_id}.txt").as_posix())
119
+ shutil.copy(vid_filename,
120
+ (run_fol / filename).as_posix())
121
+ else:
122
+ self.logger.experiment.log_artifact(
123
+ self.logger.run_id, prompt_file.as_posix(), log_folder)
124
+ self.logger.experiment.log_artifact(
125
+ self.logger.run_id, vid_filename, log_folder)
126
+
127
+ def save_to_file(self, video: torch.FloatTensor, prompt: str, video_filename: Union[str, Path], input_flow=None, conditional_video_path: str = None, prompt_on_vid: bool = True, conditional_video: torch.FloatTensor = None):
128
+ self._create_video(
129
+ video, prompt if prompt_on_vid else None, video_filename, conditional_video, input_flow=input_flow)
130
+ self._create_prompt_file(
131
+ prompt, video_filename, conditional_video_path)
132
+
133
+
134
+ def add_text_to_image(image_array, text, position, font_size, text_color, font_path=None):
135
+
136
+ # Convert the NumPy array to PIL Image
137
+ image_pil = Image.fromarray(image_array)
138
+
139
+ # Create a drawing object
140
+ draw = ImageDraw.Draw(image_pil)
141
+
142
+ if font_path is not None:
143
+ font = ImageFont.truetype(font_path, font_size)
144
+ else:
145
+ try:
146
+ # Load the font
147
+ font = ImageFont.truetype(
148
+ "/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf", font_size)
149
+ except:
150
+ font = ImageFont.load_default()
151
+
152
+ # Draw the text on the image
153
+ draw.text(position, text, font=font, fill=text_color)
154
+
155
+ # Convert the PIL Image back to NumPy array
156
+ modified_image_array = np.array(image_pil)
157
+
158
+ return modified_image_array
159
+
160
+
161
+ def add_text_to_video(video_path, prompt):
162
+
163
+ outputs_with_overlay = []
164
+ with open(video_path, "rb") as f:
165
+ vr = VideoReader(f, ctx=cpu(0))
166
+
167
+ for i in range(len(vr)):
168
+ frame = vr[i]
169
+ frame = add_text_to_image(frame, prompt, position=(
170
+ 10, 10), font_size=15, text_color=(255, 0, 0),)
171
+ outputs_with_overlay.append(frame)
172
+ outputs = outputs_with_overlay
173
+ video_path = video_path.replace("mp4", "gif")
174
+ imageio.mimsave(video_path, outputs, duration=100, loop=0)
175
+
176
+
177
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=30, prompt=None):
178
+ videos = rearrange(videos, "b c t h w -> t b c h w")
179
+ outputs = []
180
+ for x in videos:
181
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
182
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
183
+ if rescale:
184
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
185
+ x = (x * 255).numpy().astype(np.uint8)
186
+ outputs.append(x)
187
+
188
+ os.makedirs(os.path.dirname(path), exist_ok=True)
189
+
190
+ if prompt is not None:
191
+ outputs_with_overlay = []
192
+ for frame in outputs:
193
+ frame_out = add_text_to_image(
194
+ frame, prompt, position=(10, 10), font_size=10, text_color=(255, 0, 0),)
195
+ outputs_with_overlay.append(frame_out)
196
+ outputs = outputs_with_overlay
197
+ imageio.mimsave(path, outputs, duration=round(1/fps*1000), loop=0)
198
+ # iio.imwrite(path, outputs)
199
+ # optimize(path)
200
+
201
+
202
+ def set_channel_pos(data, shape_dict, channel_pos):
203
+
204
+ assert data.ndim == 5 or data.ndim == 4
205
+ batch_dim = data.shape[0]
206
+ frame_dim = shape_dict["frame_dim"]
207
+ channel_dim = shape_dict["channel_dim"]
208
+ width_dim = shape_dict["width_dim"]
209
+ height_dim = shape_dict["height_dim"]
210
+
211
+ assert batch_dim != frame_dim
212
+ assert channel_dim != frame_dim
213
+ assert channel_dim != batch_dim
214
+
215
+ video_shape = list(data.shape)
216
+ batch_pos = video_shape.index(batch_dim)
217
+
218
+ channel_pos = video_shape.index(channel_dim)
219
+ w_pos = video_shape.index(width_dim)
220
+ h_pos = video_shape.index(height_dim)
221
+ if w_pos == h_pos:
222
+ video_shape[w_pos] = -1
223
+ h_pos = video_shape.index(height_dim)
224
+ pattern_order = {}
225
+ pattern_order[batch_pos] = "B"
226
+ pattern_order[channel_pos] = "C"
227
+
228
+ pattern_order[w_pos] = "W"
229
+ pattern_order[h_pos] = "H"
230
+
231
+ if data.ndim == 5:
232
+ frame_pos = video_shape.index(frame_dim)
233
+ pattern_order[frame_pos] = "F"
234
+ if channel_pos == channel_first:
235
+ pattern = " -> B F C W H"
236
+ else:
237
+ pattern = " -> B F W H C"
238
+ else:
239
+ if channel_pos == channel_first:
240
+ pattern = " -> B C W H"
241
+ else:
242
+ pattern = " -> B W H C"
243
+ pattern_input = [pattern_order[idx] for idx in range(data.ndim)]
244
+ pattern_input = " ".join(pattern_input)
245
+ pattern = pattern_input + pattern
246
+ data = rearrange(data, pattern)
247
+
248
+
249
+ def merge_first_two_dimensions(tensor):
250
+ dims = tensor.ndim
251
+ letters = []
252
+ for letter_idx in range(dims-2):
253
+ letters.append(chr(letter_idx+67))
254
+ latters_pattern = " ".join(letters)
255
+ tensor = rearrange(tensor, "A B "+latters_pattern +
256
+ " -> (A B) "+latters_pattern)
257
+ # TODO merging first two dimensions might be easier with reshape so no need to create letters
258
+ # should be 'tensor.view(*tensor.shape[:2], -1)'
259
+ return tensor
260
+
261
+
262
+ def apply_spatial_function_to_video_tensor(video, shape, func):
263
+ # TODO detect batch, frame, channel, width, and height
264
+
265
+ assert video.ndim == 5
266
+ batch_dim = shape["batch_dim"]
267
+ frame_dim = shape["frame_dim"]
268
+ channel_dim = shape["channel_dim"]
269
+ width_dim = shape["width_dim"]
270
+ height_dim = shape["height_dim"]
271
+
272
+ assert batch_dim != frame_dim
273
+ assert channel_dim != frame_dim
274
+ assert channel_dim != batch_dim
275
+
276
+ video_shape = list(video.shape)
277
+ batch_pos = video_shape.index(batch_dim)
278
+ frame_pos = video_shape.index(frame_dim)
279
+ channel_pos = video_shape.index(channel_dim)
280
+ w_pos = video_shape.index(width_dim)
281
+ h_pos = video_shape.index(height_dim)
282
+ if w_pos == h_pos:
283
+ video_shape[w_pos] = -1
284
+ h_pos = video_shape.index(height_dim)
285
+ pattern_order = {}
286
+ pattern_order[batch_pos] = "B"
287
+ pattern_order[channel_pos] = "C"
288
+ pattern_order[frame_pos] = "F"
289
+ pattern_order[w_pos] = "W"
290
+ pattern_order[h_pos] = "H"
291
+ pattern_order = sorted(pattern_order.items(), key=lambda x: x[1])
292
+ pattern_order = [x[0] for x in pattern_order]
293
+ input_pattern = " ".join(pattern_order)
294
+ video = rearrange(video, input_pattern+" -> (B F) C W H")
295
+
296
+ video = func(video)
297
+ video = rearrange(video, "(B F) C W H -> "+input_pattern, F=frame_dim)
298
+ return video
299
+
300
+
301
+ def dump_frames(videos, as_mosaik, storage_fol, save_image_kwargs):
302
+
303
+ # assume videos is in format B F C H W, range [0,1]
304
+ num_frames = videos.shape[1]
305
+ num_videos = videos.shape[0]
306
+
307
+ if videos.shape[2] != 3 and videos.shape[-1] == 3:
308
+ videos = rearrange(videos, "B F W H C -> B F C W H")
309
+
310
+ frame_counter = 0
311
+ if not isinstance(storage_fol, Path):
312
+ storage_fol = Path(storage_fol)
313
+
314
+ for frame_idx in range(num_frames):
315
+ print(f" Creating frame {frame_idx}")
316
+ batch_frame = videos[:, frame_idx, ...]
317
+
318
+ if as_mosaik:
319
+ filename = storage_fol / f"frame_{frame_counter:03d}.png"
320
+ save_image(batch_frame, fp=filename.as_posix(),
321
+ **save_image_kwargs)
322
+ frame_counter += 1
323
+ else:
324
+ for video_idx in range(num_videos):
325
+ frame = batch_frame[video_idx]
326
+
327
+ filename = storage_fol / f"frame_{frame_counter:03d}.png"
328
+ save_image(frame, fp=filename.as_posix(),
329
+ **save_image_kwargs)
330
+ frame_counter += 1
331
+
332
+
333
+ def gif_from_videos(videos):
334
+
335
+ assert videos.dim() == 5
336
+ assert videos.min() >= 0
337
+ assert videos.max() <= 1
338
+ gif_file = Path("tmp.gif").absolute()
339
+
340
+ with tempfile.TemporaryDirectory() as tmpdirname:
341
+ storage_fol = Path(tmpdirname)
342
+ nrows = min(4, videos.shape[0])
343
+ dump_frames(
344
+ videos=videos, storage_fol=storage_fol, as_mosaik=True, save_image_kwargs={"nrow": nrows})
345
+ cmd = f"ffmpeg -y -f image2 -framerate 4 -i {storage_fol / 'frame_%03d.png'} {gif_file.as_posix()}"
346
+ subprocess.check_call(
347
+ cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
348
+ return gif_file
349
+
350
+
351
+
352
+ def add_margin(pil_img, top, right, bottom, left, color):
353
+ width, height = pil_img.size
354
+ new_width = width + right + left
355
+ new_height = height + top + bottom
356
+ result = Image.new(pil_img.mode, (new_width, new_height), color)
357
+ result.paste(pil_img, (left, top))
358
+ return result
359
+
360
+ def resize_to_fit(image, size):
361
+ W, H = size
362
+ w, h = image.size
363
+ if H / h > W / w:
364
+ H_ = int(h * W / w)
365
+ W_ = W
366
+ else:
367
+ W_ = int(w * H / h)
368
+ H_ = H
369
+ return image.resize((W_, H_))
370
+
371
+ def pad_to_fit(image, size):
372
+ W, H = size
373
+ w, h = image.size
374
+ pad_h = (H - h) // 2
375
+ pad_w = (W - w) // 2
376
+ return add_margin(image, pad_h, pad_w, pad_h, pad_w, (0, 0, 0))