hao he commited on
Commit
308c973
1 Parent(s): 88b231a

Add gradio codes for CameraCtrl with SVD-xt model

Browse files
Files changed (47) hide show
  1. LICENSE.txt +201 -0
  2. README.md +1 -1
  3. app.py +579 -0
  4. assets/example_condition_images/A_beautiful_fluffy_domestic_hen_sitting_on_white_eggs_in_a_brown_nest,_eggs_are_under_the_hen..png +0 -0
  5. assets/example_condition_images/A_car_running_on_Mars..png +0 -0
  6. assets/example_condition_images/A_lion_standing_on_a_surfboard_in_the_ocean..png +0 -0
  7. assets/example_condition_images/A_serene_mountain_lake_at_sunrise,_with_mist_hovering_over_the_water..png +0 -0
  8. assets/example_condition_images/A_tiny_finch_on_a_branch_with_spring_flowers_on_background..png +0 -0
  9. assets/example_condition_images/An_exploding_cheese_house..png +0 -0
  10. assets/example_condition_images/Dolphins_leaping_out_of_the_ocean_at_sunset..png +0 -0
  11. assets/example_condition_images/Fireworks_display_illuminating_the_night_sky..png +0 -0
  12. assets/example_condition_images/Leaves_are_falling_from_trees..png +0 -0
  13. assets/example_condition_images/Rocky_coastline_with_crashing_waves..png +0 -0
  14. assets/pose_files/0bf152ef84195293.txt +26 -0
  15. assets/pose_files/0c11dbe781b1c11c.txt +26 -0
  16. assets/pose_files/0c9b371cc6225682.txt +26 -0
  17. assets/pose_files/0f47577ab3441480.txt +26 -0
  18. assets/pose_files/0f68374b76390082.txt +26 -0
  19. assets/pose_files/2c80f9eb0d3b2bb4.txt +26 -0
  20. assets/pose_files/2f25826f0d0ef09a.txt +26 -0
  21. assets/pose_files/3f79dc32d575bcdc.txt +26 -0
  22. assets/pose_files/4a2d6753676df096.txt +26 -0
  23. assets/reference_videos/0bf152ef84195293.mp4 +0 -0
  24. assets/reference_videos/0c11dbe781b1c11c.mp4 +0 -0
  25. assets/reference_videos/0c9b371cc6225682.mp4 +0 -0
  26. assets/reference_videos/0f47577ab3441480.mp4 +0 -0
  27. assets/reference_videos/0f68374b76390082.mp4 +0 -0
  28. assets/reference_videos/2c80f9eb0d3b2bb4.mp4 +0 -0
  29. assets/reference_videos/2f25826f0d0ef09a.mp4 +0 -0
  30. assets/reference_videos/3f79dc32d575bcdc.mp4 +0 -0
  31. assets/reference_videos/4a2d6753676df096.mp4 +0 -0
  32. cameractrl/data/dataset.py +355 -0
  33. cameractrl/models/attention.py +65 -0
  34. cameractrl/models/attention_processor.py +591 -0
  35. cameractrl/models/motion_module.py +399 -0
  36. cameractrl/models/pose_adaptor.py +240 -0
  37. cameractrl/models/transformer_temporal.py +191 -0
  38. cameractrl/models/unet.py +587 -0
  39. cameractrl/models/unet_3d_blocks.py +461 -0
  40. cameractrl/pipelines/pipeline_animation.py +523 -0
  41. cameractrl/utils/convert_from_ckpt.py +556 -0
  42. cameractrl/utils/convert_lora_safetensor_to_diffusers.py +154 -0
  43. cameractrl/utils/util.py +148 -0
  44. configs/train_cameractrl/svd_320_576_cameractrl.yaml +87 -0
  45. configs/train_cameractrl/svdxt_320_576_cameractrl.yaml +88 -0
  46. inference_cameractrl.py +255 -0
  47. requirements.txt +20 -0
LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: CameraCtrl Svd Xt
3
- emoji: 🚀
4
  colorFrom: gray
5
  colorTo: indigo
6
  sdk: gradio
 
1
  ---
2
  title: CameraCtrl Svd Xt
3
+ emoji: 🎥
4
  colorFrom: gray
5
  colorTo: indigo
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import argparse
3
+ import torch
4
+ import tempfile
5
+ import os
6
+ import cv2
7
+
8
+ import numpy as np
9
+ import gradio as gr
10
+ import torchvision.transforms.functional as F
11
+ import matplotlib.pyplot as plt
12
+ import matplotlib as mpl
13
+
14
+
15
+ from omegaconf import OmegaConf
16
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
17
+ from inference_cameractrl import get_relative_pose, ray_condition, get_pipeline
18
+ from cameractrl.utils.util import save_videos_grid
19
+
20
+ cv2.setNumThreads(1)
21
+ mpl.use('agg')
22
+
23
+ #### Description ####
24
+ title = r"""<h1 align="center">CameraCtrl: Enabling Camera Control for Text-to-Video Generation</h1>"""
25
+ subtitle = r"""<h2 align="center">CameraCtrl Image2Video with <a href='https://arxiv.org/abs/2311.15127' target='_blank'> <b>Stable Video Diffusion (SVD)</b> </a>-xt <a href='https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt' target='_blank'> <b> model </b> </a> </h2>"""
26
+ description = r"""
27
+ <b>Official Gradio demo</b> for <a href='https://github.com/hehao13/CameraCtrl' target='_blank'><b>CameraCtrl: Enabling Camera Control for Text-to-Video Generation</b></a>.<br>
28
+ CameraCtrl is capable of precisely controlling the camera trajectory during the video generation process.<br>
29
+ Note that, with SVD-xt, CameraCtrl only support Image2Video now.<br>
30
+ """
31
+
32
+ closing_words = r"""
33
+
34
+ ---
35
+
36
+ If you are interested in this demo or CameraCtrl is helpful for you, please give us a ⭐ of the <a href='https://github.com/hehao13/CameraCtrl' target='_blank'> CameraCtrl</a> Github Repo !
37
+ [![GitHub Stars](https://img.shields.io/github/stars/hehao13/CameraCtrl
38
+ )](https://github.com/hehao13/CameraCtrl)
39
+
40
+ ---
41
+
42
+ 📝 **Citation**
43
+ <br>
44
+ If you find our paper or code is useful for your research, please consider citing:
45
+ ```bibtex
46
+ @article{he2024cameractrl,
47
+ title={CameraCtrl: Enabling Camera Control for Text-to-Video Generation},
48
+ author={Hao He and Yinghao Xu and Yuwei Guo and Gordon Wetzstein and Bo Dai and Hongsheng Li and Ceyuan Yang},
49
+ journal={arXiv preprint arXiv:2404.02101},
50
+ year={2024}
51
+ }
52
+ ```
53
+
54
+ 📧 **Contact**
55
+ <br>
56
+ If you have any questions, please feel free to contact me at <b>[email protected]</b>.
57
+
58
+ **Acknowledgement**
59
+ <br>
60
+ We thank <a href='https://wzhouxiff.github.io/projects/MotionCtrl/' target='_blank'><b>MotionCtrl</b></a> and <a href='https://huggingface.co/spaces/lllyasviel/IC-Light' target='_blank'><b>IC-Light</b></a> for their gradio codes.<br>
61
+ """
62
+
63
+
64
+ RESIZE_MODES = ['Resize then Center Crop', 'Directly resize']
65
+ CAMERA_TRAJECTORY_MODES = ["Provided Camera Trajectories", "Custom Camera Trajectories"]
66
+ height = 320
67
+ width = 576
68
+ num_frames = 25
69
+ device = "cuda" if torch.cuda.is_available() else "cpu"
70
+
71
+ config = "configs/train_cameractrl/svdxt_320_576_cameractrl.yaml"
72
+ model_id = "stabilityai/stable-video-diffusion-img2vid-xt"
73
+ ckpt = "checkpoints/CameraCtrl_svdxt.ckpt"
74
+ if not os.path.exists(ckpt):
75
+ os.makedirs("checkpoints", exist_ok=True)
76
+ os.system("wget -c https://huggingface.co/hehao13/CameraCtrl_svd/resolve/main/CameraCtrl_svdxt.ckpt?download=true")
77
+ os.system("mv CameraCtrl_svdxt.ckpt?download=true checkpoints/CameraCtrl_svdxt.ckpt")
78
+ model_config = OmegaConf.load(config)
79
+
80
+
81
+ pipeline = get_pipeline(model_id, "unet", model_config['down_block_types'], model_config['up_block_types'],
82
+ model_config['pose_encoder_kwargs'], model_config['attention_processor_kwargs'],
83
+ ckpt, True, device)
84
+
85
+
86
+ examples = [
87
+ [
88
+ "assets/example_condition_images/A_tiny_finch_on_a_branch_with_spring_flowers_on_background..png",
89
+ "assets/pose_files/0bf152ef84195293.txt",
90
+ "Trajectory 1"
91
+ ],
92
+ [
93
+ "assets/example_condition_images/A_beautiful_fluffy_domestic_hen_sitting_on_white_eggs_in_a_brown_nest,_eggs_are_under_the_hen..png",
94
+ "assets/pose_files/0c9b371cc6225682.txt",
95
+ "Trajectory 2"
96
+ ],
97
+ [
98
+ "assets/example_condition_images/Rocky_coastline_with_crashing_waves..png",
99
+ "assets/pose_files/0c11dbe781b1c11c.txt",
100
+ "Trajectory 3"
101
+ ],
102
+ [
103
+ "assets/example_condition_images/A_lion_standing_on_a_surfboard_in_the_ocean..png",
104
+ "assets/pose_files/0f47577ab3441480.txt",
105
+ "Trajectory 4"
106
+ ],
107
+ [
108
+ "assets/example_condition_images/An_exploding_cheese_house..png",
109
+ "assets/pose_files/0f47577ab3441480.txt",
110
+ "Trajectory 4"
111
+ ],
112
+ [
113
+ "assets/example_condition_images/Dolphins_leaping_out_of_the_ocean_at_sunset..png",
114
+ "assets/pose_files/0f68374b76390082.txt",
115
+ "Trajectory 5"
116
+ ],
117
+ [
118
+ "assets/example_condition_images/Leaves_are_falling_from_trees..png",
119
+ "assets/pose_files/2c80f9eb0d3b2bb4.txt",
120
+ "Trajectory 6"
121
+ ],
122
+ [
123
+ "assets/example_condition_images/A_serene_mountain_lake_at_sunrise,_with_mist_hovering_over_the_water..png",
124
+ "assets/pose_files/2f25826f0d0ef09a.txt",
125
+ "Trajectory 7"
126
+ ],
127
+ [
128
+ "assets/example_condition_images/Fireworks_display_illuminating_the_night_sky..png",
129
+ "assets/pose_files/3f79dc32d575bcdc.txt",
130
+ "Trajectory 8"
131
+ ],
132
+ [
133
+ "assets/example_condition_images/A_car_running_on_Mars..png",
134
+ "assets/pose_files/4a2d6753676df096.txt",
135
+ "Trajectory 9"
136
+ ],
137
+ ]
138
+
139
+
140
+ class Camera(object):
141
+ def __init__(self, entry):
142
+ fx, fy, cx, cy = entry[1:5]
143
+ self.fx = fx
144
+ self.fy = fy
145
+ self.cx = cx
146
+ self.cy = cy
147
+ w2c_mat = np.array(entry[7:]).reshape(3, 4)
148
+ w2c_mat_4x4 = np.eye(4)
149
+ w2c_mat_4x4[:3, :] = w2c_mat
150
+ self.w2c_mat = w2c_mat_4x4
151
+ self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
152
+
153
+
154
+ class CameraPoseVisualizer:
155
+ def __init__(self, xlim, ylim, zlim):
156
+ self.fig = plt.figure(figsize=(18, 7))
157
+ self.ax = self.fig.add_subplot(projection='3d')
158
+ self.plotly_data = None # plotly data traces
159
+ self.ax.set_aspect("auto")
160
+ self.ax.set_xlim(xlim)
161
+ self.ax.set_ylim(ylim)
162
+ self.ax.set_zlim(zlim)
163
+ self.ax.set_xlabel('x')
164
+ self.ax.set_ylabel('y')
165
+ self.ax.set_zlabel('z')
166
+
167
+ def extrinsic2pyramid(self, extrinsic, color_map='red', hw_ratio=9 / 16, base_xval=1, zval=3):
168
+ vertex_std = np.array([[0, 0, 0, 1],
169
+ [base_xval, -base_xval * hw_ratio, zval, 1],
170
+ [base_xval, base_xval * hw_ratio, zval, 1],
171
+ [-base_xval, base_xval * hw_ratio, zval, 1],
172
+ [-base_xval, -base_xval * hw_ratio, zval, 1]])
173
+ vertex_transformed = vertex_std @ extrinsic.T
174
+ meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]],
175
+ [vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]],
176
+ [vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]],
177
+ [vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]],
178
+ [vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1],
179
+ vertex_transformed[4, :-1]]]
180
+
181
+ color = color_map if isinstance(color_map, str) else plt.cm.rainbow(color_map)
182
+
183
+ self.ax.add_collection3d(
184
+ Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.35))
185
+
186
+ def colorbar(self, max_frame_length):
187
+ cmap = mpl.cm.rainbow
188
+ norm = mpl.colors.Normalize(vmin=0, vmax=max_frame_length)
189
+ self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=self.ax, orientation='vertical',
190
+ label='Frame Indexes')
191
+
192
+ def show(self):
193
+ plt.title('Camera Trajectory')
194
+ plt.show()
195
+
196
+
197
+ def get_c2w(w2cs):
198
+ target_cam_c2w = np.array([
199
+ [1, 0, 0, 0],
200
+ [0, 1, 0, 0],
201
+ [0, 0, 1, 0],
202
+ [0, 0, 0, 1]
203
+ ])
204
+ abs2rel = target_cam_c2w @ w2cs[0]
205
+ ret_poses = [target_cam_c2w, ] + [abs2rel @ np.linalg.inv(w2c) for w2c in w2cs[1:]]
206
+ camera_positions = np.asarray([c2w[:3, 3] for c2w in ret_poses]) # [n_frame, 3]
207
+ position_distances = [camera_positions[i] - camera_positions[i - 1] for i in range(1, len(camera_positions))]
208
+ xyz_max = np.max(camera_positions, axis=0)
209
+ xyz_min = np.min(camera_positions, axis=0)
210
+ xyz_ranges = xyz_max - xyz_min # [3, ]
211
+ max_range = np.max(xyz_ranges)
212
+ expected_xyz_ranges = 1
213
+ scale_ratio = expected_xyz_ranges / max_range
214
+ scaled_position_distances = [dis * scale_ratio for dis in position_distances] # [n_frame - 1]
215
+ scaled_camera_positions = [camera_positions[0], ]
216
+ scaled_camera_positions.extend([camera_positions[0] + np.sum(np.asarray(scaled_position_distances[:i]), axis=0)
217
+ for i in range(1, len(camera_positions))])
218
+ ret_poses = [np.concatenate(
219
+ (np.concatenate((ori_pose[:3, :3], cam_position[:, None]), axis=1), np.asarray([0, 0, 0, 1])[None]), axis=0)
220
+ for ori_pose, cam_position in zip(ret_poses, scaled_camera_positions)]
221
+ transform_matrix = np.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]).reshape(4, 4)
222
+ ret_poses = [transform_matrix @ x for x in ret_poses]
223
+ return np.array(ret_poses, dtype=np.float32)
224
+
225
+
226
+ def visualize_trajectory(trajectory_file):
227
+ with open(trajectory_file, 'r') as f:
228
+ poses = f.readlines()
229
+ w2cs = [np.asarray([float(p) for p in pose.strip().split(' ')[7:]]).reshape(3, 4) for pose in poses[1:]]
230
+ num_frames = len(w2cs)
231
+ last_row = np.zeros((1, 4))
232
+ last_row[0, -1] = 1.0
233
+ w2cs = [np.concatenate((w2c, last_row), axis=0) for w2c in w2cs]
234
+ c2ws = get_c2w(w2cs)
235
+ visualizer = CameraPoseVisualizer([-1.2, 1.2], [-1.2, 1.2], [-1.2, 1.2])
236
+ for frame_idx, c2w in enumerate(c2ws):
237
+ visualizer.extrinsic2pyramid(c2w, frame_idx / num_frames, hw_ratio=9 / 16, base_xval=0.02, zval=0.1)
238
+ visualizer.colorbar(num_frames)
239
+ return visualizer.fig
240
+
241
+
242
+ vis_traj = visualize_trajectory('assets/pose_files/0bf152ef84195293.txt')
243
+
244
+
245
+ @torch.inference_mode()
246
+ def process_input_image(input_image, resize_mode):
247
+ global height, width
248
+ expected_hw_ratio = height / width
249
+ inp_w, inp_h = input_image.size
250
+ inp_hw_ratio = inp_h / inp_w
251
+
252
+ if inp_hw_ratio > expected_hw_ratio:
253
+ resized_height = inp_hw_ratio * width
254
+ resized_width = width
255
+ else:
256
+ resized_height = height
257
+ resized_width = height / inp_hw_ratio
258
+ resized_image = F.resize(input_image, size=[resized_height, resized_width])
259
+
260
+ if resize_mode == RESIZE_MODES[0]:
261
+ return_image = F.center_crop(resized_image, output_size=[height, width])
262
+ else:
263
+ return_image = resized_image
264
+
265
+ return gr.update(visible=True, value=return_image, height=height, width=width), gr.update(visible=True), gr.update(
266
+ visible=True), gr.update(visible=True), gr.update(visible=True)
267
+
268
+
269
+ def update_camera_trajectories(trajectory_mode):
270
+ if trajectory_mode == CAMERA_TRAJECTORY_MODES[0]:
271
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
272
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
273
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
274
+ elif trajectory_mode == CAMERA_TRAJECTORY_MODES[1]:
275
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
276
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
277
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
278
+
279
+
280
+ def update_camera_args(trajectory_mode, provided_camera_trajectory, customized_trajectory_file):
281
+ if trajectory_mode == CAMERA_TRAJECTORY_MODES[0]:
282
+ res = "Provided " + str(provided_camera_trajectory)
283
+ else:
284
+ if customized_trajectory_file is None:
285
+ res = " "
286
+ else:
287
+ res = f"Customized trajectory file {customized_trajectory_file.name.split('/')[-1]}"
288
+ return res
289
+
290
+
291
+ def update_camera_args_reset():
292
+ return " "
293
+
294
+
295
+ def update_trajectory_vis_plot(camera_trajectory_args, provided_camera_trajectory, customized_trajectory_file):
296
+ if 'Provided' in camera_trajectory_args:
297
+ if provided_camera_trajectory == "Trajectory 1":
298
+ trajectory_file_path = "assets/pose_files/0bf152ef84195293.txt"
299
+ elif provided_camera_trajectory == "Trajectory 2":
300
+ trajectory_file_path = "assets/pose_files/0c9b371cc6225682.txt"
301
+ elif provided_camera_trajectory == "Trajectory 3":
302
+ trajectory_file_path = "assets/pose_files/0c11dbe781b1c11c.txt"
303
+ elif provided_camera_trajectory == "Trajectory 4":
304
+ trajectory_file_path = "assets/pose_files/0f47577ab3441480.txt"
305
+ elif provided_camera_trajectory == "Trajectory 5":
306
+ trajectory_file_path = "assets/pose_files/0f68374b76390082.txt"
307
+ elif provided_camera_trajectory == "Trajectory 6":
308
+ trajectory_file_path = "assets/pose_files/2c80f9eb0d3b2bb4.txt"
309
+ elif provided_camera_trajectory == "Trajectory 7":
310
+ trajectory_file_path = "assets/pose_files/2f25826f0d0ef09a.txt"
311
+ elif provided_camera_trajectory == "Trajectory 8":
312
+ trajectory_file_path = "assets/pose_files/3f79dc32d575bcdc.txt"
313
+ else:
314
+ trajectory_file_path = "assets/pose_files/4a2d6753676df096.txt"
315
+ else:
316
+ trajectory_file_path = customized_trajectory_file.name
317
+ vis_traj = visualize_trajectory(trajectory_file_path)
318
+ return gr.update(visible=True), vis_traj, gr.update(visible=True), gr.update(visible=True), \
319
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
320
+ gr.update(visible=True), gr.update(visible=True), trajectory_file_path
321
+
322
+
323
+ def update_set_button():
324
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
325
+
326
+
327
+ def update_buttons_for_example(example_image, example_traj_path, provided_traj_name):
328
+ global height, width
329
+ return_image = example_image
330
+ return gr.update(visible=True, value=return_image, height=height, width=width), gr.update(visible=True), \
331
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
332
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), \
333
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), \
334
+ gr.update(visible=True)
335
+
336
+ @spaces.GPU
337
+ @torch.inference_mode()
338
+ def sample_video(condition_image, trajectory_file, num_inference_step, min_guidance_scale, max_guidance_scale, fps_id, seed):
339
+ global height, width, num_frames, device, pipeline
340
+ with open(trajectory_file, 'r') as f:
341
+ poses = f.readlines()
342
+ poses = [pose.strip().split(' ') for pose in poses[1:]]
343
+ cam_params = [[float(x) for x in pose] for pose in poses]
344
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
345
+ sample_wh_ratio = width / height
346
+ pose_wh_ratio = cam_params[0].fy / cam_params[0].fx
347
+ if pose_wh_ratio > sample_wh_ratio:
348
+ resized_ori_w = height * pose_wh_ratio
349
+ for cam_param in cam_params:
350
+ cam_param.fx = resized_ori_w * cam_param.fx / width
351
+ else:
352
+ resized_ori_h = width / pose_wh_ratio
353
+ for cam_param in cam_params:
354
+ cam_param.fy = resized_ori_h * cam_param.fy / height
355
+ intrinsic = np.asarray([[cam_param.fx * width,
356
+ cam_param.fy * height,
357
+ cam_param.cx * width,
358
+ cam_param.cy * height]
359
+ for cam_param in cam_params], dtype=np.float32)
360
+ K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
361
+ c2ws = get_relative_pose(cam_params, zero_first_frame_scale=True)
362
+ c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
363
+ plucker_embedding = ray_condition(K, c2ws, height, width, device='cpu') # b f h w 6
364
+ plucker_embedding = plucker_embedding.permute(0, 1, 4, 2, 3).contiguous().to(device=device)
365
+
366
+ generator = torch.Generator(device=device)
367
+ generator.manual_seed(int(seed))
368
+
369
+ with torch.no_grad():
370
+ sample = pipeline(
371
+ image=condition_image,
372
+ pose_embedding=plucker_embedding,
373
+ height=height,
374
+ width=width,
375
+ num_frames=num_frames,
376
+ num_inference_steps=num_inference_step,
377
+ min_guidance_scale=min_guidance_scale,
378
+ max_guidance_scale=max_guidance_scale,
379
+ fps=fps_id,
380
+ do_image_process=True,
381
+ generator=generator,
382
+ output_type='pt'
383
+ ).frames[0].transpose(0, 1).cpu()
384
+
385
+ temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
386
+ save_videos_grid(sample[None], temporal_video_path, rescale=False)
387
+
388
+ return temporal_video_path
389
+
390
+
391
+ def main(args):
392
+ demo = gr.Blocks().queue()
393
+ with demo:
394
+ gr.Markdown(title)
395
+ gr.Markdown(subtitle)
396
+ gr.Markdown(description)
397
+
398
+ with gr.Column():
399
+ # step1: Input condition image
400
+ step1_title = gr.Markdown("---\n## Step 1: Input an Image", show_label=False, visible=True)
401
+ step1_dec = gr.Markdown(f"\n 1. Upload an Image by `Drag` or Click `Upload Image`; \
402
+ \n 2. Click `{RESIZE_MODES[0]}` or `{RESIZE_MODES[1]}` to select the image resize mode. \
403
+ \n - `{RESIZE_MODES[0]}`: First resize the input image, then center crop it into the resolution of 320 x 576. \
404
+ \n - `{RESIZE_MODES[1]}`: Only resize the input image, and keep the original aspect ratio.",
405
+ show_label=False, visible=True)
406
+ with gr.Row(equal_height=True):
407
+ with gr.Column(scale=2):
408
+ input_image = gr.Image(type='pil', interactive=True, elem_id='condition_image',
409
+ elem_classes='image',
410
+ visible=True)
411
+ with gr.Row():
412
+ resize_crop_button = gr.Button(RESIZE_MODES[0], visible=True)
413
+ directly_resize_button = gr.Button(RESIZE_MODES[1], visible=True)
414
+ with gr.Column(scale=2):
415
+ processed_image = gr.Image(type='pil', interactive=False, elem_id='processed_image',
416
+ elem_classes='image', visible=False)
417
+
418
+ # step2: Select camera trajectory
419
+ step2_camera_trajectory = gr.Markdown("---\n## Step 2: Select the camera trajectory", show_label=False,
420
+ visible=False)
421
+ step2_camera_trajectory_des = gr.Markdown(f"\n - `{CAMERA_TRAJECTORY_MODES[0]}`: Including 9 camera trajectories extracted from the test set of RealEstate10K dataset, each has 25 frames. \
422
+ \n - `{CAMERA_TRAJECTORY_MODES[1]}`: You can provide the customized camera trajectories in the txt file.",
423
+ show_label=False, visible=False)
424
+ with gr.Row(equal_height=True):
425
+ provide_trajectory_button = gr.Button(CAMERA_TRAJECTORY_MODES[0], visible=False)
426
+ customized_trajectory_button = gr.Button(CAMERA_TRAJECTORY_MODES[1], visible=False)
427
+ with gr.Row():
428
+ with gr.Column():
429
+ provided_camera_trajectory = gr.Markdown(f"---\n### {CAMERA_TRAJECTORY_MODES[0]}", show_label=False,
430
+ visible=False)
431
+ provided_camera_trajectory_des = gr.Markdown(f"\n 1. Click one of the provide camera trajectories, such as `Trajectory 1`; \
432
+ \n 2. Click `Visualize Trajectory` to visualize the camera trajectory; \
433
+ \n 3. Click `Reset Trajectory` to reset the camera trajectory. ",
434
+ show_label=False, visible=False)
435
+
436
+ customized_camera_trajectory = gr.Markdown(f"---\n### {CAMERA_TRAJECTORY_MODES[1]}",
437
+ show_label=False,
438
+ visible=False)
439
+ customized_run_status = gr.Markdown(f"\n 1. Input the txt file containing camera trajectory. \
440
+ \n 2. Click `Visualize Trajectory` to visualize the camera trajectory; \
441
+ \n 3. Click `Reset Trajectory` to reset the camera trajectory. ",
442
+ show_label=False, visible=False)
443
+
444
+ with gr.Row():
445
+ provided_trajectories = gr.Dropdown(
446
+ ["Trajectory 1", "Trajectory 2", "Trajectory 3", "Trajectory 4", "Trajectory 5",
447
+ "Trajectory 6", "Trajectory 7", "Trajectory 8", "Trajectory 9"],
448
+ label="Provided Trajectories", interactive=True, visible=False)
449
+ with gr.Row():
450
+ customized_camera_trajectory_file = gr.File(
451
+ label="Upload customized camera trajectory (in .txt format).", visible=False, interactive=True)
452
+
453
+ with gr.Row():
454
+ camera_args = gr.Textbox(value=" ", label="Camera Trajectory Name", visible=False)
455
+ camera_trajectory_path = gr.Textbox(value=" ", visible=False)
456
+
457
+ with gr.Row():
458
+ camera_trajectory_vis = gr.Button(value="Visualize Camera Trajectory", visible=False)
459
+ camera_trajectory_reset = gr.Button(value="Reset Camera Trajectory", visible=False)
460
+ with gr.Column():
461
+ vis_camera_trajectory = gr.Plot(vis_traj, label='Camera Trajectory', visible=False)
462
+
463
+ # step3: Set inference parameters
464
+ with gr.Row():
465
+ with gr.Column():
466
+ step3_title = gr.Markdown(f"---\n## Step3: Setting the inference hyper-parameters.", visible=False)
467
+ step3_des = gr.Markdown(
468
+ f"\n 1. Set the mumber of inference step; \
469
+ \n 2. Set the seed; \
470
+ \n 3. Set the minimum guidance scale and the maximum guidance scale; \
471
+ \n 4. Set the fps; \
472
+ \n - Please refer to the SVD paper for the meaning of the last three parameter",
473
+ visible=False)
474
+ with gr.Row():
475
+ with gr.Column():
476
+ num_inference_steps = gr.Number(value=25, label='Number Inference Steps', step=1, interactive=True,
477
+ visible=False)
478
+ with gr.Column():
479
+ seed = gr.Number(value=42, label='Seed', minimum=1, interactive=True, visible=False, step=1)
480
+ with gr.Column():
481
+ min_guidance_scale = gr.Number(value=1.0, label='Minimum Guidance Scale', minimum=1.0, step=0.5,
482
+ interactive=True, visible=False)
483
+ with gr.Column():
484
+ max_guidance_scale = gr.Number(value=3.0, label='Maximum Guidance Scale', minimum=1.0, step=0.5,
485
+ interactive=True, visible=False)
486
+ with gr.Column():
487
+ fps = gr.Number(value=7, label='FPS', minimum=1, step=1, interactive=True, visible=False)
488
+ with gr.Column():
489
+ _ = gr.Button("Seed", visible=False)
490
+ with gr.Column():
491
+ _ = gr.Button("Seed", visible=False)
492
+ with gr.Column():
493
+ _ = gr.Button("Seed", visible=False)
494
+ with gr.Row():
495
+ with gr.Column():
496
+ _ = gr.Button("Set", visible=False)
497
+ with gr.Column():
498
+ set_button = gr.Button("Set", visible=False)
499
+ with gr.Column():
500
+ _ = gr.Button("Set", visible=False)
501
+
502
+ # step 4: Generate video
503
+ with gr.Row():
504
+ with gr.Column():
505
+ step4_title = gr.Markdown("---\n## Step4 Generating video", show_label=False, visible=False)
506
+ step4_des = gr.Markdown(f"\n - Click the `Start generation !` button to generate the video.; \
507
+ \n - If the content of generated video is not very aligned with the condition image, try to increase the `Minimum Guidance Scale` and `Maximum Guidance Scale`. \
508
+ \n - If the generated videos are distored, try to increase `FPS`.",
509
+ visible=False)
510
+ start_button = gr.Button(value="Start generation !", visible=False)
511
+ with gr.Column():
512
+ generate_video = gr.Video(value=None, label="Generate Video", visible=False)
513
+ resize_crop_button.click(fn=process_input_image, inputs=[input_image, resize_crop_button],
514
+ outputs=[processed_image, step2_camera_trajectory, step2_camera_trajectory_des,
515
+ provide_trajectory_button, customized_trajectory_button])
516
+ directly_resize_button.click(fn=process_input_image, inputs=[input_image, directly_resize_button],
517
+ outputs=[processed_image, step2_camera_trajectory, step2_camera_trajectory_des,
518
+ provide_trajectory_button, customized_trajectory_button])
519
+ provide_trajectory_button.click(fn=update_camera_trajectories, inputs=[provide_trajectory_button],
520
+ outputs=[provided_camera_trajectory, provided_camera_trajectory_des,
521
+ provided_trajectories,
522
+ customized_camera_trajectory, customized_run_status,
523
+ customized_camera_trajectory_file,
524
+ camera_args, camera_trajectory_vis, camera_trajectory_reset])
525
+ customized_trajectory_button.click(fn=update_camera_trajectories, inputs=[customized_trajectory_button],
526
+ outputs=[provided_camera_trajectory, provided_camera_trajectory_des,
527
+ provided_trajectories,
528
+ customized_camera_trajectory, customized_run_status,
529
+ customized_camera_trajectory_file,
530
+ camera_args, camera_trajectory_vis, camera_trajectory_reset])
531
+
532
+ provided_trajectories.change(fn=update_camera_args, inputs=[provide_trajectory_button, provided_trajectories, customized_camera_trajectory_file],
533
+ outputs=[camera_args])
534
+ customized_camera_trajectory_file.change(fn=update_camera_args, inputs=[customized_trajectory_button, provided_trajectories, customized_camera_trajectory_file],
535
+ outputs=[camera_args])
536
+ camera_trajectory_reset.click(fn=update_camera_args_reset, inputs=None, outputs=[camera_args])
537
+ camera_trajectory_vis.click(fn=update_trajectory_vis_plot, inputs=[camera_args, provided_trajectories, customized_camera_trajectory_file],
538
+ outputs=[vis_camera_trajectory, vis_camera_trajectory, step3_title, step3_des,
539
+ num_inference_steps, min_guidance_scale, max_guidance_scale, fps,
540
+ seed, set_button, camera_trajectory_path])
541
+ set_button.click(fn=update_set_button, inputs=None, outputs=[step4_title, step4_des, start_button, generate_video])
542
+ start_button.click(fn=sample_video, inputs=[processed_image, camera_trajectory_path, num_inference_steps,
543
+ min_guidance_scale, max_guidance_scale, fps, seed],
544
+ outputs=[generate_video])
545
+
546
+ # set example
547
+ gr.Markdown("## Examples")
548
+ gr.Markdown("\n Choosing the one of the following examples to get a quick start, by selecting an example, "
549
+ "we will set the condition image and camera trajectory automatically. "
550
+ "Then, you can click the `Visualize Camera Trajectory` button to visualize the camera trajectory.")
551
+ gr.Examples(
552
+ fn=update_buttons_for_example,
553
+ run_on_click=True,
554
+ cache_examples=False,
555
+ examples=examples,
556
+ inputs=[input_image, camera_args, provided_trajectories],
557
+ outputs=[processed_image, step2_camera_trajectory, step2_camera_trajectory_des, provide_trajectory_button,
558
+ customized_trajectory_button,
559
+ provided_camera_trajectory, provided_camera_trajectory_des, provided_trajectories,
560
+ customized_camera_trajectory, customized_run_status, customized_camera_trajectory_file,
561
+ camera_args, camera_trajectory_vis, camera_trajectory_reset]
562
+ )
563
+ with gr.Row():
564
+ gr.Markdown(closing_words)
565
+
566
+ demo.launch(**args)
567
+
568
+
569
+ if __name__ == '__main__':
570
+ parser = argparse.ArgumentParser()
571
+ parser.add_argument('--listen', default='0.0.0.0')
572
+ parser.add_argument('--broswer', action='store_true')
573
+ parser.add_argument('--share', action='store_true')
574
+ args = parser.parse_args()
575
+
576
+ launch_kwargs = {'server_name': args.listen,
577
+ 'inbrowser': args.broswer,
578
+ 'share': args.share}
579
+ main(launch_kwargs)
assets/example_condition_images/A_beautiful_fluffy_domestic_hen_sitting_on_white_eggs_in_a_brown_nest,_eggs_are_under_the_hen..png ADDED
assets/example_condition_images/A_car_running_on_Mars..png ADDED
assets/example_condition_images/A_lion_standing_on_a_surfboard_in_the_ocean..png ADDED
assets/example_condition_images/A_serene_mountain_lake_at_sunrise,_with_mist_hovering_over_the_water..png ADDED
assets/example_condition_images/A_tiny_finch_on_a_branch_with_spring_flowers_on_background..png ADDED
assets/example_condition_images/An_exploding_cheese_house..png ADDED
assets/example_condition_images/Dolphins_leaping_out_of_the_ocean_at_sunset..png ADDED
assets/example_condition_images/Fireworks_display_illuminating_the_night_sky..png ADDED
assets/example_condition_images/Leaves_are_falling_from_trees..png ADDED
assets/example_condition_images/Rocky_coastline_with_crashing_waves..png ADDED
assets/pose_files/0bf152ef84195293.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ https://www.youtube.com/watch?v=QShWPZxTDoE
2
+ 157323991 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.950234294 0.023969267 -0.310612619 -0.058392330 -0.025083920 0.999685287 0.000406042 0.179560758 0.310524613 0.007405547 0.950536489 -0.411621285
3
+ 157490824 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.932122767 0.029219138 -0.360961705 0.019157260 -0.030671693 0.999528050 0.001705339 0.195243598 0.360841185 0.009481722 0.932579100 -0.489249695
4
+ 157657658 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.912891090 0.034948215 -0.406704396 0.093606521 -0.036429971 0.999327779 0.004101569 0.203909523 0.406574339 0.011071944 0.913550615 -0.570709379
5
+ 157824491 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.892021954 0.039648205 -0.450249761 0.174752186 -0.041337918 0.999126673 0.006083843 0.206605029 0.450097769 0.013185467 0.892881930 -0.657519766
6
+ 157991325 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.870897233 0.043891508 -0.489501357 0.266759117 -0.046065997 0.998909414 0.007609563 0.208293300 0.489301533 0.015922222 0.871969342 -0.739918788
7
+ 158158158 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.850054264 0.048434701 -0.524463415 0.371990684 -0.051002879 0.998652756 0.009560689 0.215520371 0.524219871 0.018622037 0.851379335 -0.814489669
8
+ 158358358 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.823578537 0.052956820 -0.564724684 0.498689894 -0.055313200 0.998385012 0.012955925 0.224118528 0.564498782 0.020566508 0.825177670 -0.889946292
9
+ 158525192 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.801249743 0.056292553 -0.595676124 0.608660065 -0.058902908 0.998149574 0.015096202 0.223320416 0.595423639 0.022991227 0.803082883 -0.943733076
10
+ 158692025 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.780003667 0.059620168 -0.622928321 0.726968666 -0.062449891 0.997897983 0.017311305 0.217967188 0.622651041 0.025398925 0.782087326 -1.002211444
11
+ 158858859 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.758300304 0.062706254 -0.648882568 0.862137737 -0.066019125 0.997632504 0.019256916 0.210050766 0.648553908 0.028236136 0.760644853 -1.055941415
12
+ 159025692 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.733386099 0.066376433 -0.676564097 1.014642875 -0.069476441 0.997329056 0.022534581 0.204417168 0.676252782 0.030478716 0.736038864 -1.100931176
13
+ 159192526 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.703763664 0.069719747 -0.707004845 1.176046236 -0.073094003 0.996997535 0.025557835 0.198280199 0.706663966 0.033691138 0.706746757 -1.127059555
14
+ 159392726 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.662238955 0.074904546 -0.745539367 1.361111617 -0.076822795 0.996534884 0.031882886 0.176548885 0.745344102 0.036160327 0.665698588 -1.136046987
15
+ 159559560 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.629729092 0.078562595 -0.772831917 1.488353738 -0.081706621 0.996052980 0.034676969 0.152860218 0.772505820 0.041308392 0.633662641 -1.137729720
16
+ 159726393 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.602676034 0.081962064 -0.793765604 1.594849532 -0.083513811 0.995727122 0.039407197 0.137400253 0.793603837 0.042540617 0.606945813 -1.154423412
17
+ 159893227 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.580205023 0.084535435 -0.810071528 1.693660542 -0.086233690 0.995384574 0.042109925 0.134338657 0.809892535 0.045423068 0.584816933 -1.189997045
18
+ 160060060 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.559533417 0.086548090 -0.824276507 1.785956560 -0.089039005 0.995054126 0.044038296 0.143407250 0.824011147 0.048751865 0.564472198 -1.233530509
19
+ 160226894 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.539407372 0.088928543 -0.837335885 1.876278781 -0.091476299 0.994710982 0.046713885 0.159821683 0.837061405 0.051398575 0.544689238 -1.287939732
20
+ 160427094 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.515480161 0.090795092 -0.852077723 1.979818582 -0.093906343 0.994367242 0.049146701 0.181896137 0.851740420 0.054681353 0.521102846 -1.359775674
21
+ 160593927 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.497423410 0.091656610 -0.862652302 2.062552118 -0.095314465 0.994156837 0.050668620 0.194005458 0.862255812 0.057019483 0.503253102 -1.415121326
22
+ 160760761 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.484208912 0.092620693 -0.870036304 2.136687359 -0.096262053 0.993984103 0.052242137 0.204385655 0.869640946 0.058455370 0.490211815 -1.477987717
23
+ 160927594 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.475284129 0.093297184 -0.874871790 2.200792438 -0.096743606 0.993874133 0.053430639 0.209217395 0.874497354 0.059243519 0.481398523 -1.547068315
24
+ 161094428 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.468848795 0.093707815 -0.878293574 2.268227083 -0.097071946 0.993799806 0.054212786 0.208793720 0.877928138 0.059840068 0.475038230 -1.634971335
25
+ 161261261 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.463450164 0.093811318 -0.881143212 2.339123750 -0.097783640 0.993721604 0.054366294 0.224513862 0.880711257 0.060965322 0.469713658 -1.732136350
26
+ 161461461 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.458983690 0.093715429 -0.883488178 2.426412787 -0.098171033 0.993681431 0.054402962 0.253829726 0.883004189 0.061762877 0.465283692 -1.863571195
assets/pose_files/0c11dbe781b1c11c.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ https://www.youtube.com/watch?v=a-Unpcomk5k
2
+ 90023267 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.949961841 -0.054589756 0.307558835 0.363597957 0.049778115 0.998484373 0.023474237 0.122943811 -0.308374137 -0.006989930 0.951239467 -0.411649725
3
+ 90190100 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.936324358 -0.058613990 0.346209586 0.384438270 0.053212658 0.998267829 0.025095066 0.136336848 -0.347080797 -0.005074390 0.937821507 -0.495378251
4
+ 90356933 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.920415699 -0.061646536 0.386050045 0.392760735 0.055660341 0.998093307 0.026676189 0.148407963 -0.386958480 -0.003065505 0.922092021 -0.584840288
5
+ 90523767 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.902696550 -0.065121211 0.425321281 0.393740251 0.058245987 0.997876167 0.029164905 0.157571476 -0.426317245 -0.001553800 0.904572368 -0.683591501
6
+ 90690600 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.883015513 -0.069272175 0.464203626 0.383146000 0.061874375 0.997597098 0.031171000 0.171538756 -0.465247482 0.001197834 0.885179818 -0.798848920
7
+ 90857433 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.863586664 -0.074230894 0.498706162 0.378236544 0.067191981 0.997224212 0.032080498 0.194804574 -0.499703199 0.005804764 0.866177261 -0.912604869
8
+ 91057633 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.837207139 -0.080319084 0.540955663 0.348125228 0.072502285 0.996726155 0.035782419 0.216269091 -0.542058706 0.009263224 0.840289593 -1.067256689
9
+ 91224467 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.814787984 -0.085085154 0.573481560 0.311799242 0.076606996 0.996299267 0.038975649 0.234736581 -0.574675500 0.012175811 0.818290770 -1.196836664
10
+ 91391300 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.792760789 -0.089765556 0.602886796 0.270539226 0.081507996 0.995825171 0.041093048 0.259814929 -0.604058623 0.016563140 0.796767771 -1.328140863
11
+ 91558133 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.771091938 -0.093814306 0.629774630 0.223948432 0.087357447 0.995320201 0.041307874 0.293608807 -0.630702674 0.023163332 0.775678813 -1.459775674
12
+ 91724967 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.751535058 -0.096625380 0.652578413 0.178494515 0.089747138 0.994993508 0.043969437 0.308307880 -0.653559864 0.025522472 0.756444395 -1.587897834
13
+ 91891800 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.735025227 -0.099746063 0.670662820 0.138219010 0.093737528 0.994570971 0.045186777 0.333516116 -0.671528995 0.029652854 0.740384698 -1.712296424
14
+ 92092000 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.719043434 -0.101635426 0.687493145 0.093557351 0.098243877 0.994179368 0.044221871 0.373143031 -0.687985957 0.035744540 0.724843204 -1.860791364
15
+ 92258833 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.709150374 -0.102171108 0.697615087 0.059169738 0.099845670 0.994025767 0.044086087 0.400782920 -0.697951734 0.038390182 0.715115070 -1.981529677
16
+ 92425667 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.703360021 -0.101482928 0.703552365 0.039205180 0.098760851 0.994108558 0.044659954 0.417778776 -0.703939617 0.038071405 0.709238708 -2.106152155
17
+ 92592500 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.700221658 -0.101235874 0.706711292 0.029170036 0.096752122 0.994219005 0.046557475 0.427528111 -0.707339108 0.035775267 0.705968499 -2.234683370
18
+ 92759333 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.698873043 -0.100907177 0.708091974 0.024634507 0.096064955 0.994270742 0.046875048 0.444746524 -0.708765149 0.035263114 0.704562664 -2.365965080
19
+ 92926167 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.698221087 -0.101446368 0.708657861 0.017460176 0.096007936 0.994235396 0.047733612 0.465147684 -0.709415078 0.034708157 0.703935742 -2.489595036
20
+ 93126367 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.701540232 -0.100508377 0.705506444 0.030925799 0.096309878 0.994293392 0.045881305 0.500429136 -0.706091821 0.035759658 0.707216740 -2.635113223
21
+ 93293200 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.706040561 -0.099173397 0.701192796 0.055235645 0.095376909 0.994441032 0.044612732 0.522969947 -0.701719284 0.035379205 0.711574554 -2.748741222
22
+ 93460033 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.711469471 -0.101191826 0.695392966 0.092386154 0.097211346 0.994235933 0.045219962 0.550373275 -0.695960581 0.035427466 0.717205524 -2.869640023
23
+ 93626867 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.715287089 -0.106710054 0.690635443 0.110879759 0.101620771 0.993650913 0.048280932 0.574557114 -0.691402614 0.035648178 0.721589625 -3.003606281
24
+ 93793700 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.717362285 -0.111333445 0.687747240 0.117481763 0.104680635 0.993167043 0.051586930 0.589310163 -0.688791215 0.034987301 0.724115014 -3.119467820
25
+ 93960533 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.717443645 -0.117016122 0.686718166 0.111165224 0.109845228 0.992461383 0.054354500 0.614623166 -0.687901616 0.036436427 0.724888742 -3.219243898
26
+ 94160733 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.715022981 -0.122569911 0.688272297 0.080960594 0.115116455 0.991714299 0.057017289 0.647785934 -0.689558089 0.038462799 0.723208308 -3.337481340
assets/pose_files/0c9b371cc6225682.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ https://www.youtube.com/watch?v=_ca03xP_KUU
2
+ 212078000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.981108844 0.010863926 -0.193151161 0.019480142 -0.008781361 0.999893725 0.011634931 -0.185801323 0.193257034 -0.009719004 0.981100023 -1.207220396
3
+ 212245000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.981206656 0.010493318 -0.192674309 0.047262620 -0.008341321 0.999893486 0.011976899 -0.196644454 0.192779467 -0.010144655 0.981189668 -1.332579514
4
+ 212412000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.981131375 0.009989015 -0.193084016 0.089602762 -0.007912135 0.999902308 0.011524491 -0.209028987 0.193180263 -0.009779332 0.981114566 -1.458343512
5
+ 212579000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980986536 0.009889571 -0.193823576 0.142988232 -0.007621351 0.999893546 0.012444697 -0.219217661 0.193926007 -0.010730883 0.980957448 -1.565616727
6
+ 212746000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980907381 0.009417370 -0.194247320 0.202069071 -0.007269385 0.999904335 0.011767862 -0.219211705 0.194339558 -0.010131124 0.980881989 -1.654996418
7
+ 212913000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980841637 0.009196524 -0.194589555 0.262465567 -0.006609587 0.999880970 0.013939449 -0.224018296 0.194694594 -0.012386235 0.980785728 -1.740759996
8
+ 213112000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980812550 0.009127630 -0.194739416 0.343576873 -0.006686049 0.999890625 0.013191325 -0.227157741 0.194838524 -0.011636180 0.980766296 -1.843349559
9
+ 213279000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980493903 0.009053789 -0.196340859 0.419120133 -0.006299877 0.999872863 0.014646200 -0.230109231 0.196448520 -0.013123587 0.980426311 -1.921706921
10
+ 213446000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980334044 0.009148465 -0.197133064 0.491943193 -0.006159810 0.999856710 0.015768444 -0.229004834 0.197249070 -0.014244041 0.980249941 -2.001160080
11
+ 213613000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980392158 0.009613466 -0.196821600 0.558373534 -0.006672133 0.999855995 0.015601818 -0.224721707 0.196943253 -0.013982680 0.980315149 -2.074274069
12
+ 213779000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980493963 0.009960363 -0.196296573 0.623893674 -0.006936011 0.999846518 0.016088497 -0.223079036 0.196426690 -0.014413159 0.980412602 -2.137999468
13
+ 213946000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980379820 0.010693249 -0.196827397 0.699812821 -0.007542110 0.999831200 0.016752303 -0.227942951 0.196973309 -0.014939127 0.980295002 -2.197760648
14
+ 214146000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.979374588 0.009917496 -0.201809332 0.814920839 -0.006856939 0.999850750 0.015859045 -0.216071799 0.201936498 -0.014148152 0.979296446 -2.259941063
15
+ 214313000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.977240086 0.010313706 -0.211885542 0.923878346 -0.006808316 0.999827743 0.017266726 -0.213645187 0.212027133 -0.015431152 0.977141917 -2.298546075
16
+ 214480000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.974372387 0.010545789 -0.224693686 1.023098265 -0.007014100 0.999839067 0.016510243 -0.202358523 0.224831656 -0.014511099 0.974289536 -2.336883235
17
+ 214647000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.970687985 0.012000944 -0.240043446 1.114468699 -0.008058093 0.999816120 0.017400362 -0.191482059 0.240208134 -0.014956030 0.970606208 -2.373449288
18
+ 214814000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.966870189 0.013693837 -0.254900873 1.198725810 -0.010257521 0.999837756 0.014805457 -0.166976597 0.255062282 -0.011700304 0.966853857 -2.418678595
19
+ 214981000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.964416146 0.015124563 -0.263955861 1.261415498 -0.015491933 0.999879777 0.000689789 -0.124738174 0.263934553 0.003423943 0.964534521 -2.488291986
20
+ 215181000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.961933076 0.016891202 -0.272762626 1.331672110 -0.022902885 0.999559581 -0.018870916 -0.076291319 0.272323757 0.024399608 0.961896241 -2.579417067
21
+ 215348000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.960170150 0.017672766 -0.278856426 1.385267829 -0.031079723 0.998559833 -0.043730423 -0.007773330 0.277681977 0.050655428 0.959336638 -2.653662977
22
+ 215515000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.959171832 0.017540060 -0.282279491 1.424533991 -0.041599691 0.995969117 -0.079466961 0.101994757 0.279747814 0.087965213 0.956035197 -2.725926173
23
+ 215681000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.958615839 0.017946823 -0.284136623 1.452641041 -0.050743751 0.992801547 -0.108490512 0.201790053 0.280144215 0.118418880 0.952625930 -2.789404412
24
+ 215848000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.958124936 0.017704720 -0.285802603 1.475525887 -0.056795157 0.990007460 -0.129071787 0.280252282 0.280661523 0.139899105 0.949556410 -2.857222541
25
+ 216015000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.958040893 0.018267807 -0.286048800 1.493161376 -0.062177986 0.987448573 -0.145186886 0.346856721 0.279806226 0.156880915 0.947151959 -2.929304305
26
+ 216215000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.959765971 0.017995594 -0.280223966 1.499055201 -0.064145394 0.985608101 -0.156403333 0.410155748 0.273376435 0.168085665 0.947107434 -3.033597428
assets/pose_files/0f47577ab3441480.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ https://www.youtube.com/watch?v=in69BD2eZqg
2
+ 195161633 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999976993 -0.003071866 0.006052452 0.037627942 0.003092153 0.999989629 -0.003345382 0.206876054 -0.006042113 0.003364020 0.999976099 -0.240768750
3
+ 195328467 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999913514 -0.003797470 0.012590170 0.037371090 0.003835545 0.999988139 -0.003001482 0.258472399 -0.012578622 0.003049513 0.999916255 -0.264166944
4
+ 195495300 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999804139 -0.004247059 0.019329911 0.038498871 0.004311955 0.999985218 -0.003316826 0.307481199 -0.019315537 0.003399526 0.999807656 -0.276803884
5
+ 195662133 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999610126 -0.005408245 0.027391966 0.038009080 0.005529573 0.999975204 -0.004355530 0.361086350 -0.027367733 0.004505298 0.999615252 -0.278727233
6
+ 195828967 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999336481 -0.006239665 0.035883281 0.034735125 0.006456365 0.999961615 -0.005926326 0.417233500 -0.035844926 0.006154070 0.999338388 -0.270773664
7
+ 195995800 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999090433 -0.007104441 0.042045686 0.033419301 0.007331387 0.999959350 -0.005245875 0.473378445 -0.042006709 0.005549357 0.999101937 -0.261640758
8
+ 196196000 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.998589635 -0.007975463 0.052489758 0.032633405 0.008245680 0.999953806 -0.004933467 0.535259197 -0.052447986 0.005359322 0.998609304 -0.250263159
9
+ 196362833 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.998181939 -0.008554175 0.059662651 0.028043281 0.008866202 0.999948382 -0.004967103 0.576287383 -0.059617080 0.005487053 0.998206258 -0.238836996
10
+ 196529667 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.997807026 -0.009027892 0.065571494 0.020762443 0.009380648 0.999943137 -0.005073830 0.611177122 -0.065521955 0.005677806 0.997834980 -0.221059185
11
+ 196696500 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.997390270 -0.009296595 0.071597412 0.013903395 0.009683816 0.999940276 -0.005063088 0.639742116 -0.071546070 0.005743211 0.997420788 -0.192511620
12
+ 196863333 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.996968508 -0.009326802 0.077245183 0.007660940 0.009716687 0.999941885 -0.004673062 0.661375479 -0.077197112 0.005409463 0.997001171 -0.161790087
13
+ 197030167 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.996511698 -0.009557574 0.082903855 0.000657017 0.009994033 0.999938309 -0.004851241 0.672208252 -0.082852371 0.005662862 0.996545732 -0.126490956
14
+ 197230367 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.996102691 -0.010129508 0.087617576 -0.013035317 0.010638822 0.999929130 -0.005347892 0.673139255 -0.087557197 0.006259197 0.996139824 -0.073934910
15
+ 197397200 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995961964 -0.010034073 0.089213885 -0.025143057 0.010407046 0.999938965 -0.003716475 0.666403518 -0.089171149 0.004629921 0.996005535 -0.027130940
16
+ 197564033 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995849669 -0.009882330 0.090475440 -0.039230446 0.010261126 0.999940395 -0.003722523 0.652124926 -0.090433262 0.004635453 0.995891750 0.029309661
17
+ 197730867 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995756924 -0.010019524 0.091475174 -0.055068664 0.010371366 0.999940515 -0.003371752 0.630272532 -0.091435947 0.004306168 0.995801628 0.101088973
18
+ 197897700 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995885789 -0.009802628 0.090085521 -0.068138022 0.010283959 0.999935210 -0.004880426 0.600038118 -0.090031847 0.005786783 0.995922089 0.182818315
19
+ 198064533 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995840013 -0.010080555 0.090559855 -0.077047283 0.010255665 0.999946356 -0.001468501 0.569244350 -0.090540186 0.002391143 0.995889962 0.259090585
20
+ 198264733 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995863378 -0.010109165 0.090299115 -0.082291512 0.010262243 0.999946594 -0.001231105 0.534897586 -0.090281844 0.002152683 0.995913923 0.348298991
21
+ 198431567 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.996120393 -0.010062339 0.087423652 -0.082856431 0.010250443 0.999945998 -0.001702961 0.509342862 -0.087401800 0.002592486 0.996169746 0.427163225
22
+ 198598400 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.996832788 -0.009680700 0.078934617 -0.077252838 0.010075771 0.999938607 -0.004608278 0.480071534 -0.078885160 0.005389010 0.996869147 0.513721870
23
+ 198765233 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.997568011 -0.009045602 0.069110014 -0.060091805 0.009451369 0.999939978 -0.005546598 0.444060897 -0.069055691 0.006186293 0.997593641 0.602911453
24
+ 198932067 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.998626053 -0.008290987 0.051742285 -0.037270541 0.008407482 0.999962568 -0.002034174 0.410440195 -0.051723484 0.002466401 0.998658419 0.690111645
25
+ 199098900 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999656141 -0.006388811 0.025431594 -0.014759559 0.006480854 0.999972761 -0.003538476 0.375793364 -0.025408294 0.003702078 0.999670327 0.777147280
26
+ 199299100 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999947906 -0.003541293 -0.009570339 -0.002547566 0.003502194 0.999985456 -0.004099103 0.343015758 0.009584717 0.004065373 0.999945819 0.878377059
assets/pose_files/0f68374b76390082.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ https://www.youtube.com/watch?v=-aldZQifF2U
2
+ 103837067 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.792261064 -0.075338066 0.605513453 -2.753106466 0.083067641 0.996426642 0.015288832 0.122302125 -0.604501545 0.038185827 0.795688212 -1.791608923
3
+ 104003900 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.772824645 -0.077280566 0.629896700 -2.856354365 0.084460691 0.996253133 0.018602582 0.115028772 -0.628974140 0.038824979 0.776456118 -1.799931844
4
+ 104170733 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.752573133 -0.078496389 0.653813422 -2.957175162 0.085868694 0.996090353 0.020750597 0.112823623 -0.652886093 0.040525761 0.756371260 -1.810994932
5
+ 104337567 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.730659664 -0.077806436 0.678293884 -3.062095207 0.087071396 0.995992005 0.020455774 0.121362801 -0.677166939 0.044113789 0.734505892 -1.811030009
6
+ 104504400 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.706461906 -0.074765891 0.703790903 -3.177137127 0.086851373 0.996047020 0.018632174 0.129874960 -0.702401876 0.047962286 0.710162818 -1.792277939
7
+ 104671233 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.681627631 -0.071119592 0.728234708 -3.294052837 0.086847432 0.996093273 0.015989548 0.143049226 -0.726526856 0.052346393 0.685141265 -1.768016440
8
+ 104871433 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.649465024 -0.065721743 0.757545888 -3.442979418 0.086763002 0.996156216 0.012038323 0.166510317 -0.755425274 0.057908490 0.652670860 -1.724684703
9
+ 105038267 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.621174812 -0.061671518 0.781241655 -3.558270668 0.087205477 0.996146977 0.009298084 0.180848136 -0.778804898 0.062352814 0.624159455 -1.675155675
10
+ 105205100 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.591690660 -0.058407109 0.804046512 -3.660407702 0.087778911 0.996109724 0.007763143 0.186383384 -0.801371992 0.065984949 0.594515741 -1.621257762
11
+ 105371933 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.561377883 -0.055783633 0.825677335 -3.752373081 0.089227341 0.995989263 0.006624432 0.194667304 -0.822735310 0.069954179 0.564103782 -1.568545872
12
+ 105538767 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.531322777 -0.053783599 0.845460474 -3.836961453 0.091640897 0.995775461 0.005754844 0.205166191 -0.842198372 0.074421078 0.534006953 -1.522108893
13
+ 105705600 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.501615226 -0.052979972 0.863467038 -3.914896511 0.093892507 0.995560884 0.006539768 0.201989601 -0.859980464 0.077792637 0.504362881 -1.476983336
14
+ 105905800 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.466019660 -0.052434672 0.883219302 -4.004424531 0.098161966 0.995143771 0.007285428 0.209186293 -0.879312158 0.083303392 0.468903631 -1.424243874
15
+ 106072633 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.435604274 -0.051984914 0.898635924 -4.083866394 0.101657487 0.994785070 0.008269622 0.213039517 -0.894379497 0.087750785 0.438617289 -1.372398599
16
+ 106239467 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.404144615 -0.051714677 0.913232028 -4.163043658 0.104999557 0.994423509 0.009845568 0.210349578 -0.908648551 0.091909930 0.407320917 -1.308948274
17
+ 106406300 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.372057080 -0.052390546 0.926730156 -4.232320456 0.108426183 0.994023800 0.012664661 0.202983014 -0.921855330 0.095769837 0.375514120 -1.239784641
18
+ 106573133 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.338993609 -0.053159237 0.939285576 -4.297918560 0.111065105 0.993681848 0.016153777 0.191918628 -0.934209764 0.098845825 0.342755914 -1.169019518
19
+ 106739967 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.305330545 -0.054462686 0.950687706 -4.358390475 0.113597691 0.993316948 0.020420864 0.175834622 -0.945446372 0.101760812 0.309476852 -1.098186456
20
+ 106940167 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.264192373 -0.056177936 0.962832510 -4.426586558 0.117604628 0.992729127 0.025652671 0.163465045 -0.957272947 0.106456317 0.268878251 -1.008524756
21
+ 107107000 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.228878200 -0.056077410 0.971838534 -4.485196000 0.120451130 0.992298782 0.028890507 0.159180748 -0.965974271 0.110446639 0.233870149 -0.923927626
22
+ 107273833 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.192813009 -0.054965079 0.979694843 -4.547398479 0.122527294 0.991963863 0.031538919 0.153786345 -0.973555446 0.113958240 0.197998255 -0.835885482
23
+ 107440667 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.155427963 -0.053089641 0.986419618 -4.614075971 0.124575593 0.991636276 0.033741303 0.151495104 -0.979960740 0.117639467 0.160741687 -0.738650735
24
+ 107607500 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.117806904 -0.051662166 0.991691768 -4.672324721 0.126608506 0.991277337 0.036600262 0.144364476 -0.984932423 0.121244848 0.123320177 -0.639080225
25
+ 107774333 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.080108978 -0.050879046 0.995486736 -4.716803649 0.129048899 0.990820825 0.040255725 0.133545828 -0.988397181 0.125241622 0.085939527 -0.541709066
26
+ 107974533 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.034108389 -0.050325166 0.998150289 -4.758242879 0.132215530 0.990180492 0.045405328 0.118994547 -0.990633965 0.130422264 0.040427230 -0.433560831
assets/pose_files/2c80f9eb0d3b2bb4.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ https://www.youtube.com/watch?v=sLIFyXD2ujI
2
+ 77010267 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.991455436 0.021077231 -0.128731906 -0.119416025 -0.023393147 0.999590099 -0.016504617 0.019347615 0.128331259 0.019375037 0.991542101 -0.092957340
3
+ 77143733 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.988697171 0.023564288 -0.148062930 -0.142632843 -0.026350429 0.999510169 -0.016883694 0.023384606 0.147592559 0.020594381 0.988833785 -0.115024468
4
+ 77277200 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.985320270 0.026362764 -0.168668360 -0.165155176 -0.029713295 0.999407530 -0.017371174 0.028412548 0.168110475 0.022127863 0.985519767 -0.141363672
5
+ 77410667 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.981402338 0.029484071 -0.189684242 -0.188834577 -0.033494804 0.999277294 -0.017972585 0.034114674 0.189017251 0.023991773 0.981680632 -0.169959835
6
+ 77544133 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.976626754 0.033091862 -0.212379664 -0.212527322 -0.037670061 0.999136209 -0.017545532 0.036524990 0.211615592 0.025135791 0.977029681 -0.204014687
7
+ 77677600 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.970793188 0.035301749 -0.237306431 -0.229683177 -0.040712819 0.999009848 -0.017938551 0.042000619 0.236438200 0.027076038 0.971269190 -0.236341621
8
+ 77811067 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.964167893 0.038360216 -0.262504756 -0.246031807 -0.044686489 0.998835802 -0.018170038 0.047261141 0.261502147 0.029249383 0.964759588 -0.276015669
9
+ 77944533 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.956261098 0.040829532 -0.289650917 -0.252766079 -0.048421524 0.998644531 -0.019089982 0.054620904 0.288478881 0.032280345 0.956941962 -0.321621308
10
+ 78078000 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.946828246 0.042435788 -0.318928629 -0.251662187 -0.051583275 0.998462617 -0.020286530 0.062274582 0.317577451 0.035659242 0.947561622 -0.373008852
11
+ 78211467 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.935860872 0.044850163 -0.349503726 -0.247351407 -0.055966165 0.998195410 -0.021766055 0.072942153 0.347896785 0.039930381 0.936682105 -0.431307858
12
+ 78344933 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.923219025 0.046543088 -0.381445110 -0.234020172 -0.059361201 0.997996330 -0.021899769 0.078518674 0.379661530 0.042861324 0.924132049 -0.487708973
13
+ 78478400 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.909880757 0.048629351 -0.412009954 -0.218247042 -0.063676558 0.997708619 -0.022863906 0.088967126 0.409954011 0.047038805 0.910892427 -0.543114491
14
+ 78645233 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.891359746 0.050869841 -0.450433195 -0.185763327 -0.067926541 0.997452736 -0.021771761 0.093745158 0.448178291 0.050002839 0.892544627 -0.611223637
15
+ 78778700 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.877080619 0.053094681 -0.477399796 -0.163606786 -0.072203092 0.997152746 -0.021752052 0.102191599 0.474885583 0.053548045 0.878416896 -0.664313657
16
+ 78912167 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.863215029 0.055334236 -0.501794696 -0.143825544 -0.076518841 0.996831775 -0.021708660 0.111709364 0.499003649 0.057135988 0.864714324 -0.719103228
17
+ 79045633 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.849143744 0.056816563 -0.525096953 -0.122334566 -0.079860382 0.996578217 -0.021311868 0.118459005 0.522089303 0.060031284 0.850775540 -0.775464728
18
+ 79179100 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.835254073 0.059146367 -0.546673834 -0.101344556 -0.084243484 0.996225357 -0.020929486 0.126763936 0.543372452 0.063535146 0.837083995 -0.832841061
19
+ 79312567 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.822106183 0.061935693 -0.565955281 -0.082663275 -0.088697352 0.995860636 -0.019859029 0.133423045 0.562382638 0.066524968 0.824196696 -0.894100189
20
+ 79446033 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.808796465 0.064479858 -0.584543109 -0.062439027 -0.093630031 0.995411158 -0.019748187 0.145033126 0.580587387 0.070703052 0.811122298 -0.951788129
21
+ 79579500 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.794901192 0.066913344 -0.603037894 -0.037377988 -0.097949244 0.995015621 -0.018705536 0.153829045 0.598780453 0.073936157 0.797493160 -1.008854626
22
+ 79712967 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.781648815 0.069040783 -0.619885862 -0.013285614 -0.101820730 0.994646847 -0.017611075 0.161173621 0.615351617 0.076882906 0.784494340 -1.070102980
23
+ 79846433 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.768634439 0.072034292 -0.635619521 0.012177816 -0.107280776 0.994082153 -0.017072625 0.174403322 0.630628169 0.081312358 0.771813691 -1.132424688
24
+ 79979900 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.755315959 0.075072937 -0.651046753 0.040377463 -0.113015875 0.993455172 -0.016559631 0.189742153 0.645542622 0.086086370 0.758856952 -1.193296093
25
+ 80113367 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.742196620 0.078075886 -0.665618777 0.069020519 -0.118120082 0.992882252 -0.015246205 0.202486741 0.659690738 0.089938626 0.746136189 -1.254564875
26
+ 80280200 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.726252913 0.078639805 -0.682914674 0.104603927 -0.119984925 0.992686808 -0.013288199 0.209760187 0.676875412 0.091590062 0.730377257 -1.329527748
assets/pose_files/2f25826f0d0ef09a.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ https://www.youtube.com/watch?v=t-mlAKnESzQ
2
+ 167300000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.991854608 -0.011446482 0.126859888 0.441665245 0.012175850 0.999913514 -0.004975420 -0.056449972 -0.126791954 0.006479521 0.991908193 -0.456202583
3
+ 167467000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.991945148 -0.011409644 0.126153216 0.506974565 0.012122569 0.999914587 -0.004884966 -0.069421149 -0.126086697 0.006374919 0.991998732 -0.517325825
4
+ 167634000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.991982698 -0.011117884 0.125883758 0.561996906 0.011760475 0.999921322 -0.004362585 -0.080740919 -0.125825346 0.005808061 0.992035389 -0.570476997
5
+ 167801000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.992486775 -0.010566713 0.121894784 0.609598405 0.011126007 0.999930441 -0.003908583 -0.087745179 -0.121845007 0.005235420 0.992535353 -0.617968773
6
+ 167968000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.993018925 -0.010175236 0.117515638 0.655818155 0.010723241 0.999934375 -0.004031916 -0.098194076 -0.117466904 0.005263917 0.993062854 -0.668642428
7
+ 168134000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.993561447 -0.009874708 0.112863302 0.703081750 0.010385432 0.999938309 -0.003938090 -0.108951006 -0.112817451 0.005084869 0.993602693 -0.730919086
8
+ 168335000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.994062483 -0.010140099 0.108337507 0.763671544 0.010665529 0.999934018 -0.004271581 -0.104826596 -0.108287044 0.005401695 0.994104981 -0.820197463
9
+ 168501000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.994565487 -0.010703885 0.103560977 0.813888661 0.011249267 0.999925733 -0.004683641 -0.095187847 -0.103503153 0.005823173 0.994612098 -0.890086513
10
+ 168668000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.994905293 -0.010604435 0.100254886 0.865790711 0.011124965 0.999927402 -0.004634405 -0.086100908 -0.100198455 0.005726126 0.994951010 -0.962092459
11
+ 168835000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.995332658 -0.010553311 0.095924698 0.905775925 0.011022467 0.999929726 -0.004362300 -0.075394333 -0.095871925 0.005399267 0.995379031 -1.025694236
12
+ 169002000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.995705128 -0.010036361 0.092035979 0.944576676 0.010483396 0.999935448 -0.004374997 -0.058609663 -0.091986135 0.005321057 0.995746076 -1.081030198
13
+ 169169000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.996029556 -0.009414902 0.088523701 0.977259045 0.009879347 0.999939620 -0.004809874 -0.042104006 -0.088473074 0.005665333 0.996062458 -1.127427189
14
+ 169369000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.996554554 -0.009220830 0.082425818 1.013619994 0.009555685 0.999947608 -0.003668923 -0.018710063 -0.082387671 0.004443917 0.996590436 -1.175459833
15
+ 169536000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.996902764 -0.008823335 0.078147218 1.041872487 0.009157063 0.999950409 -0.003913174 0.011864113 -0.078108817 0.004616653 0.996934175 -1.202554477
16
+ 169703000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.997331142 -0.008540447 0.072509713 1.068829435 0.008805763 0.999955654 -0.003340158 0.047323405 -0.072477967 0.003969747 0.997362137 -1.214284849
17
+ 169870000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.997695088 -0.008219596 0.067356706 1.095289713 0.008451649 0.999959290 -0.003160893 0.090756953 -0.067327984 0.003722883 0.997723937 -1.225599061
18
+ 170036000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.997950792 -0.008326715 0.063442364 1.112874795 0.008502332 0.999960721 -0.002498670 0.132311648 -0.063419066 0.003032958 0.997982383 -1.233305313
19
+ 170203000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998197436 -0.008063688 0.059471287 1.125840626 0.008245971 0.999962032 -0.002820280 0.178666038 -0.059446286 0.003305595 0.998226047 -1.240809047
20
+ 170403000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998360872 -0.007677370 0.056714825 1.144370603 0.007830821 0.999966264 -0.002483922 0.248055953 -0.056693841 0.002923975 0.998387337 -1.246230780
21
+ 170570000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998471320 -0.007715963 0.054730706 1.159189486 0.007868989 0.999965727 -0.002581036 0.310163907 -0.054708913 0.003007766 0.998497844 -1.245661417
22
+ 170737000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998507679 -0.007751614 0.054058932 1.165836593 0.007918007 0.999964535 -0.002864495 0.366963293 -0.054034811 0.003288259 0.998533666 -1.241523115
23
+ 170904000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998536825 -0.007817166 0.053508084 1.175042384 0.008036798 0.999960124 -0.003890704 0.423587941 -0.053475536 0.004315045 0.998559833 -1.224956309
24
+ 171071000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998537302 -0.007878507 0.053490099 1.177855699 0.008138275 0.999956131 -0.004640296 0.484100754 -0.053451192 0.005068825 0.998557627 -1.202906710
25
+ 171238000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998549581 -0.007872007 0.053261518 1.180678596 0.008100130 0.999958932 -0.004068544 0.548228374 -0.053227302 0.004494068 0.998572290 -1.184901744
26
+ 171438000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998469293 -0.008281939 0.054685175 1.181414517 0.008542870 0.999953210 -0.004539483 0.618089736 -0.054645021 0.004999703 0.998493314 -1.159911786
assets/pose_files/3f79dc32d575bcdc.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ https://www.youtube.com/watch?v=1qVpRlWxam4
2
+ 87387300 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.998291552 0.018666664 0.055367537 -0.431348097 -0.018963017 0.999808490 0.004831879 0.070488701 -0.055266738 -0.005873560 0.998454332 -0.848986490
3
+ 87554133 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.997851610 0.017319093 0.063184217 -0.464904483 -0.017837154 0.999811709 0.007644337 0.068569507 -0.063039921 -0.008754940 0.997972608 -0.876888649
4
+ 87720967 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.997675776 0.016262729 0.066170901 -0.486385324 -0.016915560 0.999813497 0.009317505 0.069230577 -0.066007033 -0.010415167 0.997764826 -0.912234761
5
+ 87887800 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.997801185 0.015721651 0.064386748 -0.496646826 -0.016416471 0.999812424 0.010276551 0.072350447 -0.064213105 -0.011310958 0.997872114 -0.952896762
6
+ 88054633 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.998319149 0.016226118 0.055637561 -0.489176520 -0.016823635 0.999805570 0.010287891 0.076572802 -0.055459809 -0.011206625 0.998398006 -1.004124831
7
+ 88221467 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.998926342 0.017387087 0.042939600 -0.475558168 -0.017787032 0.999801755 0.008949692 0.085218470 -0.042775478 -0.009703852 0.999037564 -1.053459508
8
+ 88421667 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999537408 0.020246139 0.022695299 -0.447975333 -0.020412439 0.999766290 0.007119950 0.094693503 -0.022545842 -0.007579923 0.999717057 -1.119813421
9
+ 88588500 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999718606 0.023644496 0.001895716 -0.414069396 -0.023654999 0.999703765 0.005723978 0.102792865 -0.001759814 -0.005767211 0.999981821 -1.180436614
10
+ 88755333 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999452412 0.027961638 -0.017690983 -0.387314056 -0.027902454 0.999604225 0.003583529 0.113408687 0.017784184 -0.003087946 0.999837101 -1.226234160
11
+ 88922167 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.998792231 0.032815013 -0.036568113 -0.365929800 -0.032777511 0.999461353 0.001624677 0.124849793 0.036601726 -0.000424103 0.999329865 -1.267691893
12
+ 89089000 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.997681975 0.038422074 -0.056164715 -0.342733324 -0.038413495 0.999261141 0.001232749 0.131945819 0.056170583 0.000927592 0.998420775 -1.304181539
13
+ 89255833 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.995796800 0.044428598 -0.080092981 -0.304097608 -0.044486329 0.999009430 0.001064335 0.139304626 0.080060929 0.002503182 0.996786833 -1.346184197
14
+ 89456033 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.992265880 0.051900670 -0.112759680 -0.242293511 -0.051975533 0.998645782 0.002277754 0.141999546 0.112725191 0.003600606 0.993619680 -1.403491443
15
+ 89622867 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.988518834 0.057282198 -0.139818728 -0.191310403 -0.057411496 0.998345733 0.003111851 0.144317113 0.139765680 0.004951079 0.990172207 -1.446433054
16
+ 89789700 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.984156251 0.062501073 -0.165921792 -0.143876127 -0.062264379 0.998037636 0.006632906 0.137240925 0.166010767 0.003803201 0.986116588 -1.485275757
17
+ 89956533 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.979292631 0.066839822 -0.191097870 -0.099323029 -0.066578977 0.997750700 0.007792730 0.139573975 0.191188902 0.005091738 0.981540024 -1.518326120
18
+ 90123367 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.973332286 0.070821166 -0.218194127 -0.042629488 -0.070645541 0.997464299 0.008616162 0.140175484 0.218251050 0.007028054 0.975867331 -1.554681376
19
+ 90290200 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.966279447 0.074934490 -0.246350974 0.028017454 -0.074612871 0.997155666 0.010653359 0.133648148 0.246448576 0.008086831 0.969122112 -1.595505702
20
+ 90490400 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.957038641 0.079540767 -0.278837353 0.115624588 -0.079602204 0.996764660 0.011121323 0.132533757 0.278819799 0.011552530 0.960273921 -1.622873069
21
+ 90657233 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.948620677 0.083499879 -0.305199176 0.195920884 -0.084326349 0.996382892 0.010498469 0.132694923 0.304971874 0.015777269 0.952230692 -1.640734525
22
+ 90824067 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.940162480 0.087165222 -0.329388469 0.271130852 -0.089630231 0.995945156 0.007725847 0.141518901 0.328726262 0.022259612 0.944162905 -1.645387258
23
+ 90990900 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.932207108 0.091082737 -0.350276887 0.339575901 -0.095441416 0.995423317 0.004838228 0.149739069 0.349114448 0.028920690 0.936633706 -1.648637528
24
+ 91157733 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.925111592 0.095387921 -0.367518306 0.398473437 -0.101803973 0.994802594 0.001937620 0.156527229 0.365792990 0.035622310 0.930014253 -1.650611039
25
+ 91324567 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.918998003 0.099750437 -0.381434739 0.448520864 -0.108052738 0.994145095 -0.000350893 0.159817007 0.379166484 0.041537538 0.924395680 -1.652156379
26
+ 91524767 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.913031578 0.105410583 -0.394032896 0.493990424 -0.115993932 0.993245184 -0.003064641 0.163621223 0.391048223 0.048503540 0.919091225 -1.650421710
assets/pose_files/4a2d6753676df096.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ https://www.youtube.com/watch?v=mGFQkgadzRQ
2
+ 123373000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.998857915 0.002672890 -0.047704928 -0.388737999 -0.002653247 0.999996364 0.000475094 -0.004533370 0.047706023 -0.000347978 0.998861372 0.139698036
3
+ 123581000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.997534156 0.002900333 -0.070122920 -0.417036011 -0.002881077 0.999995768 0.000375740 -0.005476288 0.070123710 -0.000172784 0.997538269 0.134851393
4
+ 123790000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.995756805 0.003055056 -0.091973245 -0.444572396 -0.003032017 0.999995351 0.000390221 -0.006227409 0.091974013 -0.000109701 0.995761395 0.129660844
5
+ 123999000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.993462563 0.003229393 -0.114112593 -0.472377562 -0.003208589 0.999994814 0.000365978 -0.005932507 0.114113182 0.000002555 0.993467748 0.123959606
6
+ 124207000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.990603268 0.003450655 -0.136723205 -0.500173495 -0.003429445 0.999994040 0.000390680 -0.006082111 0.136723727 0.000081876 0.990609229 0.117333920
7
+ 124416000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.987169921 0.003684058 -0.159630775 -0.528663584 -0.003696360 0.999993145 0.000219867 -0.006000823 0.159630492 0.000373007 0.987176776 0.110039363
8
+ 124666000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.982509255 0.003945273 -0.186171964 -0.561235187 -0.003999375 0.999992013 0.000084966 -0.007105507 0.186170816 0.000661092 0.982517183 0.100220962
9
+ 124874000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.978331029 0.004287674 -0.207002580 -0.586641713 -0.004329238 0.999990582 0.000252201 -0.009076863 0.207001716 0.000649428 0.978340387 0.091930702
10
+ 125083000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.973817825 0.004402113 -0.227287307 -0.611123286 -0.004493149 0.999989927 0.000116853 -0.009074310 0.227285519 0.000907443 0.973827720 0.083304516
11
+ 125292000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.968695283 0.004357880 -0.248214558 -0.636185581 -0.004474392 0.999989986 0.000094731 -0.008011808 0.248212487 0.001018844 0.968705058 0.074442714
12
+ 125500000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.962747812 0.004391920 -0.270365149 -0.662786287 -0.004570082 0.999989569 -0.000029452 -0.006714359 0.270362198 0.001263946 0.962757826 0.064526619
13
+ 125709000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.955500066 0.004409539 -0.294957876 -0.691555299 -0.004778699 0.999988437 -0.000530787 -0.004279872 0.294952124 0.001916682 0.955510139 0.052776269
14
+ 125959000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.946343422 0.004634521 -0.323129416 -0.724457169 -0.005231380 0.999985814 -0.000978639 -0.001732190 0.323120296 0.002616541 0.946354270 0.037519903
15
+ 126167000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.939147174 0.004834646 -0.343481004 -0.749049950 -0.005533603 0.999984145 -0.001054784 -0.002170622 0.343470454 0.002891285 0.939159036 0.026149102
16
+ 126376000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.931589127 0.004859613 -0.363480538 -0.772472331 -0.005669596 0.999983251 -0.001161554 -0.002324411 0.363468796 0.003142879 0.931601048 0.014526636
17
+ 126584000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.923323810 0.004994850 -0.383989871 -0.796474752 -0.005933880 0.999981582 -0.001260800 -0.001656055 0.383976519 0.003442676 0.923336446 0.001805353
18
+ 126793000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.914980114 0.005216786 -0.403465271 -0.819526045 -0.006272262 0.999979496 -0.001294577 -0.001858109 0.403450221 0.003715152 0.914994061 -0.010564998
19
+ 127002000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.906303227 0.005258658 -0.422595352 -0.842292418 -0.006397304 0.999978721 -0.001276282 -0.001621911 0.422579646 0.003860169 0.906317592 -0.023561723
20
+ 127252000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.894903898 0.005439967 -0.446225733 -0.870190326 -0.006754198 0.999976277 -0.001354740 -0.001280526 0.446207762 0.004226258 0.894919395 -0.040196739
21
+ 127460000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.884573221 0.005480251 -0.466369092 -0.894082633 -0.006980692 0.999974549 -0.001489853 -0.000948027 0.466349065 0.004573463 0.884588957 -0.055396928
22
+ 127669000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.873941839 0.005343055 -0.486001104 -0.917185865 -0.007038773 0.999973834 -0.001663705 -0.000687769 0.485979497 0.004874832 0.873956621 -0.070420475
23
+ 127877000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.862660766 0.005402398 -0.505754173 -0.939888304 -0.007276668 0.999972045 -0.001730187 -0.000489221 0.505730629 0.005172769 0.862675905 -0.086411685
24
+ 128086000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.851201892 0.005282878 -0.524811804 -0.961420775 -0.007401088 0.999970734 -0.001938020 -0.000533338 0.524786234 0.005533825 0.851216078 -0.102931062
25
+ 128295000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.839382112 0.005324626 -0.543515682 -0.982849443 -0.007655066 0.999968648 -0.002025823 0.001148876 0.543487847 0.005861088 0.839396596 -0.119132721
26
+ 128545000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.825496972 0.005258156 -0.564382017 -1.006530766 -0.007933038 0.999965906 -0.002286965 0.002382303 0.564350784 0.006365147 0.825510561 -0.138386240
assets/reference_videos/0bf152ef84195293.mp4 ADDED
Binary file (231 kB). View file
 
assets/reference_videos/0c11dbe781b1c11c.mp4 ADDED
Binary file (219 kB). View file
 
assets/reference_videos/0c9b371cc6225682.mp4 ADDED
Binary file (195 kB). View file
 
assets/reference_videos/0f47577ab3441480.mp4 ADDED
Binary file (161 kB). View file
 
assets/reference_videos/0f68374b76390082.mp4 ADDED
Binary file (299 kB). View file
 
assets/reference_videos/2c80f9eb0d3b2bb4.mp4 ADDED
Binary file (173 kB). View file
 
assets/reference_videos/2f25826f0d0ef09a.mp4 ADDED
Binary file (195 kB). View file
 
assets/reference_videos/3f79dc32d575bcdc.mp4 ADDED
Binary file (148 kB). View file
 
assets/reference_videos/4a2d6753676df096.mp4 ADDED
Binary file (229 kB). View file
 
cameractrl/data/dataset.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import json
4
+ import torch
5
+
6
+ import torch.nn as nn
7
+ import torchvision.transforms as transforms
8
+ import torchvision.transforms.functional as F
9
+ import numpy as np
10
+
11
+ from decord import VideoReader
12
+ from torch.utils.data.dataset import Dataset
13
+ from packaging import version as pver
14
+
15
+
16
+ class RandomHorizontalFlipWithPose(nn.Module):
17
+ def __init__(self, p=0.5):
18
+ super(RandomHorizontalFlipWithPose, self).__init__()
19
+ self.p = p
20
+
21
+ def get_flip_flag(self, n_image):
22
+ return torch.rand(n_image) < self.p
23
+
24
+ def forward(self, image, flip_flag=None):
25
+ n_image = image.shape[0]
26
+ if flip_flag is not None:
27
+ assert n_image == flip_flag.shape[0]
28
+ else:
29
+ flip_flag = self.get_flip_flag(n_image)
30
+
31
+ ret_images = []
32
+ for fflag, img in zip(flip_flag, image):
33
+ if fflag:
34
+ ret_images.append(F.hflip(img))
35
+ else:
36
+ ret_images.append(img)
37
+ return torch.stack(ret_images, dim=0)
38
+
39
+
40
+ class Camera(object):
41
+ def __init__(self, entry):
42
+ fx, fy, cx, cy = entry[1:5]
43
+ self.fx = fx
44
+ self.fy = fy
45
+ self.cx = cx
46
+ self.cy = cy
47
+ w2c_mat = np.array(entry[7:]).reshape(3, 4)
48
+ w2c_mat_4x4 = np.eye(4)
49
+ w2c_mat_4x4[:3, :] = w2c_mat
50
+ self.w2c_mat = w2c_mat_4x4
51
+ self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
52
+
53
+
54
+ def custom_meshgrid(*args):
55
+ # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
56
+ if pver.parse(torch.__version__) < pver.parse('1.10'):
57
+ return torch.meshgrid(*args)
58
+ else:
59
+ return torch.meshgrid(*args, indexing='ij')
60
+
61
+
62
+ def ray_condition(K, c2w, H, W, device, flip_flag=None):
63
+ # c2w: B, V, 4, 4
64
+ # K: B, V, 4
65
+
66
+ B, V = K.shape[:2]
67
+
68
+ j, i = custom_meshgrid(
69
+ torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
70
+ torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
71
+ )
72
+ i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW]
73
+ j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW]
74
+
75
+ n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0
76
+ if n_flip > 0:
77
+ j_flip, i_flip = custom_meshgrid(
78
+ torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
79
+ torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype)
80
+ )
81
+ i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
82
+ j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
83
+ i[:, flip_flag, ...] = i_flip
84
+ j[:, flip_flag, ...] = j_flip
85
+
86
+ fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
87
+
88
+ zs = torch.ones_like(i) # [B, V, HxW]
89
+ xs = (i - cx) / fx * zs
90
+ ys = (j - cy) / fy * zs
91
+ zs = zs.expand_as(ys)
92
+
93
+ directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
94
+ directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
95
+
96
+ rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, HW, 3
97
+ rays_o = c2w[..., :3, 3] # B, V, 3
98
+ rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, HW, 3
99
+ # c2w @ dirctions
100
+ rays_dxo = torch.linalg.cross(rays_o, rays_d) # B, V, HW, 3
101
+ plucker = torch.cat([rays_dxo, rays_d], dim=-1)
102
+ plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
103
+ # plucker = plucker.permute(0, 1, 4, 2, 3)
104
+ return plucker
105
+
106
+
107
+ class RealEstate10K(Dataset):
108
+ def __init__(
109
+ self,
110
+ root_path,
111
+ annotation_json,
112
+ sample_stride=4,
113
+ sample_n_frames=16,
114
+ sample_size=[256, 384],
115
+ is_image=False,
116
+ ):
117
+ self.root_path = root_path
118
+ self.sample_stride = sample_stride
119
+ self.sample_n_frames = sample_n_frames
120
+ self.is_image = is_image
121
+
122
+ self.dataset = json.load(open(os.path.join(root_path, annotation_json), 'r'))
123
+ self.length = len(self.dataset)
124
+
125
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
126
+ pixel_transforms = [transforms.Resize(sample_size),
127
+ transforms.RandomHorizontalFlip(),
128
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
129
+
130
+ self.pixel_transforms = transforms.Compose(pixel_transforms)
131
+
132
+ def load_video_reader(self, idx):
133
+ video_dict = self.dataset[idx]
134
+
135
+ video_path = os.path.join(self.root_path, video_dict['clip_path'])
136
+ video_reader = VideoReader(video_path)
137
+ return video_reader, video_dict['caption']
138
+
139
+ def get_batch(self, idx):
140
+ video_reader, video_caption = self.load_video_reader(idx)
141
+ total_frames = len(video_reader)
142
+
143
+ if self.is_image:
144
+ frame_indice = [random.randint(0, total_frames - 1)]
145
+ else:
146
+ if isinstance(self.sample_stride, int):
147
+ current_sample_stride = self.sample_stride
148
+ else:
149
+ assert len(self.sample_stride) == 2
150
+ assert (self.sample_stride[0] >= 1) and (self.sample_stride[1] >= self.sample_stride[0])
151
+ current_sample_stride = random.randint(self.sample_stride[0], self.sample_stride[1])
152
+
153
+ cropped_length = self.sample_n_frames * current_sample_stride
154
+ start_frame_ind = random.randint(0, max(0, total_frames - cropped_length - 1))
155
+ end_frame_ind = min(start_frame_ind + cropped_length, total_frames)
156
+
157
+ assert end_frame_ind - start_frame_ind >= self.sample_n_frames
158
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sample_n_frames, dtype=int)
159
+
160
+ pixel_values = torch.from_numpy(video_reader.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous()
161
+ pixel_values = pixel_values / 255.
162
+
163
+ if self.is_image:
164
+ pixel_values = pixel_values[0]
165
+
166
+ return pixel_values, video_caption
167
+
168
+ def __len__(self):
169
+ return self.length
170
+
171
+ def __getitem__(self, idx):
172
+ while True:
173
+ try:
174
+ video, video_caption = self.get_batch(idx)
175
+ break
176
+
177
+ except Exception as e:
178
+ idx = random.randint(0, self.length - 1)
179
+
180
+ video = self.pixel_transforms(video)
181
+ sample = dict(pixel_values=video, caption=video_caption)
182
+
183
+ return sample
184
+
185
+
186
+ class RealEstate10KPose(Dataset):
187
+ def __init__(
188
+ self,
189
+ root_path,
190
+ annotation_json,
191
+ sample_stride=4,
192
+ minimum_sample_stride=1,
193
+ sample_n_frames=16,
194
+ relative_pose=False,
195
+ zero_t_first_frame=False,
196
+ sample_size=[256, 384],
197
+ rescale_fxy=False,
198
+ shuffle_frames=False,
199
+ use_flip=False,
200
+ return_clip_name=False,
201
+ ):
202
+ self.root_path = root_path
203
+ self.relative_pose = relative_pose
204
+ self.zero_t_first_frame = zero_t_first_frame
205
+ self.sample_stride = sample_stride
206
+ self.minimum_sample_stride = minimum_sample_stride
207
+ self.sample_n_frames = sample_n_frames
208
+ self.return_clip_name = return_clip_name
209
+
210
+ self.dataset = json.load(open(os.path.join(root_path, annotation_json), 'r'))
211
+ self.length = len(self.dataset)
212
+
213
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
214
+ self.sample_size = sample_size
215
+ if use_flip:
216
+ pixel_transforms = [transforms.Resize(sample_size),
217
+ RandomHorizontalFlipWithPose(),
218
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
219
+ else:
220
+ pixel_transforms = [transforms.Resize(sample_size),
221
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
222
+ self.rescale_fxy = rescale_fxy
223
+ self.sample_wh_ratio = sample_size[1] / sample_size[0]
224
+
225
+ self.pixel_transforms = pixel_transforms
226
+ self.shuffle_frames = shuffle_frames
227
+ self.use_flip = use_flip
228
+
229
+ def get_relative_pose(self, cam_params):
230
+ abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
231
+ abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
232
+ source_cam_c2w = abs_c2ws[0]
233
+ if self.zero_t_first_frame:
234
+ cam_to_origin = 0
235
+ else:
236
+ cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3])
237
+ target_cam_c2w = np.array([
238
+ [1, 0, 0, 0],
239
+ [0, 1, 0, -cam_to_origin],
240
+ [0, 0, 1, 0],
241
+ [0, 0, 0, 1]
242
+ ])
243
+ abs2rel = target_cam_c2w @ abs_w2cs[0]
244
+ ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
245
+ ret_poses = np.array(ret_poses, dtype=np.float32)
246
+ return ret_poses
247
+
248
+ def load_video_reader(self, idx):
249
+ video_dict = self.dataset[idx]
250
+
251
+ video_path = os.path.join(self.root_path, video_dict['clip_path'])
252
+ video_reader = VideoReader(video_path)
253
+ return video_dict['clip_name'], video_reader, video_dict['caption']
254
+
255
+ def load_cameras(self, idx):
256
+ video_dict = self.dataset[idx]
257
+ pose_file = os.path.join(self.root_path, video_dict['pose_file'])
258
+ with open(pose_file, 'r') as f:
259
+ poses = f.readlines()
260
+ poses = [pose.strip().split(' ') for pose in poses[1:]]
261
+ cam_params = [[float(x) for x in pose] for pose in poses]
262
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
263
+ return cam_params
264
+
265
+ def get_batch(self, idx):
266
+ clip_name, video_reader, video_caption = self.load_video_reader(idx)
267
+ cam_params = self.load_cameras(idx)
268
+ assert len(cam_params) >= self.sample_n_frames
269
+ total_frames = len(cam_params)
270
+
271
+ current_sample_stride = self.sample_stride
272
+
273
+ if total_frames < self.sample_n_frames * current_sample_stride:
274
+ maximum_sample_stride = int(total_frames // self.sample_n_frames)
275
+ current_sample_stride = random.randint(self.minimum_sample_stride, maximum_sample_stride)
276
+
277
+ cropped_length = self.sample_n_frames * current_sample_stride
278
+ start_frame_ind = random.randint(0, max(0, total_frames - cropped_length - 1))
279
+ end_frame_ind = min(start_frame_ind + cropped_length, total_frames)
280
+
281
+ assert end_frame_ind - start_frame_ind >= self.sample_n_frames
282
+ frame_indices = np.linspace(start_frame_ind, end_frame_ind - 1, self.sample_n_frames, dtype=int)
283
+
284
+ condition_image_ind = random.sample(list(set(range(total_frames)) - set(frame_indices.tolist())), 1)
285
+ condition_image = torch.from_numpy(video_reader.get_batch(condition_image_ind).asnumpy()).permute(0, 3, 1, 2).contiguous()
286
+ condition_image = condition_image / 255.
287
+
288
+ if self.shuffle_frames:
289
+ perm = np.random.permutation(self.sample_n_frames)
290
+ frame_indices = frame_indices[perm]
291
+
292
+ pixel_values = torch.from_numpy(video_reader.get_batch(frame_indices).asnumpy()).permute(0, 3, 1, 2).contiguous()
293
+ pixel_values = pixel_values / 255.
294
+
295
+ cam_params = [cam_params[indice] for indice in frame_indices]
296
+ if self.rescale_fxy:
297
+ ori_h, ori_w = pixel_values.shape[-2:]
298
+ ori_wh_ratio = ori_w / ori_h
299
+ if ori_wh_ratio > self.sample_wh_ratio: # rescale fx
300
+ resized_ori_w = self.sample_size[0] * ori_wh_ratio
301
+ for cam_param in cam_params:
302
+ cam_param.fx = resized_ori_w * cam_param.fx / self.sample_size[1]
303
+ else: # rescale fy
304
+ resized_ori_h = self.sample_size[1] / ori_wh_ratio
305
+ for cam_param in cam_params:
306
+ cam_param.fy = resized_ori_h * cam_param.fy / self.sample_size[0]
307
+ intrinsics = np.asarray([[cam_param.fx * self.sample_size[1],
308
+ cam_param.fy * self.sample_size[0],
309
+ cam_param.cx * self.sample_size[1],
310
+ cam_param.cy * self.sample_size[0]]
311
+ for cam_param in cam_params], dtype=np.float32)
312
+ intrinsics = torch.as_tensor(intrinsics)[None] # [1, n_frame, 4]
313
+ if self.relative_pose:
314
+ c2w_poses = self.get_relative_pose(cam_params)
315
+ else:
316
+ c2w_poses = np.array([cam_param.c2w_mat for cam_param in cam_params], dtype=np.float32)
317
+ c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4]
318
+ if self.use_flip:
319
+ flip_flag = self.pixel_transforms[1].get_flip_flag(self.sample_n_frames)
320
+ else:
321
+ flip_flag = torch.zeros(self.sample_n_frames, dtype=torch.bool, device=c2w.device)
322
+ plucker_embedding = ray_condition(intrinsics, c2w, self.sample_size[0], self.sample_size[1], device='cpu',
323
+ flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous()
324
+
325
+ return pixel_values, condition_image, plucker_embedding, video_caption, flip_flag, clip_name
326
+
327
+ def __len__(self):
328
+ return self.length
329
+
330
+ def __getitem__(self, idx):
331
+ while True:
332
+ try:
333
+ video, condition_image, plucker_embedding, video_caption, flip_flag, clip_name = self.get_batch(idx)
334
+ break
335
+
336
+ except Exception as e:
337
+ idx = random.randint(0, self.length - 1)
338
+
339
+ if self.use_flip:
340
+ video = self.pixel_transforms[0](video)
341
+ video = self.pixel_transforms[1](video, flip_flag)
342
+ for transform in self.pixel_transforms[2:]:
343
+ video = transform(video)
344
+ else:
345
+ for transform in self.pixel_transforms:
346
+ video = transform(video)
347
+ for transform in self.pixel_transforms:
348
+ condition_image = transform(condition_image)
349
+ if self.return_clip_name:
350
+ sample = dict(pixel_values=video, condition_image=condition_image, plucker_embedding=plucker_embedding, video_caption=video_caption, clip_name=clip_name)
351
+ else:
352
+ sample = dict(pixel_values=video, condition_image=condition_image, plucker_embedding=plucker_embedding, video_caption=video_caption)
353
+
354
+ return sample
355
+
cameractrl/models/attention.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional
3
+ from diffusers.models.attention import TemporalBasicTransformerBlock, _chunked_feed_forward
4
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
5
+
6
+
7
+ @maybe_allow_in_graph
8
+ class TemporalPoseCondTransformerBlock(TemporalBasicTransformerBlock):
9
+ def forward(
10
+ self,
11
+ hidden_states: torch.FloatTensor, # [bs * num_frame, h * w, c]
12
+ num_frames: int,
13
+ encoder_hidden_states: Optional[torch.FloatTensor] = None, # [bs * h * w, 1, c]
14
+ pose_feature: Optional[torch.FloatTensor] = None, # [bs, c, n_frame, h, w]
15
+ ) -> torch.FloatTensor:
16
+ # Notice that normalization is always applied before the real computation in the following blocks.
17
+ # 0. Self-Attention
18
+
19
+ batch_frames, seq_length, channels = hidden_states.shape
20
+ batch_size = batch_frames // num_frames
21
+
22
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
23
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
24
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) # [bs * h * w, frame, c]
25
+
26
+ residual = hidden_states
27
+ hidden_states = self.norm_in(hidden_states)
28
+
29
+ if self._chunk_size is not None:
30
+ hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
31
+ else:
32
+ hidden_states = self.ff_in(hidden_states)
33
+
34
+ if self.is_res:
35
+ hidden_states = hidden_states + residual
36
+
37
+ norm_hidden_states = self.norm1(hidden_states)
38
+ pose_feature = pose_feature.permute(0, 3, 4, 2, 1).reshape(batch_size * seq_length, num_frames, -1)
39
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None, pose_feature=pose_feature)
40
+ hidden_states = attn_output + hidden_states
41
+
42
+ # 3. Cross-Attention
43
+ if self.attn2 is not None:
44
+ norm_hidden_states = self.norm2(hidden_states)
45
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states, pose_feature=pose_feature)
46
+ hidden_states = attn_output + hidden_states
47
+
48
+ # 4. Feed-forward
49
+ norm_hidden_states = self.norm3(hidden_states)
50
+
51
+ if self._chunk_size is not None:
52
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
53
+ else:
54
+ ff_output = self.ff(norm_hidden_states)
55
+
56
+ if self.is_res:
57
+ hidden_states = ff_output + hidden_states
58
+ else:
59
+ hidden_states = ff_output
60
+
61
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
62
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
63
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
64
+
65
+ return hidden_states
cameractrl/models/attention_processor.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.nn.init as init
5
+ import logging
6
+ from diffusers.models.attention import Attention
7
+ from diffusers.utils import USE_PEFT_BACKEND, is_xformers_available
8
+ from typing import Optional, Callable
9
+
10
+ from einops import rearrange
11
+
12
+ if is_xformers_available():
13
+ import xformers
14
+ import xformers.ops
15
+ else:
16
+ xformers = None
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class AttnProcessor:
22
+ r"""
23
+ Default processor for performing attention-related computations.
24
+ """
25
+
26
+ def __call__(
27
+ self,
28
+ attn: Attention,
29
+ hidden_states: torch.FloatTensor,
30
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
31
+ attention_mask: Optional[torch.FloatTensor] = None,
32
+ temb: Optional[torch.FloatTensor] = None,
33
+ scale: float = 1.0,
34
+ pose_feature=None
35
+ ) -> torch.Tensor:
36
+ residual = hidden_states
37
+
38
+ args = () if USE_PEFT_BACKEND else (scale,)
39
+
40
+ if attn.spatial_norm is not None:
41
+ hidden_states = attn.spatial_norm(hidden_states, temb)
42
+
43
+ input_ndim = hidden_states.ndim
44
+
45
+ if input_ndim == 4:
46
+ batch_size, channel, height, width = hidden_states.shape
47
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
48
+
49
+ batch_size, sequence_length, _ = (
50
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
51
+ )
52
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
53
+
54
+ if attn.group_norm is not None:
55
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
56
+
57
+ query = attn.to_q(hidden_states, *args)
58
+
59
+ if encoder_hidden_states is None:
60
+ encoder_hidden_states = hidden_states
61
+ elif attn.norm_cross:
62
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
63
+
64
+ key = attn.to_k(encoder_hidden_states, *args)
65
+ value = attn.to_v(encoder_hidden_states, *args)
66
+
67
+ query = attn.head_to_batch_dim(query)
68
+ key = attn.head_to_batch_dim(key)
69
+ value = attn.head_to_batch_dim(value)
70
+
71
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
72
+ hidden_states = torch.bmm(attention_probs, value)
73
+ hidden_states = attn.batch_to_head_dim(hidden_states)
74
+
75
+ # linear proj
76
+ hidden_states = attn.to_out[0](hidden_states, *args)
77
+ # dropout
78
+ hidden_states = attn.to_out[1](hidden_states)
79
+
80
+ if input_ndim == 4:
81
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
82
+
83
+ if attn.residual_connection:
84
+ hidden_states = hidden_states + residual
85
+
86
+ hidden_states = hidden_states / attn.rescale_output_factor
87
+
88
+ return hidden_states
89
+
90
+
91
+ class AttnProcessor2_0:
92
+ r"""
93
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
94
+ """
95
+
96
+ def __init__(self):
97
+ if not hasattr(F, "scaled_dot_product_attention"):
98
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
99
+
100
+ def __call__(
101
+ self,
102
+ attn: Attention,
103
+ hidden_states: torch.FloatTensor,
104
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
105
+ attention_mask: Optional[torch.FloatTensor] = None,
106
+ temb: Optional[torch.FloatTensor] = None,
107
+ scale: float = 1.0,
108
+ pose_feature=None
109
+ ) -> torch.FloatTensor:
110
+ residual = hidden_states
111
+
112
+ args = () if USE_PEFT_BACKEND else (scale,)
113
+
114
+ if attn.spatial_norm is not None:
115
+ hidden_states = attn.spatial_norm(hidden_states, temb)
116
+
117
+ input_ndim = hidden_states.ndim
118
+
119
+ if input_ndim == 4:
120
+ batch_size, channel, height, width = hidden_states.shape
121
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
122
+
123
+ batch_size, sequence_length, _ = (
124
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
125
+ )
126
+
127
+ if attention_mask is not None:
128
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
129
+ # scaled_dot_product_attention expects attention_mask shape to be
130
+ # (batch, heads, source_length, target_length)
131
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
132
+
133
+ if attn.group_norm is not None:
134
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
135
+
136
+ args = () if USE_PEFT_BACKEND else (scale,)
137
+ query = attn.to_q(hidden_states, *args)
138
+
139
+ if encoder_hidden_states is None:
140
+ encoder_hidden_states = hidden_states
141
+ elif attn.norm_cross:
142
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
143
+
144
+ key = attn.to_k(encoder_hidden_states, *args)
145
+ value = attn.to_v(encoder_hidden_states, *args)
146
+
147
+ inner_dim = key.shape[-1]
148
+ head_dim = inner_dim // attn.heads
149
+
150
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
151
+
152
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
153
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
154
+
155
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
156
+ # TODO: add support for attn.scale when we move to Torch 2.1
157
+ hidden_states = F.scaled_dot_product_attention(
158
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
159
+ )
160
+
161
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
162
+ hidden_states = hidden_states.to(query.dtype)
163
+
164
+ # linear proj
165
+ hidden_states = attn.to_out[0](hidden_states, *args)
166
+ # dropout
167
+ hidden_states = attn.to_out[1](hidden_states)
168
+
169
+ if input_ndim == 4:
170
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
171
+
172
+ if attn.residual_connection:
173
+ hidden_states = hidden_states + residual
174
+
175
+ hidden_states = hidden_states / attn.rescale_output_factor
176
+
177
+ return hidden_states
178
+
179
+
180
+ class XFormersAttnProcessor:
181
+ r"""
182
+ Processor for implementing memory efficient attention using xFormers.
183
+
184
+ Args:
185
+ attention_op (`Callable`, *optional*, defaults to `None`):
186
+ The base
187
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
188
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
189
+ operator.
190
+ """
191
+
192
+ def __init__(self, attention_op: Optional[Callable] = None):
193
+ self.attention_op = attention_op
194
+
195
+ def __call__(
196
+ self,
197
+ attn: Attention,
198
+ hidden_states: torch.FloatTensor,
199
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
200
+ attention_mask: Optional[torch.FloatTensor] = None,
201
+ temb: Optional[torch.FloatTensor] = None,
202
+ scale: float = 1.0,
203
+ pose_feature=None
204
+ ) -> torch.FloatTensor:
205
+ residual = hidden_states
206
+
207
+ args = () if USE_PEFT_BACKEND else (scale,)
208
+
209
+ if attn.spatial_norm is not None:
210
+ hidden_states = attn.spatial_norm(hidden_states, temb)
211
+
212
+ input_ndim = hidden_states.ndim
213
+
214
+ if input_ndim == 4:
215
+ batch_size, channel, height, width = hidden_states.shape
216
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
217
+
218
+ batch_size, key_tokens, _ = (
219
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
220
+ )
221
+
222
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
223
+ if attention_mask is not None:
224
+ # expand our mask's singleton query_tokens dimension:
225
+ # [batch*heads, 1, key_tokens] ->
226
+ # [batch*heads, query_tokens, key_tokens]
227
+ # so that it can be added as a bias onto the attention scores that xformers computes:
228
+ # [batch*heads, query_tokens, key_tokens]
229
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
230
+ _, query_tokens, _ = hidden_states.shape
231
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
232
+
233
+ if attn.group_norm is not None:
234
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
235
+
236
+ query = attn.to_q(hidden_states, *args)
237
+
238
+ if encoder_hidden_states is None:
239
+ encoder_hidden_states = hidden_states
240
+ elif attn.norm_cross:
241
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
242
+
243
+ key = attn.to_k(encoder_hidden_states, *args)
244
+ value = attn.to_v(encoder_hidden_states, *args)
245
+
246
+ query = attn.head_to_batch_dim(query).contiguous()
247
+ key = attn.head_to_batch_dim(key).contiguous()
248
+ value = attn.head_to_batch_dim(value).contiguous()
249
+
250
+ hidden_states = xformers.ops.memory_efficient_attention(
251
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
252
+ )
253
+ hidden_states = hidden_states.to(query.dtype)
254
+ hidden_states = attn.batch_to_head_dim(hidden_states)
255
+
256
+ # linear proj
257
+ hidden_states = attn.to_out[0](hidden_states, *args)
258
+ # dropout
259
+ hidden_states = attn.to_out[1](hidden_states)
260
+
261
+ if input_ndim == 4:
262
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
263
+
264
+ if attn.residual_connection:
265
+ hidden_states = hidden_states + residual
266
+
267
+ hidden_states = hidden_states / attn.rescale_output_factor
268
+
269
+ return hidden_states
270
+
271
+
272
+ class PoseAdaptorAttnProcessor(nn.Module):
273
+ def __init__(self,
274
+ hidden_size, # dimension of hidden state
275
+ pose_feature_dim=None, # dimension of the pose feature
276
+ cross_attention_dim=None, # dimension of the text embedding
277
+ query_condition=False,
278
+ key_value_condition=False,
279
+ scale=1.0):
280
+ super().__init__()
281
+
282
+ self.hidden_size = hidden_size
283
+ self.pose_feature_dim = pose_feature_dim
284
+ self.cross_attention_dim = cross_attention_dim
285
+ self.scale = scale
286
+ self.query_condition = query_condition
287
+ self.key_value_condition = key_value_condition
288
+ assert hidden_size == pose_feature_dim
289
+ if self.query_condition and self.key_value_condition:
290
+ self.qkv_merge = nn.Linear(hidden_size, hidden_size)
291
+ init.zeros_(self.qkv_merge.weight)
292
+ init.zeros_(self.qkv_merge.bias)
293
+ elif self.query_condition:
294
+ self.q_merge = nn.Linear(hidden_size, hidden_size)
295
+ init.zeros_(self.q_merge.weight)
296
+ init.zeros_(self.q_merge.bias)
297
+ else:
298
+ self.kv_merge = nn.Linear(hidden_size, hidden_size)
299
+ init.zeros_(self.kv_merge.weight)
300
+ init.zeros_(self.kv_merge.bias)
301
+
302
+ def forward(self,
303
+ attn,
304
+ hidden_states,
305
+ pose_feature,
306
+ encoder_hidden_states=None,
307
+ attention_mask=None,
308
+ temb=None,
309
+ scale=None,):
310
+ assert pose_feature is not None
311
+ pose_embedding_scale = (scale or self.scale)
312
+
313
+ residual = hidden_states
314
+ if attn.spatial_norm is not None:
315
+ hidden_states = attn.spatial_norm(hidden_states, temb)
316
+
317
+ assert hidden_states.ndim == 3 and pose_feature.ndim == 3
318
+
319
+ if self.query_condition and self.key_value_condition:
320
+ assert encoder_hidden_states is None
321
+
322
+ if encoder_hidden_states is None:
323
+ encoder_hidden_states = hidden_states
324
+
325
+ assert encoder_hidden_states.ndim == 3
326
+
327
+ batch_size, ehs_sequence_length, _ = encoder_hidden_states.shape
328
+ attention_mask = attn.prepare_attention_mask(attention_mask, ehs_sequence_length, batch_size)
329
+
330
+ if attn.group_norm is not None:
331
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
332
+
333
+ if attn.norm_cross:
334
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
335
+
336
+ if self.query_condition and self.key_value_condition: # only self attention
337
+ query_hidden_state = self.qkv_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states
338
+ key_value_hidden_state = query_hidden_state
339
+ elif self.query_condition:
340
+ query_hidden_state = self.q_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states
341
+ key_value_hidden_state = encoder_hidden_states
342
+ else:
343
+ key_value_hidden_state = self.kv_merge(encoder_hidden_states + pose_feature) * pose_embedding_scale + encoder_hidden_states
344
+ query_hidden_state = hidden_states
345
+
346
+ # original attention
347
+ query = attn.to_q(query_hidden_state)
348
+ key = attn.to_k(key_value_hidden_state)
349
+ value = attn.to_v(key_value_hidden_state)
350
+
351
+ query = attn.head_to_batch_dim(query)
352
+ key = attn.head_to_batch_dim(key)
353
+ value = attn.head_to_batch_dim(value)
354
+
355
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
356
+ hidden_states = torch.bmm(attention_probs, value)
357
+ hidden_states = attn.batch_to_head_dim(hidden_states)
358
+
359
+ # linear proj
360
+ hidden_states = attn.to_out[0](hidden_states)
361
+ # dropout
362
+ hidden_states = attn.to_out[1](hidden_states)
363
+
364
+ if attn.residual_connection:
365
+ hidden_states = hidden_states + residual
366
+
367
+ hidden_states = hidden_states / attn.rescale_output_factor
368
+
369
+ return hidden_states
370
+
371
+
372
+ class PoseAdaptorAttnProcessor2_0(nn.Module):
373
+ def __init__(self,
374
+ hidden_size, # dimension of hidden state
375
+ pose_feature_dim=None, # dimension of the pose feature
376
+ cross_attention_dim=None, # dimension of the text embedding
377
+ query_condition=False,
378
+ key_value_condition=False,
379
+ scale=1.0):
380
+ super().__init__()
381
+ if not hasattr(F, "scaled_dot_product_attention"):
382
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
383
+
384
+ self.hidden_size = hidden_size
385
+ self.pose_feature_dim = pose_feature_dim
386
+ self.cross_attention_dim = cross_attention_dim
387
+ self.scale = scale
388
+ self.query_condition = query_condition
389
+ self.key_value_condition = key_value_condition
390
+ assert hidden_size == pose_feature_dim
391
+ if self.query_condition and self.key_value_condition:
392
+ self.qkv_merge = nn.Linear(hidden_size, hidden_size)
393
+ init.zeros_(self.qkv_merge.weight)
394
+ init.zeros_(self.qkv_merge.bias)
395
+ elif self.query_condition:
396
+ self.q_merge = nn.Linear(hidden_size, hidden_size)
397
+ init.zeros_(self.q_merge.weight)
398
+ init.zeros_(self.q_merge.bias)
399
+ else:
400
+ self.kv_merge = nn.Linear(hidden_size, hidden_size)
401
+ init.zeros_(self.kv_merge.weight)
402
+ init.zeros_(self.kv_merge.bias)
403
+
404
+ def forward(self,
405
+ attn,
406
+ hidden_states,
407
+ pose_feature,
408
+ encoder_hidden_states=None,
409
+ attention_mask=None,
410
+ temb=None,
411
+ scale=None,):
412
+ assert pose_feature is not None
413
+ pose_embedding_scale = (scale or self.scale)
414
+
415
+ residual = hidden_states
416
+ if attn.spatial_norm is not None:
417
+ hidden_states = attn.spatial_norm(hidden_states, temb)
418
+
419
+ assert hidden_states.ndim == 3 and pose_feature.ndim == 3
420
+
421
+ if self.query_condition and self.key_value_condition:
422
+ assert encoder_hidden_states is None
423
+
424
+ if encoder_hidden_states is None:
425
+ encoder_hidden_states = hidden_states
426
+
427
+ assert encoder_hidden_states.ndim == 3
428
+
429
+ batch_size, ehs_sequence_length, _ = encoder_hidden_states.shape
430
+ if attention_mask is not None:
431
+ attention_mask = attn.prepare_attention_mask(attention_mask, ehs_sequence_length, batch_size)
432
+ # scaled_dot_product_attention expects attention_mask shape to be
433
+ # (batch, heads, source_length, target_length)
434
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
435
+
436
+ if attn.group_norm is not None:
437
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
438
+
439
+ if attn.norm_cross:
440
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
441
+
442
+ if self.query_condition and self.key_value_condition: # only self attention
443
+ query_hidden_state = self.qkv_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states
444
+ key_value_hidden_state = query_hidden_state
445
+ elif self.query_condition:
446
+ query_hidden_state = self.q_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states
447
+ key_value_hidden_state = encoder_hidden_states
448
+ else:
449
+ key_value_hidden_state = self.kv_merge(encoder_hidden_states + pose_feature) * pose_embedding_scale + encoder_hidden_states
450
+ query_hidden_state = hidden_states
451
+
452
+ # original attention
453
+ query = attn.to_q(query_hidden_state)
454
+ key = attn.to_k(key_value_hidden_state)
455
+ value = attn.to_v(key_value_hidden_state)
456
+
457
+ inner_dim = key.shape[-1]
458
+ head_dim = inner_dim // attn.heads
459
+
460
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # [bs, seq_len, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
461
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
462
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
463
+
464
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) # [bs, nhead, seq_len, head_dim]
465
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) # [bs, seq_len, dim]
466
+ hidden_states = hidden_states.to(query.dtype)
467
+
468
+ # linear proj
469
+ hidden_states = attn.to_out[0](hidden_states)
470
+ # dropout
471
+ hidden_states = attn.to_out[1](hidden_states)
472
+
473
+ if attn.residual_connection:
474
+ hidden_states = hidden_states + residual
475
+
476
+ hidden_states = hidden_states / attn.rescale_output_factor
477
+
478
+ return hidden_states
479
+
480
+
481
+ class PoseAdaptorXFormersAttnProcessor(nn.Module):
482
+ def __init__(self,
483
+ hidden_size, # dimension of hidden state
484
+ pose_feature_dim=None, # dimension of the pose feature
485
+ cross_attention_dim=None, # dimension of the text embedding
486
+ query_condition=False,
487
+ key_value_condition=False,
488
+ scale=1.0,
489
+ attention_op: Optional[Callable] = None):
490
+ super().__init__()
491
+
492
+ self.hidden_size = hidden_size
493
+ self.pose_feature_dim = pose_feature_dim
494
+ self.cross_attention_dim = cross_attention_dim
495
+ self.scale = scale
496
+ self.query_condition = query_condition
497
+ self.key_value_condition = key_value_condition
498
+ self.attention_op = attention_op
499
+ assert hidden_size == pose_feature_dim
500
+ if self.query_condition and self.key_value_condition:
501
+ self.qkv_merge = nn.Linear(hidden_size, hidden_size)
502
+ init.zeros_(self.qkv_merge.weight)
503
+ init.zeros_(self.qkv_merge.bias)
504
+ elif self.query_condition:
505
+ self.q_merge = nn.Linear(hidden_size, hidden_size)
506
+ init.zeros_(self.q_merge.weight)
507
+ init.zeros_(self.q_merge.bias)
508
+ else:
509
+ self.kv_merge = nn.Linear(hidden_size, hidden_size)
510
+ init.zeros_(self.kv_merge.weight)
511
+ init.zeros_(self.kv_merge.bias)
512
+
513
+ def forward(self,
514
+ attn,
515
+ hidden_states,
516
+ pose_feature,
517
+ encoder_hidden_states=None,
518
+ attention_mask=None,
519
+ temb=None,
520
+ scale=None,):
521
+ assert pose_feature is not None
522
+ pose_embedding_scale = (scale or self.scale)
523
+
524
+ residual = hidden_states
525
+ if attn.spatial_norm is not None:
526
+ hidden_states = attn.spatial_norm(hidden_states, temb)
527
+
528
+ assert hidden_states.ndim == 3 and pose_feature.ndim == 3
529
+
530
+ if self.query_condition and self.key_value_condition:
531
+ assert encoder_hidden_states is None
532
+
533
+ if encoder_hidden_states is None:
534
+ encoder_hidden_states = hidden_states
535
+
536
+ assert encoder_hidden_states.ndim == 3
537
+
538
+ batch_size, ehs_sequence_length, _ = encoder_hidden_states.shape
539
+ attention_mask = attn.prepare_attention_mask(attention_mask, ehs_sequence_length, batch_size)
540
+ if attention_mask is not None:
541
+ # expand our mask's singleton query_tokens dimension:
542
+ # [batch*heads, 1, key_tokens] ->
543
+ # [batch*heads, query_tokens, key_tokens]
544
+ # so that it can be added as a bias onto the attention scores that xformers computes:
545
+ # [batch*heads, query_tokens, key_tokens]
546
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
547
+ _, query_tokens, _ = hidden_states.shape
548
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
549
+
550
+ if attn.group_norm is not None:
551
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
552
+
553
+ if attn.norm_cross:
554
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
555
+
556
+ if self.query_condition and self.key_value_condition: # only self attention
557
+ query_hidden_state = self.qkv_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states
558
+ key_value_hidden_state = query_hidden_state
559
+ elif self.query_condition:
560
+ query_hidden_state = self.q_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states
561
+ key_value_hidden_state = encoder_hidden_states
562
+ else:
563
+ key_value_hidden_state = self.kv_merge(encoder_hidden_states + pose_feature) * pose_embedding_scale + encoder_hidden_states
564
+ query_hidden_state = hidden_states
565
+
566
+ # original attention
567
+ query = attn.to_q(query_hidden_state)
568
+ key = attn.to_k(key_value_hidden_state)
569
+ value = attn.to_v(key_value_hidden_state)
570
+
571
+ query = attn.head_to_batch_dim(query).contiguous()
572
+ key = attn.head_to_batch_dim(key).contiguous()
573
+ value = attn.head_to_batch_dim(value).contiguous()
574
+
575
+ hidden_states = xformers.ops.memory_efficient_attention(
576
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
577
+ )
578
+ hidden_states = hidden_states.to(query.dtype)
579
+ hidden_states = attn.batch_to_head_dim(hidden_states)
580
+
581
+ # linear proj
582
+ hidden_states = attn.to_out[0](hidden_states)
583
+ # dropout
584
+ hidden_states = attn.to_out[1](hidden_states)
585
+
586
+ if attn.residual_connection:
587
+ hidden_states = hidden_states + residual
588
+
589
+ hidden_states = hidden_states / attn.rescale_output_factor
590
+
591
+ return hidden_states
cameractrl/models/motion_module.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Callable, Optional
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from diffusers.utils import BaseOutput
8
+ from diffusers.models.attention_processor import Attention
9
+ from diffusers.models.attention import FeedForward
10
+
11
+ from typing import Dict, Any
12
+ from cameractrl.models.attention_processor import PoseAdaptorAttnProcessor
13
+
14
+ from einops import rearrange
15
+ import math
16
+
17
+
18
+ class InflatedGroupNorm(nn.GroupNorm):
19
+ def forward(self, x):
20
+ # return super().forward(x)
21
+
22
+ video_length = x.shape[2]
23
+
24
+ x = rearrange(x, "b c f h w -> (b f) c h w")
25
+ x = super().forward(x)
26
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
27
+
28
+ return x
29
+
30
+ def zero_module(module):
31
+ # Zero out the parameters of a module and return it.
32
+ for p in module.parameters():
33
+ p.detach().zero_()
34
+ return module
35
+
36
+
37
+ @dataclass
38
+ class TemporalTransformer3DModelOutput(BaseOutput):
39
+ sample: torch.FloatTensor
40
+
41
+
42
+ def get_motion_module(
43
+ in_channels,
44
+ motion_module_type: str,
45
+ motion_module_kwargs: dict
46
+ ):
47
+ if motion_module_type == "Vanilla":
48
+ return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs)
49
+ else:
50
+ raise ValueError
51
+
52
+
53
+ class VanillaTemporalModule(nn.Module):
54
+ def __init__(
55
+ self,
56
+ in_channels,
57
+ num_attention_heads=8,
58
+ num_transformer_block=2,
59
+ attention_block_types=("Temporal_Self",),
60
+ temporal_position_encoding=True,
61
+ temporal_position_encoding_max_len=32,
62
+ temporal_attention_dim_div=1,
63
+ cross_attention_dim=320,
64
+ zero_initialize=True,
65
+ encoder_hidden_states_query=(False, False),
66
+ attention_activation_scale=1.0,
67
+ attention_processor_kwargs: Dict = {},
68
+ causal_temporal_attention=False,
69
+ causal_temporal_attention_mask_type="",
70
+ rescale_output_factor=1.0
71
+ ):
72
+ super().__init__()
73
+
74
+ self.temporal_transformer = TemporalTransformer3DModel(
75
+ in_channels=in_channels,
76
+ num_attention_heads=num_attention_heads,
77
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
78
+ num_layers=num_transformer_block,
79
+ attention_block_types=attention_block_types,
80
+ cross_attention_dim=cross_attention_dim,
81
+ temporal_position_encoding=temporal_position_encoding,
82
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
83
+ encoder_hidden_states_query=encoder_hidden_states_query,
84
+ attention_activation_scale=attention_activation_scale,
85
+ attention_processor_kwargs=attention_processor_kwargs,
86
+ causal_temporal_attention=causal_temporal_attention,
87
+ causal_temporal_attention_mask_type=causal_temporal_attention_mask_type,
88
+ rescale_output_factor=rescale_output_factor
89
+ )
90
+
91
+ if zero_initialize:
92
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
93
+
94
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None,
95
+ cross_attention_kwargs: Dict[str, Any] = {}):
96
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask, cross_attention_kwargs=cross_attention_kwargs)
97
+
98
+ output = hidden_states
99
+ return output
100
+
101
+
102
+ class TemporalTransformer3DModel(nn.Module):
103
+ def __init__(
104
+ self,
105
+ in_channels,
106
+ num_attention_heads,
107
+ attention_head_dim,
108
+ num_layers,
109
+ attention_block_types=("Temporal_Self", "Temporal_Self",),
110
+ dropout=0.0,
111
+ norm_num_groups=32,
112
+ cross_attention_dim=320,
113
+ activation_fn="geglu",
114
+ attention_bias=False,
115
+ upcast_attention=False,
116
+ temporal_position_encoding=False,
117
+ temporal_position_encoding_max_len=32,
118
+ encoder_hidden_states_query=(False, False),
119
+ attention_activation_scale=1.0,
120
+ attention_processor_kwargs: Dict = {},
121
+
122
+ causal_temporal_attention=None,
123
+ causal_temporal_attention_mask_type="",
124
+ rescale_output_factor=1.0
125
+ ):
126
+ super().__init__()
127
+ assert causal_temporal_attention is not None
128
+ self.causal_temporal_attention = causal_temporal_attention
129
+
130
+ assert (not causal_temporal_attention) or (causal_temporal_attention_mask_type != "")
131
+ self.causal_temporal_attention_mask_type = causal_temporal_attention_mask_type
132
+ self.causal_temporal_attention_mask = None
133
+
134
+ inner_dim = num_attention_heads * attention_head_dim
135
+
136
+ self.norm = InflatedGroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
137
+ self.proj_in = nn.Linear(in_channels, inner_dim)
138
+
139
+ self.transformer_blocks = nn.ModuleList(
140
+ [
141
+ TemporalTransformerBlock(
142
+ dim=inner_dim,
143
+ num_attention_heads=num_attention_heads,
144
+ attention_head_dim=attention_head_dim,
145
+ attention_block_types=attention_block_types,
146
+ dropout=dropout,
147
+ norm_num_groups=norm_num_groups,
148
+ cross_attention_dim=cross_attention_dim,
149
+ activation_fn=activation_fn,
150
+ attention_bias=attention_bias,
151
+ upcast_attention=upcast_attention,
152
+ temporal_position_encoding=temporal_position_encoding,
153
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
154
+ encoder_hidden_states_query=encoder_hidden_states_query,
155
+ attention_activation_scale=attention_activation_scale,
156
+ attention_processor_kwargs=attention_processor_kwargs,
157
+ rescale_output_factor=rescale_output_factor,
158
+ )
159
+ for d in range(num_layers)
160
+ ]
161
+ )
162
+ self.proj_out = nn.Linear(inner_dim, in_channels)
163
+
164
+ def get_causal_temporal_attention_mask(self, hidden_states):
165
+ batch_size, sequence_length, dim = hidden_states.shape
166
+
167
+ if self.causal_temporal_attention_mask is None or self.causal_temporal_attention_mask.shape != (
168
+ batch_size, sequence_length, sequence_length):
169
+ if self.causal_temporal_attention_mask_type == "causal":
170
+ # 1. vanilla causal mask
171
+ mask = torch.tril(torch.ones(sequence_length, sequence_length))
172
+
173
+ elif self.causal_temporal_attention_mask_type == "2-seq":
174
+ # 2. 2-seq
175
+ mask = torch.zeros(sequence_length, sequence_length)
176
+ mask[:sequence_length // 2, :sequence_length // 2] = 1
177
+ mask[-sequence_length // 2:, -sequence_length // 2:] = 1
178
+
179
+ elif self.causal_temporal_attention_mask_type == "0-prev":
180
+ # attn to the previous frame
181
+ indices = torch.arange(sequence_length)
182
+ indices_prev = indices - 1
183
+ indices_prev[0] = 0
184
+ mask = torch.zeros(sequence_length, sequence_length)
185
+ mask[:, 0] = 1.
186
+ mask[indices, indices_prev] = 1.
187
+
188
+ elif self.causal_temporal_attention_mask_type == "0":
189
+ # only attn to first frame
190
+ mask = torch.zeros(sequence_length, sequence_length)
191
+ mask[:, 0] = 1
192
+
193
+ elif self.causal_temporal_attention_mask_type == "wo-self":
194
+ indices = torch.arange(sequence_length)
195
+ mask = torch.ones(sequence_length, sequence_length)
196
+ mask[indices, indices] = 0
197
+
198
+ elif self.causal_temporal_attention_mask_type == "circle":
199
+ indices = torch.arange(sequence_length)
200
+ indices_prev = indices - 1
201
+ indices_prev[0] = 0
202
+
203
+ mask = torch.eye(sequence_length)
204
+ mask[indices, indices_prev] = 1
205
+ mask[0, -1] = 1
206
+
207
+ else:
208
+ raise ValueError
209
+
210
+ # generate attention mask fron binary values
211
+ mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
212
+ mask = mask.unsqueeze(0)
213
+ mask = mask.repeat(batch_size, 1, 1)
214
+
215
+ self.causal_temporal_attention_mask = mask.to(hidden_states.device)
216
+
217
+ return self.causal_temporal_attention_mask
218
+
219
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None,
220
+ cross_attention_kwargs: Dict[str, Any] = {},):
221
+ residual = hidden_states
222
+
223
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
224
+ height, width = hidden_states.shape[-2:]
225
+
226
+ hidden_states = self.norm(hidden_states)
227
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b h w) f c")
228
+ hidden_states = self.proj_in(hidden_states)
229
+
230
+ attention_mask = self.get_causal_temporal_attention_mask(
231
+ hidden_states) if self.causal_temporal_attention else attention_mask
232
+
233
+ # Transformer Blocks
234
+ for block in self.transformer_blocks:
235
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states,
236
+ attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs)
237
+ hidden_states = self.proj_out(hidden_states)
238
+
239
+ hidden_states = rearrange(hidden_states, "(b h w) f c -> b c f h w", h=height, w=width)
240
+
241
+ output = hidden_states + residual
242
+
243
+ return output
244
+
245
+
246
+ class TemporalTransformerBlock(nn.Module):
247
+ def __init__(
248
+ self,
249
+ dim,
250
+ num_attention_heads,
251
+ attention_head_dim,
252
+ attention_block_types=("Temporal_Self", "Temporal_Self",),
253
+ dropout=0.0,
254
+ norm_num_groups=32,
255
+ cross_attention_dim=768,
256
+ activation_fn="geglu",
257
+ attention_bias=False,
258
+ upcast_attention=False,
259
+ temporal_position_encoding=False,
260
+ temporal_position_encoding_max_len=32,
261
+ encoder_hidden_states_query=(False, False),
262
+ attention_activation_scale=1.0,
263
+ attention_processor_kwargs: Dict = {},
264
+ rescale_output_factor=1.0
265
+ ):
266
+ super().__init__()
267
+
268
+ attention_blocks = []
269
+ norms = []
270
+ self.attention_block_types = attention_block_types
271
+
272
+ for block_idx, block_name in enumerate(attention_block_types):
273
+ attention_blocks.append(
274
+ TemporalSelfAttention(
275
+ attention_mode=block_name,
276
+ cross_attention_dim=cross_attention_dim if block_name in ['Temporal_Cross', 'Temporal_Pose_Adaptor'] else None,
277
+ query_dim=dim,
278
+ heads=num_attention_heads,
279
+ dim_head=attention_head_dim,
280
+ dropout=dropout,
281
+ bias=attention_bias,
282
+ upcast_attention=upcast_attention,
283
+ temporal_position_encoding=temporal_position_encoding,
284
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
285
+ rescale_output_factor=rescale_output_factor,
286
+ )
287
+ )
288
+ norms.append(nn.LayerNorm(dim))
289
+
290
+ self.attention_blocks = nn.ModuleList(attention_blocks)
291
+ self.norms = nn.ModuleList(norms)
292
+
293
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
294
+ self.ff_norm = nn.LayerNorm(dim)
295
+
296
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs: Dict[str, Any] = {}):
297
+ for attention_block, norm, attention_block_type in zip(self.attention_blocks, self.norms, self.attention_block_types):
298
+ norm_hidden_states = norm(hidden_states)
299
+ hidden_states = attention_block(
300
+ norm_hidden_states,
301
+ encoder_hidden_states=norm_hidden_states if attention_block_type == 'Temporal_Self' else encoder_hidden_states,
302
+ attention_mask=attention_mask,
303
+ **cross_attention_kwargs
304
+ ) + hidden_states
305
+
306
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
307
+
308
+ output = hidden_states
309
+ return output
310
+
311
+
312
+ class PositionalEncoding(nn.Module):
313
+ def __init__(
314
+ self,
315
+ d_model,
316
+ dropout=0.,
317
+ max_len=32,
318
+ ):
319
+ super().__init__()
320
+ self.dropout = nn.Dropout(p=dropout)
321
+ position = torch.arange(max_len).unsqueeze(1)
322
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
323
+ pe = torch.zeros(1, max_len, d_model)
324
+ pe[0, :, 0::2] = torch.sin(position * div_term)
325
+ pe[0, :, 1::2] = torch.cos(position * div_term)
326
+ self.register_buffer('pe', pe)
327
+
328
+ def forward(self, x):
329
+ x = x + self.pe[:, :x.size(1)]
330
+ return self.dropout(x)
331
+
332
+
333
+ class TemporalSelfAttention(Attention):
334
+ def __init__(
335
+ self,
336
+ attention_mode=None,
337
+ temporal_position_encoding=False,
338
+ temporal_position_encoding_max_len=32,
339
+ rescale_output_factor=1.0,
340
+ *args, **kwargs
341
+ ):
342
+ super().__init__(*args, **kwargs)
343
+ assert attention_mode == "Temporal_Self"
344
+
345
+ self.pos_encoder = PositionalEncoding(
346
+ kwargs["query_dim"],
347
+ max_len=temporal_position_encoding_max_len
348
+ ) if temporal_position_encoding else None
349
+ self.rescale_output_factor = rescale_output_factor
350
+
351
+ def set_use_memory_efficient_attention_xformers(
352
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
353
+ ):
354
+ # disable motion module efficient xformers to avoid bad results, don't know why
355
+ # TODO: fix this bug
356
+ pass
357
+
358
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
359
+ # The `Attention` class can call different attention processors / attention functions
360
+ # here we simply pass along all tensors to the selected processor class
361
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
362
+
363
+ # add position encoding
364
+ if self.pos_encoder is not None:
365
+ hidden_states = self.pos_encoder(hidden_states)
366
+ if "pose_feature" in cross_attention_kwargs:
367
+ pose_feature = cross_attention_kwargs["pose_feature"]
368
+ if pose_feature.ndim == 5:
369
+ pose_feature = rearrange(pose_feature, "b c f h w -> (b h w) f c")
370
+ else:
371
+ assert pose_feature.ndim == 3
372
+ cross_attention_kwargs["pose_feature"] = pose_feature
373
+
374
+ if isinstance(self.processor, PoseAdaptorAttnProcessor):
375
+ return self.processor(
376
+ self,
377
+ hidden_states,
378
+ cross_attention_kwargs.pop('pose_feature'),
379
+ encoder_hidden_states=None,
380
+ attention_mask=attention_mask,
381
+ **cross_attention_kwargs,
382
+ )
383
+ elif hasattr(self.processor, "__call__"):
384
+ return self.processor.__call__(
385
+ self,
386
+ hidden_states,
387
+ encoder_hidden_states=None,
388
+ attention_mask=attention_mask,
389
+ **cross_attention_kwargs,
390
+ )
391
+ else:
392
+ return self.processor(
393
+ self,
394
+ hidden_states,
395
+ encoder_hidden_states=None,
396
+ attention_mask=attention_mask,
397
+ **cross_attention_kwargs,
398
+ )
399
+
cameractrl/models/pose_adaptor.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange
5
+ from typing import List, Tuple
6
+ from cameractrl.models.motion_module import TemporalTransformerBlock
7
+
8
+
9
+ def get_parameter_dtype(parameter: torch.nn.Module):
10
+ try:
11
+ params = tuple(parameter.parameters())
12
+ if len(params) > 0:
13
+ return params[0].dtype
14
+
15
+ buffers = tuple(parameter.buffers())
16
+ if len(buffers) > 0:
17
+ return buffers[0].dtype
18
+
19
+ except StopIteration:
20
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
21
+
22
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, torch.Tensor]]:
23
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
24
+ return tuples
25
+
26
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
27
+ first_tuple = next(gen)
28
+ return first_tuple[1].dtype
29
+
30
+
31
+ def conv_nd(dims, *args, **kwargs):
32
+ """
33
+ Create a 1D, 2D, or 3D convolution module.
34
+ """
35
+ if dims == 1:
36
+ return nn.Conv1d(*args, **kwargs)
37
+ elif dims == 2:
38
+ return nn.Conv2d(*args, **kwargs)
39
+ elif dims == 3:
40
+ return nn.Conv3d(*args, **kwargs)
41
+ raise ValueError(f"unsupported dimensions: {dims}")
42
+
43
+
44
+ def avg_pool_nd(dims, *args, **kwargs):
45
+ """
46
+ Create a 1D, 2D, or 3D average pooling module.
47
+ """
48
+ if dims == 1:
49
+ return nn.AvgPool1d(*args, **kwargs)
50
+ elif dims == 2:
51
+ return nn.AvgPool2d(*args, **kwargs)
52
+ elif dims == 3:
53
+ return nn.AvgPool3d(*args, **kwargs)
54
+ raise ValueError(f"unsupported dimensions: {dims}")
55
+
56
+
57
+ class PoseAdaptor(nn.Module):
58
+ def __init__(self, unet, pose_encoder):
59
+ super().__init__()
60
+ self.unet = unet
61
+ self.pose_encoder = pose_encoder
62
+
63
+ def forward(self, noisy_latents, c_noise, encoder_hidden_states, added_time_ids, pose_embedding):
64
+ assert pose_embedding.ndim == 5
65
+ pose_embedding_features = self.pose_encoder(pose_embedding) # b c f h w
66
+ noise_pred = self.unet(noisy_latents,
67
+ c_noise,
68
+ encoder_hidden_states,
69
+ added_time_ids=added_time_ids,
70
+ pose_features=pose_embedding_features).sample
71
+ return noise_pred
72
+
73
+
74
+ class Downsample(nn.Module):
75
+ """
76
+ A downsampling layer with an optional convolution.
77
+ :param channels: channels in the inputs and outputs.
78
+ :param use_conv: a bool determining if a convolution is applied.
79
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
80
+ downsampling occurs in the inner-two dimensions.
81
+ """
82
+
83
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
84
+ super().__init__()
85
+ self.channels = channels
86
+ self.out_channels = out_channels or channels
87
+ self.use_conv = use_conv
88
+ self.dims = dims
89
+ stride = 2 if dims != 3 else (1, 2, 2)
90
+ if use_conv:
91
+ self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
92
+ else:
93
+ assert self.channels == self.out_channels
94
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
95
+
96
+ def forward(self, x):
97
+ assert x.shape[1] == self.channels
98
+ return self.op(x)
99
+
100
+
101
+ class ResnetBlock(nn.Module):
102
+
103
+ def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
104
+ super().__init__()
105
+ ps = ksize // 2
106
+ if in_c != out_c or sk == False:
107
+ self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
108
+ else:
109
+ self.in_conv = None
110
+ self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
111
+ self.act = nn.ReLU()
112
+ self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
113
+ if sk == False:
114
+ self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
115
+ else:
116
+ self.skep = None
117
+
118
+ self.down = down
119
+ if self.down == True:
120
+ self.down_opt = Downsample(in_c, use_conv=use_conv)
121
+
122
+ def forward(self, x):
123
+ if self.down == True:
124
+ x = self.down_opt(x)
125
+ if self.in_conv is not None: # edit
126
+ x = self.in_conv(x)
127
+
128
+ h = self.block1(x)
129
+ h = self.act(h)
130
+ h = self.block2(h)
131
+ if self.skep is not None:
132
+ return h + self.skep(x)
133
+ else:
134
+ return h + x
135
+
136
+
137
+ class PositionalEncoding(nn.Module):
138
+ def __init__(
139
+ self,
140
+ d_model,
141
+ dropout=0.,
142
+ max_len=32,
143
+ ):
144
+ super().__init__()
145
+ self.dropout = nn.Dropout(p=dropout)
146
+ position = torch.arange(max_len).unsqueeze(1)
147
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
148
+ pe = torch.zeros(1, max_len, d_model)
149
+ pe[0, :, 0::2, ...] = torch.sin(position * div_term)
150
+ pe[0, :, 1::2, ...] = torch.cos(position * div_term)
151
+ pe.unsqueeze_(-1).unsqueeze_(-1)
152
+ self.register_buffer('pe', pe)
153
+
154
+ def forward(self, x):
155
+ x = x + self.pe[:, :x.size(1), ...]
156
+ return self.dropout(x)
157
+
158
+
159
+ class CameraPoseEncoder(nn.Module):
160
+
161
+ def __init__(self,
162
+ downscale_factor,
163
+ channels=[320, 640, 1280, 1280],
164
+ nums_rb=3,
165
+ cin=64,
166
+ ksize=3,
167
+ sk=False,
168
+ use_conv=True,
169
+ compression_factor=1,
170
+ temporal_attention_nhead=8,
171
+ attention_block_types=("Temporal_Self", ),
172
+ temporal_position_encoding=False,
173
+ temporal_position_encoding_max_len=16,
174
+ rescale_output_factor=1.0):
175
+ super(CameraPoseEncoder, self).__init__()
176
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
177
+ self.channels = channels
178
+ self.nums_rb = nums_rb
179
+ self.encoder_down_conv_blocks = nn.ModuleList()
180
+ self.encoder_down_attention_blocks = nn.ModuleList()
181
+ for i in range(len(channels)):
182
+ conv_layers = nn.ModuleList()
183
+ temporal_attention_layers = nn.ModuleList()
184
+ for j in range(nums_rb):
185
+ if j == 0 and i != 0:
186
+ in_dim = channels[i - 1]
187
+ out_dim = int(channels[i] / compression_factor)
188
+ conv_layer = ResnetBlock(in_dim, out_dim, down=True, ksize=ksize, sk=sk, use_conv=use_conv)
189
+ elif j == 0:
190
+ in_dim = channels[0]
191
+ out_dim = int(channels[i] / compression_factor)
192
+ conv_layer = ResnetBlock(in_dim, out_dim, down=False, ksize=ksize, sk=sk, use_conv=use_conv)
193
+ elif j == nums_rb - 1:
194
+ in_dim = channels[i] / compression_factor
195
+ out_dim = channels[i]
196
+ conv_layer = ResnetBlock(in_dim, out_dim, down=False, ksize=ksize, sk=sk, use_conv=use_conv)
197
+ else:
198
+ in_dim = int(channels[i] / compression_factor)
199
+ out_dim = int(channels[i] / compression_factor)
200
+ conv_layer = ResnetBlock(in_dim, out_dim, down=False, ksize=ksize, sk=sk, use_conv=use_conv)
201
+ temporal_attention_layer = TemporalTransformerBlock(dim=out_dim,
202
+ num_attention_heads=temporal_attention_nhead,
203
+ attention_head_dim=int(out_dim / temporal_attention_nhead),
204
+ attention_block_types=attention_block_types,
205
+ dropout=0.0,
206
+ cross_attention_dim=None,
207
+ temporal_position_encoding=temporal_position_encoding,
208
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
209
+ rescale_output_factor=rescale_output_factor)
210
+ conv_layers.append(conv_layer)
211
+ temporal_attention_layers.append(temporal_attention_layer)
212
+ self.encoder_down_conv_blocks.append(conv_layers)
213
+ self.encoder_down_attention_blocks.append(temporal_attention_layers)
214
+
215
+ self.encoder_conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)
216
+
217
+ @property
218
+ def dtype(self) -> torch.dtype:
219
+ """
220
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
221
+ """
222
+ return get_parameter_dtype(self)
223
+
224
+ def forward(self, x):
225
+ # unshuffle
226
+ bs = x.shape[0]
227
+ x = rearrange(x, "b f c h w -> (b f) c h w")
228
+ x = self.unshuffle(x)
229
+ # extract features
230
+ features = []
231
+ x = self.encoder_conv_in(x)
232
+ for res_block, attention_block in zip(self.encoder_down_conv_blocks, self.encoder_down_attention_blocks):
233
+ for res_layer, attention_layer in zip(res_block, attention_block):
234
+ x = res_layer(x)
235
+ h, w = x.shape[-2:]
236
+ x = rearrange(x, '(b f) c h w -> (b h w) f c', b=bs)
237
+ x = attention_layer(x)
238
+ x = rearrange(x, '(b h w) f c -> (b f) c h w', h=h, w=w)
239
+ features.append(rearrange(x, '(b f) c h w -> b c f h w', b=bs))
240
+ return features
cameractrl/models/transformer_temporal.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from typing import Optional
5
+ from diffusers.models.transformer_temporal import TransformerTemporalModelOutput
6
+ from diffusers.models.attention import BasicTransformerBlock
7
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
8
+ from diffusers.models.resnet import AlphaBlender
9
+ from cameractrl.models.attention import TemporalPoseCondTransformerBlock
10
+
11
+
12
+ class TransformerSpatioTemporalModelPoseCond(nn.Module):
13
+ """
14
+ A Transformer model for video-like data.
15
+
16
+ Parameters:
17
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
18
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
19
+ in_channels (`int`, *optional*):
20
+ The number of channels in the input and output (specify if the input is **continuous**).
21
+ out_channels (`int`, *optional*):
22
+ The number of channels in the output (specify if the input is **continuous**).
23
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
24
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ num_attention_heads: int = 16,
30
+ attention_head_dim: int = 88,
31
+ in_channels: int = 320,
32
+ out_channels: Optional[int] = None,
33
+ num_layers: int = 1,
34
+ cross_attention_dim: Optional[int] = None,
35
+ ):
36
+ super().__init__()
37
+ self.num_attention_heads = num_attention_heads
38
+ self.attention_head_dim = attention_head_dim
39
+
40
+ inner_dim = num_attention_heads * attention_head_dim
41
+ self.inner_dim = inner_dim
42
+
43
+ # 2. Define input layers
44
+ self.in_channels = in_channels
45
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
46
+ self.proj_in = nn.Linear(in_channels, inner_dim)
47
+
48
+ # 3. Define transformers blocks
49
+ self.transformer_blocks = nn.ModuleList(
50
+ [
51
+ BasicTransformerBlock(
52
+ inner_dim,
53
+ num_attention_heads,
54
+ attention_head_dim,
55
+ cross_attention_dim=cross_attention_dim,
56
+ )
57
+ for d in range(num_layers)
58
+ ]
59
+ )
60
+
61
+ time_mix_inner_dim = inner_dim
62
+ self.temporal_transformer_blocks = nn.ModuleList(
63
+ [
64
+ TemporalPoseCondTransformerBlock(
65
+ inner_dim,
66
+ time_mix_inner_dim,
67
+ num_attention_heads,
68
+ attention_head_dim,
69
+ cross_attention_dim=cross_attention_dim,
70
+ )
71
+ for _ in range(num_layers)
72
+ ]
73
+ )
74
+
75
+ time_embed_dim = in_channels * 4
76
+ self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
77
+ self.time_proj = Timesteps(in_channels, True, 0)
78
+ self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
79
+
80
+ # 4. Define output layers
81
+ self.out_channels = in_channels if out_channels is None else out_channels
82
+ # TODO: should use out_channels for continuous projections
83
+ self.proj_out = nn.Linear(inner_dim, in_channels)
84
+
85
+ self.gradient_checkpointing = False
86
+
87
+ def forward(
88
+ self,
89
+ hidden_states: torch.Tensor, # [bs * frame, c, h, w]
90
+ encoder_hidden_states: Optional[torch.Tensor] = None, # [bs * frame, 1, c]
91
+ image_only_indicator: Optional[torch.Tensor] = None, # [bs, frame]
92
+ pose_feature: Optional[torch.Tensor] = None, # [bs, c, frame, h, w]
93
+ return_dict: bool = True,
94
+ ):
95
+ """
96
+ Args:
97
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
98
+ Input hidden_states.
99
+ num_frames (`int`):
100
+ The number of frames to be processed per batch. This is used to reshape the hidden states.
101
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
102
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
103
+ self-attention.
104
+ image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
105
+ A tensor indicating whether the input contains only images. 1 indicates that the input contains only
106
+ images, 0 indicates that the input contains video frames.
107
+ return_dict (`bool`, *optional*, defaults to `True`):
108
+ Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a
109
+ plain tuple.
110
+
111
+ Returns:
112
+ [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
113
+ If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
114
+ returned, otherwise a `tuple` where the first element is the sample tensor.
115
+ """
116
+ # 1. Input
117
+ batch_frames, _, height, width = hidden_states.shape
118
+ num_frames = image_only_indicator.shape[-1]
119
+ batch_size = batch_frames // num_frames
120
+
121
+ time_context = encoder_hidden_states # [bs * frame, 1, c]
122
+ time_context_first_timestep = time_context[None, :].reshape(
123
+ batch_size, num_frames, -1, time_context.shape[-1]
124
+ )[:, 0] # [bs, frame, c]
125
+ time_context = time_context_first_timestep[:, None].broadcast_to(
126
+ batch_size, height * width, time_context.shape[-2], time_context.shape[-1]
127
+ ) # [bs, h*w, 1, c]
128
+ time_context = time_context.reshape(batch_size * height * width, -1, time_context.shape[-1]) # [bs * h * w, 1, c]
129
+
130
+ residual = hidden_states
131
+
132
+ hidden_states = self.norm(hidden_states) # [bs * frame, c, h, w]
133
+ inner_dim = hidden_states.shape[1]
134
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim) # [bs * frame, h * w, c]
135
+ hidden_states = self.proj_in(hidden_states) # [bs * frame, h * w, c]
136
+
137
+ num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
138
+ num_frames_emb = num_frames_emb.repeat(batch_size, 1) # [bs, frame]
139
+ num_frames_emb = num_frames_emb.reshape(-1) # [bs * frame]
140
+ t_emb = self.time_proj(num_frames_emb) # [bs * frame, c]
141
+
142
+ # `Timesteps` does not contain any weights and will always return f32 tensors
143
+ # but time_embedding might actually be running in fp16. so we need to cast here.
144
+ # there might be better ways to encapsulate this.
145
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
146
+
147
+ emb = self.time_pos_embed(t_emb)
148
+ emb = emb[:, None, :] # [bs * frame, 1, c]
149
+
150
+ # 2. Blocks
151
+ for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
152
+ if self.training and self.gradient_checkpointing:
153
+ hidden_states = torch.utils.checkpoint.checkpoint(
154
+ block,
155
+ hidden_states,
156
+ None,
157
+ encoder_hidden_states,
158
+ None,
159
+ use_reentrant=False,
160
+ )
161
+ else:
162
+ hidden_states = block(
163
+ hidden_states, # [bs * frame, h * w, c]
164
+ encoder_hidden_states=encoder_hidden_states, # [bs * frame, 1, c]
165
+ ) # [bs * frame, h * w, c]
166
+
167
+ hidden_states_mix = hidden_states
168
+ hidden_states_mix = hidden_states_mix + emb
169
+
170
+ hidden_states_mix = temporal_block(
171
+ hidden_states_mix, # [bs * frame, h * w, c]
172
+ num_frames=num_frames,
173
+ encoder_hidden_states=time_context, # [bs * h * w, 1, c]
174
+ pose_feature=pose_feature
175
+ )
176
+ hidden_states = self.time_mixer(
177
+ x_spatial=hidden_states,
178
+ x_temporal=hidden_states_mix,
179
+ image_only_indicator=image_only_indicator,
180
+ )
181
+
182
+ # 3. Output
183
+ hidden_states = self.proj_out(hidden_states)
184
+ hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
185
+
186
+ output = hidden_states + residual
187
+
188
+ if not return_dict:
189
+ return (output,)
190
+
191
+ return TransformerTemporalModelOutput(sample=output)
cameractrl/models/unet.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint
6
+
7
+ from typing import List, Optional, Tuple, Union, Dict
8
+
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor, CROSS_ATTENTION_PROCESSORS
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
13
+ from diffusers.loaders import UNet2DConditionLoadersMixin
14
+ from diffusers.models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput
15
+
16
+ from cameractrl.models.unet_3d_blocks import (
17
+ get_down_block,
18
+ get_up_block,
19
+ UNetMidBlockSpatioTemporalPoseCond
20
+ )
21
+ from cameractrl.models.attention_processor import XFormersAttnProcessor as CustomizedXFormerAttnProcessor
22
+ from cameractrl.models.attention_processor import PoseAdaptorXFormersAttnProcessor
23
+
24
+ if hasattr(F, "scaled_dot_product_attention"):
25
+ from cameractrl.models.attention_processor import PoseAdaptorAttnProcessor2_0 as PoseAdaptorAttnProcessor
26
+ from cameractrl.models.attention_processor import AttnProcessor2_0 as CustomizedAttnProcessor
27
+ else:
28
+ from cameractrl.models.attention_processor import PoseAdaptorAttnProcessor
29
+ from cameractrl.models.attention_processor import AttnProcessor as CustomizedAttnProcessor
30
+
31
+ class UNetSpatioTemporalConditionModelPoseCond(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
32
+ r"""
33
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample
34
+ shaped output.
35
+
36
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
37
+ for all models (such as downloading or saving).
38
+
39
+ Parameters:
40
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
41
+ Height and width of input/output sample.
42
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
43
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
44
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
45
+ The tuple of downsample blocks to use.
46
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
47
+ The tuple of upsample blocks to use.
48
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
49
+ The tuple of output channels for each block.
50
+ addition_time_embed_dim: (`int`, defaults to 256):
51
+ Dimension to to encode the additional time ids.
52
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
53
+ The dimension of the projection of encoded `added_time_ids`.
54
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
55
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
56
+ The dimension of the cross attention features.
57
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
58
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
59
+ [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
60
+ [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
61
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
62
+ The number of attention heads.
63
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
64
+ """
65
+
66
+ _supports_gradient_checkpointing = True
67
+
68
+ @register_to_config
69
+ def __init__(
70
+ self,
71
+ sample_size: Optional[int] = None,
72
+ in_channels: int = 8,
73
+ out_channels: int = 4,
74
+ down_block_types: Tuple[str] = (
75
+ "CrossAttnDownBlockSpatioTemporalPoseCond",
76
+ "CrossAttnDownBlockSpatioTemporalPoseCond",
77
+ "CrossAttnDownBlockSpatioTemporalPoseCond",
78
+ "DownBlockSpatioTemporal",
79
+ ),
80
+ up_block_types: Tuple[str] = (
81
+ "UpBlockSpatioTemporal",
82
+ "CrossAttnUpBlockSpatioTemporalPoseCond",
83
+ "CrossAttnUpBlockSpatioTemporalPoseCond",
84
+ "CrossAttnUpBlockSpatioTemporalPoseCond",
85
+ ),
86
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
87
+ addition_time_embed_dim: int = 256,
88
+ projection_class_embeddings_input_dim: int = 768,
89
+ layers_per_block: Union[int, Tuple[int]] = 2,
90
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
91
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
92
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
93
+ num_frames: int = 25,
94
+ ):
95
+ super().__init__()
96
+
97
+ self.sample_size = sample_size
98
+
99
+ # Check inputs
100
+ if len(down_block_types) != len(up_block_types):
101
+ raise ValueError(
102
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
103
+ )
104
+
105
+ if len(block_out_channels) != len(down_block_types):
106
+ raise ValueError(
107
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
108
+ )
109
+
110
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
111
+ raise ValueError(
112
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
113
+ )
114
+
115
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
116
+ raise ValueError(
117
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
118
+ )
119
+
120
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
121
+ raise ValueError(
122
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
123
+ )
124
+
125
+ # input
126
+ self.conv_in = nn.Conv2d(
127
+ in_channels,
128
+ block_out_channels[0],
129
+ kernel_size=3,
130
+ padding=1,
131
+ )
132
+
133
+ # time
134
+ time_embed_dim = block_out_channels[0] * 4
135
+
136
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
137
+ timestep_input_dim = block_out_channels[0]
138
+
139
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
140
+
141
+ self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
142
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
143
+
144
+ self.down_blocks = nn.ModuleList([])
145
+ self.up_blocks = nn.ModuleList([])
146
+
147
+ if isinstance(num_attention_heads, int):
148
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
149
+
150
+ if isinstance(cross_attention_dim, int):
151
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
152
+
153
+ if isinstance(layers_per_block, int):
154
+ layers_per_block = [layers_per_block] * len(down_block_types)
155
+
156
+ if isinstance(transformer_layers_per_block, int):
157
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
158
+
159
+ blocks_time_embed_dim = time_embed_dim
160
+
161
+ # down
162
+ output_channel = block_out_channels[0]
163
+ for i, down_block_type in enumerate(down_block_types):
164
+ input_channel = output_channel
165
+ output_channel = block_out_channels[i]
166
+ is_final_block = i == len(block_out_channels) - 1
167
+
168
+ down_block = get_down_block(
169
+ down_block_type,
170
+ num_layers=layers_per_block[i],
171
+ transformer_layers_per_block=transformer_layers_per_block[i],
172
+ in_channels=input_channel,
173
+ out_channels=output_channel,
174
+ temb_channels=blocks_time_embed_dim,
175
+ add_downsample=not is_final_block,
176
+ resnet_eps=1e-5,
177
+ cross_attention_dim=cross_attention_dim[i],
178
+ num_attention_heads=num_attention_heads[i],
179
+ resnet_act_fn="silu",
180
+ )
181
+ self.down_blocks.append(down_block)
182
+
183
+ # mid
184
+ self.mid_block = UNetMidBlockSpatioTemporalPoseCond(
185
+ block_out_channels[-1],
186
+ temb_channels=blocks_time_embed_dim,
187
+ transformer_layers_per_block=transformer_layers_per_block[-1],
188
+ cross_attention_dim=cross_attention_dim[-1],
189
+ num_attention_heads=num_attention_heads[-1],
190
+ )
191
+
192
+ # count how many layers upsample the images
193
+ self.num_upsamplers = 0
194
+
195
+ # up
196
+ reversed_block_out_channels = list(reversed(block_out_channels))
197
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
198
+ reversed_layers_per_block = list(reversed(layers_per_block))
199
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
200
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
201
+
202
+ output_channel = reversed_block_out_channels[0]
203
+ for i, up_block_type in enumerate(up_block_types):
204
+ is_final_block = i == len(block_out_channels) - 1
205
+
206
+ prev_output_channel = output_channel
207
+ output_channel = reversed_block_out_channels[i]
208
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
209
+
210
+ # add upsample block for all BUT final layer
211
+ if not is_final_block:
212
+ add_upsample = True
213
+ self.num_upsamplers += 1
214
+ else:
215
+ add_upsample = False
216
+
217
+ up_block = get_up_block(
218
+ up_block_type,
219
+ num_layers=reversed_layers_per_block[i] + 1,
220
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
221
+ in_channels=input_channel,
222
+ out_channels=output_channel,
223
+ prev_output_channel=prev_output_channel,
224
+ temb_channels=blocks_time_embed_dim,
225
+ add_upsample=add_upsample,
226
+ resnet_eps=1e-5,
227
+ resolution_idx=i,
228
+ cross_attention_dim=reversed_cross_attention_dim[i],
229
+ num_attention_heads=reversed_num_attention_heads[i],
230
+ resnet_act_fn="silu",
231
+ )
232
+ self.up_blocks.append(up_block)
233
+ prev_output_channel = output_channel
234
+
235
+ # out
236
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
237
+ self.conv_act = nn.SiLU()
238
+
239
+ self.conv_out = nn.Conv2d(
240
+ block_out_channels[0],
241
+ out_channels,
242
+ kernel_size=3,
243
+ padding=1,
244
+ )
245
+
246
+ @property
247
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
248
+ r"""
249
+ Returns:
250
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
251
+ indexed by its weight name.
252
+ """
253
+ # set recursively
254
+ processors = {}
255
+
256
+ def fn_recursive_add_processors(
257
+ name: str,
258
+ module: torch.nn.Module,
259
+ processors: Dict[str, AttentionProcessor],
260
+ ):
261
+ if hasattr(module, "get_processor"):
262
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
263
+
264
+ for sub_name, child in module.named_children():
265
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
266
+
267
+ return processors
268
+
269
+ for name, module in self.named_children():
270
+ fn_recursive_add_processors(name, module, processors)
271
+
272
+ return processors
273
+
274
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
275
+ r"""
276
+ Sets the attention processor to use to compute attention.
277
+
278
+ Parameters:
279
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
280
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
281
+ for **all** `Attention` layers.
282
+
283
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
284
+ processor. This is strongly recommended when setting trainable attention processors.
285
+
286
+ """
287
+ count = len(self.attn_processors.keys())
288
+
289
+ if isinstance(processor, dict) and len(processor) != count:
290
+ raise ValueError(
291
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
292
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
293
+ )
294
+
295
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
296
+ if hasattr(module, "set_processor"):
297
+ if not isinstance(processor, dict):
298
+ module.set_processor(processor)
299
+ else:
300
+ module.set_processor(processor.pop(f"{name}.processor"))
301
+
302
+ for sub_name, child in module.named_children():
303
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
304
+
305
+ for name, module in self.named_children():
306
+ fn_recursive_attn_processor(name, module, processor)
307
+
308
+ def set_default_attn_processor(self):
309
+ """
310
+ Disables custom attention processors and sets the default attention implementation.
311
+ """
312
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
313
+ processor = AttnProcessor()
314
+ else:
315
+ raise ValueError(
316
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
317
+ )
318
+
319
+ self.set_attn_processor(processor)
320
+
321
+ def _set_gradient_checkpointing(self, module, value=False):
322
+ if hasattr(module, "gradient_checkpointing"):
323
+ module.gradient_checkpointing = value
324
+
325
+ # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
326
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
327
+ """
328
+ Sets the attention processor to use [feed forward
329
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
330
+
331
+ Parameters:
332
+ chunk_size (`int`, *optional*):
333
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
334
+ over each tensor of dim=`dim`.
335
+ dim (`int`, *optional*, defaults to `0`):
336
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
337
+ or dim=1 (sequence length).
338
+ """
339
+ if dim not in [0, 1]:
340
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
341
+
342
+ # By default chunk size is 1
343
+ chunk_size = chunk_size or 1
344
+
345
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
346
+ if hasattr(module, "set_chunk_feed_forward"):
347
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
348
+
349
+ for child in module.children():
350
+ fn_recursive_feed_forward(child, chunk_size, dim)
351
+
352
+ for module in self.children():
353
+ fn_recursive_feed_forward(module, chunk_size, dim)
354
+
355
+ def set_pose_cond_attn_processor(self,
356
+ add_spatial=False,
357
+ add_temporal=False,
358
+ enable_xformers=False,
359
+ attn_processor_name='attn1',
360
+ pose_feature_dimensions=[320, 640, 1280, 1280],
361
+ **attention_processor_kwargs):
362
+ all_attn_processors = {}
363
+ set_processor_names = attn_processor_name.split(',')
364
+ if add_spatial:
365
+ for processor_key in self.attn_processors.keys():
366
+ if 'temporal' in processor_key:
367
+ continue
368
+ processor_name = processor_key.split('.')[-2]
369
+ cross_attention_dim = None if processor_name == 'attn1' else self.config.cross_attention_dim
370
+ if processor_key.startswith("mid_block"):
371
+ hidden_size = self.config.block_out_channels[-1]
372
+ block_id = -1
373
+ add_pose_adaptor = processor_name in set_processor_names
374
+ pose_feature_dim = pose_feature_dimensions[block_id] if add_pose_adaptor else None
375
+ elif processor_key.startswith("up_blocks"):
376
+ block_id = int(processor_key[len("up_blocks.")])
377
+ hidden_size = list(reversed(self.config.block_out_channels))[block_id]
378
+ add_pose_adaptor = processor_name in set_processor_names
379
+ pose_feature_dim = list(reversed(pose_feature_dimensions))[block_id] if add_pose_adaptor else None
380
+ else:
381
+ block_id = int(processor_key[len("down_blocks.")])
382
+ hidden_size = self.config.block_out_channels[block_id]
383
+ add_pose_adaptor = processor_name in set_processor_names
384
+ pose_feature_dim = pose_feature_dimensions[block_id] if add_pose_adaptor else None
385
+ if add_pose_adaptor and enable_xformers:
386
+ all_attn_processors[processor_key] = PoseAdaptorXFormersAttnProcessor(hidden_size=hidden_size,
387
+ pose_feature_dim=pose_feature_dim,
388
+ cross_attention_dim=cross_attention_dim,
389
+ **attention_processor_kwargs)
390
+ elif add_pose_adaptor:
391
+ all_attn_processors[processor_key] = PoseAdaptorAttnProcessor(hidden_size=hidden_size,
392
+ pose_feature_dim=pose_feature_dim,
393
+ cross_attention_dim=cross_attention_dim,
394
+ **attention_processor_kwargs)
395
+ elif enable_xformers:
396
+ all_attn_processors[processor_key] = CustomizedXFormerAttnProcessor()
397
+ else:
398
+ all_attn_processors[processor_key] = CustomizedAttnProcessor()
399
+ else:
400
+ for processor_key in self.attn_processors.keys():
401
+ if 'temporal' not in processor_key and enable_xformers:
402
+ all_attn_processors[processor_key] = CustomizedXFormerAttnProcessor()
403
+ elif 'temporal' not in processor_key:
404
+ all_attn_processors[processor_key] = CustomizedAttnProcessor()
405
+
406
+ if add_temporal:
407
+ for processor_key in self.attn_processors.keys():
408
+ if 'temporal' not in processor_key:
409
+ continue
410
+ processor_name = processor_key.split('.')[-2]
411
+ cross_attention_dim = None if processor_name == 'attn1' else self.config.cross_attention_dim
412
+ if processor_key.startswith("mid_block"):
413
+ hidden_size = self.config.block_out_channels[-1]
414
+ block_id = -1
415
+ add_pose_adaptor = processor_name in set_processor_names
416
+ pose_feature_dim = pose_feature_dimensions[block_id] if add_pose_adaptor else None
417
+ elif processor_key.startswith("up_blocks"):
418
+ block_id = int(processor_key[len("up_blocks.")])
419
+ hidden_size = list(reversed(self.config.block_out_channels))[block_id]
420
+ add_pose_adaptor = (processor_name in set_processor_names)
421
+ pose_feature_dim = list(reversed(pose_feature_dimensions))[block_id] if add_pose_adaptor else None
422
+ else:
423
+ block_id = int(processor_key[len("down_blocks.")])
424
+ hidden_size = self.config.block_out_channels[block_id]
425
+ add_pose_adaptor = processor_name in set_processor_names
426
+ pose_feature_dim = pose_feature_dimensions[block_id] if add_pose_adaptor else None
427
+ if add_pose_adaptor and enable_xformers:
428
+ all_attn_processors[processor_key] = PoseAdaptorAttnProcessor(hidden_size=hidden_size,
429
+ pose_feature_dim=pose_feature_dim,
430
+ cross_attention_dim=cross_attention_dim,
431
+ **attention_processor_kwargs)
432
+ elif add_pose_adaptor:
433
+ all_attn_processors[processor_key] = PoseAdaptorAttnProcessor(hidden_size=hidden_size,
434
+ pose_feature_dim=pose_feature_dim,
435
+ cross_attention_dim=cross_attention_dim,
436
+ **attention_processor_kwargs)
437
+ elif enable_xformers:
438
+ all_attn_processors[processor_key] = CustomizedXFormerAttnProcessor()
439
+ else:
440
+ all_attn_processors[processor_key] = CustomizedAttnProcessor()
441
+ else:
442
+ for processor_key in self.attn_processors.keys():
443
+ if 'temporal' in processor_key and enable_xformers:
444
+ all_attn_processors[processor_key] = CustomizedXFormerAttnProcessor()
445
+ elif 'temporal' in processor_key:
446
+ all_attn_processors[processor_key] = CustomizedAttnProcessor()
447
+
448
+ self.set_attn_processor(all_attn_processors)
449
+
450
+ def forward(
451
+ self,
452
+ sample: torch.FloatTensor,
453
+ timestep: Union[torch.Tensor, float, int],
454
+ encoder_hidden_states: torch.Tensor,
455
+ added_time_ids: torch.Tensor,
456
+ pose_features: List[torch.Tensor] = None,
457
+ return_dict: bool = True,
458
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
459
+ r"""
460
+ The [`UNetSpatioTemporalConditionModel`] forward method.
461
+
462
+ Args:
463
+ sample (`torch.FloatTensor`):
464
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
465
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
466
+ encoder_hidden_states (`torch.FloatTensor`):
467
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
468
+ added_time_ids: (`torch.FloatTensor`):
469
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
470
+ embeddings and added to the time embeddings.
471
+ return_dict (`bool`, *optional*, defaults to `True`):
472
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
473
+ tuple.
474
+ Returns:
475
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
476
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
477
+ a `tuple` is returned where the first element is the sample tensor.
478
+ """
479
+ # 1. time
480
+ timesteps = timestep
481
+ if not torch.is_tensor(timesteps):
482
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
483
+ # This would be a good case for the `match` statement (Python 3.10+)
484
+ is_mps = sample.device.type == "mps"
485
+ if isinstance(timestep, float):
486
+ dtype = torch.float32 if is_mps else torch.float64
487
+ else:
488
+ dtype = torch.int32 if is_mps else torch.int64
489
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
490
+ elif len(timesteps.shape) == 0:
491
+ timesteps = timesteps[None].to(sample.device)
492
+
493
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
494
+ batch_size, num_frames = sample.shape[:2]
495
+ timesteps = timesteps.expand(batch_size)
496
+
497
+ t_emb = self.time_proj(timesteps)
498
+
499
+ # `Timesteps` does not contain any weights and will always return f32 tensors
500
+ # but time_embedding might actually be running in fp16. so we need to cast here.
501
+ # there might be better ways to encapsulate this.
502
+ t_emb = t_emb.to(dtype=sample.dtype)
503
+
504
+ emb = self.time_embedding(t_emb)
505
+
506
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
507
+ time_embeds = time_embeds.reshape((batch_size, -1))
508
+ time_embeds = time_embeds.to(emb.dtype)
509
+ aug_emb = self.add_embedding(time_embeds)
510
+ emb = emb + aug_emb
511
+
512
+ # Flatten the batch and frames dimensions
513
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
514
+ sample = sample.flatten(0, 1)
515
+ # Repeat the embeddings num_video_frames times
516
+ # emb: [batch, channels] -> [batch * frames, channels]
517
+ emb = emb.repeat_interleave(num_frames, dim=0)
518
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
519
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
520
+
521
+ # 2. pre-process
522
+ sample = self.conv_in(sample)
523
+
524
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
525
+
526
+ down_block_res_samples = (sample,)
527
+ for block_idx, downsample_block in enumerate(self.down_blocks):
528
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
529
+ sample, res_samples = downsample_block(
530
+ hidden_states=sample,
531
+ temb=emb,
532
+ encoder_hidden_states=encoder_hidden_states,
533
+ image_only_indicator=image_only_indicator,
534
+ pose_feature=pose_features[block_idx]
535
+ )
536
+ else:
537
+ sample, res_samples = downsample_block(
538
+ hidden_states=sample,
539
+ temb=emb,
540
+ image_only_indicator=image_only_indicator,
541
+ )
542
+
543
+ down_block_res_samples += res_samples
544
+
545
+ # 4. mid
546
+ sample = self.mid_block(
547
+ hidden_states=sample,
548
+ temb=emb,
549
+ encoder_hidden_states=encoder_hidden_states,
550
+ image_only_indicator=image_only_indicator,
551
+ pose_feature=pose_features[-1]
552
+ )
553
+
554
+ # 5. up
555
+ for block_idx, upsample_block in enumerate(self.up_blocks):
556
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
557
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
558
+
559
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
560
+ sample = upsample_block(
561
+ hidden_states=sample,
562
+ temb=emb,
563
+ res_hidden_states_tuple=res_samples,
564
+ encoder_hidden_states=encoder_hidden_states,
565
+ image_only_indicator=image_only_indicator,
566
+ pose_feature=pose_features[-(block_idx + 1)]
567
+ )
568
+ else:
569
+ sample = upsample_block(
570
+ hidden_states=sample,
571
+ temb=emb,
572
+ res_hidden_states_tuple=res_samples,
573
+ image_only_indicator=image_only_indicator,
574
+ )
575
+
576
+ # 6. post-process
577
+ sample = self.conv_norm_out(sample)
578
+ sample = self.conv_act(sample)
579
+ sample = self.conv_out(sample)
580
+
581
+ # 7. Reshape back to original shape
582
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
583
+
584
+ if not return_dict:
585
+ return (sample,)
586
+
587
+ return UNetSpatioTemporalConditionOutput(sample=sample)
cameractrl/models/unet_3d_blocks.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Union, Tuple, Optional, Dict, Any
4
+ from diffusers.utils import is_torch_version
5
+ from diffusers.models.resnet import (
6
+ Downsample2D,
7
+ SpatioTemporalResBlock,
8
+ Upsample2D
9
+ )
10
+ from diffusers.models.unet_3d_blocks import (
11
+ DownBlockSpatioTemporal,
12
+ UpBlockSpatioTemporal,
13
+ )
14
+
15
+ from cameractrl.models.transformer_temporal import TransformerSpatioTemporalModelPoseCond
16
+
17
+
18
+ def get_down_block(
19
+ down_block_type: str,
20
+ num_layers: int,
21
+ in_channels: int,
22
+ out_channels: int,
23
+ temb_channels: int,
24
+ add_downsample: bool,
25
+ num_attention_heads: int,
26
+ cross_attention_dim: Optional[int] = None,
27
+ transformer_layers_per_block: int = 1,
28
+ **kwargs,
29
+ ) -> Union[
30
+ "DownBlockSpatioTemporal",
31
+ "CrossAttnDownBlockSpatioTemporalPoseCond",
32
+ ]:
33
+ if down_block_type == "DownBlockSpatioTemporal":
34
+ # added for SDV
35
+ return DownBlockSpatioTemporal(
36
+ num_layers=num_layers,
37
+ in_channels=in_channels,
38
+ out_channels=out_channels,
39
+ temb_channels=temb_channels,
40
+ add_downsample=add_downsample,
41
+ )
42
+ elif down_block_type == "CrossAttnDownBlockSpatioTemporalPoseCond":
43
+ # added for SDV
44
+ if cross_attention_dim is None:
45
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal")
46
+ return CrossAttnDownBlockSpatioTemporalPoseCond(
47
+ in_channels=in_channels,
48
+ out_channels=out_channels,
49
+ temb_channels=temb_channels,
50
+ num_layers=num_layers,
51
+ transformer_layers_per_block=transformer_layers_per_block,
52
+ add_downsample=add_downsample,
53
+ cross_attention_dim=cross_attention_dim,
54
+ num_attention_heads=num_attention_heads,
55
+ )
56
+
57
+ raise ValueError(f"{down_block_type} does not exist.")
58
+
59
+
60
+ def get_up_block(
61
+ up_block_type: str,
62
+ num_layers: int,
63
+ in_channels: int,
64
+ out_channels: int,
65
+ prev_output_channel: int,
66
+ temb_channels: int,
67
+ add_upsample: bool,
68
+ num_attention_heads: int,
69
+ resolution_idx: Optional[int] = None,
70
+ cross_attention_dim: Optional[int] = None,
71
+ transformer_layers_per_block: int = 1,
72
+ **kwargs,
73
+ ) -> Union[
74
+ "UpBlockSpatioTemporal",
75
+ "CrossAttnUpBlockSpatioTemporalPoseCond",
76
+ ]:
77
+ if up_block_type == "UpBlockSpatioTemporal":
78
+ # added for SDV
79
+ return UpBlockSpatioTemporal(
80
+ num_layers=num_layers,
81
+ in_channels=in_channels,
82
+ out_channels=out_channels,
83
+ prev_output_channel=prev_output_channel,
84
+ temb_channels=temb_channels,
85
+ resolution_idx=resolution_idx,
86
+ add_upsample=add_upsample,
87
+ )
88
+ elif up_block_type == "CrossAttnUpBlockSpatioTemporalPoseCond":
89
+ # added for SDV
90
+ if cross_attention_dim is None:
91
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal")
92
+ return CrossAttnUpBlockSpatioTemporalPoseCond(
93
+ in_channels=in_channels,
94
+ out_channels=out_channels,
95
+ prev_output_channel=prev_output_channel,
96
+ temb_channels=temb_channels,
97
+ num_layers=num_layers,
98
+ transformer_layers_per_block=transformer_layers_per_block,
99
+ add_upsample=add_upsample,
100
+ cross_attention_dim=cross_attention_dim,
101
+ num_attention_heads=num_attention_heads,
102
+ resolution_idx=resolution_idx,
103
+ )
104
+
105
+ raise ValueError(f"{up_block_type} does not exist.")
106
+
107
+
108
+ class CrossAttnDownBlockSpatioTemporalPoseCond(nn.Module):
109
+ def __init__(
110
+ self,
111
+ in_channels: int,
112
+ out_channels: int,
113
+ temb_channels: int,
114
+ num_layers: int = 1,
115
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
116
+ num_attention_heads: int = 1,
117
+ cross_attention_dim: int = 1280,
118
+ add_downsample: bool = True,
119
+ ):
120
+ super().__init__()
121
+ resnets = []
122
+ attentions = []
123
+
124
+ self.has_cross_attention = True
125
+ self.num_attention_heads = num_attention_heads
126
+ if isinstance(transformer_layers_per_block, int):
127
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
128
+
129
+ for i in range(num_layers):
130
+ in_channels = in_channels if i == 0 else out_channels
131
+ resnets.append(
132
+ SpatioTemporalResBlock(
133
+ in_channels=in_channels,
134
+ out_channels=out_channels,
135
+ temb_channels=temb_channels,
136
+ eps=1e-6,
137
+ )
138
+ )
139
+ attentions.append(
140
+ TransformerSpatioTemporalModelPoseCond(
141
+ num_attention_heads,
142
+ out_channels // num_attention_heads,
143
+ in_channels=out_channels,
144
+ num_layers=transformer_layers_per_block[i],
145
+ cross_attention_dim=cross_attention_dim,
146
+ )
147
+ )
148
+
149
+ self.attentions = nn.ModuleList(attentions)
150
+ self.resnets = nn.ModuleList(resnets)
151
+
152
+ if add_downsample:
153
+ self.downsamplers = nn.ModuleList(
154
+ [
155
+ Downsample2D(
156
+ out_channels,
157
+ use_conv=True,
158
+ out_channels=out_channels,
159
+ padding=1,
160
+ name="op",
161
+ )
162
+ ]
163
+ )
164
+ else:
165
+ self.downsamplers = None
166
+
167
+ self.gradient_checkpointing = False
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states: torch.FloatTensor, # [bs * frame, c, h, w]
172
+ temb: Optional[torch.FloatTensor] = None, # [bs * frame, c]
173
+ encoder_hidden_states: Optional[torch.FloatTensor] = None, # [bs * frame, 1, c]
174
+ image_only_indicator: Optional[torch.Tensor] = None, # [bs, frame]
175
+ pose_feature: Optional[torch.Tensor] = None # [bs, c, frame, h, w]
176
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
177
+ output_states = ()
178
+
179
+ blocks = list(zip(self.resnets, self.attentions))
180
+ for resnet, attn in blocks:
181
+ if self.training and self.gradient_checkpointing: # TODO
182
+
183
+ def create_custom_forward(module, return_dict=None):
184
+ def custom_forward(*inputs):
185
+ if return_dict is not None:
186
+ return module(*inputs, return_dict=return_dict)
187
+ else:
188
+ return module(*inputs)
189
+
190
+ return custom_forward
191
+
192
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
193
+ hidden_states = torch.utils.checkpoint.checkpoint(
194
+ create_custom_forward(resnet),
195
+ hidden_states,
196
+ temb,
197
+ image_only_indicator,
198
+ **ckpt_kwargs,
199
+ )
200
+
201
+ hidden_states = attn(
202
+ hidden_states,
203
+ encoder_hidden_states=encoder_hidden_states,
204
+ image_only_indicator=image_only_indicator,
205
+ return_dict=False,
206
+ )[0]
207
+ else:
208
+ hidden_states = resnet(
209
+ hidden_states,
210
+ temb,
211
+ image_only_indicator=image_only_indicator,
212
+ ) # [bs * frame, c, h, w]
213
+ hidden_states = attn(
214
+ hidden_states,
215
+ encoder_hidden_states=encoder_hidden_states,
216
+ image_only_indicator=image_only_indicator,
217
+ pose_feature=pose_feature,
218
+ return_dict=False,
219
+ )[0]
220
+
221
+ output_states = output_states + (hidden_states,)
222
+
223
+ if self.downsamplers is not None:
224
+ for downsampler in self.downsamplers:
225
+ hidden_states = downsampler(hidden_states)
226
+
227
+ output_states = output_states + (hidden_states,)
228
+
229
+ return hidden_states, output_states
230
+
231
+
232
+ class UNetMidBlockSpatioTemporalPoseCond(nn.Module):
233
+ def __init__(
234
+ self,
235
+ in_channels: int,
236
+ temb_channels: int,
237
+ num_layers: int = 1,
238
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
239
+ num_attention_heads: int = 1,
240
+ cross_attention_dim: int = 1280,
241
+ ):
242
+ super().__init__()
243
+
244
+ self.has_cross_attention = True
245
+ self.num_attention_heads = num_attention_heads
246
+
247
+ # support for variable transformer layers per block
248
+ if isinstance(transformer_layers_per_block, int):
249
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
250
+
251
+ # there is always at least one resnet
252
+ resnets = [
253
+ SpatioTemporalResBlock(
254
+ in_channels=in_channels,
255
+ out_channels=in_channels,
256
+ temb_channels=temb_channels,
257
+ eps=1e-5,
258
+ )
259
+ ]
260
+ attentions = []
261
+
262
+ for i in range(num_layers):
263
+ attentions.append(
264
+ TransformerSpatioTemporalModelPoseCond(
265
+ num_attention_heads,
266
+ in_channels // num_attention_heads,
267
+ in_channels=in_channels,
268
+ num_layers=transformer_layers_per_block[i],
269
+ cross_attention_dim=cross_attention_dim,
270
+ )
271
+ )
272
+
273
+ resnets.append(
274
+ SpatioTemporalResBlock(
275
+ in_channels=in_channels,
276
+ out_channels=in_channels,
277
+ temb_channels=temb_channels,
278
+ eps=1e-5,
279
+ )
280
+ )
281
+
282
+ self.attentions = nn.ModuleList(attentions)
283
+ self.resnets = nn.ModuleList(resnets)
284
+
285
+ self.gradient_checkpointing = False
286
+
287
+ def forward(
288
+ self,
289
+ hidden_states: torch.FloatTensor,
290
+ temb: Optional[torch.FloatTensor] = None,
291
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
292
+ image_only_indicator: Optional[torch.Tensor] = None,
293
+ pose_feature: Optional[torch.Tensor] = None # [bs, c, frame, h, w]
294
+ ) -> torch.FloatTensor:
295
+ hidden_states = self.resnets[0](
296
+ hidden_states,
297
+ temb,
298
+ image_only_indicator=image_only_indicator,
299
+ )
300
+
301
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
302
+ if self.training and self.gradient_checkpointing: # TODO
303
+
304
+ def create_custom_forward(module, return_dict=None):
305
+ def custom_forward(*inputs):
306
+ if return_dict is not None:
307
+ return module(*inputs, return_dict=return_dict)
308
+ else:
309
+ return module(*inputs)
310
+
311
+ return custom_forward
312
+
313
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
314
+ hidden_states = attn(
315
+ hidden_states,
316
+ encoder_hidden_states=encoder_hidden_states,
317
+ image_only_indicator=image_only_indicator,
318
+ return_dict=False,
319
+ )[0]
320
+ hidden_states = torch.utils.checkpoint.checkpoint(
321
+ create_custom_forward(resnet),
322
+ hidden_states,
323
+ temb,
324
+ image_only_indicator,
325
+ **ckpt_kwargs,
326
+ )
327
+ else:
328
+ hidden_states = attn(
329
+ hidden_states,
330
+ encoder_hidden_states=encoder_hidden_states,
331
+ image_only_indicator=image_only_indicator,
332
+ pose_feature=pose_feature,
333
+ return_dict=False,
334
+ )[0]
335
+ hidden_states = resnet(
336
+ hidden_states,
337
+ temb,
338
+ image_only_indicator=image_only_indicator,
339
+ )
340
+
341
+ return hidden_states
342
+
343
+
344
+ class CrossAttnUpBlockSpatioTemporalPoseCond(nn.Module):
345
+ def __init__(
346
+ self,
347
+ in_channels: int,
348
+ out_channels: int,
349
+ prev_output_channel: int,
350
+ temb_channels: int,
351
+ resolution_idx: Optional[int] = None,
352
+ num_layers: int = 1,
353
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
354
+ resnet_eps: float = 1e-6,
355
+ num_attention_heads: int = 1,
356
+ cross_attention_dim: int = 1280,
357
+ add_upsample: bool = True,
358
+ ):
359
+ super().__init__()
360
+ resnets = []
361
+ attentions = []
362
+
363
+ self.has_cross_attention = True
364
+ self.num_attention_heads = num_attention_heads
365
+
366
+ if isinstance(transformer_layers_per_block, int):
367
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
368
+
369
+ for i in range(num_layers):
370
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
371
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
372
+
373
+ resnets.append(
374
+ SpatioTemporalResBlock(
375
+ in_channels=resnet_in_channels + res_skip_channels,
376
+ out_channels=out_channels,
377
+ temb_channels=temb_channels,
378
+ eps=resnet_eps,
379
+ )
380
+ )
381
+ attentions.append(
382
+ TransformerSpatioTemporalModelPoseCond(
383
+ num_attention_heads,
384
+ out_channels // num_attention_heads,
385
+ in_channels=out_channels,
386
+ num_layers=transformer_layers_per_block[i],
387
+ cross_attention_dim=cross_attention_dim,
388
+ )
389
+ )
390
+
391
+ self.attentions = nn.ModuleList(attentions)
392
+ self.resnets = nn.ModuleList(resnets)
393
+
394
+ if add_upsample:
395
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
396
+ else:
397
+ self.upsamplers = None
398
+
399
+ self.gradient_checkpointing = False
400
+ self.resolution_idx = resolution_idx
401
+
402
+ def forward(
403
+ self,
404
+ hidden_states: torch.FloatTensor,
405
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
406
+ temb: Optional[torch.FloatTensor] = None,
407
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
408
+ image_only_indicator: Optional[torch.Tensor] = None,
409
+ pose_feature: Optional[torch.Tensor] = None # [bs, c, frame, h, w]
410
+ ) -> torch.FloatTensor:
411
+ for resnet, attn in zip(self.resnets, self.attentions):
412
+ # pop res hidden states
413
+ res_hidden_states = res_hidden_states_tuple[-1]
414
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
415
+
416
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
417
+
418
+ if self.training and self.gradient_checkpointing: # TODO
419
+
420
+ def create_custom_forward(module, return_dict=None):
421
+ def custom_forward(*inputs):
422
+ if return_dict is not None:
423
+ return module(*inputs, return_dict=return_dict)
424
+ else:
425
+ return module(*inputs)
426
+
427
+ return custom_forward
428
+
429
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
430
+ hidden_states = torch.utils.checkpoint.checkpoint(
431
+ create_custom_forward(resnet),
432
+ hidden_states,
433
+ temb,
434
+ image_only_indicator,
435
+ **ckpt_kwargs,
436
+ )
437
+ hidden_states = attn(
438
+ hidden_states,
439
+ encoder_hidden_states=encoder_hidden_states,
440
+ image_only_indicator=image_only_indicator,
441
+ return_dict=False,
442
+ )[0]
443
+ else:
444
+ hidden_states = resnet(
445
+ hidden_states,
446
+ temb,
447
+ image_only_indicator=image_only_indicator,
448
+ )
449
+ hidden_states = attn(
450
+ hidden_states,
451
+ encoder_hidden_states=encoder_hidden_states,
452
+ image_only_indicator=image_only_indicator,
453
+ pose_feature=pose_feature,
454
+ return_dict=False,
455
+ )[0]
456
+
457
+ if self.upsamplers is not None:
458
+ for upsampler in self.upsamplers:
459
+ hidden_states = upsampler(hidden_states)
460
+
461
+ return hidden_states
cameractrl/pipelines/pipeline_animation.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2
+
3
+ import inspect
4
+
5
+ import pandas as pd
6
+ import torch
7
+ import PIL.Image
8
+
9
+ from typing import Callable, List, Optional, Union, Dict
10
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
11
+ from diffusers.models import AutoencoderKLTemporalDecoder
12
+ from diffusers.image_processor import VaeImageProcessor
13
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
14
+ from diffusers.schedulers import EulerDiscreteScheduler
15
+ from diffusers.utils import logging
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
18
+ _resize_with_antialiasing,
19
+ _append_dims,
20
+ tensor2vid,
21
+ StableVideoDiffusionPipelineOutput
22
+ )
23
+
24
+ from cameractrl.models.pose_adaptor import CameraPoseEncoder
25
+ from cameractrl.models.unet import UNetSpatioTemporalConditionModelPoseCond
26
+
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ class StableVideoDiffusionPipelinePoseCond(DiffusionPipeline):
32
+ r"""
33
+ Pipeline to generate video from an input image using Stable Video Diffusion.
34
+
35
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
36
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
37
+
38
+ Args:
39
+ vae ([`AutoencoderKL`]):
40
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
41
+ image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
42
+ Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
43
+ unet ([`UNetSpatioTemporalConditionModel`]):
44
+ A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
45
+ scheduler ([`EulerDiscreteScheduler`]):
46
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
47
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
48
+ A `CLIPImageProcessor` to extract features from generated images.
49
+ """
50
+
51
+ model_cpu_offload_seq = "image_encoder->unet->vae"
52
+ _callback_tensor_inputs = ["latents"]
53
+
54
+ def __init__(
55
+ self,
56
+ vae: AutoencoderKLTemporalDecoder,
57
+ image_encoder: CLIPVisionModelWithProjection,
58
+ unet: UNetSpatioTemporalConditionModelPoseCond,
59
+ scheduler: EulerDiscreteScheduler,
60
+ feature_extractor: CLIPImageProcessor,
61
+ pose_encoder: CameraPoseEncoder
62
+ ):
63
+ super().__init__()
64
+
65
+ self.register_modules(
66
+ vae=vae,
67
+ image_encoder=image_encoder,
68
+ unet=unet,
69
+ scheduler=scheduler,
70
+ feature_extractor=feature_extractor,
71
+ pose_encoder=pose_encoder
72
+ )
73
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
74
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
75
+
76
+ def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance, do_resize_normalize):
77
+ dtype = next(self.image_encoder.parameters()).dtype
78
+
79
+ if not isinstance(image, torch.Tensor):
80
+ image = self.image_processor.pil_to_numpy(image)
81
+ image = self.image_processor.numpy_to_pt(image)
82
+
83
+ # We normalize the image before resizing to match with the original implementation.
84
+ # Then we unnormalize it after resizing.
85
+ image = image * 2.0 - 1.0
86
+ image = _resize_with_antialiasing(image, (224, 224))
87
+ image = (image + 1.0) / 2.0
88
+
89
+ # Normalize the image with for CLIP input
90
+ image = self.feature_extractor(
91
+ images=image,
92
+ do_normalize=True,
93
+ do_center_crop=False,
94
+ do_resize=False,
95
+ do_rescale=False,
96
+ return_tensors="pt",
97
+ ).pixel_values
98
+ elif do_resize_normalize:
99
+ image = _resize_with_antialiasing(image, (224, 224))
100
+ image = (image + 1.0) / 2.0
101
+ # Normalize the image with for CLIP input
102
+ image = self.feature_extractor(
103
+ images=image,
104
+ do_normalize=True,
105
+ do_center_crop=False,
106
+ do_resize=False,
107
+ do_rescale=False,
108
+ return_tensors="pt",
109
+ ).pixel_values
110
+
111
+ image = image.to(device=device, dtype=dtype)
112
+ image_embeddings = self.image_encoder(image).image_embeds
113
+ image_embeddings = image_embeddings.unsqueeze(1)
114
+
115
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
116
+ bs_embed, seq_len, _ = image_embeddings.shape
117
+ image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
118
+ image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
119
+
120
+ if do_classifier_free_guidance:
121
+ negative_image_embeddings = torch.zeros_like(image_embeddings)
122
+
123
+ # For classifier free guidance, we need to do two forward passes.
124
+ # Here we concatenate the unconditional and text embeddings into a single batch
125
+ # to avoid doing two forward passes
126
+ image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
127
+
128
+ return image_embeddings
129
+
130
+ def _encode_vae_image(
131
+ self,
132
+ image: torch.Tensor,
133
+ device,
134
+ num_videos_per_prompt,
135
+ do_classifier_free_guidance,
136
+ ):
137
+ image = image.to(device=device)
138
+ image_latents = self.vae.encode(image).latent_dist.mode()
139
+
140
+ if do_classifier_free_guidance:
141
+ negative_image_latents = torch.zeros_like(image_latents)
142
+
143
+ # For classifier free guidance, we need to do two forward passes.
144
+ # Here we concatenate the unconditional and text embeddings into a single batch
145
+ # to avoid doing two forward passes
146
+ image_latents = torch.cat([negative_image_latents, image_latents])
147
+
148
+ # duplicate image_latents for each generation per prompt, using mps friendly method
149
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
150
+
151
+ return image_latents
152
+
153
+ def _get_add_time_ids(
154
+ self,
155
+ fps,
156
+ motion_bucket_id,
157
+ noise_aug_strength,
158
+ dtype,
159
+ batch_size,
160
+ num_videos_per_prompt,
161
+ do_classifier_free_guidance,
162
+ ):
163
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
164
+
165
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
166
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
167
+
168
+ if expected_add_embed_dim != passed_add_embed_dim:
169
+ raise ValueError(
170
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
171
+ )
172
+
173
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
174
+ add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
175
+
176
+ if do_classifier_free_guidance:
177
+ add_time_ids = torch.cat([add_time_ids, add_time_ids])
178
+
179
+ return add_time_ids
180
+
181
+ def decode_latents(self, latents, num_frames, decode_chunk_size=14):
182
+ # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
183
+ latents = latents.flatten(0, 1)
184
+
185
+ latents = 1 / self.vae.config.scaling_factor * latents
186
+
187
+ accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys())
188
+
189
+ # decode decode_chunk_size frames at a time to avoid OOM
190
+ frames = []
191
+ for i in range(0, latents.shape[0], decode_chunk_size):
192
+ num_frames_in = latents[i : i + decode_chunk_size].shape[0]
193
+ decode_kwargs = {}
194
+ if accepts_num_frames:
195
+ # we only pass num_frames_in if it's expected
196
+ decode_kwargs["num_frames"] = num_frames_in
197
+
198
+ frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
199
+ frames.append(frame)
200
+ frames = torch.cat(frames, dim=0)
201
+
202
+ # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
203
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
204
+
205
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
206
+ frames = frames.float()
207
+ return frames
208
+
209
+ def check_inputs(self, image, height, width):
210
+ if (
211
+ not isinstance(image, torch.Tensor)
212
+ and not isinstance(image, PIL.Image.Image)
213
+ and not isinstance(image, list)
214
+ ):
215
+ raise ValueError(
216
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
217
+ f" {type(image)}"
218
+ )
219
+
220
+ if height % 8 != 0 or width % 8 != 0:
221
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
222
+
223
+ def prepare_latents(
224
+ self,
225
+ batch_size,
226
+ num_frames,
227
+ num_channels_latents,
228
+ height,
229
+ width,
230
+ dtype,
231
+ device,
232
+ generator,
233
+ latents=None,
234
+ ):
235
+ shape = (
236
+ batch_size,
237
+ num_frames,
238
+ num_channels_latents // 2,
239
+ height // self.vae_scale_factor,
240
+ width // self.vae_scale_factor,
241
+ )
242
+ if isinstance(generator, list) and len(generator) != batch_size:
243
+ raise ValueError(
244
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
245
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
246
+ )
247
+
248
+ if latents is None:
249
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
250
+ else:
251
+ latents = latents.to(device)
252
+
253
+ # scale the initial noise by the standard deviation required by the scheduler
254
+ latents = latents * self.scheduler.init_noise_sigma
255
+ return latents
256
+
257
+ @property
258
+ def guidance_scale(self):
259
+ return self._guidance_scale
260
+
261
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
262
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
263
+ # corresponds to doing no classifier free guidance.
264
+ @property
265
+ def do_classifier_free_guidance(self):
266
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
267
+
268
+ @property
269
+ def num_timesteps(self):
270
+ return self._num_timesteps
271
+
272
+ @torch.no_grad()
273
+ def __call__(
274
+ self,
275
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
276
+ pose_embedding: torch.FloatTensor,
277
+ height: int = 576,
278
+ width: int = 1024,
279
+ num_frames: Optional[int] = None,
280
+ num_inference_steps: int = 25,
281
+ min_guidance_scale: float = 1.0,
282
+ max_guidance_scale: float = 3.0,
283
+ fps: int = 7,
284
+ motion_bucket_id: int = 127,
285
+ noise_aug_strength: int = 0.02,
286
+ do_resize_normalize: bool = True,
287
+ do_image_process: bool = False,
288
+ decode_chunk_size: Optional[int] = None,
289
+ num_videos_per_prompt: Optional[int] = 1,
290
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
291
+ latents: Optional[torch.FloatTensor] = None,
292
+ output_type: Optional[str] = "pil",
293
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
294
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
295
+ return_dict: bool = True,
296
+ ):
297
+ r"""
298
+ The call function to the pipeline for generation.
299
+
300
+ Args:
301
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
302
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
303
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
304
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
305
+ The height in pixels of the generated image.
306
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
307
+ The width in pixels of the generated image.
308
+ num_frames (`int`, *optional*):
309
+ The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
310
+ num_inference_steps (`int`, *optional*, defaults to 25):
311
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
312
+ expense of slower inference. This parameter is modulated by `strength`.
313
+ min_guidance_scale (`float`, *optional*, defaults to 1.0):
314
+ The minimum guidance scale. Used for the classifier free guidance with first frame.
315
+ max_guidance_scale (`float`, *optional*, defaults to 3.0):
316
+ The maximum guidance scale. Used for the classifier free guidance with last frame.
317
+ fps (`int`, *optional*, defaults to 7):
318
+ Frames per second. The rate at which the generated images shall be exported to a video after generation.
319
+ Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
320
+ motion_bucket_id (`int`, *optional*, defaults to 127):
321
+ The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
322
+ noise_aug_strength (`int`, *optional*, defaults to 0.02):
323
+ The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
324
+ decode_chunk_size (`int`, *optional*):
325
+ The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
326
+ between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
327
+ for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
328
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
329
+ The number of images to generate per prompt.
330
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
331
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
332
+ generation deterministic.
333
+ latents (`torch.FloatTensor`, *optional*):
334
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
335
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
336
+ tensor is generated by sampling using the supplied random `generator`.
337
+ output_type (`str`, *optional*, defaults to `"pil"`):
338
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
339
+ callback_on_step_end (`Callable`, *optional*):
340
+ A function that calls at the end of each denoising steps during the inference. The function is called
341
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
342
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
343
+ `callback_on_step_end_tensor_inputs`.
344
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
345
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
346
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
347
+ `._callback_tensor_inputs` attribute of your pipeline class.
348
+ return_dict (`bool`, *optional*, defaults to `True`):
349
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
350
+ plain tuple.
351
+
352
+ Returns:
353
+ [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
354
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
355
+ otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
356
+
357
+ Examples:
358
+
359
+ ```py
360
+ from diffusers import StableVideoDiffusionPipeline
361
+ from diffusers.utils import load_image, export_to_video
362
+
363
+ pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
364
+ pipe.to("cuda")
365
+
366
+ image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
367
+ image = image.resize((1024, 576))
368
+
369
+ frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
370
+ export_to_video(frames, "generated.mp4", fps=7)
371
+ ```
372
+ """
373
+ # 0. Default height and width to unet
374
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
375
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
376
+
377
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
378
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
379
+
380
+ # 1. Check inputs. Raise error if not correct
381
+ self.check_inputs(image, height, width)
382
+
383
+ # 2. Define call parameters
384
+ if isinstance(image, PIL.Image.Image):
385
+ batch_size = 1
386
+ elif isinstance(image, list):
387
+ batch_size = len(image)
388
+ else:
389
+ batch_size = image.shape[0]
390
+ device = pose_embedding.device
391
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
392
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
393
+ # corresponds to doing no classifier free guidance.
394
+ do_classifier_free_guidance = max_guidance_scale > 1.0
395
+
396
+ # 3. Encode input image
397
+ image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance, do_resize_normalize=do_resize_normalize)
398
+
399
+ # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
400
+ # is why it is reduced here.
401
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
402
+ fps = fps - 1
403
+
404
+ # 4. Encode input image using VAE
405
+ if do_image_process:
406
+ image = self.image_processor.preprocess(image, height=height, width=width).to(image_embeddings.device)
407
+ noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
408
+ image = image + noise_aug_strength * noise
409
+
410
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
411
+ if needs_upcasting:
412
+ self.vae.to(dtype=torch.float32)
413
+
414
+ image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
415
+ image_latents = image_latents.to(image_embeddings.dtype)
416
+
417
+ # cast back to fp16 if needed
418
+ if needs_upcasting:
419
+ self.vae.to(dtype=torch.float16)
420
+
421
+ # Repeat the image latents for each frame so we can concatenate them with the noise
422
+ # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
423
+ image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
424
+
425
+ # 5. Get Added Time IDs
426
+ added_time_ids = self._get_add_time_ids(
427
+ fps,
428
+ motion_bucket_id,
429
+ noise_aug_strength,
430
+ image_embeddings.dtype,
431
+ batch_size,
432
+ num_videos_per_prompt,
433
+ do_classifier_free_guidance,
434
+ )
435
+ added_time_ids = added_time_ids.to(device)
436
+
437
+ # 4. Prepare timesteps
438
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
439
+ timesteps = self.scheduler.timesteps
440
+
441
+ # 5. Prepare latent variables
442
+ num_channels_latents = self.unet.config.in_channels
443
+ latents = self.prepare_latents(
444
+ batch_size * num_videos_per_prompt,
445
+ num_frames,
446
+ num_channels_latents,
447
+ height,
448
+ width,
449
+ image_embeddings.dtype,
450
+ device,
451
+ generator,
452
+ latents,
453
+ ) # [bs, frame, c, h, w]
454
+
455
+ # 7. Prepare guidance scale
456
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
457
+ guidance_scale = guidance_scale.to(device, latents.dtype)
458
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
459
+ guidance_scale = _append_dims(guidance_scale, latents.ndim) # [bs, frame, 1, 1, 1]
460
+
461
+ self._guidance_scale = guidance_scale
462
+
463
+ # 8. Prepare pose features
464
+ assert pose_embedding.ndim == 5 # [b, f, c, h, w]
465
+ pose_features = self.pose_encoder(pose_embedding) # list of [b, c, f, h, w]
466
+ pose_features = [torch.cat([x, x], dim=0) for x in pose_features] if do_classifier_free_guidance else pose_features
467
+
468
+ # 9. Denoising loop
469
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
470
+ self._num_timesteps = len(timesteps)
471
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
472
+ for i, t in enumerate(timesteps):
473
+ # expand the latents if we are doing classifier free guidance
474
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
475
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
476
+
477
+ # Concatenate image_latents over channels dimention
478
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
479
+
480
+ # predict the noise residual
481
+ noise_pred = self.unet(
482
+ latent_model_input,
483
+ t,
484
+ encoder_hidden_states=image_embeddings,
485
+ added_time_ids=added_time_ids,
486
+ pose_features=pose_features,
487
+ return_dict=False,
488
+ )[0]
489
+
490
+ # perform guidance
491
+ if do_classifier_free_guidance:
492
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
493
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
494
+
495
+ # compute the previous noisy sample x_t -> x_t-1
496
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
497
+
498
+ if callback_on_step_end is not None:
499
+ callback_kwargs = {}
500
+ for k in callback_on_step_end_tensor_inputs:
501
+ callback_kwargs[k] = locals()[k]
502
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
503
+
504
+ latents = callback_outputs.pop("latents", latents)
505
+
506
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
507
+ progress_bar.update()
508
+
509
+ if not output_type == "latent":
510
+ # cast back to fp16 if needed
511
+ if needs_upcasting:
512
+ self.vae.to(dtype=torch.float16)
513
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size) # [b, c, f, h, w]
514
+ frames = tensor2vid(frames, self.image_processor, output_type=output_type)
515
+ else:
516
+ frames = latents
517
+
518
+ self.maybe_free_model_hooks()
519
+
520
+ if not return_dict:
521
+ return frames
522
+
523
+ return StableVideoDiffusionPipelineOutput(frames=frames)
cameractrl/utils/convert_from_ckpt.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Conversion script for the Stable Diffusion checkpoints."""
16
+
17
+ import re
18
+ from transformers import CLIPTextModel
19
+
20
+ def shave_segments(path, n_shave_prefix_segments=1):
21
+ """
22
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
23
+ """
24
+ if n_shave_prefix_segments >= 0:
25
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
26
+ else:
27
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
28
+
29
+
30
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
31
+ """
32
+ Updates paths inside resnets to the new naming scheme (local renaming)
33
+ """
34
+ mapping = []
35
+ for old_item in old_list:
36
+ new_item = old_item.replace("in_layers.0", "norm1")
37
+ new_item = new_item.replace("in_layers.2", "conv1")
38
+
39
+ new_item = new_item.replace("out_layers.0", "norm2")
40
+ new_item = new_item.replace("out_layers.3", "conv2")
41
+
42
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
43
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
44
+
45
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
46
+
47
+ mapping.append({"old": old_item, "new": new_item})
48
+
49
+ return mapping
50
+
51
+
52
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
53
+ """
54
+ Updates paths inside resnets to the new naming scheme (local renaming)
55
+ """
56
+ mapping = []
57
+ for old_item in old_list:
58
+ new_item = old_item
59
+
60
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
61
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
62
+
63
+ mapping.append({"old": old_item, "new": new_item})
64
+
65
+ return mapping
66
+
67
+
68
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
69
+ """
70
+ Updates paths inside attentions to the new naming scheme (local renaming)
71
+ """
72
+ mapping = []
73
+ for old_item in old_list:
74
+ new_item = old_item
75
+
76
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
77
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
78
+
79
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
80
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
81
+
82
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
83
+
84
+ mapping.append({"old": old_item, "new": new_item})
85
+
86
+ return mapping
87
+
88
+
89
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
90
+ """
91
+ Updates paths inside attentions to the new naming scheme (local renaming)
92
+ """
93
+ mapping = []
94
+ for old_item in old_list:
95
+ new_item = old_item
96
+
97
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
98
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
99
+
100
+ new_item = new_item.replace("q.weight", "query.weight")
101
+ new_item = new_item.replace("q.bias", "query.bias")
102
+
103
+ new_item = new_item.replace("k.weight", "key.weight")
104
+ new_item = new_item.replace("k.bias", "key.bias")
105
+
106
+ new_item = new_item.replace("v.weight", "value.weight")
107
+ new_item = new_item.replace("v.bias", "value.bias")
108
+
109
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
110
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
111
+
112
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
113
+
114
+ mapping.append({"old": old_item, "new": new_item})
115
+
116
+ return mapping
117
+
118
+
119
+ def assign_to_checkpoint(
120
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
121
+ ):
122
+ """
123
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
124
+ attention layers, and takes into account additional replacements that may arise.
125
+
126
+ Assigns the weights to the new checkpoint.
127
+ """
128
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
129
+
130
+ # Splits the attention layers into three variables.
131
+ if attention_paths_to_split is not None:
132
+ for path, path_map in attention_paths_to_split.items():
133
+ old_tensor = old_checkpoint[path]
134
+ channels = old_tensor.shape[0] // 3
135
+
136
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
137
+
138
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
139
+
140
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
141
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
142
+
143
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
144
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
145
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
146
+
147
+ for path in paths:
148
+ new_path = path["new"]
149
+
150
+ # These have already been assigned
151
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
152
+ continue
153
+
154
+ # Global renaming happens here
155
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
156
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
157
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
158
+
159
+ if additional_replacements is not None:
160
+ for replacement in additional_replacements:
161
+ new_path = new_path.replace(replacement["old"], replacement["new"])
162
+
163
+ # proj_attn.weight has to be converted from conv 1D to linear
164
+ if "proj_attn.weight" in new_path:
165
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
166
+ else:
167
+ checkpoint[new_path] = old_checkpoint[path["old"]]
168
+
169
+
170
+ def conv_attn_to_linear(checkpoint):
171
+ keys = list(checkpoint.keys())
172
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
173
+ for key in keys:
174
+ if ".".join(key.split(".")[-2:]) in attn_keys:
175
+ if checkpoint[key].ndim > 2:
176
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
177
+ elif "proj_attn.weight" in key:
178
+ if checkpoint[key].ndim > 2:
179
+ checkpoint[key] = checkpoint[key][:, :, 0]
180
+
181
+
182
+ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
183
+ """
184
+ Takes a state dict and a config, and returns a converted checkpoint.
185
+ """
186
+
187
+ # extract state_dict for UNet
188
+ unet_state_dict = {}
189
+ keys = list(checkpoint.keys())
190
+
191
+ if controlnet:
192
+ unet_key = "control_model."
193
+ else:
194
+ unet_key = "model.diffusion_model."
195
+
196
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
197
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
198
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
199
+ print(
200
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
201
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
202
+ )
203
+ for key in keys:
204
+ if key.startswith("model.diffusion_model"):
205
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
206
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
207
+ else:
208
+ if sum(k.startswith("model_ema") for k in keys) > 100:
209
+ print(
210
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
211
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
212
+ )
213
+
214
+ for key in keys:
215
+ if key.startswith(unet_key):
216
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
217
+
218
+ new_checkpoint = {}
219
+
220
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
221
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
222
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
223
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
224
+
225
+ if config["class_embed_type"] is None:
226
+ # No parameters to port
227
+ ...
228
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
229
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
230
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
231
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
232
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
233
+ else:
234
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
235
+
236
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
237
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
238
+
239
+ if not controlnet:
240
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
241
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
242
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
243
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
244
+
245
+ # Retrieves the keys for the input blocks only
246
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
247
+ input_blocks = {
248
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
249
+ for layer_id in range(num_input_blocks)
250
+ }
251
+
252
+ # Retrieves the keys for the middle blocks only
253
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
254
+ middle_blocks = {
255
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
256
+ for layer_id in range(num_middle_blocks)
257
+ }
258
+
259
+ # Retrieves the keys for the output blocks only
260
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
261
+ output_blocks = {
262
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
263
+ for layer_id in range(num_output_blocks)
264
+ }
265
+
266
+ for i in range(1, num_input_blocks):
267
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
268
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
269
+
270
+ resnets = [
271
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
272
+ ]
273
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
274
+
275
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
276
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
277
+ f"input_blocks.{i}.0.op.weight"
278
+ )
279
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
280
+ f"input_blocks.{i}.0.op.bias"
281
+ )
282
+
283
+ paths = renew_resnet_paths(resnets)
284
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
285
+ assign_to_checkpoint(
286
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
287
+ )
288
+
289
+ if len(attentions):
290
+ paths = renew_attention_paths(attentions)
291
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
292
+ assign_to_checkpoint(
293
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
294
+ )
295
+
296
+ resnet_0 = middle_blocks[0]
297
+ attentions = middle_blocks[1]
298
+ resnet_1 = middle_blocks[2]
299
+
300
+ resnet_0_paths = renew_resnet_paths(resnet_0)
301
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
302
+
303
+ resnet_1_paths = renew_resnet_paths(resnet_1)
304
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
305
+
306
+ attentions_paths = renew_attention_paths(attentions)
307
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
308
+ assign_to_checkpoint(
309
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
310
+ )
311
+
312
+ for i in range(num_output_blocks):
313
+ block_id = i // (config["layers_per_block"] + 1)
314
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
315
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
316
+ output_block_list = {}
317
+
318
+ for layer in output_block_layers:
319
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
320
+ if layer_id in output_block_list:
321
+ output_block_list[layer_id].append(layer_name)
322
+ else:
323
+ output_block_list[layer_id] = [layer_name]
324
+
325
+ if len(output_block_list) > 1:
326
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
327
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
328
+
329
+ resnet_0_paths = renew_resnet_paths(resnets)
330
+ paths = renew_resnet_paths(resnets)
331
+
332
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
333
+ assign_to_checkpoint(
334
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
335
+ )
336
+
337
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
338
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
339
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
340
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
341
+ f"output_blocks.{i}.{index}.conv.weight"
342
+ ]
343
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
344
+ f"output_blocks.{i}.{index}.conv.bias"
345
+ ]
346
+
347
+ # Clear attentions as they have been attributed above.
348
+ if len(attentions) == 2:
349
+ attentions = []
350
+
351
+ if len(attentions):
352
+ paths = renew_attention_paths(attentions)
353
+ meta_path = {
354
+ "old": f"output_blocks.{i}.1",
355
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
356
+ }
357
+ assign_to_checkpoint(
358
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
359
+ )
360
+ else:
361
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
362
+ for path in resnet_0_paths:
363
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
364
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
365
+
366
+ new_checkpoint[new_path] = unet_state_dict[old_path]
367
+
368
+ if controlnet:
369
+ # conditioning embedding
370
+
371
+ orig_index = 0
372
+
373
+ new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
374
+ f"input_hint_block.{orig_index}.weight"
375
+ )
376
+ new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
377
+ f"input_hint_block.{orig_index}.bias"
378
+ )
379
+
380
+ orig_index += 2
381
+
382
+ diffusers_index = 0
383
+
384
+ while diffusers_index < 6:
385
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
386
+ f"input_hint_block.{orig_index}.weight"
387
+ )
388
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
389
+ f"input_hint_block.{orig_index}.bias"
390
+ )
391
+ diffusers_index += 1
392
+ orig_index += 2
393
+
394
+ new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
395
+ f"input_hint_block.{orig_index}.weight"
396
+ )
397
+ new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
398
+ f"input_hint_block.{orig_index}.bias"
399
+ )
400
+
401
+ # down blocks
402
+ for i in range(num_input_blocks):
403
+ new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
404
+ new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
405
+
406
+ # mid block
407
+ new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
408
+ new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
409
+
410
+ return new_checkpoint
411
+
412
+
413
+ def convert_ldm_vae_checkpoint(checkpoint, config):
414
+ # extract state dict for VAE
415
+ vae_state_dict = {}
416
+ keys = list(checkpoint.keys())
417
+ vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
418
+ for key in keys:
419
+ if key.startswith(vae_key):
420
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
421
+
422
+ new_checkpoint = {}
423
+
424
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
425
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
426
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
427
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
428
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
429
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
430
+
431
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
432
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
433
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
434
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
435
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
436
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
437
+
438
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
439
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
440
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
441
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
442
+
443
+ # Retrieves the keys for the encoder down blocks only
444
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
445
+ down_blocks = {
446
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
447
+ }
448
+
449
+ # Retrieves the keys for the decoder up blocks only
450
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
451
+ up_blocks = {
452
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
453
+ }
454
+
455
+ for i in range(num_down_blocks):
456
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
457
+
458
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
459
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
460
+ f"encoder.down.{i}.downsample.conv.weight"
461
+ )
462
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
463
+ f"encoder.down.{i}.downsample.conv.bias"
464
+ )
465
+
466
+ paths = renew_vae_resnet_paths(resnets)
467
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
468
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
469
+
470
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
471
+ num_mid_res_blocks = 2
472
+ for i in range(1, num_mid_res_blocks + 1):
473
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
474
+
475
+ paths = renew_vae_resnet_paths(resnets)
476
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
477
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
478
+
479
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
480
+ paths = renew_vae_attention_paths(mid_attentions)
481
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
482
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
483
+ conv_attn_to_linear(new_checkpoint)
484
+
485
+ for i in range(num_up_blocks):
486
+ block_id = num_up_blocks - 1 - i
487
+ resnets = [
488
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
489
+ ]
490
+
491
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
492
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
493
+ f"decoder.up.{block_id}.upsample.conv.weight"
494
+ ]
495
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
496
+ f"decoder.up.{block_id}.upsample.conv.bias"
497
+ ]
498
+
499
+ paths = renew_vae_resnet_paths(resnets)
500
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
501
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
502
+
503
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
504
+ num_mid_res_blocks = 2
505
+ for i in range(1, num_mid_res_blocks + 1):
506
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
507
+
508
+ paths = renew_vae_resnet_paths(resnets)
509
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
510
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
511
+
512
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
513
+ paths = renew_vae_attention_paths(mid_attentions)
514
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
515
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
516
+ conv_attn_to_linear(new_checkpoint)
517
+ return new_checkpoint
518
+
519
+
520
+ def convert_ldm_clip_checkpoint(checkpoint):
521
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
522
+ keys = list(checkpoint.keys())
523
+
524
+ text_model_dict = {}
525
+
526
+ for key in keys:
527
+ if key.startswith("cond_stage_model.transformer"):
528
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
529
+
530
+ text_model.load_state_dict(text_model_dict)
531
+
532
+ return text_model
533
+
534
+
535
+ textenc_conversion_lst = [
536
+ ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
537
+ ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
538
+ ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
539
+ ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
540
+ ]
541
+ textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
542
+
543
+ textenc_transformer_conversion_lst = [
544
+ # (stable-diffusion, HF Diffusers)
545
+ ("resblocks.", "text_model.encoder.layers."),
546
+ ("ln_1", "layer_norm1"),
547
+ ("ln_2", "layer_norm2"),
548
+ (".c_fc.", ".fc1."),
549
+ (".c_proj.", ".fc2."),
550
+ (".attn", ".self_attn"),
551
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
552
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
553
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
554
+ ]
555
+ protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
556
+ textenc_pattern = re.compile("|".join(protected.keys()))
cameractrl/utils/convert_lora_safetensor_to_diffusers.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Conversion script for the LoRA's safetensors checkpoints. """
17
+
18
+ import argparse
19
+
20
+ import torch
21
+ from safetensors.torch import load_file
22
+
23
+ from diffusers import StableDiffusionPipeline
24
+ import pdb
25
+
26
+
27
+
28
+ def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
29
+ # directly update weight in diffusers model
30
+ for key in state_dict:
31
+ # only process lora down key
32
+ if "up." in key: continue
33
+
34
+ up_key = key.replace(".down.", ".up.")
35
+ model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
36
+ model_key = model_key.replace("to_out.", "to_out.0.")
37
+ layer_infos = model_key.split(".")[:-1]
38
+
39
+ curr_layer = pipeline.unet
40
+ while len(layer_infos) > 0:
41
+ temp_name = layer_infos.pop(0)
42
+ curr_layer = curr_layer.__getattr__(temp_name)
43
+
44
+ weight_down = state_dict[key]
45
+ weight_up = state_dict[up_key]
46
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
47
+
48
+ return pipeline
49
+
50
+
51
+
52
+ def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
53
+ # load base model
54
+ # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
55
+
56
+ # load LoRA weight from .safetensors
57
+ # state_dict = load_file(checkpoint_path)
58
+
59
+ visited = []
60
+
61
+ # directly update weight in diffusers model
62
+ for key in state_dict:
63
+ # it is suggested to print out the key, it usually will be something like below
64
+ # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
65
+
66
+ # as we have set the alpha beforehand, so just skip
67
+ if ".alpha" in key or key in visited:
68
+ continue
69
+
70
+ if "text" in key:
71
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
72
+ curr_layer = pipeline.text_encoder
73
+ else:
74
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
75
+ curr_layer = pipeline.unet
76
+
77
+ # find the target layer
78
+ temp_name = layer_infos.pop(0)
79
+ while len(layer_infos) > -1:
80
+ try:
81
+ curr_layer = curr_layer.__getattr__(temp_name)
82
+ if len(layer_infos) > 0:
83
+ temp_name = layer_infos.pop(0)
84
+ elif len(layer_infos) == 0:
85
+ break
86
+ except Exception:
87
+ if len(temp_name) > 0:
88
+ temp_name += "_" + layer_infos.pop(0)
89
+ else:
90
+ temp_name = layer_infos.pop(0)
91
+
92
+ pair_keys = []
93
+ if "lora_down" in key:
94
+ pair_keys.append(key.replace("lora_down", "lora_up"))
95
+ pair_keys.append(key)
96
+ else:
97
+ pair_keys.append(key)
98
+ pair_keys.append(key.replace("lora_up", "lora_down"))
99
+
100
+ # update weight
101
+ if len(state_dict[pair_keys[0]].shape) == 4:
102
+ weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
103
+ weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
104
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
105
+ else:
106
+ weight_up = state_dict[pair_keys[0]].to(torch.float32)
107
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
108
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
109
+
110
+ # update visited list
111
+ for item in pair_keys:
112
+ visited.append(item)
113
+
114
+ return pipeline
115
+
116
+
117
+ if __name__ == "__main__":
118
+ parser = argparse.ArgumentParser()
119
+
120
+ parser.add_argument(
121
+ "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
122
+ )
123
+ parser.add_argument(
124
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
125
+ )
126
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
127
+ parser.add_argument(
128
+ "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
129
+ )
130
+ parser.add_argument(
131
+ "--lora_prefix_text_encoder",
132
+ default="lora_te",
133
+ type=str,
134
+ help="The prefix of text encoder weight in safetensors",
135
+ )
136
+ parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
137
+ parser.add_argument(
138
+ "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
139
+ )
140
+ parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
141
+
142
+ args = parser.parse_args()
143
+
144
+ base_model_path = args.base_model_path
145
+ checkpoint_path = args.checkpoint_path
146
+ dump_path = args.dump_path
147
+ lora_prefix_unet = args.lora_prefix_unet
148
+ lora_prefix_text_encoder = args.lora_prefix_text_encoder
149
+ alpha = args.alpha
150
+
151
+ pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
152
+
153
+ pipe = pipe.to(args.device)
154
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
cameractrl/utils/util.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import functools
3
+ import logging
4
+ import sys
5
+ import imageio
6
+ import atexit
7
+ import importlib
8
+ import torch
9
+ import torchvision
10
+ import numpy as np
11
+ from termcolor import colored
12
+
13
+ from einops import rearrange
14
+
15
+
16
+ def instantiate_from_config(config, **additional_kwargs):
17
+ if not "target" in config:
18
+ if config == '__is_first_stage__':
19
+ return None
20
+ elif config == "__is_unconditional__":
21
+ return None
22
+ raise KeyError("Expected key `target` to instantiate.")
23
+
24
+ additional_kwargs.update(config.get("kwargs", dict()))
25
+ return get_obj_from_str(config["target"])(**additional_kwargs)
26
+
27
+
28
+ def get_obj_from_str(string, reload=False):
29
+ module, cls = string.rsplit(".", 1)
30
+ if reload:
31
+ module_imp = importlib.import_module(module)
32
+ importlib.reload(module_imp)
33
+ return getattr(importlib.import_module(module, package=None), cls)
34
+
35
+
36
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
37
+ videos = rearrange(videos, "b c t h w -> t b c h w")
38
+ outputs = []
39
+ for x in videos:
40
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
41
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
42
+ if rescale:
43
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
44
+ x = (x * 255).numpy().astype(np.uint8)
45
+ outputs.append(x)
46
+
47
+ os.makedirs(os.path.dirname(path), exist_ok=True)
48
+ imageio.mimsave(path, outputs, fps=fps)
49
+
50
+
51
+ # Logger utils are copied from detectron2
52
+ class _ColorfulFormatter(logging.Formatter):
53
+ def __init__(self, *args, **kwargs):
54
+ self._root_name = kwargs.pop("root_name") + "."
55
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
56
+ if len(self._abbrev_name):
57
+ self._abbrev_name = self._abbrev_name + "."
58
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
59
+
60
+ def formatMessage(self, record):
61
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
62
+ log = super(_ColorfulFormatter, self).formatMessage(record)
63
+ if record.levelno == logging.WARNING:
64
+ prefix = colored("WARNING", "red", attrs=["blink"])
65
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
66
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
67
+ else:
68
+ return log
69
+ return prefix + " " + log
70
+
71
+
72
+ # cache the opened file object, so that different calls to `setup_logger`
73
+ # with the same file name can safely write to the same file.
74
+ @functools.lru_cache(maxsize=None)
75
+ def _cached_log_stream(filename):
76
+ # use 1K buffer if writing to cloud storage
77
+ io = open(filename, "a", buffering=1024 if "://" in filename else -1)
78
+ atexit.register(io.close)
79
+ return io
80
+
81
+ @functools.lru_cache()
82
+ def setup_logger(output, distributed_rank, color=True, name='AnimateDiff', abbrev_name=None):
83
+ logger = logging.getLogger(name)
84
+ logger.setLevel(logging.DEBUG)
85
+ logger.propagate = False
86
+
87
+ if abbrev_name is None:
88
+ abbrev_name = 'AD'
89
+ plain_formatter = logging.Formatter(
90
+ "[%(asctime)s] %(name)s:%(lineno)d %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
91
+ )
92
+
93
+ # stdout logging: master only
94
+ if distributed_rank == 0:
95
+ ch = logging.StreamHandler(stream=sys.stdout)
96
+ ch.setLevel(logging.DEBUG)
97
+ if color:
98
+ formatter = _ColorfulFormatter(
99
+ colored("[%(asctime)s %(name)s:%(lineno)d]: ", "green") + "%(message)s",
100
+ datefmt="%m/%d %H:%M:%S",
101
+ root_name=name,
102
+ abbrev_name=str(abbrev_name),
103
+ )
104
+ else:
105
+ formatter = plain_formatter
106
+ ch.setFormatter(formatter)
107
+ logger.addHandler(ch)
108
+
109
+ # file logging: all workers
110
+ if output is not None:
111
+ if output.endswith(".txt") or output.endswith(".log"):
112
+ filename = output
113
+ else:
114
+ filename = os.path.join(output, "log.txt")
115
+ if distributed_rank > 0:
116
+ filename = filename + ".rank{}".format(distributed_rank)
117
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
118
+
119
+ fh = logging.StreamHandler(_cached_log_stream(filename))
120
+ fh.setLevel(logging.DEBUG)
121
+ fh.setFormatter(plain_formatter)
122
+ logger.addHandler(fh)
123
+
124
+ return logger
125
+
126
+
127
+ def format_time(elapsed_time):
128
+ # Time thresholds
129
+ minute = 60
130
+ hour = 60 * minute
131
+ day = 24 * hour
132
+
133
+ days, remainder = divmod(elapsed_time, day)
134
+ hours, remainder = divmod(remainder, hour)
135
+ minutes, seconds = divmod(remainder, minute)
136
+
137
+ formatted_time = ""
138
+
139
+ if days > 0:
140
+ formatted_time += f"{int(days)} days "
141
+ if hours > 0:
142
+ formatted_time += f"{int(hours)} hours "
143
+ if minutes > 0:
144
+ formatted_time += f"{int(minutes)} minutes "
145
+ if seconds > 0:
146
+ formatted_time += f"{seconds:.2f} seconds"
147
+
148
+ return formatted_time.strip()
configs/train_cameractrl/svd_320_576_cameractrl.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: "output/cameractrl_model"
2
+ pretrained_model_path: "/mnt/petrelfs/liangzhengyang.d/.cache/huggingface/hub/models--stabilityai--stable-video-diffusion-img2vid/snapshots/2586584918a955489b599d4dc76b6bb3fdb3fbb2"
3
+ unet_subfolder: "unet"
4
+ down_block_types: ['CrossAttnDownBlockSpatioTemporalPoseCond', 'CrossAttnDownBlockSpatioTemporalPoseCond', 'CrossAttnDownBlockSpatioTemporalPoseCond', 'DownBlockSpatioTemporal']
5
+ up_block_types: ['UpBlockSpatioTemporal', 'CrossAttnUpBlockSpatioTemporalPoseCond', 'CrossAttnUpBlockSpatioTemporalPoseCond', 'CrossAttnUpBlockSpatioTemporalPoseCond']
6
+
7
+ train_data:
8
+ root_path: "/mnt/petrelfs/share_data/hehao/datasets/RealEstate10k"
9
+ annotation_json: "annotations/train.json"
10
+ sample_stride: 8
11
+ sample_n_frames: 14
12
+ relative_pose: true
13
+ zero_t_first_frame: true
14
+ sample_size: [320, 576]
15
+ rescale_fxy: true
16
+ shuffle_frames: false
17
+ use_flip: false
18
+
19
+ validation_data:
20
+ root_path: "/mnt/petrelfs/share_data/hehao/datasets/RealEstate10k"
21
+ annotation_json: "annotations/validation.json"
22
+ sample_stride: 8
23
+ sample_n_frames: 14
24
+ relative_pose: true
25
+ zero_t_first_frame: true
26
+ sample_size: [320, 576]
27
+ rescale_fxy: true
28
+ shuffle_frames: false
29
+ use_flip: false
30
+ return_clip_name: true
31
+
32
+ random_null_image_ratio: 0.15
33
+
34
+ pose_encoder_kwargs:
35
+ downscale_factor: 8
36
+ channels: [320, 640, 1280, 1280]
37
+ nums_rb: 2
38
+ cin: 384
39
+ ksize: 1
40
+ sk: true
41
+ use_conv: false
42
+ compression_factor: 1
43
+ temporal_attention_nhead: 8
44
+ attention_block_types: ["Temporal_Self", ]
45
+ temporal_position_encoding: true
46
+ temporal_position_encoding_max_len: 14
47
+
48
+ attention_processor_kwargs:
49
+ add_spatial: false
50
+ add_temporal: true
51
+ attn_processor_name: 'attn1'
52
+ pose_feature_dimensions: [320, 640, 1280, 1280]
53
+ query_condition: true
54
+ key_value_condition: true
55
+ scale: 1.0
56
+
57
+ do_sanity_check: true
58
+ sample_before_training: false
59
+
60
+ max_train_epoch: -1
61
+ max_train_steps: 50000
62
+ validation_steps: 2500
63
+ validation_steps_tuple: [500, ]
64
+
65
+ learning_rate: 3.e-5
66
+
67
+ P_mean: 0.7
68
+ P_std: 1.6
69
+ condition_image_noise_mean: -3.0
70
+ condition_image_noise_std: 0.5
71
+ sample_latent: true
72
+ first_image_cond: true
73
+
74
+ num_inference_steps: 25
75
+ min_guidance_scale: 1.0
76
+ max_guidance_scale: 3.0
77
+
78
+ num_workers: 8
79
+ train_batch_size: 1
80
+ checkpointing_epochs: -1
81
+ checkpointing_steps: 10000
82
+
83
+ mixed_precision_training: false
84
+ enable_xformers_memory_efficient_attention: true
85
+
86
+ global_seed: 42
87
+ logger_interval: 10
configs/train_cameractrl/svdxt_320_576_cameractrl.yaml ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: "output/cameractrl_model"
2
+ pretrained_model_path: "/mnt/petrelfs/liangzhengyang.d/.cache/huggingface/hub/models--stabilityai--stable-video-diffusion-img2vid-xt/snapshots/4420c0886aad9930787308c62d9dd8befd4900f6"
3
+ unet_subfolder: "unet"
4
+ down_block_types: ['CrossAttnDownBlockSpatioTemporalPoseCond', 'CrossAttnDownBlockSpatioTemporalPoseCond', 'CrossAttnDownBlockSpatioTemporalPoseCond', 'DownBlockSpatioTemporal']
5
+ up_block_types: ['UpBlockSpatioTemporal', 'CrossAttnUpBlockSpatioTemporalPoseCond', 'CrossAttnUpBlockSpatioTemporalPoseCond', 'CrossAttnUpBlockSpatioTemporalPoseCond']
6
+
7
+ train_data:
8
+ root_path: "/mnt/petrelfs/share_data/hehao/datasets/RealEstate10k"
9
+ annotation_json: "annotations/train.json"
10
+ sample_stride: 5
11
+ sample_n_frames: 25
12
+ relative_pose: true
13
+ zero_t_first_frame: true
14
+ sample_size: [320, 576]
15
+ rescale_fxy: true
16
+ shuffle_frames: false
17
+ use_flip: false
18
+
19
+ validation_data:
20
+ root_path: "/mnt/petrelfs/share_data/hehao/datasets/RealEstate10k"
21
+ annotation_json: "annotations/validation.json"
22
+ sample_stride: 5
23
+ sample_n_frames: 25
24
+ relative_pose: true
25
+ zero_t_first_frame: true
26
+ sample_size: [320, 576]
27
+ rescale_fxy: true
28
+ shuffle_frames: false
29
+ use_flip: false
30
+ return_clip_name: true
31
+
32
+ random_null_image_ratio: 0.15
33
+
34
+ pose_encoder_kwargs:
35
+ downscale_factor: 8
36
+ channels: [320, 640, 1280, 1280]
37
+ nums_rb: 2
38
+ cin: 384
39
+ ksize: 1
40
+ sk: true
41
+ use_conv: false
42
+ compression_factor: 1
43
+ temporal_attention_nhead: 8
44
+ attention_block_types: ["Temporal_Self", ]
45
+ temporal_position_encoding: true
46
+ temporal_position_encoding_max_len: 25
47
+
48
+ attention_processor_kwargs:
49
+ add_spatial: false
50
+ add_temporal: true
51
+ attn_processor_name: 'attn1'
52
+ pose_feature_dimensions: [320, 640, 1280, 1280]
53
+ query_condition: true
54
+ key_value_condition: true
55
+ scale: 1.0
56
+
57
+ do_sanity_check: false
58
+ sample_before_training: false
59
+ video_length: 25
60
+
61
+ max_train_epoch: -1
62
+ max_train_steps: 50000
63
+ validation_steps: 2500
64
+ validation_steps_tuple: [1000, ]
65
+
66
+ learning_rate: 3.e-5
67
+
68
+ P_mean: 0.7
69
+ P_std: 1.6
70
+ condition_image_noise_mean: -3.0
71
+ condition_image_noise_std: 0.5
72
+ sample_latent: true
73
+ first_image_cond: true
74
+
75
+ num_inference_steps: 25
76
+ min_guidance_scale: 1.0
77
+ max_guidance_scale: 3.0
78
+
79
+ num_workers: 8
80
+ train_batch_size: 1
81
+ checkpointing_epochs: -1
82
+ checkpointing_steps: 10000
83
+
84
+ mixed_precision_training: false
85
+ enable_xformers_memory_efficient_attention: true
86
+
87
+ global_seed: 42
88
+ logger_interval: 10
inference_cameractrl.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import torch
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from omegaconf import OmegaConf
8
+ from PIL import Image
9
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10
+ from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler
11
+ from diffusers.utils.import_utils import is_xformers_available
12
+ from diffusers.models.attention_processor import AttnProcessor2_0
13
+ from packaging import version as pver
14
+
15
+ from cameractrl.pipelines.pipeline_animation import StableVideoDiffusionPipelinePoseCond
16
+ from cameractrl.models.unet import UNetSpatioTemporalConditionModelPoseCond
17
+ from cameractrl.models.pose_adaptor import CameraPoseEncoder
18
+ from cameractrl.utils.util import save_videos_grid
19
+
20
+
21
+ class Camera(object):
22
+ def __init__(self, entry):
23
+ fx, fy, cx, cy = entry[1:5]
24
+ self.fx = fx
25
+ self.fy = fy
26
+ self.cx = cx
27
+ self.cy = cy
28
+ w2c_mat = np.array(entry[7:]).reshape(3, 4)
29
+ w2c_mat_4x4 = np.eye(4)
30
+ w2c_mat_4x4[:3, :] = w2c_mat
31
+ self.w2c_mat = w2c_mat_4x4
32
+ self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
33
+
34
+
35
+ def setup_for_distributed(is_master):
36
+ """
37
+ This function disables printing when not in master process
38
+ """
39
+ import builtins as __builtin__
40
+ builtin_print = __builtin__.print
41
+
42
+ def print(*args, **kwargs):
43
+ force = kwargs.pop('force', False)
44
+ if is_master or force:
45
+ builtin_print(*args, **kwargs)
46
+
47
+ __builtin__.print = print
48
+
49
+
50
+ def custom_meshgrid(*args):
51
+ # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
52
+ if pver.parse(torch.__version__) < pver.parse('1.10'):
53
+ return torch.meshgrid(*args)
54
+ else:
55
+ return torch.meshgrid(*args, indexing='ij')
56
+
57
+
58
+ def get_relative_pose(cam_params, zero_first_frame_scale):
59
+ abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
60
+ abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
61
+ source_cam_c2w = abs_c2ws[0]
62
+ if zero_first_frame_scale:
63
+ cam_to_origin = 0
64
+ else:
65
+ cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3])
66
+ target_cam_c2w = np.array([
67
+ [1, 0, 0, 0],
68
+ [0, 1, 0, -cam_to_origin],
69
+ [0, 0, 1, 0],
70
+ [0, 0, 0, 1]
71
+ ])
72
+ abs2rel = target_cam_c2w @ abs_w2cs[0]
73
+ ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
74
+ ret_poses = np.array(ret_poses, dtype=np.float32)
75
+ return ret_poses
76
+
77
+
78
+ def ray_condition(K, c2w, H, W, device):
79
+ # c2w: B, V, 4, 4
80
+ # K: B, V, 4
81
+
82
+ B = K.shape[0]
83
+
84
+ j, i = custom_meshgrid(
85
+ torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
86
+ torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
87
+ )
88
+ i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
89
+ j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
90
+
91
+ fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
92
+
93
+ zs = torch.ones_like(i) # [B, HxW]
94
+ xs = (i - cx) / fx * zs
95
+ ys = (j - cy) / fy * zs
96
+ zs = zs.expand_as(ys)
97
+
98
+ directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
99
+ directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
100
+
101
+ rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
102
+ rays_o = c2w[..., :3, 3] # B, V, 3
103
+ rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
104
+ # c2w @ dirctions
105
+ rays_dxo = torch.linalg.cross(rays_o, rays_d)
106
+ plucker = torch.cat([rays_dxo, rays_d], dim=-1)
107
+ plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
108
+ return plucker
109
+
110
+
111
+ def get_pipeline(ori_model_path, unet_subfolder, down_block_types, up_block_types, pose_encoder_kwargs,
112
+ attention_processor_kwargs, pose_adaptor_ckpt, enable_xformers, device):
113
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(ori_model_path, subfolder="scheduler")
114
+ feature_extractor = CLIPImageProcessor.from_pretrained(ori_model_path, subfolder="feature_extractor")
115
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(ori_model_path, subfolder="image_encoder")
116
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(ori_model_path, subfolder="vae")
117
+ unet = UNetSpatioTemporalConditionModelPoseCond.from_pretrained(ori_model_path,
118
+ subfolder=unet_subfolder,
119
+ down_block_types=down_block_types,
120
+ up_block_types=up_block_types)
121
+ pose_encoder = CameraPoseEncoder(**pose_encoder_kwargs)
122
+ print("Setting the attention processors")
123
+ unet.set_pose_cond_attn_processor(enable_xformers=(enable_xformers and is_xformers_available()), **attention_processor_kwargs)
124
+ print(f"Loading weights of camera encoder and attention processor from {pose_adaptor_ckpt}")
125
+ ckpt_dict = torch.load(pose_adaptor_ckpt, map_location=unet.device)
126
+ pose_encoder_state_dict = ckpt_dict['pose_encoder_state_dict']
127
+ pose_encoder_m, pose_encoder_u = pose_encoder.load_state_dict(pose_encoder_state_dict)
128
+ assert len(pose_encoder_m) == 0 and len(pose_encoder_u) == 0
129
+ attention_processor_state_dict = ckpt_dict['attention_processor_state_dict']
130
+ _, attention_processor_u = unet.load_state_dict(attention_processor_state_dict, strict=False)
131
+ assert len(attention_processor_u) == 0
132
+ print("Loading done")
133
+ vae.set_attn_processor(AttnProcessor2_0())
134
+ vae.to(device)
135
+ image_encoder.to(device)
136
+ unet.to(device)
137
+ pipeline = StableVideoDiffusionPipelinePoseCond(
138
+ vae=vae,
139
+ image_encoder=image_encoder,
140
+ unet=unet,
141
+ scheduler=noise_scheduler,
142
+ feature_extractor=feature_extractor,
143
+ pose_encoder=pose_encoder
144
+ )
145
+ pipeline = pipeline.to(device)
146
+ return pipeline
147
+
148
+
149
+ def main(args):
150
+ os.makedirs(os.path.join(args.out_root, 'generated_videos'), exist_ok=True)
151
+ os.makedirs(os.path.join(args.out_root, 'reference_images'), exist_ok=True)
152
+ rank = args.local_rank
153
+ setup_for_distributed(rank == 0)
154
+ gpu_id = rank % torch.cuda.device_count()
155
+ model_configs = OmegaConf.load(args.model_config)
156
+ device = f"cuda:{gpu_id}"
157
+ print(f'Constructing pipeline')
158
+ pipeline = get_pipeline(args.ori_model_path, model_configs['unet_subfolder'], model_configs['down_block_types'],
159
+ model_configs['up_block_types'], model_configs['pose_encoder_kwargs'],
160
+ model_configs['attention_processor_kwargs'], args.pose_adaptor_ckpt, args.enable_xformers, device)
161
+ print('Done')
162
+
163
+ print('Loading K, R, t matrix')
164
+ with open(args.trajectory_file, 'r') as f:
165
+ poses = f.readlines()
166
+ poses = [pose.strip().split(' ') for pose in poses[1:]]
167
+ cam_params = [[float(x) for x in pose] for pose in poses]
168
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
169
+
170
+ sample_wh_ratio = args.image_width / args.image_height
171
+ pose_wh_ratio = args.original_pose_width / args.original_pose_height
172
+ if pose_wh_ratio > sample_wh_ratio:
173
+ resized_ori_w = args.image_height * pose_wh_ratio
174
+ for cam_param in cam_params:
175
+ cam_param.fx = resized_ori_w * cam_param.fx / args.image_width
176
+ else:
177
+ resized_ori_h = args.image_width / pose_wh_ratio
178
+ for cam_param in cam_params:
179
+ cam_param.fy = resized_ori_h * cam_param.fy / args.image_height
180
+ intrinsic = np.asarray([[cam_param.fx * args.image_width,
181
+ cam_param.fy * args.image_height,
182
+ cam_param.cx * args.image_width,
183
+ cam_param.cy * args.image_height]
184
+ for cam_param in cam_params], dtype=np.float32)
185
+ K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
186
+ c2ws = get_relative_pose(cam_params, zero_first_frame_scale=True)
187
+ c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
188
+ plucker_embedding = ray_condition(K, c2ws, args.image_height, args.image_width, device='cpu') # b f h w 6
189
+ plucker_embedding = plucker_embedding.permute(0, 1, 4, 2, 3).contiguous().to(device=device)
190
+
191
+ prompt_dict = json.load(open(args.prompt_file, 'r'))
192
+ prompt_images = prompt_dict['image_paths']
193
+ prompt_captions = prompt_dict['captions']
194
+ N = int(len(prompt_images) // args.n_procs)
195
+ remainder = int(len(prompt_images) % args.n_procs)
196
+ prompts_per_gpu = [N + 1 if gpu_id < remainder else N for gpu_id in range(args.n_procs)]
197
+ low_idx = sum(prompts_per_gpu[:gpu_id])
198
+ high_idx = low_idx + prompts_per_gpu[gpu_id]
199
+ prompt_images = prompt_images[low_idx: high_idx]
200
+ prompt_captions = prompt_captions[low_idx: high_idx]
201
+ print(f"rank {rank} / {torch.cuda.device_count()}, number of prompts: {len(prompt_images)}")
202
+
203
+ generator = torch.Generator(device=device)
204
+ generator.manual_seed(42)
205
+
206
+ for prompt_image, prompt_caption in tqdm(zip(prompt_images, prompt_captions)):
207
+ save_name = "_".join(prompt_caption.split(" "))
208
+ condition_image = Image.open(prompt_image)
209
+ with torch.no_grad():
210
+ sample = pipeline(
211
+ image=condition_image,
212
+ pose_embedding=plucker_embedding,
213
+ height=args.image_height,
214
+ width=args.image_width,
215
+ num_frames=args.num_frames,
216
+ num_inference_steps=args.num_inference_steps,
217
+ min_guidance_scale=args.min_guidance_scale,
218
+ max_guidance_scale=args.max_guidance_scale,
219
+ do_image_process=True,
220
+ generator=generator,
221
+ output_type='pt'
222
+ ).frames[0].transpose(0, 1).cpu() # [3, f, h, w] 0-1
223
+ resized_condition_image = condition_image.resize((args.image_width, args.image_height))
224
+ save_videos_grid(sample[None], f"{os.path.join(args.out_root, 'generated_videos')}/{save_name}.mp4", rescale=False)
225
+ resized_condition_image.save(os.path.join(args.out_root, 'reference_images', f'{save_name}.png'))
226
+
227
+
228
+ if __name__ == '__main__':
229
+ parser = argparse.ArgumentParser()
230
+ parser.add_argument("--out_root", type=str)
231
+ parser.add_argument("--image_height", type=int, default=320)
232
+ parser.add_argument("--image_width", type=int, default=576)
233
+ parser.add_argument("--num_frames", type=int, default=14)
234
+ parser.add_argument("--ori_model_path", type=str)
235
+ parser.add_argument("--unet_subfolder", type=str, default='unet')
236
+ parser.add_argument("--enable_xformers", action='store_true')
237
+ parser.add_argument("--pose_adaptor_ckpt", default=None)
238
+ parser.add_argument("--num_inference_steps", type=int, default=25)
239
+ parser.add_argument("--min_guidance_scale", type=float, default=1.0)
240
+ parser.add_argument("--max_guidance_scale", type=float, default=3.0)
241
+ parser.add_argument("--prompt_file", required=True, help='prompts path, json or txt')
242
+ parser.add_argument("--trajectory_file", required=True)
243
+ parser.add_argument("--original_pose_width", type=int, default=1280)
244
+ parser.add_argument("--original_pose_height", type=int, default=720)
245
+ parser.add_argument("--model_config", required=True)
246
+ parser.add_argument("--n_procs", type=int, default=8)
247
+
248
+ # DDP args
249
+ parser.add_argument("--world_size", default=1, type=int,
250
+ help="number of the distributed processes.")
251
+ parser.add_argument('--local-rank', type=int, default=-1,
252
+ help='Replica rank on the current node. This field is required '
253
+ 'by `torch.distributed.launch`.')
254
+ args = parser.parse_args()
255
+ main(args)
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+ torch
3
+ torchvision
4
+ diffusers==0.24.0
5
+ imageio==2.27.0
6
+ transformers==4.39.3
7
+ gradio==4.26.0
8
+ imageio==2.27.0
9
+ imageio-ffmpeg==0.4.9
10
+ accelerate==0.30.0
11
+ opencv-python
12
+ gdown
13
+ einops
14
+ decord
15
+ omegaconf
16
+ safetensors
17
+ gradio
18
+ wandb
19
+ triton
20
+ termcolor