marselgames9 commited on
Commit
a3806ed
1 Parent(s): 60f46cc

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ venv
3
+ .idea
4
+ __pycache__
5
+ build
6
+ *.egg-info
CITATION.cff ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This CITATION.cff file was generated with cffinit.
2
+ # Visit https://bit.ly/cffinit to generate yours today!
3
+
4
+ cff-version: 1.2.0
5
+ title: Hotshot-XL
6
+ message: Personalized GIF Generation with Diffusion Models
7
+ type: software
8
+ authors:
9
+ - given-names: John
10
+ family-names: Mullan
11
12
+ affiliation: 'Natural Synthetics, Inc.'
13
+ - given-names: Duncan
14
+ family-names: Crawbuck
15
16
+ affiliation: 'Natural Synthetics, Inc.'
17
+ - given-names: Aakash
18
+ family-names: Sastry
19
20
+ affiliation: 'Natural Synthetics, Inc.'
21
+ identifiers:
22
+ - type: url
23
+ value: 'https://hotshot.co'
24
+ description: Hotshot Website
25
+ repository-code: 'https://github.com/hotshotco/hotshot-xl'
26
+ url: 'https://hotshot.co'
27
+ repository-artifact: 'https://huggingface.co/hotshotco/Hotshot-XL'
28
+ abstract: >-
29
+ Hotshot-XL is an AI text-to-GIF model trained to work
30
+ alongside Stable Diffusion XL. Hotshot-XL can generate
31
+ GIFs with any fine-tuned SDXL model.
32
+
33
+
34
+ Hotshot-XL is able to make GIFs with any existing or newly
35
+ fine-tuned SDXL model you may want to use. If you'd like
36
+ to make GIFs of personalized subjects, you can load your
37
+ own SDXL based LORAs, and not have to worry about
38
+ fine-tuning Hotshot-XL. This is awesome because it’s
39
+ usually much easier to find suitable images for training
40
+ data than it is to find videos.
41
+
42
+
43
+ Hotshot-XL is compatible with SDXL ControlNet to make GIFs
44
+ in the composition/layout you’d like.
45
+
46
+
47
+ Hotshot-XL was trained to generate 1 second GIFs at 8 FPS.
48
+
49
+
50
+ Hotshot-XL was trained on various aspect ratios. To
51
+ achieve more efficient training + inference, we fine tuned
52
+ SDXL at/around 512 resolution prior to training
53
+ Hotshot-XL. We also publish our fine tuned SDXL spatial
54
+ model for use among the research community.
55
+ keywords:
56
+ - ai
57
+ - text-to-video
58
+ - sdxl
59
+ - text-to-video-generation
60
+ - text-to-gif
61
+ - hotshot-xl
62
+ - hotshot
63
+ license: Apache-2.0
64
+ commit: 16f99c4e8cbf8cebd038a282173767d609836889
65
+ version: 1.0.0
66
+ date-released: '2023-10-03'
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,270 @@
1
- ---
2
- title: Marselgames9 Gif135animation
3
- emoji: 📊
4
- colorFrom: pink
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 4.36.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center"><img src="https://i.imgur.com/HsWXQTW.png" width="24px" alt="logo" /> Hotshot-XL</h1>
2
+
3
+ <h1 align="center">
4
+ <a href="https://www.hotshot.co">🌐 Try it</a>
5
+ &nbsp;
6
+ <a href="https://huggingface.co/hotshotco/Hotshot-XL">🃏 Model card</a>
7
+ &nbsp;
8
+ <a href="https://discord.gg/2FjCRRxHCz">💬 Discord</a>
9
+ </h1>
10
+
11
+ <p align="center">
12
+ <img src="https://dvfx9cgvtgnyd.cloudfront.net/hotshot/image-gen/gif_e8a50e1e-0b2e-4ebc-8229-817703585405.gif" alt="a barbie doll smiling in kitchen, oven on fire, disaster, pink wes anderson vibes, cinematic" width="195px" height="111.42px"/>
13
+ &nbsp;
14
+ <img src="https://dvfx9cgvtgnyd.cloudfront.net/hotshot/image-gen/gif_f6ca56a3-30b8-4b2a-9342-111353e85b96.gif" alt="a teddy bear writing a letter" width="195px" height="111.42px"/>
15
+ &nbsp;
16
+ <img src="https://dvfx9cgvtgnyd.cloudfront.net/hotshot/image-gen/gif_6c219102-7f72-45e9-b4fa-b7a07c004ae1.gif" alt="dslr photo of mark zuckerberg happy, pulling on threads, lots of threads everywhere, laughing, hd, 8k" width="195px" height="111.42px"/>
17
+ &nbsp;
18
+ <img src="https://dvfx9cgvtgnyd.cloudfront.net/hotshot/image-gen/gif_2dd3c30f-42c5-4f37-8fa6-b2494fcac4b4.gif" alt="a cat laughing" width="195px" height="111.42px"/>
19
+ &nbsp;
20
+ </p>
21
+
22
+ Hotshot-XL is an AI text-to-GIF model trained to work alongside [Stable Diffusion XL](https://stability.ai/stable-diffusion).
23
+
24
+ Hotshot-XL can generate GIFs with any fine-tuned SDXL model. This means two things:
25
+ 1. You’ll be able to make GIFs with any existing or newly fine-tuned SDXL model you may want to use.
26
+ 2. If you'd like to make GIFs of personalized subjects, you can load your own SDXL based LORAs, and not have to worry about fine-tuning Hotshot-XL. This is awesome because it’s usually much easier to find suitable images for training data than it is to find videos. It also hopefully fits into everyone's existing LORA usage/workflows :) See more [here](#text-to-gif-with-personalized-loras).
27
+
28
+ Hotshot-XL is compatible with SDXL ControlNet to make GIFs in the composition/layout you’d like. See the [ControlNet](#text-to-gif-with-controlnet) section below.
29
+
30
+ Hotshot-XL was trained to generate 1 second GIFs at 8 FPS.
31
+
32
+ Hotshot-XL was trained on various aspect ratios. For best results with the base Hotshot-XL model, we recommend using it with an SDXL model that has been fine-tuned with 512x512 images. You can find an SDXL model we fine-tuned for 512x512 resolutions [here](https://huggingface.co/hotshotco/SDXL-512).
33
+
34
+ # 🌐 Try It
35
+
36
+ Try Hotshot-XL yourself here: https://www.hotshot.co
37
+
38
+ Or, if you'd like to run Hotshot-XL yourself locally, continue on to the sections below.
39
+
40
+ If you’re running Hotshot-XL yourself, you are going to be able to have a lot more flexibility/control with the model. As a very simple example, you’ll be able to change the sampler. We’ve seen best results with Euler-A so far, but you may find interesting results with some other ones.
41
+
42
+ # 🔧 Setup
43
+
44
+ ### Environment Setup
45
+ ```
46
+ pip install virtualenv --upgrade
47
+ virtualenv -p $(which python3) venv
48
+ source venv/bin/activate
49
+ pip install -r requirements.txt
50
+ ```
51
+
52
+ ### Download the Hotshot-XL Weights
53
+
54
+ ```
55
+ # Make sure you have git-lfs installed (https://git-lfs.com)
56
+ git lfs install
57
+ git clone https://huggingface.co/hotshotco/Hotshot-XL
58
+ ```
59
+
60
+ or visit [https://huggingface.co/hotshotco/Hotshot-XL](https://huggingface.co/hotshotco/Hotshot-XL)
61
+
62
+ ### Download our fine-tuned SDXL model (or BYOSDXL)
63
+
64
+ - *Note*: To maximize data and training efficiency, Hotshot-XL was trained at various aspect ratios around 512x512 resolution. For best results with the base Hotshot-XL model, we recommend using it with an SDXL model that has been fine-tuned with images around the 512x512 resolution. You can download an SDXL model we trained with images at 512x512 resolution below, or bring your own SDXL base model.
65
+
66
+ ```
67
+ # Make sure you have git-lfs installed (https://git-lfs.com)
68
+ git lfs install
69
+ git clone https://huggingface.co/hotshotco/SDXL-512
70
+ ```
71
+
72
+ or visit [https://huggingface.co/hotshotco/SDXL-512](https://huggingface.co/hotshotco/SDXL-512)
73
+
74
+ # 🔮 Inference
75
+
76
+ ### Text-to-GIF
77
+ ```
78
+ python inference.py \
79
+ --prompt="a bulldog in the captains chair of a spaceship, hd, high quality" \
80
+ --output="output.gif"
81
+ ```
82
+
83
+ *What to Expect:*
84
+ | **Prompt** | Sasquatch scuba diving | a camel smoking a cigarette | Ronald McDonald sitting at a vanity mirror putting on lipstick | drake licking his lips and staring through a window at a cupcake |
85
+ |-----------|----------|----------|----------|----------|
86
+ | **Output** | <img src="https://dvfx9cgvtgnyd.cloudfront.net/hotshot/image-gen/gif_441b7ea2-9887-4124-a52b-14c9db1d15aa.gif" /> | <img src="https://dvfx9cgvtgnyd.cloudfront.net/hotshot/image-gen/gif_7956a022-0464-4441-88b8-15a6de953335.gif"/> | <img src="https://dvfx9cgvtgnyd.cloudfront.net/hotshot/image-gen/gif_35f55a64-7ed9-498e-894e-6ec7a8026fba.gif"/> | <img src="https://dvfx9cgvtgnyd.cloudfront.net/hotshot/image-gen/gif_df5f52cb-d74d-40b5-a066-2ce567dae512.gif"/> |
87
+
88
+ ### Text-to-GIF with personalized LORAs
89
+
90
+ ```
91
+ python inference.py \
92
+ --prompt="a bulldog in the captains chair of a spaceship, hd, high quality" \
93
+ --output="output.gif" \
94
+ --spatial_unet_base="path/to/stabilityai/stable-diffusion-xl-base-1.0/unet" \
95
+ --lora="path/to/lora"
96
+ ```
97
+
98
+ *What to Expect:*
99
+
100
+ *Note*: The outputs below use the DDIMScheduler.
101
+
102
+ | **Prompt** | sks person screaming at a capri sun | sks person kissing kermit the frog | sks person wearing a tuxedo holding up a glass of champagne, fireworks in background, hd, high quality, 4K |
103
+ |-----------|----------|----------|----------|
104
+ | **Output** | <img src="https://dvfx9cgvtgnyd.cloudfront.net/hotshot/inf-temp/79a20eae-ffeb-4d24-8d22-609fa77c292f.gif" /> | <img src="https://dvfx9cgvtgnyd.cloudfront.net/hotshot/r/aakash.gif" /> | <img src="https://dvfx9cgvtgnyd.cloudfront.net/hotshot/inf-temp/4fa34a16-2835-4a12-8c59-348caa4f3891.gif" /> |
105
+
106
+ ### Text-to-GIF with ControlNet
107
+ ```
108
+ python inference.py \
109
+ --prompt="a girl jumping up and down and pumping her fist, hd, high quality" \
110
+ --output="output.gif" \
111
+ --control_type="depth" \
112
+ --gif="https://media1.giphy.com/media/v1.Y2lkPTc5MGI3NjExbXNneXJicG1mOHJ2dzQ2Y2JteDY1ZWlrdjNjMjl3ZWxyeWFxY2EzdyZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/YOTAoXBgMCmFeQQzuZ/giphy.gif"
113
+ ```
114
+
115
+ By default, Hotshot-XL will create key frames from your source gif using 8 equally spaced frames and crop the keyframes to the default aspect ratio. For finer grained control, learn how to [vary aspect ratios](#varying-aspect-ratios) and [vary frame rates/lengths](#varying-frame-rates--lengths-experimental).
116
+
117
+ Hotshot-XL currently supports the use of one ControlNet model at a time; supporting Multi-ControlNet would be [exciting](#-further-work).
118
+
119
+ *What to Expect:*
120
+ | **Prompt** | pixar style girl putting two thumbs up, happy, high quality, 8k, 3d, animated disney render | keanu reaves holding a sign that says "HELP", hd, high quality | a woman laughing, hd, high quality | barack obama making a rainbow with their hands, the word "MAGIC" in front of them, wearing a blue and white striped hoodie, hd, high quality |
121
+ |-----------|----------|----------|----------|----------|
122
+ | **Output** | <img src="https://dvfx9cgvtgnyd.cloudfront.net/hotshot/inf-temp/387d8b68-7289-45e3-9b21-1a9e6ad8a782.gif"/> | <img src="https://dvfx9cgvtgnyd.cloudfront.net/hotshot%2Finf-temp/047543b2-d499-4de8-8fd2-3712c3a6c446.gif"/> | <img src="https://dvfx9cgvtgnyd.cloudfront.net/hotshot/inf-temp/8f50f4d8-4b86-4df7-a643-aae3e9d8634d.gif"> | <img src="https://dvfx9cgvtgnyd.cloudfront.net/hotshot/inf-temp/c133d8b7-46ad-4469-84fd-b7f7444a47a0.gif"/> |
123
+ | **Control** |<img src="https://media1.giphy.com/media/3o6Zt8qDiPE2d3kayI/giphy.gif?cid=ecf05e47igskj73xpl62pv8kyk9m39brlualxcz1j68vk8ul&ep=v1_gifs_related&rid=giphy.gif&ct=g"/> | <img src="https://media2.giphy.com/media/IoXVrbzUIuvTy/giphy.gif?cid=ecf05e47ill5r35i1bhxk0tr7quqbpruqivjtuy7gcgkfmx5&ep=v1_gifs_search&rid=giphy.gif&ct=g"/> | <img src="https://media0.giphy.com/media/12msOFU8oL1eww/giphy.gif"> | <img src="https://media4.giphy.com/media/3o84U6421OOWegpQhq/giphy.gif?cid=ecf05e47eufup08cz2up9fn9bitkgltb88ez37829mxz43cc&ep=v1_gifs_related&rid=giphy.gif&ct=g"/> |
124
+
125
+ ### Varying Aspect Ratios
126
+
127
+ - *Note*: The base SDXL model is trained to best create images around 1024x1024 resolution. To maximize data and training efficiency, Hotshot-XL was trained at aspect ratios around 512x512 resolution. Please see [Additional Notes](#supported-aspect-ratios) for a list of aspect ratios the base Hotshot-XL model was trained with.
128
+
129
+ Like SDXL, Hotshot-XL was trained at various aspect ratios with aspect ratio bucketing, and includes support for SDXL parameters like target-size and original-size. This means you can create GIFs at several different aspect ratios and resolutions, just with the base Hotshot-XL model.
130
+
131
+ ```
132
+ python inference.py \
133
+ --prompt="a bulldog in the captains chair of a spaceship, hd, high quality" \
134
+ --output="output.gif" \
135
+ --width=<WIDTH> \
136
+ --height=<HEIGHT>
137
+ ```
138
+
139
+ *What to Expect:*
140
+ | | 512x512 | 672x384 | 384x672 |
141
+ |-----------|----------|----------|----------|
142
+ | **a monkey playing guitar, nature footage, hd, high quality** | <img src="https://dvfx9cgvtgnyd.cloudfront.net/hotshot/inf-temp/2295c6af-c345-47a4-8afe-62e77f84141b.gif"/> | <img src="https://dvfx9cgvtgnyd.cloudfront.net/hotshot/inf-temp/909a86c5-60df-459a-b662-ce4e85706303.gif"/> | <img src="https://dvfx9cgvtgnyd.cloudfront.net/hotshot/inf-temp/8512854d-66ea-41ff-919e-6e36d6e6a541.gif"> |
143
+
144
+ ### Varying frame rates & lengths (*Experimental*)
145
+ By default, Hotshot-XL is trained to generate GIFs that are 1 second long with 8FPS. If you'd like to play with generating GIFs with varying frame rates and time lengths, you can try out the parameters `video_length` and `video_duration`.
146
+
147
+ `video_length` sets the number of frames. The default value is 8.
148
+
149
+ `video_duration` sets the runtime of the output gif in milliseconds. The default value is 1000.
150
+
151
+ Please note that you should expect unstable/"jittery" results when modifying these parameters as the model was only trained with 1s videos @ 8fps. You'll be able to improve the stability of results for different time lengths and frame rates by [fine-tuning Hotshot-XL](#-fine-tuning). Please let us know if you do!
152
+
153
+ ```
154
+ python inference.py \
155
+ --prompt="a bulldog in the captains chair of a spaceship, hd, high quality" \
156
+ --output="output.gif" \
157
+ --video_length=16 \
158
+ --video_duration=2000
159
+ ```
160
+
161
+ ### Spatial Layers Only
162
+ Hotshot-XL is trained to generate GIFs alongside SDXL. If you'd like to generate just an image, you can simply set `video_length=1` in your inference call and the Hotshot-XL temporal layers will be ignored, as you'd expect.
163
+
164
+ ```
165
+ python inference.py \
166
+ --prompt="a bulldog in the captains chair of a spaceship, hd, high quality" \
167
+ --output="output.jpg" \
168
+ --video_length=1
169
+ ```
170
+
171
+ ### Additional Notes
172
+
173
+ #### Supported Aspect Ratios
174
+ Hotshot-XL was trained at the following aspect ratios; to reliably generate GIFs outside the range of these aspect ratios, you will want to fine-tune Hotshot-XL with videos at the resolution of your desired aspect ratio.
175
+
176
+ | Aspect Ratio | Size |
177
+ |--------------|------|
178
+ | 0.42 |320 x 768|
179
+ | 0.57 |384 x 672|
180
+ | 0.68 |416 x 608|
181
+ | 1.00 |512 x 512|
182
+ | 1.46 |608 x 416|
183
+ | 1.75 |672 x 384|
184
+ | 2.40 |768 x 320|
185
+
186
+
187
+ # 💪 Fine-Tuning
188
+ The following section relates to fine-tuning the Hotshot-XL temporal model with additional text/video pairs. If you're trying to generate GIFs of personalized concepts/subjects, we'd recommend not fine-tuning Hotshot-XL, but instead training your own SDXL based LORAs and [just loading those](#text-to-gif-with-personalized-loras).
189
+
190
+ ### Fine-Tuning Hotshot-XL
191
+
192
+ #### Dataset Preparation
193
+
194
+ The `fine_tune.py` script expects your samples to be structured like this:
195
+
196
+ ```
197
+ fine_tune_dataset
198
+ ├── sample_001
199
+ │ ├── 0.jpg
200
+ │ ├── 1.jpg
201
+ │ ├── 2.jpg
202
+ ...
203
+ ...
204
+ │ ├── n.jpg
205
+ │ └── prompt.txt
206
+ ```
207
+
208
+ Each sample directory should contain your **n key frames** and a `prompt.txt` file which contains the prompt.
209
+ The final checkpoint will be saved to `output_dir`.
210
+ We've found it useful to send validation GIFs to [Weights & Biases](www.wandb.ai) every so often. If you choose to use validation with Weights & Biases, you can set how often this runs with the `validate_every_steps` parameter.
211
+
212
+ ```
213
+ accelerate launch fine_tune.py \
214
+ --output_dir="<OUTPUT_DIR>" \
215
+ --data_dir="fine_tune_dataset" \
216
+ --report_to="wandb" \
217
+ --run_validation_at_start \
218
+ --resolution=512 \
219
+ --mixed_precision=fp16 \
220
+ --train_batch_size=4 \
221
+ --learning_rate=1.25e-05 \
222
+ --lr_scheduler="constant" \
223
+ --lr_warmup_steps=0 \
224
+ --max_train_steps=1000 \
225
+ --save_n_steps=20 \
226
+ --validate_every_steps=50 \
227
+ --vae_b16 \
228
+ --gradient_checkpointing \
229
+ --noise_offset=0.05 \
230
+ --snr_gamma \
231
+ --test_prompts="man sits at a table in a cafe, he greets another man with a smile and a handshakes"
232
+ ```
233
+
234
+ # 📝 Further work
235
+ There are lots of ways we are excited about improving Hotshot-XL. For example:
236
+
237
+ - [ ] Fine-Tuning Hotshot-XL at larger frame rates to create longer/higher frame-rate GIFs
238
+ - [ ] Fine-Tuning Hotshot-XL at larger resolutions to create higher resolution GIFs
239
+ - [ ] Training temporal layers for a latent upscaler to produce higher resolution GIFs
240
+ - [ ] Training an image conditioned "frame prediction" model for more coherent, longer GIFs
241
+ - [ ] Training temporal layers for a VAE to mitigate flickering/dithering in outputs
242
+ - [ ] Supporting Multi-ControlNet for greater control over GIF generation
243
+ - [ ] Training & integrating different ControlNet models for further control over GIF generation (finer facial expression control would be very cool)
244
+ - [ ] Moving Hotshot-XL into [AITemplate](https://github.com/facebookincubator/AITemplate) for faster inference times
245
+
246
+ We 💗 contributions from the open-source community! Please let us know in the issues or PRs if you're interested in working on these improvements or anything else!
247
+
248
+ # 📚 BibTeX
249
+ ```
250
+ @software{Mullan_Hotshot-XL_2023,
251
+ author = {Mullan, John and Crawbuck, Duncan and Sastry, Aakash},
252
+ license = {Apache-2.0},
253
+ month = oct,
254
+ title = {{Hotshot-XL}},
255
+ url = {https://github.com/hotshotco/hotshot-xl},
256
+ version = {1.0.0},
257
+ year = {2023}
258
+ }
259
+ ```
260
+
261
+ # 🙏 Acknowledgements
262
+ Text-to-Video models are improving quickly and the development of Hotshot-XL has been greatly inspired by the following amazing works and teams:
263
+
264
+ - [SDXL](https://stability.ai/stable-diffusion)
265
+ - [Align Your Latents](https://research.nvidia.com/labs/toronto-ai/VideoLDM/)
266
+ - [Make-A-Video](https://makeavideo.studio/)
267
+ - [AnimateDiff](https://animatediff.github.io/)
268
+ - [Imagen Video](https://imagen.research.google/video/)
269
+
270
+ We hope that releasing this model/codebase helps the community to continue pushing these creative tools forward in an open and responsible way.
docker/Dockerfile ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
2
+
3
+ COPY requirements.txt .
4
+
5
+ RUN pip install --no-cache-dir -r requirements.txt
docker/Readme.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Setup
2
+
3
+ This docker file is for the **environment only**. This is to keep the docker image as small as possible!
4
+
5
+ ## Quickstart
6
+
7
+ Hotshot have their own docker image you can use directly:
8
+ ```
9
+ docker pull hotshotapp/hotshot-xl-env:latest
10
+ ```
11
+
12
+ Or you can build it yourself
13
+
14
+ ```
15
+ cd docker
16
+ docker build -t hotshotapp/hotshot-xl-env:latest .
17
+ ```
18
+
19
+ ## Running the docker image
20
+
21
+ We recommend storing the weights locally on your machine. That way the weights persist if you kill the container!
22
+
23
+ - Install the models to a folder locally (Optional)
24
+ ```
25
+ cd /path/to/models
26
+ git lfs install
27
+ git clone https://huggingface.co/hotshotco/Hotshot-XL
28
+ ```
29
+ - Run the docker from the project root
30
+ - **Linux**
31
+ ```
32
+ docker run -it --gpus=all --rm -v $(pwd):/local -v /path/to/models:/models hotshotapp/hotshot-xl-env:latest
33
+ ```
34
+ - **Windows (Powershell)**
35
+ ```
36
+ docker run -it --gpus=all --rm -v ${PWD}:/local -v C:\path\to\models:/models hotshotapp/hotshot-xl-env:latest
37
+ ```
38
+
39
+ If you want to download the models from within the container itself then you do not need to map the volumes and ` -v /path/to/models:/models` can be removed.
40
+
41
+ **Note**: Ensure you have NVIDIA Docker runtime installed if you want to utilize GPU support with `--gpus=all`.
docker/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.23.0
2
+ einops==0.7.0
3
+ diffusers==0.21.4
4
+ transformers==4.34.0
5
+ wandb==0.15.11
6
+ moviepy==1.0.3
7
+ imageio==2.31.5
8
+ xformers==0.0.22
fine_tune.py ADDED
@@ -0,0 +1,987 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Natural Synthetics Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import math
17
+ import os
18
+ import traceback
19
+ from pathlib import Path
20
+ import time
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ import torch.multiprocessing as mp
24
+ from accelerate import Accelerator
25
+ from accelerate.logging import get_logger
26
+ from accelerate.utils import set_seed
27
+ from diffusers import AutoencoderKL
28
+ from diffusers.optimization import get_scheduler
29
+ from diffusers import DDPMScheduler
30
+ from torchvision import transforms
31
+ from tqdm.auto import tqdm
32
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
33
+ import torch.nn.functional as F
34
+ import gc
35
+ from typing import Callable
36
+ from PIL import Image
37
+ import numpy as np
38
+ from concurrent.futures import ThreadPoolExecutor
39
+ from hotshot_xl.models.unet import UNet3DConditionModel
40
+ from hotshot_xl.pipelines.hotshot_xl_pipeline import HotshotXLPipeline
41
+ from hotshot_xl.utils import get_crop_coordinates, res_to_aspect_map, scale_aspect_fill
42
+ from einops import rearrange
43
+ from torch.utils.data import Dataset, DataLoader
44
+ from datetime import timedelta
45
+ from accelerate.utils.dataclasses import InitProcessGroupKwargs
46
+ from diffusers.utils import is_wandb_available
47
+
48
+ if is_wandb_available():
49
+ import wandb
50
+
51
+ logger = get_logger(__file__)
52
+
53
+
54
+ class HotshotXLDataset(Dataset):
55
+
56
+ def __init__(self, directory: str, make_sample_fn: Callable):
57
+ """
58
+
59
+ Training data folder needs to look like:
60
+ + training_samples
61
+ --- + sample_001
62
+ ------- + frame_0.jpg
63
+ ------- + frame_1.jpg
64
+ ------- + ...
65
+ ------- + frame_n.jpg
66
+ ------- + prompt.txt
67
+ --- + sample_002
68
+ ------- + frame_0.jpg
69
+ ------- + frame_1.jpg
70
+ ------- + ...
71
+ ------- + frame_n.jpg
72
+ ------- + prompt.txt
73
+
74
+ Args:
75
+ directory: base directory of the training samples
76
+ make_sample_fn: a delegate call to load the images and prep the sample for batching
77
+ """
78
+ samples_dir = [os.path.join(directory, p) for p in os.listdir(directory)]
79
+ samples_dir = [p for p in samples_dir if os.path.isdir(p)]
80
+ samples = []
81
+
82
+ for d in samples_dir:
83
+ file_paths = [os.path.join(d, p) for p in os.listdir(d)]
84
+ image_fps = [f for f in file_paths if os.path.splitext(f)[1] in {".png", ".jpg"}]
85
+ with open(os.path.join(d, "prompt.txt")) as f:
86
+ prompt = f.read().strip()
87
+
88
+ samples.append({
89
+ "image_fps": image_fps,
90
+ "prompt": prompt
91
+ })
92
+
93
+ self.samples = samples
94
+ self.length = len(samples)
95
+ self.make_sample_fn = make_sample_fn
96
+
97
+ def __len__(self):
98
+ return self.length
99
+
100
+ def __getitem__(self, index):
101
+ return self.make_sample_fn(
102
+ self.samples[index]
103
+ )
104
+
105
+
106
+ def parse_args():
107
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
108
+ parser.add_argument(
109
+ "--pretrained_model_name_or_path",
110
+ type=str,
111
+ default="hotshotco/Hotshot-XL",
112
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
113
+ )
114
+ parser.add_argument(
115
+ "--unet_resume_path",
116
+ type=str,
117
+ default=None,
118
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
119
+ )
120
+
121
+ parser.add_argument(
122
+ "--data_dir",
123
+ type=str,
124
+ required=True,
125
+ help="Path to data to train.",
126
+ )
127
+
128
+ parser.add_argument(
129
+ "--report_to",
130
+ type=str,
131
+ default="wandb",
132
+ help=(
133
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
134
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
135
+ ),
136
+ )
137
+
138
+ parser.add_argument("--run_validation_at_start", action="store_true")
139
+ parser.add_argument("--max_vae_encode", type=int, default=None)
140
+ parser.add_argument("--vae_b16", action="store_true")
141
+ parser.add_argument("--disable_optimizer_restore", action="store_true")
142
+
143
+ parser.add_argument(
144
+ "--latent_nan_checking",
145
+ action="store_true",
146
+ help="Check if latents contain nans - important if vae is f16",
147
+ )
148
+ parser.add_argument(
149
+ "--test_prompts",
150
+ type=str,
151
+ default=None,
152
+ )
153
+ parser.add_argument(
154
+ "--project_name",
155
+ type=str,
156
+ default="fine-tune-hotshot-xl",
157
+ help="the name of the run",
158
+ )
159
+ parser.add_argument(
160
+ "--run_name",
161
+ type=str,
162
+ default="run-01",
163
+ help="the name of the run",
164
+ )
165
+ parser.add_argument(
166
+ "--output_dir",
167
+ type=str,
168
+ default="output",
169
+ help="The output directory where the model predictions and checkpoints will be written.",
170
+ )
171
+ parser.add_argument("--noise_offset", type=float, default=0.05, help="The scale of noise offset.")
172
+ parser.add_argument("--seed", type=int, default=111, help="A seed for reproducible training.")
173
+ parser.add_argument(
174
+ "--resolution",
175
+ type=int,
176
+ default=512,
177
+ help=(
178
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
179
+ " resolution"
180
+ ),
181
+ )
182
+ parser.add_argument(
183
+ "--aspect_ratio",
184
+ type=str,
185
+ default="1.75",
186
+ choices=list(res_to_aspect_map[512].keys()),
187
+ help="Aspect ratio to train at",
188
+ )
189
+
190
+ parser.add_argument("--xformers", action="store_true")
191
+
192
+ parser.add_argument(
193
+ "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
194
+ )
195
+
196
+ parser.add_argument("--num_train_epochs", type=int, default=1)
197
+
198
+ parser.add_argument(
199
+ "--max_train_steps",
200
+ type=int,
201
+ default=9999999,
202
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
203
+ )
204
+ parser.add_argument(
205
+ "--gradient_accumulation_steps",
206
+ type=int,
207
+ default=1,
208
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
209
+ )
210
+ parser.add_argument(
211
+ "--gradient_checkpointing",
212
+ action="store_true",
213
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
214
+ )
215
+
216
+ parser.add_argument(
217
+ "--learning_rate",
218
+ type=float,
219
+ default=5e-6,
220
+ help="Initial learning rate (after the potential warmup period) to use.",
221
+ )
222
+
223
+ parser.add_argument(
224
+ "--scale_lr",
225
+ action="store_true",
226
+ default=False,
227
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
228
+ )
229
+ parser.add_argument(
230
+ "--lr_scheduler",
231
+ type=str,
232
+ default="constant",
233
+ help=(
234
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
235
+ ' "constant", "constant_with_warmup"]'
236
+ ),
237
+ )
238
+ parser.add_argument(
239
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
240
+ )
241
+ parser.add_argument(
242
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
243
+ )
244
+
245
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
246
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
247
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
248
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
249
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
250
+
251
+ parser.add_argument(
252
+ "--logging_dir",
253
+ type=str,
254
+ default="logs",
255
+ help=(
256
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
257
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
258
+ ),
259
+ )
260
+
261
+ parser.add_argument(
262
+ "--mixed_precision",
263
+ type=str,
264
+ default="no",
265
+ choices=["no", "fp16", "bf16"],
266
+ help=(
267
+ "Whether to use mixed precision. Choose"
268
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
269
+ "and an Nvidia Ampere GPU."
270
+ ),
271
+ )
272
+
273
+ parser.add_argument(
274
+ "--validate_every_steps",
275
+ type=int,
276
+ default=100,
277
+ help="Run inference every",
278
+ )
279
+
280
+ parser.add_argument(
281
+ "--save_n_steps",
282
+ type=int,
283
+ default=100,
284
+ help="Save the model every n global_steps",
285
+ )
286
+
287
+ parser.add_argument(
288
+ "--save_starting_step",
289
+ type=int,
290
+ default=100,
291
+ help="The step from which it starts saving intermediary checkpoints",
292
+ )
293
+
294
+ parser.add_argument(
295
+ "--nccl_timeout",
296
+ type=int,
297
+ help="nccl_timeout",
298
+ default=3600
299
+ )
300
+
301
+ parser.add_argument("--snr_gamma", action="store_true")
302
+
303
+ args = parser.parse_args()
304
+
305
+ return args
306
+
307
+
308
+ def add_time_ids(
309
+ unet_config,
310
+ unet_add_embedding,
311
+ text_encoder_2: CLIPTextModelWithProjection,
312
+ original_size: tuple,
313
+ crops_coords_top_left: tuple,
314
+ target_size: tuple,
315
+ dtype: torch.dtype):
316
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
317
+
318
+ passed_add_embed_dim = (
319
+ unet_config.addition_time_embed_dim * len(add_time_ids) + text_encoder_2.config.projection_dim
320
+ )
321
+ expected_add_embed_dim = unet_add_embedding.linear_1.in_features
322
+
323
+ if expected_add_embed_dim != passed_add_embed_dim:
324
+ raise ValueError(
325
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
326
+ )
327
+
328
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
329
+ return add_time_ids
330
+
331
+
332
+ def main():
333
+ global_step = 0
334
+ min_steps_before_validation = 0
335
+
336
+ args = parse_args()
337
+
338
+ next_save_iter = args.save_starting_step
339
+
340
+ if args.save_starting_step < 1:
341
+ next_save_iter = None
342
+
343
+ if args.report_to == "wandb":
344
+ if not is_wandb_available():
345
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
346
+
347
+ accelerator = Accelerator(
348
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
349
+ mixed_precision=args.mixed_precision,
350
+ log_with=args.report_to,
351
+ kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(args.nccl_timeout))]
352
+ )
353
+
354
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
355
+ def save_model_hook(models, weights, output_dir):
356
+ nonlocal global_step
357
+
358
+ for model in models:
359
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
360
+ model.save_pretrained(os.path.join(output_dir, 'unet'))
361
+ # make sure to pop weight so that corresponding model is not saved again
362
+ weights.pop()
363
+
364
+ accelerator.register_save_state_pre_hook(save_model_hook)
365
+
366
+ set_seed(args.seed)
367
+
368
+ # Handle the repository creation
369
+ if accelerator.is_local_main_process:
370
+ if args.output_dir is not None:
371
+ os.makedirs(args.output_dir, exist_ok=True)
372
+
373
+ # Load the tokenizer
374
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
375
+ tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2")
376
+
377
+ # Load models and create wrapper for stable diffusion
378
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
379
+ text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_name_or_path,
380
+ subfolder="text_encoder_2")
381
+
382
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
383
+
384
+ optimizer_resume_path = None
385
+
386
+ if args.unet_resume_path:
387
+ optimizer_fp = os.path.join(args.unet_resume_path, "optimizer.bin")
388
+
389
+ if os.path.exists(optimizer_fp):
390
+ optimizer_resume_path = optimizer_fp
391
+
392
+ unet = UNet3DConditionModel.from_pretrained(args.unet_resume_path,
393
+ subfolder="unet",
394
+ low_cpu_mem_usage=False,
395
+ device_map=None)
396
+
397
+ else:
398
+ unet = UNet3DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
399
+
400
+ if args.xformers:
401
+ vae.set_use_memory_efficient_attention_xformers(True, None)
402
+ unet.set_use_memory_efficient_attention_xformers(True, None)
403
+
404
+ unet_config = unet.config
405
+ unet_add_embedding = unet.add_embedding
406
+
407
+ unet.requires_grad_(False)
408
+
409
+ temporal_params = unet.temporal_parameters()
410
+
411
+ for p in temporal_params:
412
+ p.requires_grad_(True)
413
+
414
+ vae.requires_grad_(False)
415
+ text_encoder.requires_grad_(False)
416
+ text_encoder_2.requires_grad_(False)
417
+
418
+ if args.gradient_checkpointing:
419
+ unet.enable_gradient_checkpointing()
420
+
421
+ if args.scale_lr:
422
+ args.learning_rate = (
423
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
424
+ )
425
+
426
+ # Use 8-bit Adam for lower memory usage
427
+ if args.use_8bit_adam:
428
+ try:
429
+ import bitsandbytes as bnb
430
+ except ImportError:
431
+ raise ImportError(
432
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
433
+ )
434
+
435
+ optimizer_class = bnb.optim.AdamW8bit
436
+ else:
437
+ optimizer_class = torch.optim.AdamW
438
+
439
+ learning_rate = args.learning_rate
440
+
441
+ params_to_optimize = [
442
+ {'params': temporal_params, "lr": learning_rate},
443
+ ]
444
+
445
+ optimizer = optimizer_class(
446
+ params_to_optimize,
447
+ lr=args.learning_rate,
448
+ betas=(args.adam_beta1, args.adam_beta2),
449
+ weight_decay=args.adam_weight_decay,
450
+ eps=args.adam_epsilon,
451
+ )
452
+
453
+ if optimizer_resume_path and not args.disable_optimizer_restore:
454
+ logger.info("Restoring the optimizer.")
455
+ try:
456
+
457
+ old_optimizer_state_dict = torch.load(optimizer_resume_path)
458
+
459
+ # Extract only the state
460
+ old_state = old_optimizer_state_dict['state']
461
+
462
+ # Set the state of the new optimizer
463
+ optimizer.load_state_dict({'state': old_state, 'param_groups': optimizer.param_groups})
464
+
465
+ del old_optimizer_state_dict
466
+ del old_state
467
+
468
+ torch.cuda.empty_cache()
469
+ torch.cuda.synchronize()
470
+ gc.collect()
471
+
472
+ logger.info(f"Restored the optimizer ok")
473
+
474
+ except:
475
+ logger.error("Failed to restore the optimizer...", exc_info=True)
476
+ traceback.print_exc()
477
+ raise
478
+
479
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
480
+
481
+ def compute_snr(timesteps):
482
+ """
483
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
484
+ """
485
+ alphas_cumprod = noise_scheduler.alphas_cumprod
486
+ sqrt_alphas_cumprod = alphas_cumprod ** 0.5
487
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
488
+
489
+ # Expand the tensors.
490
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
491
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
492
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
493
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
494
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
495
+
496
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
497
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
498
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
499
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
500
+
501
+ # Compute SNR.
502
+ snr = (alpha / sigma) ** 2
503
+ return snr
504
+
505
+ device = torch.device('cuda')
506
+
507
+ image_transforms = transforms.Compose(
508
+ [
509
+ transforms.ToTensor(),
510
+ transforms.Normalize([0.5], [0.5]),
511
+ ]
512
+ )
513
+
514
+ def image_to_tensor(img):
515
+ with torch.no_grad():
516
+
517
+ if img.mode != "RGB":
518
+ img = img.convert("RGB")
519
+
520
+ image = image_transforms(img).to(accelerator.device)
521
+
522
+ if image.shape[0] == 1:
523
+ image = image.repeat(3, 1, 1)
524
+
525
+ if image.shape[0] > 3:
526
+ image = image[:3, :, :]
527
+
528
+ return image
529
+
530
+ def make_sample(sample):
531
+
532
+ nonlocal unet_config
533
+ nonlocal unet_add_embedding
534
+
535
+ images = [Image.open(img) for img in sample['image_fps']]
536
+
537
+ og_size = images[0].size
538
+
539
+ for i, im in enumerate(images):
540
+ if im.mode != "RGB":
541
+ images[i] = im.convert("RGB")
542
+
543
+ aspect_ratio_map = res_to_aspect_map[args.resolution]
544
+
545
+ required_size = tuple(aspect_ratio_map[args.aspect_ratio])
546
+
547
+ if required_size != og_size:
548
+
549
+ def resize_image(x):
550
+ img_size = x.size
551
+ if img_size == required_size:
552
+ return x.resize(required_size, Image.LANCZOS)
553
+
554
+ return scale_aspect_fill(x, required_size[0], required_size[1])
555
+
556
+ with ThreadPoolExecutor(max_workers=len(images)) as executor:
557
+ images = list(executor.map(resize_image, images))
558
+
559
+ frames = torch.stack([image_to_tensor(x) for x in images])
560
+
561
+ l, u, *_ = get_crop_coordinates(og_size, images[0].size)
562
+ crop_coords = (l, u)
563
+
564
+ additional_time_ids = add_time_ids(
565
+ unet_config,
566
+ unet_add_embedding,
567
+ text_encoder_2,
568
+ og_size,
569
+ crop_coords,
570
+ (required_size[0], required_size[1]),
571
+ dtype=torch.float32
572
+ ).to(device)
573
+
574
+ input_ids_0 = tokenizer(
575
+ sample['prompt'],
576
+ padding="do_not_pad",
577
+ truncation=True,
578
+ max_length=tokenizer.model_max_length,
579
+ ).input_ids
580
+
581
+ input_ids_1 = tokenizer_2(
582
+ sample['prompt'],
583
+ padding="do_not_pad",
584
+ truncation=True,
585
+ max_length=tokenizer.model_max_length,
586
+ ).input_ids
587
+
588
+ return {
589
+ "frames": frames,
590
+ "input_ids_0": input_ids_0,
591
+ "input_ids_1": input_ids_1,
592
+ "additional_time_ids": additional_time_ids,
593
+ }
594
+
595
+ def collate_fn(examples: list) -> dict:
596
+
597
+ # Two Text encoders
598
+ # First Text Encoder -> Penultimate Layer
599
+ # Second Text Encoder -> Pooled Layer
600
+
601
+ input_ids_0 = [example['input_ids_0'] for example in examples]
602
+ input_ids_0 = tokenizer.pad({"input_ids": input_ids_0}, padding="max_length",
603
+ max_length=tokenizer.model_max_length, return_tensors="pt").input_ids
604
+
605
+ prompt_embeds_0 = text_encoder(
606
+ input_ids_0.to(device),
607
+ output_hidden_states=True,
608
+ )
609
+
610
+ # we take penultimate embeddings from the first text encoder
611
+ prompt_embeds_0 = prompt_embeds_0.hidden_states[-2]
612
+
613
+ input_ids_1 = [example['input_ids_1'] for example in examples]
614
+ input_ids_1 = tokenizer_2.pad({"input_ids": input_ids_1}, padding="max_length",
615
+ max_length=tokenizer.model_max_length, return_tensors="pt").input_ids
616
+
617
+ # We are only ALWAYS interested in the pooled output of the final text encoder
618
+ prompt_embeds = text_encoder_2(
619
+ input_ids_1.to(device),
620
+ output_hidden_states=True
621
+ )
622
+
623
+ pooled_prompt_embeds = prompt_embeds[0]
624
+ prompt_embeds_1 = prompt_embeds.hidden_states[-2]
625
+
626
+ prompt_embeds = torch.concat([prompt_embeds_0, prompt_embeds_1], dim=-1)
627
+
628
+ *_, h, w = examples[0]['frames'].shape
629
+
630
+ return {
631
+ "frames": torch.stack([x['frames'] for x in examples]).to(memory_format=torch.contiguous_format).float(),
632
+ "prompt_embeds": prompt_embeds.to(memory_format=torch.contiguous_format).float(),
633
+ "pooled_prompt_embeds": pooled_prompt_embeds,
634
+ "additional_time_ids": torch.stack([x['additional_time_ids'] for x in examples]),
635
+ }
636
+
637
+ # Region - Dataloaders
638
+ dataset = HotshotXLDataset(args.data_dir, make_sample)
639
+ dataloader = DataLoader(dataset, args.train_batch_size, shuffle=True, collate_fn=collate_fn)
640
+
641
+ # Scheduler and math around the number of training steps.
642
+ overrode_max_train_steps = False
643
+ num_update_steps_per_epoch = math.ceil(len(dataloader) / args.gradient_accumulation_steps)
644
+
645
+ if args.max_train_steps is None:
646
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
647
+ overrode_max_train_steps = True
648
+
649
+ lr_scheduler = get_scheduler(
650
+ args.lr_scheduler,
651
+ optimizer=optimizer,
652
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
653
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
654
+ )
655
+
656
+ unet, optimizer, lr_scheduler, dataloader = accelerator.prepare(
657
+ unet, optimizer, lr_scheduler, dataloader
658
+ )
659
+
660
+ def to_images(video_frames: torch.Tensor):
661
+ import torchvision.transforms as transforms
662
+ to_pil = transforms.ToPILImage()
663
+ video_frames = rearrange(video_frames, "b c f w h -> b f c w h")
664
+ bsz = video_frames.shape[0]
665
+ images = []
666
+ for i in range(bsz):
667
+ video = video_frames[i]
668
+ for j in range(video.shape[0]):
669
+ image = to_pil(video[j])
670
+ images.append(image)
671
+ return images
672
+
673
+ def to_video_frames(images: list) -> np.ndarray:
674
+ x = np.stack([np.asarray(img) for img in images])
675
+ return np.transpose(x, (0, 3, 1, 2))
676
+
677
+ def run_validation(step=0, node_index=0):
678
+
679
+ nonlocal global_step
680
+ nonlocal accelerator
681
+
682
+ if args.test_prompts:
683
+ prompts = args.test_prompts.split("|")
684
+ else:
685
+ prompts = [
686
+ "a woman is lifting weights in a gym",
687
+ "a group of people are dancing at a party",
688
+ "a teddy bear doing the front crawl"
689
+ ]
690
+
691
+ torch.cuda.empty_cache()
692
+ gc.collect()
693
+
694
+ logger.info(f"Running inference to test model at {step} steps")
695
+ with torch.no_grad():
696
+
697
+ pipe = HotshotXLPipeline.from_pretrained(
698
+ args.pretrained_model_name_or_path,
699
+ unet=accelerator.unwrap_model(unet),
700
+ text_encoder=text_encoder,
701
+ text_encoder_2=text_encoder_2,
702
+ vae=vae,
703
+ )
704
+
705
+ videos = []
706
+
707
+ aspect_ratio_map = res_to_aspect_map[args.resolution]
708
+ w, h = aspect_ratio_map[args.aspect_ratio]
709
+
710
+ for prompt in prompts:
711
+ video = pipe(prompt,
712
+ width=w,
713
+ height=h,
714
+ original_size=(1920, 1080), # todo - pass in as args?
715
+ target_size=(args.resolution, args.resolution),
716
+ num_inference_steps=30,
717
+ video_length=8,
718
+ output_type="tensor",
719
+ generator=torch.Generator().manual_seed(111)).videos
720
+
721
+ videos.append(to_images(video))
722
+
723
+ for tracker in accelerator.trackers:
724
+
725
+ if tracker.name == "wandb":
726
+ tracker.log(
727
+ {
728
+ "validation": [wandb.Video(to_video_frames(video), fps=8, format='mp4') for video in
729
+ videos],
730
+ }, step=global_step
731
+ )
732
+
733
+ del pipe
734
+
735
+ return
736
+
737
+ # Move text_encode and vae to gpu.
738
+ vae.to(accelerator.device, dtype=torch.bfloat16 if args.vae_b16 else torch.float32)
739
+ text_encoder.to(accelerator.device)
740
+ text_encoder_2.to(accelerator.device)
741
+
742
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
743
+
744
+ num_update_steps_per_epoch = math.ceil(len(dataloader) / args.gradient_accumulation_steps)
745
+ if overrode_max_train_steps:
746
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
747
+ # Afterward we recalculate our number of training epochs
748
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
749
+
750
+ # We need to initialize the trackers we use, and also store our configuration.
751
+ # The trackers initialize automatically on the main process.
752
+
753
+ if accelerator.is_main_process:
754
+ accelerator.init_trackers(args.project_name)
755
+
756
+ def bar(prg):
757
+ br = '|' + '█' * prg + ' ' * (25 - prg) + '|'
758
+ return br
759
+
760
+ # Train!
761
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
762
+
763
+ if accelerator.is_main_process:
764
+ logger.info("***** Running training *****")
765
+ logger.info(f" Num examples = {len(dataset)}")
766
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
767
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
768
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
769
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
770
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
771
+
772
+ # Only show the progress bar once on each machine.
773
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
774
+
775
+ latents_scaler = vae.config.scaling_factor
776
+
777
+ def save_checkpoint():
778
+ save_dir = Path(args.output_dir)
779
+ save_dir = str(save_dir)
780
+ save_dir = save_dir.replace(" ", "_")
781
+ if not os.path.exists(save_dir):
782
+ os.makedirs(save_dir, exist_ok=True)
783
+ accelerator.save_state(save_dir)
784
+
785
+ def save_checkpoint_and_wait():
786
+ if accelerator.is_main_process:
787
+ save_checkpoint()
788
+ accelerator.wait_for_everyone()
789
+
790
+ def save_model_and_wait():
791
+ if accelerator.is_main_process:
792
+ HotshotXLPipeline.from_pretrained(
793
+ args.pretrained_model_name_or_path,
794
+ unet=accelerator.unwrap_model(unet),
795
+ text_encoder=text_encoder,
796
+ text_encoder_2=text_encoder_2,
797
+ vae=vae,
798
+ ).save_pretrained(args.output_dir, safe_serialization=True)
799
+ accelerator.wait_for_everyone()
800
+
801
+ def compute_loss_from_batch(batch: dict):
802
+ frames = batch["frames"]
803
+ bsz, number_of_frames, c, w, h = frames.shape
804
+
805
+ # Convert images to latent space
806
+ with torch.no_grad():
807
+
808
+ if args.max_vae_encode:
809
+ latents = []
810
+
811
+ x = rearrange(frames, "bs nf c h w -> (bs nf) c h w")
812
+
813
+ for latent_index in range(0, x.shape[0], args.max_vae_encode):
814
+ sample = x[latent_index: latent_index + args.max_vae_encode]
815
+
816
+ latent = vae.encode(sample.to(dtype=vae.dtype)).latent_dist.sample().float()
817
+ if len(latent.shape) == 3:
818
+ latent = latent.unsqueeze(0)
819
+
820
+ latents.append(latent)
821
+ torch.cuda.empty_cache()
822
+
823
+ latents = torch.cat(latents, dim=0)
824
+ else:
825
+
826
+ # convert the latents from 5d -> 4d, so we can run it though the vae encoder
827
+ x = rearrange(frames, "bs nf c h w -> (bs nf) c h w")
828
+
829
+ del frames
830
+
831
+ torch.cuda.empty_cache()
832
+
833
+ latents = vae.encode(x.to(dtype=vae.dtype)).latent_dist.sample().float()
834
+
835
+ if args.latent_nan_checking and torch.any(torch.isnan(latents)):
836
+ accelerator.print("NaN found in latents, replacing with zeros")
837
+ latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
838
+
839
+ latents = rearrange(latents, "(b f) c h w -> b c f h w", b=bsz)
840
+
841
+ torch.cuda.empty_cache()
842
+
843
+ noise = torch.randn_like(latents, device=latents.device)
844
+
845
+ if args.noise_offset:
846
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
847
+ noise += args.noise_offset * torch.randn(
848
+ (latents.shape[0], latents.shape[1], 1, 1, 1), device=latents.device
849
+ )
850
+
851
+ # Sample a random timestep for each image
852
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
853
+ timesteps = timesteps.long() # .repeat_interleave(number_of_frames)
854
+ latents = latents * latents_scaler
855
+
856
+ # Add noise to the latents according to the noise magnitude at each timestep
857
+ # (this is the forward diffusion process)
858
+
859
+ prompt_embeds = batch['prompt_embeds']
860
+ add_text_embeds = batch['pooled_prompt_embeds']
861
+
862
+ additional_time_ids = batch['additional_time_ids'] # .repeat_interleave(number_of_frames, dim=0)
863
+
864
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": additional_time_ids}
865
+
866
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
867
+
868
+ if noise_scheduler.config.prediction_type == "epsilon":
869
+ target = noise
870
+ elif noise_scheduler.config.prediction_type == "v_prediction":
871
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
872
+ else:
873
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
874
+
875
+ noisy_latents.requires_grad = True
876
+
877
+ model_pred = unet(noisy_latents,
878
+ timesteps,
879
+ cross_attention_kwargs=None,
880
+ encoder_hidden_states=prompt_embeds,
881
+ added_cond_kwargs=added_cond_kwargs,
882
+ return_dict=False,
883
+ )[0]
884
+
885
+ if args.snr_gamma:
886
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
887
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
888
+ # This is discussed in Section 4.2 of the same paper.
889
+ snr = compute_snr(timesteps)
890
+ mse_loss_weights = (
891
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
892
+ )
893
+ # We first calculate the original loss. Then we mean over the non-batch dimensions and
894
+ # rebalance the sample-wise losses with their respective loss weights.
895
+ # Finally, we take the mean of the rebalanced loss.
896
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
897
+
898
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
899
+ return loss.mean()
900
+ else:
901
+ return F.mse_loss(model_pred.float(), target.float(), reduction='mean')
902
+
903
+ def process_batch(batch: dict):
904
+ nonlocal global_step
905
+ nonlocal next_save_iter
906
+
907
+ now = time.time()
908
+
909
+ with accelerator.accumulate(unet):
910
+
911
+ logging_data = {}
912
+ if global_step == 0:
913
+ # print(f"Running initial validation at step")
914
+ if accelerator.is_main_process and args.run_validation_at_start:
915
+ run_validation(step=global_step, node_index=accelerator.process_index // 8)
916
+ accelerator.wait_for_everyone()
917
+
918
+ loss = compute_loss_from_batch(batch)
919
+
920
+ accelerator.backward(loss)
921
+
922
+ if accelerator.sync_gradients:
923
+ accelerator.clip_grad_norm_(temporal_params, args.max_grad_norm)
924
+
925
+ optimizer.step()
926
+
927
+ lr_scheduler.step()
928
+ optimizer.zero_grad()
929
+
930
+ # Checks if the accelerator has performed an optimization step behind the scenes
931
+ if accelerator.sync_gradients:
932
+ progress_bar.update(1)
933
+ global_step += 1
934
+
935
+ fll = round((global_step * 100) / args.max_train_steps)
936
+ fll = round(fll / 4)
937
+ pr = bar(fll)
938
+
939
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "loss_time": (time.time() - now)}
940
+
941
+ if args.validate_every_steps is not None and global_step > min_steps_before_validation and global_step % args.validate_every_steps == 0:
942
+ if accelerator.is_main_process:
943
+ run_validation(step=global_step, node_index=accelerator.process_index // 8)
944
+
945
+ accelerator.wait_for_everyone()
946
+
947
+ for key, val in logging_data.items():
948
+ logs[key] = val
949
+
950
+ progress_bar.set_postfix(**logs)
951
+ progress_bar.set_description_str("Progress:" + pr)
952
+ accelerator.log(logs, step=global_step)
953
+
954
+ if accelerator.is_main_process \
955
+ and next_save_iter is not None \
956
+ and global_step < args.max_train_steps \
957
+ and global_step + 1 == next_save_iter:
958
+ save_checkpoint()
959
+
960
+ torch.cuda.empty_cache()
961
+ gc.collect()
962
+
963
+ next_save_iter += args.save_n_steps
964
+
965
+ for epoch in range(args.num_train_epochs):
966
+ unet.train()
967
+
968
+ for step, batch in enumerate(dataloader):
969
+ process_batch(batch)
970
+
971
+ if global_step >= args.max_train_steps:
972
+ break
973
+
974
+ if global_step >= args.max_train_steps:
975
+ logger.info("Max train steps reached. Breaking while loop")
976
+ break
977
+
978
+ accelerator.wait_for_everyone()
979
+
980
+ save_model_and_wait()
981
+
982
+ accelerator.end_training()
983
+
984
+
985
+ if __name__ == "__main__":
986
+ mp.set_start_method('spawn')
987
+ main()
hotshot_xl/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Natural Synthetics Inc. All rights reserved.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Union
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+ # don't remove these imports - they are needed to load from pretrain.
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+ from .models.unet import UNet3DConditionModel
18
+
19
+ from diffusers.utils import (
20
+ BaseOutput,
21
+ )
22
+
23
+ @dataclass
24
+ class HotshotPipelineXLOutput(BaseOutput):
25
+ videos: Union[torch.Tensor, np.ndarray]
hotshot_xl/models/__init__.py ADDED
File without changes
hotshot_xl/models/resnet.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Natural Synthetics Inc. All rights reserved.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from diffusers.models.resnet import Upsample2D, Downsample2D, LoRACompatibleConv
12
+ from einops import rearrange
13
+
14
+
15
+ class Upsample3D(Upsample2D):
16
+ def forward(self, hidden_states, output_size=None, scale: float = 1.0):
17
+ f = hidden_states.shape[2]
18
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
19
+ hidden_states = super(Upsample3D, self).forward(hidden_states, output_size, scale)
20
+ return rearrange(hidden_states, "(b f) c h w -> b c f h w", f=f)
21
+
22
+
23
+ class Downsample3D(Downsample2D):
24
+
25
+ def forward(self, hidden_states, scale: float = 1.0):
26
+ f = hidden_states.shape[2]
27
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
28
+ hidden_states = super(Downsample3D, self).forward(hidden_states, scale)
29
+ return rearrange(hidden_states, "(b f) c h w -> b c f h w", f=f)
30
+
31
+
32
+ class Conv3d(LoRACompatibleConv):
33
+ def forward(self, hidden_states, scale: float = 1.0):
34
+ f = hidden_states.shape[2]
35
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
36
+ hidden_states = super().forward(hidden_states, scale)
37
+ return rearrange(hidden_states, "(b f) c h w -> b c f h w", f=f)
38
+
39
+
40
+ class ResnetBlock3D(nn.Module):
41
+ def __init__(
42
+ self,
43
+ *,
44
+ in_channels,
45
+ out_channels=None,
46
+ conv_shortcut=False,
47
+ dropout=0.0,
48
+ temb_channels=512,
49
+ groups=32,
50
+ groups_out=None,
51
+ pre_norm=True,
52
+ eps=1e-6,
53
+ non_linearity="silu",
54
+ time_embedding_norm="default",
55
+ output_scale_factor=1.0,
56
+ use_in_shortcut=None,
57
+ conv_shortcut_bias: bool = True,
58
+ ):
59
+ super().__init__()
60
+ self.pre_norm = pre_norm
61
+ self.pre_norm = True
62
+ self.in_channels = in_channels
63
+ out_channels = in_channels if out_channels is None else out_channels
64
+ self.out_channels = out_channels
65
+ self.use_conv_shortcut = conv_shortcut
66
+ self.time_embedding_norm = time_embedding_norm
67
+ self.output_scale_factor = output_scale_factor
68
+
69
+ if groups_out is None:
70
+ groups_out = groups
71
+
72
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
73
+ self.conv1 = Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
74
+
75
+ if temb_channels is not None:
76
+ if self.time_embedding_norm == "default":
77
+ time_emb_proj_out_channels = out_channels
78
+ elif self.time_embedding_norm == "scale_shift":
79
+ time_emb_proj_out_channels = out_channels * 2
80
+ else:
81
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
82
+
83
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
84
+ else:
85
+ self.time_emb_proj = None
86
+
87
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
88
+ self.dropout = torch.nn.Dropout(dropout)
89
+ self.conv2 = Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
90
+
91
+ assert non_linearity == "silu"
92
+
93
+ self.nonlinearity = nn.SiLU()
94
+
95
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
96
+
97
+ self.conv_shortcut = None
98
+ if self.use_in_shortcut:
99
+ self.conv_shortcut = Conv3d(
100
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
101
+ )
102
+
103
+ def forward(self, input_tensor, temb):
104
+ hidden_states = input_tensor
105
+
106
+ hidden_states = self.norm1(hidden_states)
107
+ hidden_states = self.nonlinearity(hidden_states)
108
+
109
+ hidden_states = self.conv1(hidden_states)
110
+
111
+ if temb is not None:
112
+ temb = self.nonlinearity(temb)
113
+ temb = self.time_emb_proj(temb)[:, :, None, None, None]
114
+
115
+ if temb is not None and self.time_embedding_norm == "default":
116
+ hidden_states = hidden_states + temb
117
+
118
+ hidden_states = self.norm2(hidden_states)
119
+
120
+ if temb is not None and self.time_embedding_norm == "scale_shift":
121
+ scale, shift = torch.chunk(temb, 2, dim=1)
122
+ hidden_states = hidden_states * (1 + scale) + shift
123
+
124
+ hidden_states = self.nonlinearity(hidden_states)
125
+
126
+ hidden_states = self.dropout(hidden_states)
127
+ hidden_states = self.conv2(hidden_states)
128
+
129
+ if self.conv_shortcut is not None:
130
+ input_tensor = self.conv_shortcut(input_tensor)
131
+
132
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
133
+
134
+ return output_tensor
hotshot_xl/models/transformer_3d.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Natural Synthetics Inc. All rights reserved.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Optional
11
+ import torch
12
+ from torch import nn
13
+ from diffusers.utils import BaseOutput
14
+ from diffusers.models.transformer_2d import Transformer2DModel
15
+ from einops import rearrange, repeat
16
+ from typing import Dict, Any
17
+
18
+
19
+ @dataclass
20
+ class Transformer3DModelOutput(BaseOutput):
21
+ """
22
+ The output of [`Transformer3DModel`].
23
+
24
+ Args:
25
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
26
+ The hidden states output conditioned on the `encoder_hidden_states` input.
27
+ """
28
+
29
+ sample: torch.FloatTensor
30
+
31
+
32
+ class Transformer3DModel(Transformer2DModel):
33
+
34
+ def __init__(self, *args, **kwargs):
35
+ super(Transformer3DModel, self).__init__(*args, **kwargs)
36
+ nn.init.zeros_(self.proj_out.weight.data)
37
+ nn.init.zeros_(self.proj_out.bias.data)
38
+
39
+ def forward(
40
+ self,
41
+ hidden_states: torch.Tensor,
42
+ encoder_hidden_states: Optional[torch.Tensor] = None,
43
+ timestep: Optional[torch.LongTensor] = None,
44
+ class_labels: Optional[torch.LongTensor] = None,
45
+ cross_attention_kwargs: Dict[str, Any] = None,
46
+ attention_mask: Optional[torch.Tensor] = None,
47
+ encoder_attention_mask: Optional[torch.Tensor] = None,
48
+ enable_temporal_layers: bool = True,
49
+ positional_embedding: Optional[torch.Tensor] = None,
50
+ return_dict: bool = True,
51
+ ):
52
+
53
+ is_video = len(hidden_states.shape) == 5
54
+
55
+ if is_video:
56
+ f = hidden_states.shape[2]
57
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
58
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=f)
59
+
60
+ hidden_states = super(Transformer3DModel, self).forward(hidden_states,
61
+ encoder_hidden_states,
62
+ timestep,
63
+ class_labels,
64
+ cross_attention_kwargs,
65
+ attention_mask,
66
+ encoder_attention_mask,
67
+ return_dict=False)[0]
68
+
69
+ if is_video:
70
+ hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=f)
71
+
72
+ if not return_dict:
73
+ return (hidden_states,)
74
+
75
+ return Transformer3DModelOutput(sample=hidden_states)
hotshot_xl/models/transformer_temporal.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Natural Synthetics Inc. All rights reserved.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+
9
+ import torch
10
+ import math
11
+ from dataclasses import dataclass
12
+ from torch import nn
13
+ from diffusers.utils import BaseOutput
14
+ from diffusers.models.attention import Attention, FeedForward
15
+ from einops import rearrange, repeat
16
+ from typing import Optional
17
+
18
+
19
+ class PositionalEncoding(nn.Module):
20
+ """
21
+ Implements positional encoding as described in "Attention Is All You Need".
22
+ Adds sinusoidal based positional encodings to the input tensor.
23
+ """
24
+
25
+ _SCALE_FACTOR = 10000.0 # Scale factor used in the positional encoding computation.
26
+
27
+ def __init__(self, dim: int, dropout: float = 0.0, max_length: int = 24):
28
+ super(PositionalEncoding, self).__init__()
29
+
30
+ self.dropout = nn.Dropout(p=dropout)
31
+
32
+ # The size is (1, max_length, dim) to allow easy addition to input tensors.
33
+ positional_encoding = torch.zeros(1, max_length, dim)
34
+
35
+ # Position and dim are used in the sinusoidal computation.
36
+ position = torch.arange(max_length).unsqueeze(1)
37
+ div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(self._SCALE_FACTOR) / dim))
38
+
39
+ positional_encoding[0, :, 0::2] = torch.sin(position * div_term)
40
+ positional_encoding[0, :, 1::2] = torch.cos(position * div_term)
41
+
42
+ # Register the positional encoding matrix as a buffer,
43
+ # so it's part of the model's state but not the parameters.
44
+ self.register_buffer('positional_encoding', positional_encoding)
45
+
46
+ def forward(self, hidden_states: torch.Tensor, length: int) -> torch.Tensor:
47
+ hidden_states = hidden_states + self.positional_encoding[:, :length]
48
+ return self.dropout(hidden_states)
49
+
50
+
51
+ class TemporalAttention(Attention):
52
+ def __init__(self, *args, **kwargs):
53
+ super().__init__(*args, **kwargs)
54
+ self.pos_encoder = PositionalEncoding(kwargs["query_dim"], dropout=0)
55
+
56
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, number_of_frames=8):
57
+ sequence_length = hidden_states.shape[1]
58
+ hidden_states = rearrange(hidden_states, "(b f) s c -> (b s) f c", f=number_of_frames)
59
+ hidden_states = self.pos_encoder(hidden_states, length=number_of_frames)
60
+
61
+ if encoder_hidden_states:
62
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b s) n c", s=sequence_length)
63
+
64
+ hidden_states = super().forward(hidden_states, encoder_hidden_states, attention_mask=attention_mask)
65
+
66
+ return rearrange(hidden_states, "(b s) f c -> (b f) s c", s=sequence_length)
67
+
68
+
69
+ @dataclass
70
+ class TransformerTemporalOutput(BaseOutput):
71
+ sample: torch.FloatTensor
72
+
73
+
74
+ class TransformerTemporal(nn.Module):
75
+ def __init__(
76
+ self,
77
+ num_attention_heads: int,
78
+ attention_head_dim: int,
79
+ in_channels: int,
80
+ num_layers: int = 1,
81
+ dropout: float = 0.0,
82
+ norm_num_groups: int = 32,
83
+ cross_attention_dim: Optional[int] = None,
84
+ attention_bias: bool = False,
85
+ activation_fn: str = "geglu",
86
+ upcast_attention: bool = False,
87
+ ):
88
+ super().__init__()
89
+
90
+ inner_dim = num_attention_heads * attention_head_dim
91
+
92
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
93
+ self.proj_in = nn.Linear(in_channels, inner_dim)
94
+
95
+ self.transformer_blocks = nn.ModuleList(
96
+ [
97
+ TransformerBlock(
98
+ dim=inner_dim,
99
+ num_attention_heads=num_attention_heads,
100
+ attention_head_dim=attention_head_dim,
101
+ dropout=dropout,
102
+ activation_fn=activation_fn,
103
+ attention_bias=attention_bias,
104
+ upcast_attention=upcast_attention,
105
+ cross_attention_dim=cross_attention_dim
106
+ )
107
+ for _ in range(num_layers)
108
+ ]
109
+ )
110
+ self.proj_out = nn.Linear(inner_dim, in_channels)
111
+
112
+ def forward(self, hidden_states, encoder_hidden_states=None):
113
+ _, num_channels, f, height, width = hidden_states.shape
114
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
115
+
116
+ skip = hidden_states
117
+
118
+ hidden_states = self.norm(hidden_states)
119
+ hidden_states = rearrange(hidden_states, "bf c h w -> bf (h w) c")
120
+ hidden_states = self.proj_in(hidden_states)
121
+
122
+ for block in self.transformer_blocks:
123
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, number_of_frames=f)
124
+
125
+ hidden_states = self.proj_out(hidden_states)
126
+ hidden_states = rearrange(hidden_states, "bf (h w) c -> bf c h w", h=height, w=width).contiguous()
127
+
128
+ output = hidden_states + skip
129
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=f)
130
+
131
+ return output
132
+
133
+
134
+ class TransformerBlock(nn.Module):
135
+ def __init__(
136
+ self,
137
+ dim,
138
+ num_attention_heads,
139
+ attention_head_dim,
140
+ dropout=0.0,
141
+ activation_fn="geglu",
142
+ attention_bias=False,
143
+ upcast_attention=False,
144
+ depth=2,
145
+ cross_attention_dim: Optional[int] = None
146
+ ):
147
+ super().__init__()
148
+
149
+ self.is_cross = cross_attention_dim is not None
150
+
151
+ attention_blocks = []
152
+ norms = []
153
+
154
+ for _ in range(depth):
155
+ attention_blocks.append(
156
+ TemporalAttention(
157
+ query_dim=dim,
158
+ cross_attention_dim=cross_attention_dim,
159
+ heads=num_attention_heads,
160
+ dim_head=attention_head_dim,
161
+ dropout=dropout,
162
+ bias=attention_bias,
163
+ upcast_attention=upcast_attention,
164
+ )
165
+ )
166
+ norms.append(nn.LayerNorm(dim))
167
+
168
+ self.attention_blocks = nn.ModuleList(attention_blocks)
169
+ self.norms = nn.ModuleList(norms)
170
+
171
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
172
+ self.ff_norm = nn.LayerNorm(dim)
173
+
174
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, number_of_frames=None):
175
+
176
+ if not self.is_cross:
177
+ encoder_hidden_states = None
178
+
179
+ for block, norm in zip(self.attention_blocks, self.norms):
180
+ norm_hidden_states = norm(hidden_states)
181
+ hidden_states = block(
182
+ norm_hidden_states,
183
+ encoder_hidden_states=encoder_hidden_states,
184
+ attention_mask=attention_mask,
185
+ number_of_frames=number_of_frames
186
+ ) + hidden_states
187
+
188
+ norm_hidden_states = self.ff_norm(hidden_states)
189
+ hidden_states = self.ff(norm_hidden_states) + hidden_states
190
+
191
+ output = hidden_states
192
+ return output
hotshot_xl/models/unet.py ADDED
@@ -0,0 +1,982 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modifications:
16
+ # Copyright 2023 Natural Synthetics Inc. All rights reserved.
17
+ # - Unet now supports SDXL
18
+
19
+ from dataclasses import dataclass
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.utils.checkpoint
25
+
26
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
27
+ from diffusers.loaders import UNet2DConditionLoadersMixin
28
+ from diffusers.utils import BaseOutput, logging
29
+ from diffusers.models.activations import get_activation
30
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
31
+ from diffusers.models.embeddings import (
32
+ GaussianFourierProjection,
33
+ ImageHintTimeEmbedding,
34
+ ImageProjection,
35
+ ImageTimeEmbedding,
36
+ TextImageProjection,
37
+ TextImageTimeEmbedding,
38
+ TextTimeEmbedding,
39
+ TimestepEmbedding,
40
+ Timesteps,
41
+ )
42
+
43
+ from diffusers.models.modeling_utils import ModelMixin
44
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
45
+ from .unet_blocks import (
46
+ CrossAttnDownBlock3D,
47
+ CrossAttnUpBlock3D,
48
+ DownBlock3D,
49
+ UNetMidBlock3DCrossAttn,
50
+ UpBlock3D,
51
+ get_down_block,
52
+ get_up_block,
53
+ )
54
+
55
+ from .resnet import Conv3d
56
+
57
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
58
+
59
+
60
+ @dataclass
61
+ class UNet3DConditionOutput(BaseOutput):
62
+ """
63
+ The output of [`UNet2DConditionModel`].
64
+
65
+ Args:
66
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
67
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
68
+ """
69
+
70
+ sample: torch.FloatTensor = None
71
+
72
+
73
+ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
74
+ _supports_gradient_checkpointing = True
75
+
76
+ @register_to_config
77
+ def __init__(
78
+ self,
79
+ sample_size: Optional[int] = None,
80
+ in_channels: int = 4,
81
+ out_channels: int = 4,
82
+ center_input_sample: bool = False,
83
+ flip_sin_to_cos: bool = True,
84
+ freq_shift: int = 0,
85
+ down_block_types: Tuple[str] = (
86
+ "CrossAttnDownBlock3D",
87
+ "CrossAttnDownBlock3D",
88
+ "DownBlock3D",
89
+ ),
90
+ mid_block_type: Optional[str] = "UNetMidBlock3DCrossAttn",
91
+ up_block_types: Tuple[str] = (
92
+ "UpBlock3D",
93
+ "CrossAttnUpBlock3D",
94
+ "CrossAttnUpBlock3D",
95
+ ),
96
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
97
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
98
+ layers_per_block: Union[int, Tuple[int]] = 2,
99
+ downsample_padding: int = 1,
100
+ mid_block_scale_factor: float = 1,
101
+ act_fn: str = "silu",
102
+ norm_num_groups: Optional[int] = 32,
103
+ norm_eps: float = 1e-5,
104
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
105
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
106
+ encoder_hid_dim: Optional[int] = None,
107
+ encoder_hid_dim_type: Optional[str] = None,
108
+ attention_head_dim: Union[int, Tuple[int]] = 8,
109
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
110
+ dual_cross_attention: bool = False,
111
+ use_linear_projection: bool = False,
112
+ class_embed_type: Optional[str] = None,
113
+ addition_embed_type: Optional[str] = None,
114
+ addition_time_embed_dim: Optional[int] = None,
115
+ num_class_embeds: Optional[int] = None,
116
+ upcast_attention: bool = False,
117
+ resnet_time_scale_shift: str = "default",
118
+ resnet_skip_time_act: bool = False,
119
+ resnet_out_scale_factor: int = 1.0,
120
+ time_embedding_type: str = "positional",
121
+ time_embedding_dim: Optional[int] = None,
122
+ time_embedding_act_fn: Optional[str] = None,
123
+ timestep_post_act: Optional[str] = None,
124
+ time_cond_proj_dim: Optional[int] = None,
125
+ conv_in_kernel: int = 3,
126
+ conv_out_kernel: int = 3,
127
+ projection_class_embeddings_input_dim: Optional[int] = None,
128
+ class_embeddings_concat: bool = False,
129
+ mid_block_only_cross_attention: Optional[bool] = None,
130
+ cross_attention_norm: Optional[str] = None,
131
+ addition_embed_type_num_heads=64,
132
+ ):
133
+ super().__init__()
134
+
135
+ self.sample_size = sample_size
136
+
137
+ if num_attention_heads is not None:
138
+ raise ValueError(
139
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
140
+ )
141
+
142
+ # If `num_attention_heads` is not defined (which is the case for most models)
143
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
144
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
145
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
146
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
147
+ # which is why we correct for the naming here.
148
+ num_attention_heads = num_attention_heads or attention_head_dim
149
+
150
+ # Check inputs
151
+ if len(down_block_types) != len(up_block_types):
152
+ raise ValueError(
153
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
154
+ )
155
+
156
+ if len(block_out_channels) != len(down_block_types):
157
+ raise ValueError(
158
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
159
+ )
160
+
161
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
162
+ raise ValueError(
163
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
164
+ )
165
+
166
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
167
+ raise ValueError(
168
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
169
+ )
170
+
171
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
172
+ raise ValueError(
173
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
174
+ )
175
+
176
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
177
+ raise ValueError(
178
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
179
+ )
180
+
181
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
182
+ raise ValueError(
183
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
184
+ )
185
+
186
+ # input
187
+ conv_in_padding = (conv_in_kernel - 1) // 2
188
+
189
+ self.conv_in = Conv3d(in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
190
+
191
+ # time
192
+ if time_embedding_type == "fourier":
193
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
194
+ if time_embed_dim % 2 != 0:
195
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
196
+ self.time_proj = GaussianFourierProjection(
197
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
198
+ )
199
+ timestep_input_dim = time_embed_dim
200
+ elif time_embedding_type == "positional":
201
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
202
+
203
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
204
+ timestep_input_dim = block_out_channels[0]
205
+ else:
206
+ raise ValueError(
207
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
208
+ )
209
+
210
+ self.time_embedding = TimestepEmbedding(
211
+ timestep_input_dim,
212
+ time_embed_dim,
213
+ act_fn=act_fn,
214
+ post_act_fn=timestep_post_act,
215
+ cond_proj_dim=time_cond_proj_dim,
216
+ )
217
+
218
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
219
+ encoder_hid_dim_type = "text_proj"
220
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
221
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
222
+
223
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
224
+ raise ValueError(
225
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
226
+ )
227
+
228
+ if encoder_hid_dim_type == "text_proj":
229
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
230
+ elif encoder_hid_dim_type == "text_image_proj":
231
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
232
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
233
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
234
+ self.encoder_hid_proj = TextImageProjection(
235
+ text_embed_dim=encoder_hid_dim,
236
+ image_embed_dim=cross_attention_dim,
237
+ cross_attention_dim=cross_attention_dim,
238
+ )
239
+ elif encoder_hid_dim_type == "image_proj":
240
+ # Kandinsky 2.2
241
+ self.encoder_hid_proj = ImageProjection(
242
+ image_embed_dim=encoder_hid_dim,
243
+ cross_attention_dim=cross_attention_dim,
244
+ )
245
+ elif encoder_hid_dim_type is not None:
246
+ raise ValueError(
247
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
248
+ )
249
+ else:
250
+ self.encoder_hid_proj = None
251
+
252
+ # class embedding
253
+ if class_embed_type is None and num_class_embeds is not None:
254
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
255
+ elif class_embed_type == "timestep":
256
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
257
+ elif class_embed_type == "identity":
258
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
259
+ elif class_embed_type == "projection":
260
+ if projection_class_embeddings_input_dim is None:
261
+ raise ValueError(
262
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
263
+ )
264
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
265
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
266
+ # 2. it projects from an arbitrary input dimension.
267
+ #
268
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
269
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
270
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
271
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
272
+ elif class_embed_type == "simple_projection":
273
+ if projection_class_embeddings_input_dim is None:
274
+ raise ValueError(
275
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
276
+ )
277
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
278
+ else:
279
+ self.class_embedding = None
280
+
281
+ if addition_embed_type == "text":
282
+ if encoder_hid_dim is not None:
283
+ text_time_embedding_from_dim = encoder_hid_dim
284
+ else:
285
+ text_time_embedding_from_dim = cross_attention_dim
286
+
287
+ self.add_embedding = TextTimeEmbedding(
288
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
289
+ )
290
+ elif addition_embed_type == "text_image":
291
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
292
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
293
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
294
+ self.add_embedding = TextImageTimeEmbedding(
295
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
296
+ )
297
+ elif addition_embed_type == "text_time":
298
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
299
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
300
+ elif addition_embed_type == "image":
301
+ # Kandinsky 2.2
302
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
303
+ elif addition_embed_type == "image_hint":
304
+ # Kandinsky 2.2 ControlNet
305
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
306
+ elif addition_embed_type is not None:
307
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
308
+
309
+ if time_embedding_act_fn is None:
310
+ self.time_embed_act = None
311
+ else:
312
+ self.time_embed_act = get_activation(time_embedding_act_fn)
313
+
314
+ self.down_blocks = nn.ModuleList([])
315
+ self.up_blocks = nn.ModuleList([])
316
+
317
+ if isinstance(only_cross_attention, bool):
318
+ if mid_block_only_cross_attention is None:
319
+ mid_block_only_cross_attention = only_cross_attention
320
+
321
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
322
+
323
+ if mid_block_only_cross_attention is None:
324
+ mid_block_only_cross_attention = False
325
+
326
+ if isinstance(num_attention_heads, int):
327
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
328
+
329
+ if isinstance(attention_head_dim, int):
330
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
331
+
332
+ if isinstance(cross_attention_dim, int):
333
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
334
+
335
+ if isinstance(layers_per_block, int):
336
+ layers_per_block = [layers_per_block] * len(down_block_types)
337
+
338
+ if isinstance(transformer_layers_per_block, int):
339
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
340
+
341
+ if class_embeddings_concat:
342
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
343
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
344
+ # regular time embeddings
345
+ blocks_time_embed_dim = time_embed_dim * 2
346
+ else:
347
+ blocks_time_embed_dim = time_embed_dim
348
+
349
+ # down
350
+ output_channel = block_out_channels[0]
351
+ for i, down_block_type in enumerate(down_block_types):
352
+ res = 2 ** i
353
+ input_channel = output_channel
354
+ output_channel = block_out_channels[i]
355
+ is_final_block = i == len(block_out_channels) - 1
356
+
357
+ down_block = get_down_block(
358
+ down_block_type,
359
+ num_layers=layers_per_block[i],
360
+ transformer_layers_per_block=transformer_layers_per_block[i],
361
+ in_channels=input_channel,
362
+ out_channels=output_channel,
363
+ temb_channels=blocks_time_embed_dim,
364
+ add_downsample=not is_final_block,
365
+ resnet_eps=norm_eps,
366
+ resnet_act_fn=act_fn,
367
+ resnet_groups=norm_num_groups,
368
+ cross_attention_dim=cross_attention_dim[i],
369
+ num_attention_heads=num_attention_heads[i],
370
+ downsample_padding=downsample_padding,
371
+ dual_cross_attention=dual_cross_attention,
372
+ use_linear_projection=use_linear_projection,
373
+ only_cross_attention=only_cross_attention[i],
374
+ upcast_attention=upcast_attention,
375
+ resnet_time_scale_shift=resnet_time_scale_shift,
376
+ resnet_skip_time_act=resnet_skip_time_act,
377
+ resnet_out_scale_factor=resnet_out_scale_factor,
378
+ cross_attention_norm=cross_attention_norm,
379
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
380
+ )
381
+ self.down_blocks.append(down_block)
382
+
383
+ # mid
384
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
385
+ self.mid_block = UNetMidBlock3DCrossAttn(
386
+ transformer_layers_per_block=transformer_layers_per_block[-1],
387
+ in_channels=block_out_channels[-1],
388
+ temb_channels=blocks_time_embed_dim,
389
+ resnet_eps=norm_eps,
390
+ resnet_act_fn=act_fn,
391
+ output_scale_factor=mid_block_scale_factor,
392
+ resnet_time_scale_shift=resnet_time_scale_shift,
393
+ cross_attention_dim=cross_attention_dim[-1],
394
+ num_attention_heads=num_attention_heads[-1],
395
+ resnet_groups=norm_num_groups,
396
+ dual_cross_attention=dual_cross_attention,
397
+ use_linear_projection=use_linear_projection,
398
+ upcast_attention=upcast_attention,
399
+ )
400
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
401
+ raise ValueError("UNetMidBlock2DSimpleCrossAttn not supported")
402
+
403
+ elif mid_block_type is None:
404
+ self.mid_block = None
405
+ else:
406
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
407
+
408
+ # count how many layers upsample the images
409
+ self.num_upsamplers = 0
410
+
411
+ # up
412
+ reversed_block_out_channels = list(reversed(block_out_channels))
413
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
414
+ reversed_layers_per_block = list(reversed(layers_per_block))
415
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
416
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
417
+ only_cross_attention = list(reversed(only_cross_attention))
418
+
419
+ output_channel = reversed_block_out_channels[0]
420
+ for i, up_block_type in enumerate(up_block_types):
421
+ res = 2 ** (len(up_block_types) - 1 - i)
422
+ is_final_block = i == len(block_out_channels) - 1
423
+
424
+ prev_output_channel = output_channel
425
+ output_channel = reversed_block_out_channels[i]
426
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
427
+
428
+ # add upsample block for all BUT final layer
429
+ if not is_final_block:
430
+ add_upsample = True
431
+ self.num_upsamplers += 1
432
+ else:
433
+ add_upsample = False
434
+
435
+ up_block = get_up_block(
436
+ up_block_type,
437
+ num_layers=reversed_layers_per_block[i] + 1,
438
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
439
+ in_channels=input_channel,
440
+ out_channels=output_channel,
441
+ prev_output_channel=prev_output_channel,
442
+ temb_channels=blocks_time_embed_dim,
443
+ add_upsample=add_upsample,
444
+ resnet_eps=norm_eps,
445
+ resnet_act_fn=act_fn,
446
+ resnet_groups=norm_num_groups,
447
+ cross_attention_dim=reversed_cross_attention_dim[i],
448
+ num_attention_heads=reversed_num_attention_heads[i],
449
+ dual_cross_attention=dual_cross_attention,
450
+ use_linear_projection=use_linear_projection,
451
+ only_cross_attention=only_cross_attention[i],
452
+ upcast_attention=upcast_attention,
453
+ resnet_time_scale_shift=resnet_time_scale_shift,
454
+ resnet_skip_time_act=resnet_skip_time_act,
455
+ resnet_out_scale_factor=resnet_out_scale_factor,
456
+ cross_attention_norm=cross_attention_norm,
457
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
458
+ )
459
+ self.up_blocks.append(up_block)
460
+ prev_output_channel = output_channel
461
+
462
+ # out
463
+ if norm_num_groups is not None:
464
+ self.conv_norm_out = nn.GroupNorm(
465
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
466
+ )
467
+
468
+ self.conv_act = get_activation(act_fn)
469
+
470
+ else:
471
+ self.conv_norm_out = None
472
+ self.conv_act = None
473
+
474
+ conv_out_padding = (conv_out_kernel - 1) // 2
475
+
476
+ self.conv_out = Conv3d(block_out_channels[0], out_channels, kernel_size=conv_out_kernel,
477
+ padding=conv_out_padding)
478
+
479
+ def temporal_parameters(self) -> list:
480
+ output = []
481
+ all_blocks = self.down_blocks + self.up_blocks + [self.mid_block]
482
+ for block in all_blocks:
483
+ output.extend(block.temporal_parameters())
484
+ return output
485
+
486
+ @property
487
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
488
+ return self.get_attn_processors(include_temporal_layers=False)
489
+
490
+ def get_attn_processors(self, include_temporal_layers=True) -> Dict[str, AttentionProcessor]:
491
+ r"""
492
+ Returns:
493
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
494
+ indexed by its weight name.
495
+ """
496
+ # set recursively
497
+ processors = {}
498
+
499
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
500
+
501
+ if not include_temporal_layers:
502
+ if 'temporal' in name:
503
+ return processors
504
+
505
+ if hasattr(module, "set_processor"):
506
+ processors[f"{name}.processor"] = module.processor
507
+
508
+ for sub_name, child in module.named_children():
509
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
510
+
511
+ return processors
512
+
513
+ for name, module in self.named_children():
514
+ fn_recursive_add_processors(name, module, processors)
515
+
516
+ return processors
517
+
518
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
519
+ include_temporal_layers=False):
520
+ r"""
521
+ Sets the attention processor to use to compute attention.
522
+
523
+ Parameters:
524
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
525
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
526
+ for **all** `Attention` layers.
527
+
528
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
529
+ processor. This is strongly recommended when setting trainable attention processors.
530
+
531
+ """
532
+ count = len(self.get_attn_processors(include_temporal_layers=include_temporal_layers).keys())
533
+
534
+ if isinstance(processor, dict) and len(processor) != count:
535
+ raise ValueError(
536
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
537
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
538
+ )
539
+
540
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
541
+
542
+ if not include_temporal_layers:
543
+ if "temporal" in name:
544
+ return
545
+
546
+ if hasattr(module, "set_processor"):
547
+ if not isinstance(processor, dict):
548
+ module.set_processor(processor)
549
+ else:
550
+ module.set_processor(processor.pop(f"{name}.processor"))
551
+
552
+ for sub_name, child in module.named_children():
553
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
554
+
555
+ for name, module in self.named_children():
556
+ fn_recursive_attn_processor(name, module, processor)
557
+
558
+ def set_default_attn_processor(self):
559
+ """
560
+ Disables custom attention processors and sets the default attention implementation.
561
+ """
562
+ self.set_attn_processor(AttnProcessor())
563
+
564
+ def set_attention_slice(self, slice_size):
565
+ r"""
566
+ Enable sliced attention computation.
567
+
568
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
569
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
570
+
571
+ Args:
572
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
573
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
574
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
575
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
576
+ must be a multiple of `slice_size`.
577
+ """
578
+ sliceable_head_dims = []
579
+
580
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
581
+ if hasattr(module, "set_attention_slice"):
582
+ sliceable_head_dims.append(module.sliceable_head_dim)
583
+
584
+ for child in module.children():
585
+ fn_recursive_retrieve_sliceable_dims(child)
586
+
587
+ # retrieve number of attention layers
588
+ for module in self.children():
589
+ fn_recursive_retrieve_sliceable_dims(module)
590
+
591
+ num_sliceable_layers = len(sliceable_head_dims)
592
+
593
+ if slice_size == "auto":
594
+ # half the attention head size is usually a good trade-off between
595
+ # speed and memory
596
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
597
+ elif slice_size == "max":
598
+ # make smallest slice possible
599
+ slice_size = num_sliceable_layers * [1]
600
+
601
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
602
+
603
+ if len(slice_size) != len(sliceable_head_dims):
604
+ raise ValueError(
605
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
606
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
607
+ )
608
+
609
+ for i in range(len(slice_size)):
610
+ size = slice_size[i]
611
+ dim = sliceable_head_dims[i]
612
+ if size is not None and size > dim:
613
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
614
+
615
+ # Recursively walk through all the children.
616
+ # Any children which exposes the set_attention_slice method
617
+ # gets the message
618
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
619
+ if hasattr(module, "set_attention_slice"):
620
+ module.set_attention_slice(slice_size.pop())
621
+
622
+ for child in module.children():
623
+ fn_recursive_set_attention_slice(child, slice_size)
624
+
625
+ reversed_slice_size = list(reversed(slice_size))
626
+ for module in self.children():
627
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
628
+
629
+ def _set_gradient_checkpointing(self, module, value=False):
630
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
631
+ module.gradient_checkpointing = value
632
+
633
+ def forward(
634
+ self,
635
+ sample: torch.FloatTensor,
636
+ timestep: Union[torch.Tensor, float, int],
637
+ encoder_hidden_states: torch.Tensor,
638
+ class_labels: Optional[torch.Tensor] = None,
639
+ timestep_cond: Optional[torch.Tensor] = None,
640
+ attention_mask: Optional[torch.Tensor] = None,
641
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
642
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
643
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
644
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
645
+ encoder_attention_mask: Optional[torch.Tensor] = None,
646
+ return_dict: bool = True,
647
+ enable_temporal_attentions: bool = True
648
+ ) -> Union[UNet3DConditionOutput, Tuple]:
649
+ r"""
650
+ The [`UNet2DConditionModel`] forward method.
651
+
652
+ Args:
653
+ sample (`torch.FloatTensor`):
654
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
655
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
656
+ encoder_hidden_states (`torch.FloatTensor`):
657
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
658
+ encoder_attention_mask (`torch.Tensor`):
659
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
660
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
661
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
662
+ return_dict (`bool`, *optional*, defaults to `True`):
663
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
664
+ tuple.
665
+ cross_attention_kwargs (`dict`, *optional*):
666
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
667
+ added_cond_kwargs: (`dict`, *optional*):
668
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
669
+ are passed along to the UNet blocks.
670
+
671
+ Returns:
672
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
673
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
674
+ a `tuple` is returned where the first element is the sample tensor.
675
+ """
676
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
677
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
678
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
679
+ # on the fly if necessary.
680
+ default_overall_up_factor = 2 ** self.num_upsamplers
681
+
682
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
683
+ forward_upsample_size = False
684
+ upsample_size = None
685
+
686
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
687
+ logger.info("Forward upsample size to force interpolation output size.")
688
+ forward_upsample_size = True
689
+
690
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
691
+ # expects mask of shape:
692
+ # [batch, key_tokens]
693
+ # adds singleton query_tokens dimension:
694
+ # [batch, 1, key_tokens]
695
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
696
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
697
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
698
+ if attention_mask is not None:
699
+ # assume that mask is expressed as:
700
+ # (1 = keep, 0 = discard)
701
+ # convert mask into a bias that can be added to attention scores:
702
+ # (keep = +0, discard = -10000.0)
703
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
704
+ attention_mask = attention_mask.unsqueeze(1)
705
+
706
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
707
+ if encoder_attention_mask is not None:
708
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
709
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
710
+
711
+ # 0. center input if necessary
712
+ if self.config.center_input_sample:
713
+ sample = 2 * sample - 1.0
714
+
715
+ # 1. time
716
+ timesteps = timestep
717
+ if not torch.is_tensor(timesteps):
718
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
719
+ # This would be a good case for the `match` statement (Python 3.10+)
720
+ is_mps = sample.device.type == "mps"
721
+ if isinstance(timestep, float):
722
+ dtype = torch.float32 if is_mps else torch.float64
723
+ else:
724
+ dtype = torch.int32 if is_mps else torch.int64
725
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
726
+ elif len(timesteps.shape) == 0:
727
+ timesteps = timesteps[None].to(sample.device)
728
+
729
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
730
+ timesteps = timesteps.expand(sample.shape[0])
731
+
732
+ t_emb = self.time_proj(timesteps)
733
+
734
+ # `Timesteps` does not contain any weights and will always return f32 tensors
735
+ # but time_embedding might actually be running in fp16. so we need to cast here.
736
+ # there might be better ways to encapsulate this.
737
+ t_emb = t_emb.to(dtype=sample.dtype)
738
+
739
+ emb = self.time_embedding(t_emb, timestep_cond)
740
+ aug_emb = None
741
+
742
+ if self.class_embedding is not None:
743
+ if class_labels is None:
744
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
745
+
746
+ if self.config.class_embed_type == "timestep":
747
+ class_labels = self.time_proj(class_labels)
748
+
749
+ # `Timesteps` does not contain any weights and will always return f32 tensors
750
+ # there might be better ways to encapsulate this.
751
+ class_labels = class_labels.to(dtype=sample.dtype)
752
+
753
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
754
+
755
+ if self.config.class_embeddings_concat:
756
+ emb = torch.cat([emb, class_emb], dim=-1)
757
+ else:
758
+ emb = emb + class_emb
759
+
760
+ if self.config.addition_embed_type == "text":
761
+ aug_emb = self.add_embedding(encoder_hidden_states)
762
+ elif self.config.addition_embed_type == "text_image":
763
+ # Kandinsky 2.1 - style
764
+ if "image_embeds" not in added_cond_kwargs:
765
+ raise ValueError(
766
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
767
+ )
768
+
769
+ image_embs = added_cond_kwargs.get("image_embeds")
770
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
771
+ aug_emb = self.add_embedding(text_embs, image_embs)
772
+ elif self.config.addition_embed_type == "text_time":
773
+ if "text_embeds" not in added_cond_kwargs:
774
+ raise ValueError(
775
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
776
+ )
777
+ text_embeds = added_cond_kwargs.get("text_embeds")
778
+ if "time_ids" not in added_cond_kwargs:
779
+ raise ValueError(
780
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
781
+ )
782
+ time_ids = added_cond_kwargs.get("time_ids")
783
+ time_embeds = self.add_time_proj(time_ids.flatten())
784
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
785
+
786
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
787
+ add_embeds = add_embeds.to(emb.dtype)
788
+ aug_emb = self.add_embedding(add_embeds)
789
+ elif self.config.addition_embed_type == "image":
790
+ # Kandinsky 2.2 - style
791
+ if "image_embeds" not in added_cond_kwargs:
792
+ raise ValueError(
793
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
794
+ )
795
+ image_embs = added_cond_kwargs.get("image_embeds")
796
+ aug_emb = self.add_embedding(image_embs)
797
+ elif self.config.addition_embed_type == "image_hint":
798
+ # Kandinsky 2.2 - style
799
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
800
+ raise ValueError(
801
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
802
+ )
803
+ image_embs = added_cond_kwargs.get("image_embeds")
804
+ hint = added_cond_kwargs.get("hint")
805
+ aug_emb, hint = self.add_embedding(image_embs, hint)
806
+ sample = torch.cat([sample, hint], dim=1)
807
+
808
+ emb = emb + aug_emb if aug_emb is not None else emb
809
+
810
+ if self.time_embed_act is not None:
811
+ emb = self.time_embed_act(emb)
812
+
813
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
814
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
815
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
816
+ # Kadinsky 2.1 - style
817
+ if "image_embeds" not in added_cond_kwargs:
818
+ raise ValueError(
819
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
820
+ )
821
+
822
+ image_embeds = added_cond_kwargs.get("image_embeds")
823
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
824
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
825
+ # Kandinsky 2.2 - style
826
+ if "image_embeds" not in added_cond_kwargs:
827
+ raise ValueError(
828
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
829
+ )
830
+ image_embeds = added_cond_kwargs.get("image_embeds")
831
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
832
+ # 2. pre-process
833
+
834
+ sample = self.conv_in(sample)
835
+
836
+ # 3. down
837
+ down_block_res_samples = (sample,)
838
+ for downsample_block in self.down_blocks:
839
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
840
+ sample, res_samples = downsample_block(
841
+ hidden_states=sample,
842
+ temb=emb,
843
+ encoder_hidden_states=encoder_hidden_states,
844
+ attention_mask=attention_mask,
845
+ cross_attention_kwargs=cross_attention_kwargs,
846
+ enable_temporal_attentions=enable_temporal_attentions
847
+ )
848
+ else:
849
+ sample, res_samples = downsample_block(hidden_states=sample,
850
+ temb=emb,
851
+ encoder_hidden_states=encoder_hidden_states,
852
+ enable_temporal_attentions=enable_temporal_attentions)
853
+
854
+ down_block_res_samples += res_samples
855
+
856
+ if down_block_additional_residuals is not None:
857
+ new_down_block_res_samples = ()
858
+
859
+ for down_block_res_sample, down_block_additional_residual in zip(
860
+ down_block_res_samples, down_block_additional_residuals
861
+ ):
862
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
863
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
864
+
865
+ down_block_res_samples = new_down_block_res_samples
866
+
867
+ # 4. mid
868
+ if self.mid_block is not None:
869
+ sample = self.mid_block(
870
+ sample,
871
+ emb,
872
+ encoder_hidden_states=encoder_hidden_states,
873
+ attention_mask=attention_mask,
874
+ cross_attention_kwargs=cross_attention_kwargs,
875
+ enable_temporal_attentions=enable_temporal_attentions
876
+ )
877
+
878
+ if mid_block_additional_residual is not None:
879
+ sample = sample + mid_block_additional_residual
880
+
881
+ # 5. up
882
+ for i, upsample_block in enumerate(self.up_blocks):
883
+ is_final_block = i == len(self.up_blocks) - 1
884
+
885
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
886
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
887
+
888
+ # if we have not reached the final block and need to forward the
889
+ # upsample size, we do it here
890
+ if not is_final_block and forward_upsample_size:
891
+ upsample_size = down_block_res_samples[-1].shape[2:]
892
+
893
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
894
+ sample = upsample_block(
895
+ hidden_states=sample,
896
+ temb=emb,
897
+ res_hidden_states_tuple=res_samples,
898
+ encoder_hidden_states=encoder_hidden_states,
899
+ cross_attention_kwargs=cross_attention_kwargs,
900
+ upsample_size=upsample_size,
901
+ attention_mask=attention_mask,
902
+ enable_temporal_attentions=enable_temporal_attentions
903
+ )
904
+ else:
905
+ sample = upsample_block(
906
+ hidden_states=sample,
907
+ temb=emb,
908
+ res_hidden_states_tuple=res_samples,
909
+ upsample_size=upsample_size,
910
+ encoder_hidden_states=encoder_hidden_states,
911
+ enable_temporal_attentions=enable_temporal_attentions
912
+ )
913
+
914
+ # 6. post-process
915
+ if self.conv_norm_out:
916
+ sample = self.conv_norm_out(sample)
917
+ sample = self.conv_act(sample)
918
+
919
+ sample = self.conv_out(sample)
920
+
921
+ if not return_dict:
922
+ return (sample,)
923
+
924
+ return UNet3DConditionOutput(sample=sample)
925
+
926
+ @classmethod
927
+ def from_pretrained_spatial(cls, pretrained_model_path, subfolder=None):
928
+
929
+ import os
930
+ import json
931
+
932
+ if subfolder is not None:
933
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
934
+
935
+ config_file = os.path.join(pretrained_model_path, 'config.json')
936
+
937
+ with open(config_file, "r") as f:
938
+ config = json.load(f)
939
+
940
+ config["_class_name"] = "UNet3DConditionModel"
941
+
942
+ config["down_block_types"] = [
943
+ "DownBlock3D",
944
+ "CrossAttnDownBlock3D",
945
+ "CrossAttnDownBlock3D",
946
+ ]
947
+ config["up_block_types"] = [
948
+ "CrossAttnUpBlock3D",
949
+ "CrossAttnUpBlock3D",
950
+ "UpBlock3D"
951
+ ]
952
+
953
+ config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
954
+
955
+ model = cls.from_config(config)
956
+
957
+ model_files = [
958
+ os.path.join(pretrained_model_path, 'diffusion_pytorch_model.bin'),
959
+ os.path.join(pretrained_model_path, 'diffusion_pytorch_model.safetensors')
960
+ ]
961
+
962
+ model_file = None
963
+
964
+ for fp in model_files:
965
+ if os.path.exists(fp):
966
+ model_file = fp
967
+
968
+ if not model_file:
969
+ raise RuntimeError(f"{model_file} does not exist")
970
+
971
+ if model_file.split(".")[-1] == "safetensors":
972
+ from safetensors import safe_open
973
+ state_dict = {}
974
+ with safe_open(model_file, framework="pt", device="cuda") as f:
975
+ for key in f.keys():
976
+ state_dict[key] = f.get_tensor(key)
977
+ else:
978
+ state_dict = torch.load(model_file, map_location="cpu")
979
+
980
+ model.load_state_dict(state_dict, strict=False)
981
+
982
+ return model
hotshot_xl/models/unet_blocks.py ADDED
@@ -0,0 +1,740 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modifications:
16
+ # Copyright 2023 Natural Synthetics Inc. All rights reserved.
17
+ # - Add temporal transformers to unet blocks
18
+
19
+ import torch
20
+ from torch import nn
21
+
22
+ from .transformer_3d import Transformer3DModel
23
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
24
+ from .transformer_temporal import TransformerTemporal
25
+
26
+
27
+ def get_down_block(
28
+ down_block_type,
29
+ num_layers,
30
+ in_channels,
31
+ out_channels,
32
+ temb_channels,
33
+ add_downsample,
34
+ resnet_eps,
35
+ resnet_act_fn,
36
+ transformer_layers_per_block=1,
37
+ num_attention_heads=None,
38
+ resnet_groups=None,
39
+ cross_attention_dim=None,
40
+ downsample_padding=None,
41
+ dual_cross_attention=False,
42
+ use_linear_projection=False,
43
+ only_cross_attention=False,
44
+ upcast_attention=False,
45
+ resnet_time_scale_shift="default",
46
+ resnet_skip_time_act=False,
47
+ resnet_out_scale_factor=1.0,
48
+ cross_attention_norm=None,
49
+ attention_head_dim=None,
50
+ downsample_type=None,
51
+ ):
52
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
53
+ if down_block_type == "DownBlock3D":
54
+ return DownBlock3D(
55
+ num_layers=num_layers,
56
+ in_channels=in_channels,
57
+ out_channels=out_channels,
58
+ temb_channels=temb_channels,
59
+ add_downsample=add_downsample,
60
+ resnet_eps=resnet_eps,
61
+ resnet_act_fn=resnet_act_fn,
62
+ resnet_groups=resnet_groups,
63
+ downsample_padding=downsample_padding,
64
+ resnet_time_scale_shift=resnet_time_scale_shift,
65
+ )
66
+ elif down_block_type == "CrossAttnDownBlock3D":
67
+ if cross_attention_dim is None:
68
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
69
+ return CrossAttnDownBlock3D(
70
+ num_layers=num_layers,
71
+ in_channels=in_channels,
72
+ out_channels=out_channels,
73
+ transformer_layers_per_block=transformer_layers_per_block,
74
+ temb_channels=temb_channels,
75
+ add_downsample=add_downsample,
76
+ resnet_eps=resnet_eps,
77
+ resnet_act_fn=resnet_act_fn,
78
+ resnet_groups=resnet_groups,
79
+ downsample_padding=downsample_padding,
80
+ cross_attention_dim=cross_attention_dim,
81
+ num_attention_heads=num_attention_heads,
82
+ dual_cross_attention=dual_cross_attention,
83
+ use_linear_projection=use_linear_projection,
84
+ only_cross_attention=only_cross_attention,
85
+ upcast_attention=upcast_attention,
86
+ resnet_time_scale_shift=resnet_time_scale_shift,
87
+ )
88
+ raise ValueError(f"{down_block_type} does not exist.")
89
+
90
+
91
+ def get_up_block(
92
+ up_block_type,
93
+ num_layers,
94
+ in_channels,
95
+ out_channels,
96
+ prev_output_channel,
97
+ temb_channels,
98
+ add_upsample,
99
+ resnet_eps,
100
+ resnet_act_fn,
101
+ transformer_layers_per_block=1,
102
+ num_attention_heads=None,
103
+ resnet_groups=None,
104
+ cross_attention_dim=None,
105
+ dual_cross_attention=False,
106
+ use_linear_projection=False,
107
+ only_cross_attention=False,
108
+ upcast_attention=False,
109
+ resnet_time_scale_shift="default",
110
+ resnet_skip_time_act=False,
111
+ resnet_out_scale_factor=1.0,
112
+ cross_attention_norm=None,
113
+ attention_head_dim=None,
114
+ upsample_type=None,
115
+ ):
116
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
117
+ if up_block_type == "UpBlock3D":
118
+ return UpBlock3D(
119
+ num_layers=num_layers,
120
+ in_channels=in_channels,
121
+ out_channels=out_channels,
122
+ prev_output_channel=prev_output_channel,
123
+ temb_channels=temb_channels,
124
+ add_upsample=add_upsample,
125
+ resnet_eps=resnet_eps,
126
+ resnet_act_fn=resnet_act_fn,
127
+ resnet_groups=resnet_groups,
128
+ resnet_time_scale_shift=resnet_time_scale_shift,
129
+ )
130
+ elif up_block_type == "CrossAttnUpBlock3D":
131
+ if cross_attention_dim is None:
132
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
133
+ return CrossAttnUpBlock3D(
134
+ num_layers=num_layers,
135
+ in_channels=in_channels,
136
+ transformer_layers_per_block=transformer_layers_per_block,
137
+ out_channels=out_channels,
138
+ prev_output_channel=prev_output_channel,
139
+ temb_channels=temb_channels,
140
+ add_upsample=add_upsample,
141
+ resnet_eps=resnet_eps,
142
+ resnet_act_fn=resnet_act_fn,
143
+ resnet_groups=resnet_groups,
144
+ cross_attention_dim=cross_attention_dim,
145
+ num_attention_heads=num_attention_heads,
146
+ dual_cross_attention=dual_cross_attention,
147
+ use_linear_projection=use_linear_projection,
148
+ only_cross_attention=only_cross_attention,
149
+ upcast_attention=upcast_attention,
150
+ resnet_time_scale_shift=resnet_time_scale_shift,
151
+ )
152
+ raise ValueError(f"{up_block_type} does not exist.")
153
+
154
+
155
+ class UNetMidBlock3DCrossAttn(nn.Module):
156
+ def __init__(
157
+ self,
158
+ in_channels: int,
159
+ temb_channels: int,
160
+ dropout: float = 0.0,
161
+ num_layers: int = 1,
162
+ transformer_layers_per_block: int = 1,
163
+ resnet_eps: float = 1e-6,
164
+ resnet_time_scale_shift: str = "default",
165
+ resnet_act_fn: str = "swish",
166
+ resnet_groups: int = 32,
167
+ resnet_pre_norm: bool = True,
168
+ num_attention_heads=1,
169
+ output_scale_factor=1.0,
170
+ cross_attention_dim=1280,
171
+ dual_cross_attention=False,
172
+ use_linear_projection=False,
173
+ upcast_attention=False,
174
+ ):
175
+ super().__init__()
176
+
177
+ self.has_cross_attention = True
178
+ self.num_attention_heads = num_attention_heads
179
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
180
+
181
+ # there is always at least one resnet
182
+ resnets = [
183
+ ResnetBlock3D(
184
+ in_channels=in_channels,
185
+ out_channels=in_channels,
186
+ temb_channels=temb_channels,
187
+ eps=resnet_eps,
188
+ groups=resnet_groups,
189
+ dropout=dropout,
190
+ time_embedding_norm=resnet_time_scale_shift,
191
+ non_linearity=resnet_act_fn,
192
+ output_scale_factor=output_scale_factor,
193
+ pre_norm=resnet_pre_norm,
194
+ )
195
+ ]
196
+ attentions = []
197
+
198
+ for _ in range(num_layers):
199
+ if dual_cross_attention:
200
+ raise NotImplementedError
201
+ attentions.append(
202
+ Transformer3DModel(
203
+ num_attention_heads,
204
+ in_channels // num_attention_heads,
205
+ in_channels=in_channels,
206
+ num_layers=transformer_layers_per_block,
207
+ cross_attention_dim=cross_attention_dim,
208
+ norm_num_groups=resnet_groups,
209
+ use_linear_projection=use_linear_projection,
210
+ upcast_attention=upcast_attention,
211
+ )
212
+ )
213
+
214
+ resnets.append(
215
+ ResnetBlock3D(
216
+ in_channels=in_channels,
217
+ out_channels=in_channels,
218
+ temb_channels=temb_channels,
219
+ eps=resnet_eps,
220
+ groups=resnet_groups,
221
+ dropout=dropout,
222
+ time_embedding_norm=resnet_time_scale_shift,
223
+ non_linearity=resnet_act_fn,
224
+ output_scale_factor=output_scale_factor,
225
+ pre_norm=resnet_pre_norm,
226
+ )
227
+ )
228
+
229
+ self.attentions = nn.ModuleList(attentions)
230
+ self.resnets = nn.ModuleList(resnets)
231
+
232
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None,
233
+ cross_attention_kwargs=None, enable_temporal_attentions: bool = True):
234
+ hidden_states = self.resnets[0](hidden_states, temb)
235
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
236
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
237
+ hidden_states = resnet(hidden_states, temb)
238
+
239
+ return hidden_states
240
+
241
+ def temporal_parameters(self) -> list:
242
+ return []
243
+
244
+
245
+ class CrossAttnDownBlock3D(nn.Module):
246
+ def __init__(
247
+ self,
248
+ in_channels: int,
249
+ out_channels: int,
250
+ temb_channels: int,
251
+ dropout: float = 0.0,
252
+ num_layers: int = 1,
253
+ transformer_layers_per_block: int = 1,
254
+ resnet_eps: float = 1e-6,
255
+ resnet_time_scale_shift: str = "default",
256
+ resnet_act_fn: str = "swish",
257
+ resnet_groups: int = 32,
258
+ resnet_pre_norm: bool = True,
259
+ num_attention_heads=1,
260
+ cross_attention_dim=1280,
261
+ output_scale_factor=1.0,
262
+ downsample_padding=1,
263
+ add_downsample=True,
264
+ dual_cross_attention=False,
265
+ use_linear_projection=False,
266
+ only_cross_attention=False,
267
+ upcast_attention=False,
268
+ ):
269
+ super().__init__()
270
+ resnets = []
271
+ attentions = []
272
+ temporal_attentions = []
273
+
274
+ self.has_cross_attention = True
275
+ self.num_attention_heads = num_attention_heads
276
+
277
+ for i in range(num_layers):
278
+ in_channels = in_channels if i == 0 else out_channels
279
+ resnets.append(
280
+ ResnetBlock3D(
281
+ in_channels=in_channels,
282
+ out_channels=out_channels,
283
+ temb_channels=temb_channels,
284
+ eps=resnet_eps,
285
+ groups=resnet_groups,
286
+ dropout=dropout,
287
+ time_embedding_norm=resnet_time_scale_shift,
288
+ non_linearity=resnet_act_fn,
289
+ output_scale_factor=output_scale_factor,
290
+ pre_norm=resnet_pre_norm,
291
+ )
292
+ )
293
+ if dual_cross_attention:
294
+ raise NotImplementedError
295
+ attentions.append(
296
+ Transformer3DModel(
297
+ num_attention_heads,
298
+ out_channels // num_attention_heads,
299
+ in_channels=out_channels,
300
+ num_layers=transformer_layers_per_block,
301
+ cross_attention_dim=cross_attention_dim,
302
+ norm_num_groups=resnet_groups,
303
+ use_linear_projection=use_linear_projection,
304
+ only_cross_attention=only_cross_attention,
305
+ upcast_attention=upcast_attention,
306
+ )
307
+ )
308
+ temporal_attentions.append(
309
+ TransformerTemporal(
310
+ num_attention_heads=8,
311
+ attention_head_dim=out_channels // 8,
312
+ in_channels=out_channels,
313
+ cross_attention_dim=None,
314
+ )
315
+ )
316
+
317
+ self.attentions = nn.ModuleList(attentions)
318
+ self.resnets = nn.ModuleList(resnets)
319
+ self.temporal_attentions = nn.ModuleList(temporal_attentions)
320
+
321
+ if add_downsample:
322
+ self.downsamplers = nn.ModuleList(
323
+ [
324
+ Downsample3D(
325
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
326
+ )
327
+ ]
328
+ )
329
+ else:
330
+ self.downsamplers = None
331
+
332
+ self.gradient_checkpointing = False
333
+
334
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None,
335
+ cross_attention_kwargs=None, enable_temporal_attentions: bool = True):
336
+ output_states = ()
337
+
338
+ for resnet, attn, temporal_attention \
339
+ in zip(self.resnets, self.attentions, self.temporal_attentions):
340
+ if self.training and self.gradient_checkpointing:
341
+
342
+ def create_custom_forward(module, return_dict=None):
343
+ def custom_forward(*inputs):
344
+ if return_dict is not None:
345
+ return module(*inputs, return_dict=return_dict)
346
+ else:
347
+ return module(*inputs)
348
+
349
+ return custom_forward
350
+
351
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb,
352
+ use_reentrant=False)
353
+
354
+ hidden_states = torch.utils.checkpoint.checkpoint(
355
+ create_custom_forward(attn, return_dict=False),
356
+ hidden_states,
357
+ encoder_hidden_states,
358
+ use_reentrant=False
359
+ )[0]
360
+ if enable_temporal_attentions and temporal_attention is not None:
361
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temporal_attention),
362
+ hidden_states, encoder_hidden_states,
363
+ use_reentrant=False)
364
+
365
+ else:
366
+ hidden_states = resnet(hidden_states, temb)
367
+
368
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
369
+
370
+ if temporal_attention and enable_temporal_attentions:
371
+ hidden_states = temporal_attention(hidden_states,
372
+ encoder_hidden_states=encoder_hidden_states)
373
+
374
+ output_states += (hidden_states,)
375
+
376
+ if self.downsamplers is not None:
377
+ for downsampler in self.downsamplers:
378
+ hidden_states = downsampler(hidden_states)
379
+
380
+ output_states += (hidden_states,)
381
+
382
+ return hidden_states, output_states
383
+
384
+ def temporal_parameters(self) -> list:
385
+ output = []
386
+ for block in self.temporal_attentions:
387
+ if block:
388
+ output.extend(block.parameters())
389
+ return output
390
+
391
+
392
+ class DownBlock3D(nn.Module):
393
+ def __init__(
394
+ self,
395
+ in_channels: int,
396
+ out_channels: int,
397
+ temb_channels: int,
398
+ dropout: float = 0.0,
399
+ num_layers: int = 1,
400
+ resnet_eps: float = 1e-6,
401
+ resnet_time_scale_shift: str = "default",
402
+ resnet_act_fn: str = "swish",
403
+ resnet_groups: int = 32,
404
+ resnet_pre_norm: bool = True,
405
+ output_scale_factor=1.0,
406
+ add_downsample=True,
407
+ downsample_padding=1,
408
+ ):
409
+ super().__init__()
410
+ resnets = []
411
+ temporal_attentions = []
412
+
413
+ for i in range(num_layers):
414
+ in_channels = in_channels if i == 0 else out_channels
415
+ resnets.append(
416
+ ResnetBlock3D(
417
+ in_channels=in_channels,
418
+ out_channels=out_channels,
419
+ temb_channels=temb_channels,
420
+ eps=resnet_eps,
421
+ groups=resnet_groups,
422
+ dropout=dropout,
423
+ time_embedding_norm=resnet_time_scale_shift,
424
+ non_linearity=resnet_act_fn,
425
+ output_scale_factor=output_scale_factor,
426
+ pre_norm=resnet_pre_norm,
427
+ )
428
+ )
429
+ temporal_attentions.append(
430
+ TransformerTemporal(
431
+ num_attention_heads=8,
432
+ attention_head_dim=out_channels // 8,
433
+ in_channels=out_channels,
434
+ cross_attention_dim=None
435
+ )
436
+ )
437
+
438
+ self.resnets = nn.ModuleList(resnets)
439
+ self.temporal_attentions = nn.ModuleList(temporal_attentions)
440
+
441
+ if add_downsample:
442
+ self.downsamplers = nn.ModuleList(
443
+ [
444
+ Downsample3D(
445
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
446
+ )
447
+ ]
448
+ )
449
+ else:
450
+ self.downsamplers = None
451
+
452
+ self.gradient_checkpointing = False
453
+
454
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, enable_temporal_attentions: bool = True):
455
+ output_states = ()
456
+
457
+ for resnet, temporal_attention in zip(self.resnets, self.temporal_attentions):
458
+ if self.training and self.gradient_checkpointing:
459
+ def create_custom_forward(module):
460
+ def custom_forward(*inputs):
461
+ return module(*inputs)
462
+
463
+ return custom_forward
464
+
465
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb,
466
+ use_reentrant=False)
467
+ if enable_temporal_attentions and temporal_attention is not None:
468
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temporal_attention),
469
+ hidden_states, encoder_hidden_states,
470
+ use_reentrant=False)
471
+ else:
472
+ hidden_states = resnet(hidden_states, temb)
473
+
474
+ if enable_temporal_attentions and temporal_attention:
475
+ hidden_states = temporal_attention(hidden_states, encoder_hidden_states=encoder_hidden_states)
476
+
477
+ output_states += (hidden_states,)
478
+
479
+ if self.downsamplers is not None:
480
+ for downsampler in self.downsamplers:
481
+ hidden_states = downsampler(hidden_states)
482
+
483
+ output_states += (hidden_states,)
484
+
485
+ return hidden_states, output_states
486
+
487
+ def temporal_parameters(self) -> list:
488
+ output = []
489
+ for block in self.temporal_attentions:
490
+ if block:
491
+ output.extend(block.parameters())
492
+ return output
493
+
494
+
495
+ class CrossAttnUpBlock3D(nn.Module):
496
+ def __init__(
497
+ self,
498
+ in_channels: int,
499
+ out_channels: int,
500
+ prev_output_channel: int,
501
+ temb_channels: int,
502
+ dropout: float = 0.0,
503
+ num_layers: int = 1,
504
+ transformer_layers_per_block: int = 1,
505
+ resnet_eps: float = 1e-6,
506
+ resnet_time_scale_shift: str = "default",
507
+ resnet_act_fn: str = "swish",
508
+ resnet_groups: int = 32,
509
+ resnet_pre_norm: bool = True,
510
+ num_attention_heads=1,
511
+ cross_attention_dim=1280,
512
+ output_scale_factor=1.0,
513
+ add_upsample=True,
514
+ dual_cross_attention=False,
515
+ use_linear_projection=False,
516
+ only_cross_attention=False,
517
+ upcast_attention=False,
518
+ ):
519
+ super().__init__()
520
+ resnets = []
521
+ attentions = []
522
+ temporal_attentions = []
523
+
524
+ self.has_cross_attention = True
525
+ self.num_attention_heads = num_attention_heads
526
+
527
+ for i in range(num_layers):
528
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
529
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
530
+
531
+ resnets.append(
532
+ ResnetBlock3D(
533
+ in_channels=resnet_in_channels + res_skip_channels,
534
+ out_channels=out_channels,
535
+ temb_channels=temb_channels,
536
+ eps=resnet_eps,
537
+ groups=resnet_groups,
538
+ dropout=dropout,
539
+ time_embedding_norm=resnet_time_scale_shift,
540
+ non_linearity=resnet_act_fn,
541
+ output_scale_factor=output_scale_factor,
542
+ pre_norm=resnet_pre_norm,
543
+ )
544
+ )
545
+ if dual_cross_attention:
546
+ raise NotImplementedError
547
+ attentions.append(
548
+ Transformer3DModel(
549
+ num_attention_heads,
550
+ out_channels // num_attention_heads,
551
+ in_channels=out_channels,
552
+ num_layers=transformer_layers_per_block,
553
+ cross_attention_dim=cross_attention_dim,
554
+ norm_num_groups=resnet_groups,
555
+ use_linear_projection=use_linear_projection,
556
+ only_cross_attention=only_cross_attention,
557
+ upcast_attention=upcast_attention,
558
+ )
559
+ )
560
+ temporal_attentions.append(
561
+ TransformerTemporal(
562
+ num_attention_heads=8,
563
+ attention_head_dim=out_channels // 8,
564
+ in_channels=out_channels,
565
+ cross_attention_dim=None
566
+ )
567
+ )
568
+
569
+ self.attentions = nn.ModuleList(attentions)
570
+ self.resnets = nn.ModuleList(resnets)
571
+ self.temporal_attentions = nn.ModuleList(temporal_attentions)
572
+
573
+ if add_upsample:
574
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
575
+ else:
576
+ self.upsamplers = None
577
+
578
+ self.gradient_checkpointing = False
579
+
580
+ def forward(
581
+ self,
582
+ hidden_states,
583
+ res_hidden_states_tuple,
584
+ temb=None,
585
+ encoder_hidden_states=None,
586
+ upsample_size=None,
587
+ cross_attention_kwargs=None,
588
+ attention_mask=None,
589
+ enable_temporal_attentions: bool = True
590
+ ):
591
+ for resnet, attn, temporal_attention \
592
+ in zip(self.resnets, self.attentions, self.temporal_attentions):
593
+ # pop res hidden states
594
+ res_hidden_states = res_hidden_states_tuple[-1]
595
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
596
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
597
+
598
+ if self.training and self.gradient_checkpointing:
599
+
600
+ def create_custom_forward(module, return_dict=None):
601
+ def custom_forward(*inputs):
602
+ if return_dict is not None:
603
+ return module(*inputs, return_dict=return_dict)
604
+ else:
605
+ return module(*inputs)
606
+
607
+ return custom_forward
608
+
609
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb,
610
+ use_reentrant=False)
611
+
612
+ hidden_states = torch.utils.checkpoint.checkpoint(
613
+ create_custom_forward(attn, return_dict=False),
614
+ hidden_states,
615
+ encoder_hidden_states,
616
+ use_reentrant=False,
617
+ )[0]
618
+ if enable_temporal_attentions and temporal_attention is not None:
619
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temporal_attention),
620
+ hidden_states, encoder_hidden_states,
621
+ use_reentrant=False)
622
+
623
+ else:
624
+ hidden_states = resnet(hidden_states, temb)
625
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
626
+
627
+ if enable_temporal_attentions and temporal_attention:
628
+ hidden_states = temporal_attention(hidden_states,
629
+ encoder_hidden_states=encoder_hidden_states)
630
+
631
+ if self.upsamplers is not None:
632
+ for upsampler in self.upsamplers:
633
+ hidden_states = upsampler(hidden_states, upsample_size)
634
+
635
+ return hidden_states
636
+
637
+ def temporal_parameters(self) -> list:
638
+ output = []
639
+ for block in self.temporal_attentions:
640
+ if block:
641
+ output.extend(block.parameters())
642
+ return output
643
+
644
+
645
+ class UpBlock3D(nn.Module):
646
+ def __init__(
647
+ self,
648
+ in_channels: int,
649
+ prev_output_channel: int,
650
+ out_channels: int,
651
+ temb_channels: int,
652
+ dropout: float = 0.0,
653
+ num_layers: int = 1,
654
+ resnet_eps: float = 1e-6,
655
+ resnet_time_scale_shift: str = "default",
656
+ resnet_act_fn: str = "swish",
657
+ resnet_groups: int = 32,
658
+ resnet_pre_norm: bool = True,
659
+ output_scale_factor=1.0,
660
+ add_upsample=True,
661
+ ):
662
+ super().__init__()
663
+ resnets = []
664
+ temporal_attentions = []
665
+
666
+ for i in range(num_layers):
667
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
668
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
669
+
670
+ resnets.append(
671
+ ResnetBlock3D(
672
+ in_channels=resnet_in_channels + res_skip_channels,
673
+ out_channels=out_channels,
674
+ temb_channels=temb_channels,
675
+ eps=resnet_eps,
676
+ groups=resnet_groups,
677
+ dropout=dropout,
678
+ time_embedding_norm=resnet_time_scale_shift,
679
+ non_linearity=resnet_act_fn,
680
+ output_scale_factor=output_scale_factor,
681
+ pre_norm=resnet_pre_norm,
682
+ )
683
+ )
684
+ temporal_attentions.append(
685
+ TransformerTemporal(
686
+ num_attention_heads=8,
687
+ attention_head_dim=out_channels // 8,
688
+ in_channels=out_channels,
689
+ cross_attention_dim=None
690
+ )
691
+ )
692
+
693
+ self.resnets = nn.ModuleList(resnets)
694
+ self.temporal_attentions = nn.ModuleList(temporal_attentions)
695
+
696
+ if add_upsample:
697
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
698
+ else:
699
+ self.upsamplers = None
700
+
701
+ self.gradient_checkpointing = False
702
+
703
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,
704
+ enable_temporal_attentions: bool = True):
705
+ for resnet, temporal_attention in zip(self.resnets, self.temporal_attentions):
706
+ # pop res hidden states
707
+ res_hidden_states = res_hidden_states_tuple[-1]
708
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
709
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
710
+
711
+ if self.training and self.gradient_checkpointing:
712
+ def create_custom_forward(module):
713
+ def custom_forward(*inputs):
714
+ return module(*inputs)
715
+
716
+ return custom_forward
717
+
718
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb,
719
+ use_reentrant=False)
720
+ if enable_temporal_attentions and temporal_attention is not None:
721
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temporal_attention),
722
+ hidden_states, encoder_hidden_states,
723
+ use_reentrant=False)
724
+ else:
725
+ hidden_states = resnet(hidden_states, temb)
726
+ hidden_states = temporal_attention(hidden_states,
727
+ encoder_hidden_states=encoder_hidden_states) if enable_temporal_attentions and temporal_attention is not None else hidden_states
728
+
729
+ if self.upsamplers is not None:
730
+ for upsampler in self.upsamplers:
731
+ hidden_states = upsampler(hidden_states, upsample_size)
732
+
733
+ return hidden_states
734
+
735
+ def temporal_parameters(self) -> list:
736
+ output = []
737
+ for block in self.temporal_attentions:
738
+ if block:
739
+ output.extend(block.parameters())
740
+ return output
hotshot_xl/pipelines/__init__.py ADDED
File without changes
hotshot_xl/pipelines/hotshot_xl_controlnet_pipeline.py ADDED
@@ -0,0 +1,1389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modifications:
16
+ # Copyright 2023 Natural Synthetics Inc. All rights reserved.
17
+ # - Adapted the SDXL Controlnet Pipeline to work temporally
18
+
19
+ import inspect
20
+ import os
21
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import PIL.Image
25
+ import torch
26
+ import torch.nn.functional as F
27
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
28
+
29
+ from hotshot_xl import HotshotPipelineXLOutput
30
+
31
+ from diffusers.image_processor import VaeImageProcessor
32
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
33
+ from diffusers.models import AutoencoderKL, ControlNetModel
34
+ from diffusers.models.attention_processor import (
35
+ AttnProcessor2_0,
36
+ LoRAAttnProcessor2_0,
37
+ LoRAXFormersAttnProcessor,
38
+ XFormersAttnProcessor,
39
+ )
40
+ from diffusers.schedulers import KarrasDiffusionSchedulers
41
+ from diffusers.utils import (
42
+ is_accelerate_available,
43
+ is_accelerate_version,
44
+ logging,
45
+ replace_example_docstring,
46
+ )
47
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
48
+ from diffusers.utils.torch_utils import randn_tensor, is_compiled_module
49
+
50
+ from ..models.unet import UNet3DConditionModel
51
+
52
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
53
+ from einops import rearrange
54
+ from tqdm import tqdm
55
+
56
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57
+
58
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
59
+ """
60
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
61
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
62
+ """
63
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
64
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
65
+ # rescale the results from guidance (fixes overexposure)
66
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
67
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
68
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
69
+ return noise_cfg
70
+
71
+ EXAMPLE_DOC_STRING = """
72
+ Examples:
73
+ ```py
74
+ >>> import torch
75
+ >>> from hotshot_xl import HotshotPipelineXL
76
+ >>> from diffusers import ControlNetModel
77
+
78
+ >>> pipe = HotshotXLPipeline.from_pretrained(
79
+ ... "hotshotco/Hotshot-XL",
80
+ ... controlnet=ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0")
81
+ ... )
82
+
83
+ >>> def canny(image):
84
+ >>> image = cv2.Canny(image, 100, 200)
85
+ >>> image = image[:, :, None]
86
+ >>> image = np.concatenate([image, image, image], axis=2)
87
+ >>> return Image.fromarray(image)
88
+
89
+ >>> # assuming you have 8 keyframes in current directory...
90
+
91
+ >>> keyframes = [f"image_{i}.jpg" for i in range(8)]
92
+ >>> control_images = [canny(Image.open(fp)) for fp in keyframes]
93
+
94
+ >>> pipe = pipe.to("cuda")
95
+
96
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
97
+ >>> video = pipe(prompt,
98
+ ... width=672, height=384,
99
+ ... original_size=(1920, 1080),
100
+ ... target_size=(512, 512),
101
+ ... output_type="tensor",
102
+ ... controlnet_conditioning_scale=0.7,
103
+ ... control_images=control_images
104
+ ).video
105
+ ```
106
+ """
107
+ class HotshotXLControlNetPipeline(
108
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
109
+ ):
110
+ r"""
111
+ Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
112
+
113
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
114
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
115
+
116
+ The pipeline also inherits the following loading methods:
117
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
118
+ - [`loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
119
+ - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
120
+
121
+ Args:
122
+ vae ([`AutoencoderKL`]):
123
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
124
+ text_encoder ([`~transformers.CLIPTextModel`]):
125
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
126
+ text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]):
127
+ Second frozen text-encoder
128
+ ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
129
+ tokenizer ([`~transformers.CLIPTokenizer`]):
130
+ A `CLIPTokenizer` to tokenize text.
131
+ tokenizer_2 ([`~transformers.CLIPTokenizer`]):
132
+ A `CLIPTokenizer` to tokenize text.
133
+ unet ([`UNet3DConditionModel`]):
134
+ A `UNet3DConditionModel` to denoise the encoded image latents.
135
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
136
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
137
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
138
+ additional conditioning.
139
+ scheduler ([`SchedulerMixin`]):
140
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
141
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
142
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
143
+ Whether the negative prompt embeddings should always be set to 0. Also see the config of
144
+ `stabilityai/stable-diffusion-xl-base-1-0`.
145
+ add_watermarker (`bool`, *optional*):
146
+ Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to
147
+ watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no
148
+ watermarker is used.
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ vae: AutoencoderKL,
154
+ text_encoder: CLIPTextModel,
155
+ text_encoder_2: CLIPTextModelWithProjection,
156
+ tokenizer: CLIPTokenizer,
157
+ tokenizer_2: CLIPTokenizer,
158
+ unet: UNet3DConditionModel,
159
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
160
+ scheduler: KarrasDiffusionSchedulers,
161
+ force_zeros_for_empty_prompt: bool = True,
162
+ add_watermarker: Optional[bool] = None,
163
+ ):
164
+ super().__init__()
165
+
166
+ if isinstance(controlnet, (list, tuple)):
167
+ controlnet = MultiControlNetModel(controlnet)
168
+
169
+ self.register_modules(
170
+ vae=vae,
171
+ text_encoder=text_encoder,
172
+ text_encoder_2=text_encoder_2,
173
+ tokenizer=tokenizer,
174
+ tokenizer_2=tokenizer_2,
175
+ unet=unet,
176
+ controlnet=controlnet,
177
+ scheduler=scheduler,
178
+ )
179
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
180
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
181
+ self.control_image_processor = VaeImageProcessor(
182
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
183
+ )
184
+
185
+ self.watermark = None
186
+
187
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
188
+
189
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
190
+ def enable_vae_slicing(self):
191
+ r"""
192
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
193
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
194
+ """
195
+ self.vae.enable_slicing()
196
+
197
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
198
+ def disable_vae_slicing(self):
199
+ r"""
200
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
201
+ computing decoding in one step.
202
+ """
203
+ self.vae.disable_slicing()
204
+
205
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
206
+ def enable_vae_tiling(self):
207
+ r"""
208
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
209
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
210
+ processing larger images.
211
+ """
212
+ self.vae.enable_tiling()
213
+
214
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
215
+ def disable_vae_tiling(self):
216
+ r"""
217
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
218
+ computing decoding in one step.
219
+ """
220
+ self.vae.disable_tiling()
221
+
222
+ def enable_model_cpu_offload(self, gpu_id=0):
223
+ r"""
224
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
225
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
226
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
227
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
228
+ """
229
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
230
+ from accelerate import cpu_offload_with_hook
231
+ else:
232
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
233
+
234
+ device = torch.device(f"cuda:{gpu_id}")
235
+
236
+ if self.device.type != "cpu":
237
+ self.to("cpu", silence_dtype_warnings=True)
238
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
239
+
240
+ model_sequence = (
241
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
242
+ )
243
+ model_sequence.extend([self.unet, self.vae])
244
+
245
+ hook = None
246
+ for cpu_offloaded_model in model_sequence:
247
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
248
+
249
+ cpu_offload_with_hook(self.controlnet, device)
250
+
251
+ # We'll offload the last model manually.
252
+ self.final_offload_hook = hook
253
+
254
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
255
+ def encode_prompt(
256
+ self,
257
+ prompt: str,
258
+ prompt_2: Optional[str] = None,
259
+ device: Optional[torch.device] = None,
260
+ num_images_per_prompt: int = 1,
261
+ do_classifier_free_guidance: bool = True,
262
+ negative_prompt: Optional[str] = None,
263
+ negative_prompt_2: Optional[str] = None,
264
+ prompt_embeds: Optional[torch.FloatTensor] = None,
265
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
266
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
267
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
268
+ lora_scale: Optional[float] = None,
269
+ ):
270
+ r"""
271
+ Encodes the prompt into text encoder hidden states.
272
+
273
+ Args:
274
+ prompt (`str` or `List[str]`, *optional*):
275
+ prompt to be encoded
276
+ prompt_2 (`str` or `List[str]`, *optional*):
277
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
278
+ used in both text-encoders
279
+ device: (`torch.device`):
280
+ torch device
281
+ num_images_per_prompt (`int`):
282
+ number of images that should be generated per prompt
283
+ do_classifier_free_guidance (`bool`):
284
+ whether to use classifier free guidance or not
285
+ negative_prompt (`str` or `List[str]`, *optional*):
286
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
287
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
288
+ less than `1`).
289
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
290
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
291
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
292
+ prompt_embeds (`torch.FloatTensor`, *optional*):
293
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
294
+ provided, text embeddings will be generated from `prompt` input argument.
295
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
296
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
297
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
298
+ argument.
299
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
300
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
301
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
302
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
303
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
304
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
305
+ input argument.
306
+ lora_scale (`float`, *optional*):
307
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
308
+ """
309
+ device = device or self._execution_device
310
+
311
+ # set lora scale so that monkey patched LoRA
312
+ # function of text encoder can correctly access it
313
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
314
+ self._lora_scale = lora_scale
315
+
316
+ if prompt is not None and isinstance(prompt, str):
317
+ batch_size = 1
318
+ elif prompt is not None and isinstance(prompt, list):
319
+ batch_size = len(prompt)
320
+ else:
321
+ batch_size = prompt_embeds.shape[0]
322
+
323
+ # Define tokenizers and text encoders
324
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
325
+ text_encoders = (
326
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
327
+ )
328
+
329
+ if prompt_embeds is None:
330
+ prompt_2 = prompt_2 or prompt
331
+ # textual inversion: procecss multi-vector tokens if necessary
332
+ prompt_embeds_list = []
333
+ prompts = [prompt, prompt_2]
334
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
335
+ if isinstance(self, TextualInversionLoaderMixin):
336
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
337
+
338
+ text_inputs = tokenizer(
339
+ prompt,
340
+ padding="max_length",
341
+ max_length=tokenizer.model_max_length,
342
+ truncation=True,
343
+ return_tensors="pt",
344
+ )
345
+
346
+ text_input_ids = text_inputs.input_ids
347
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
348
+
349
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
350
+ text_input_ids, untruncated_ids
351
+ ):
352
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
353
+ logger.warning(
354
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
355
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
356
+ )
357
+
358
+ prompt_embeds = text_encoder(
359
+ text_input_ids.to(device),
360
+ output_hidden_states=True,
361
+ )
362
+
363
+ # We are only ALWAYS interested in the pooled output of the final text encoder
364
+ pooled_prompt_embeds = prompt_embeds[0]
365
+ prompt_embeds = prompt_embeds.hidden_states[-2]
366
+
367
+ prompt_embeds_list.append(prompt_embeds)
368
+
369
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
370
+
371
+ # get unconditional embeddings for classifier free guidance
372
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
373
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
374
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
375
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
376
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
377
+ negative_prompt = negative_prompt or ""
378
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
379
+
380
+ uncond_tokens: List[str]
381
+ if prompt is not None and type(prompt) is not type(negative_prompt):
382
+ raise TypeError(
383
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
384
+ f" {type(prompt)}."
385
+ )
386
+ elif isinstance(negative_prompt, str):
387
+ uncond_tokens = [negative_prompt, negative_prompt_2]
388
+ elif batch_size != len(negative_prompt):
389
+ raise ValueError(
390
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
391
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
392
+ " the batch size of `prompt`."
393
+ )
394
+ else:
395
+ uncond_tokens = [negative_prompt, negative_prompt_2]
396
+
397
+ negative_prompt_embeds_list = []
398
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
399
+ if isinstance(self, TextualInversionLoaderMixin):
400
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
401
+
402
+ max_length = prompt_embeds.shape[1]
403
+ uncond_input = tokenizer(
404
+ negative_prompt,
405
+ padding="max_length",
406
+ max_length=max_length,
407
+ truncation=True,
408
+ return_tensors="pt",
409
+ )
410
+
411
+ negative_prompt_embeds = text_encoder(
412
+ uncond_input.input_ids.to(device),
413
+ output_hidden_states=True,
414
+ )
415
+ # We are only ALWAYS interested in the pooled output of the final text encoder
416
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
417
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
418
+
419
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
420
+
421
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
422
+
423
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
424
+ bs_embed, seq_len, _ = prompt_embeds.shape
425
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
426
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
427
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
428
+
429
+ if do_classifier_free_guidance:
430
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
431
+ seq_len = negative_prompt_embeds.shape[1]
432
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
433
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
434
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
435
+
436
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
437
+ bs_embed * num_images_per_prompt, -1
438
+ )
439
+ if do_classifier_free_guidance:
440
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
441
+ bs_embed * num_images_per_prompt, -1
442
+ )
443
+
444
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
445
+
446
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
447
+ def prepare_extra_step_kwargs(self, generator, eta):
448
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
449
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
450
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
451
+ # and should be between [0, 1]
452
+
453
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
454
+ extra_step_kwargs = {}
455
+ if accepts_eta:
456
+ extra_step_kwargs["eta"] = eta
457
+
458
+ # check if the scheduler accepts generator
459
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
460
+ if accepts_generator:
461
+ extra_step_kwargs["generator"] = generator
462
+ return extra_step_kwargs
463
+
464
+ def check_inputs(
465
+ self,
466
+ prompt,
467
+ prompt_2,
468
+ control_images,
469
+ video_length,
470
+ callback_steps,
471
+ negative_prompt=None,
472
+ negative_prompt_2=None,
473
+ prompt_embeds=None,
474
+ negative_prompt_embeds=None,
475
+ pooled_prompt_embeds=None,
476
+ negative_pooled_prompt_embeds=None,
477
+ controlnet_conditioning_scale=1.0,
478
+ control_guidance_start=0.0,
479
+ control_guidance_end=1.0,
480
+ ):
481
+ if (callback_steps is None) or (
482
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
483
+ ):
484
+ raise ValueError(
485
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
486
+ f" {type(callback_steps)}."
487
+ )
488
+
489
+ if prompt is not None and prompt_embeds is not None:
490
+ raise ValueError(
491
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
492
+ " only forward one of the two."
493
+ )
494
+ elif prompt_2 is not None and prompt_embeds is not None:
495
+ raise ValueError(
496
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
497
+ " only forward one of the two."
498
+ )
499
+ elif prompt is None and prompt_embeds is None:
500
+ raise ValueError(
501
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
502
+ )
503
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
504
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
505
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
506
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
507
+
508
+ if negative_prompt is not None and negative_prompt_embeds is not None:
509
+ raise ValueError(
510
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
511
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
512
+ )
513
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
514
+ raise ValueError(
515
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
516
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
517
+ )
518
+
519
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
520
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
521
+ raise ValueError(
522
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
523
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
524
+ f" {negative_prompt_embeds.shape}."
525
+ )
526
+
527
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
528
+ raise ValueError(
529
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
530
+ )
531
+
532
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
533
+ raise ValueError(
534
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
535
+ )
536
+
537
+ # `prompt` needs more sophisticated handling when there are multiple
538
+ # conditionings.
539
+ if isinstance(self.controlnet, MultiControlNetModel):
540
+ if isinstance(prompt, list):
541
+ logger.warning(
542
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
543
+ " prompts. The conditionings will be fixed across the prompts."
544
+ )
545
+
546
+ # Check `image`
547
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
548
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
549
+ )
550
+ if (
551
+ isinstance(self.controlnet, ControlNetModel)
552
+ or is_compiled
553
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
554
+ ):
555
+
556
+ assert len(control_images) == video_length
557
+ # for image in control_images:
558
+ # self.check_image(image, prompt, prompt_embeds)
559
+ elif (
560
+ isinstance(self.controlnet, MultiControlNetModel)
561
+ or is_compiled
562
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
563
+ ):
564
+ ...
565
+ # todo
566
+ #
567
+ # if not isinstance(image, list):
568
+ # raise TypeError("For multiple controlnets: `image` must be type `list`")
569
+ #
570
+ # # When `image` is a nested list:
571
+ # # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
572
+ # elif any(isinstance(i, list) for i in image):
573
+ # raise ValueError("A single batch of multiple conditionings are supported at the moment.")
574
+ # elif len(image) != len(self.controlnet.nets):
575
+ # raise ValueError(
576
+ # f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
577
+ # )
578
+ #
579
+ # for image_ in image:
580
+ # self.check_image(image_, prompt, prompt_embeds)
581
+ else:
582
+ assert False
583
+
584
+ # Check `controlnet_conditioning_scale`
585
+ if (
586
+ isinstance(self.controlnet, ControlNetModel)
587
+ or is_compiled
588
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
589
+ ):
590
+ if not isinstance(controlnet_conditioning_scale, float):
591
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
592
+ elif (
593
+ isinstance(self.controlnet, MultiControlNetModel)
594
+ or is_compiled
595
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
596
+ ):
597
+ if isinstance(controlnet_conditioning_scale, list):
598
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
599
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
600
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
601
+ self.controlnet.nets
602
+ ):
603
+ raise ValueError(
604
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
605
+ " the same length as the number of controlnets"
606
+ )
607
+ else:
608
+ assert False
609
+
610
+ if not isinstance(control_guidance_start, (tuple, list)):
611
+ control_guidance_start = [control_guidance_start]
612
+
613
+ if not isinstance(control_guidance_end, (tuple, list)):
614
+ control_guidance_end = [control_guidance_end]
615
+
616
+ if len(control_guidance_start) != len(control_guidance_end):
617
+ raise ValueError(
618
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
619
+ )
620
+
621
+ if isinstance(self.controlnet, MultiControlNetModel):
622
+ if len(control_guidance_start) != len(self.controlnet.nets):
623
+ raise ValueError(
624
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
625
+ )
626
+
627
+ for start, end in zip(control_guidance_start, control_guidance_end):
628
+ if start >= end:
629
+ raise ValueError(
630
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
631
+ )
632
+ if start < 0.0:
633
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
634
+ if end > 1.0:
635
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
636
+
637
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
638
+ def check_image(self, image, prompt, prompt_embeds):
639
+ image_is_pil = isinstance(image, PIL.Image.Image)
640
+ image_is_tensor = isinstance(image, torch.Tensor)
641
+ image_is_np = isinstance(image, np.ndarray)
642
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
643
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
644
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
645
+
646
+ if (
647
+ not image_is_pil
648
+ and not image_is_tensor
649
+ and not image_is_np
650
+ and not image_is_pil_list
651
+ and not image_is_tensor_list
652
+ and not image_is_np_list
653
+ ):
654
+ raise TypeError(
655
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
656
+ )
657
+
658
+ if image_is_pil:
659
+ image_batch_size = 1
660
+ else:
661
+ image_batch_size = len(image)
662
+
663
+ if prompt is not None and isinstance(prompt, str):
664
+ prompt_batch_size = 1
665
+ elif prompt is not None and isinstance(prompt, list):
666
+ prompt_batch_size = len(prompt)
667
+ elif prompt_embeds is not None:
668
+ prompt_batch_size = prompt_embeds.shape[0]
669
+
670
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
671
+ raise ValueError(
672
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
673
+ )
674
+
675
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
676
+ def prepare_images(
677
+ self,
678
+ images,
679
+ width,
680
+ height,
681
+ batch_size,
682
+ num_images_per_prompt,
683
+ device,
684
+ dtype,
685
+ do_classifier_free_guidance=False,
686
+ guess_mode=False,
687
+ ):
688
+ images_pre_processed = [self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) for image in images]
689
+
690
+ images_pre_processed = torch.cat(images_pre_processed, dim=0)
691
+
692
+ repeat_factor = [1] * len(images_pre_processed.shape)
693
+ repeat_factor[0] = batch_size * num_images_per_prompt
694
+ images_pre_processed = images_pre_processed.repeat(*repeat_factor)
695
+
696
+ images = images_pre_processed.unsqueeze(0)
697
+
698
+ # image_batch_size = image.shape[0]
699
+ #
700
+ # if image_batch_size == 1:
701
+ # repeat_by = batch_size
702
+ # else:
703
+ # # image batch size is the same as prompt batch size
704
+ # repeat_by = num_images_per_prompt
705
+
706
+ #image = image.repeat_interleave(repeat_by, dim=0)
707
+
708
+ images = images.to(device=device, dtype=dtype)
709
+
710
+ if do_classifier_free_guidance and not guess_mode:
711
+ repeat_factor = [1] * len(images.shape)
712
+ repeat_factor[0] = 2
713
+ images = images.repeat(*repeat_factor)
714
+
715
+ return images
716
+
717
+ # def prepare_images(self,
718
+ # images: list,
719
+ # width,
720
+ # height,
721
+ # batch_size,
722
+ # num_images_per_prompt,
723
+ # device,
724
+ # dtype,
725
+ # do_classifier_free_guidance=False,
726
+ # guess_mode=False):
727
+ #
728
+ # images = [self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) for image in images]
729
+ #
730
+ # image_batch_size = image.shape[0]
731
+ #
732
+ # if image_batch_size == 1:
733
+ # repeat_by = batch_size
734
+ # else:
735
+ # # image batch size is the same as prompt batch size
736
+ # repeat_by = num_images_per_prompt
737
+ #
738
+ # image = image.repeat_interleave(repeat_by, dim=0)
739
+ #
740
+ # image = image.to(device=device, dtype=dtype)
741
+ #
742
+ # if do_classifier_free_guidance and not guess_mode:
743
+ # image = torch.cat([image] * 2)
744
+ #
745
+ # return image
746
+
747
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
748
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
749
+ #shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
750
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
751
+ if isinstance(generator, list) and len(generator) != batch_size:
752
+ raise ValueError(
753
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
754
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
755
+ )
756
+
757
+ if latents is None:
758
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
759
+ else:
760
+ latents = latents.to(device)
761
+
762
+ # scale the initial noise by the standard deviation required by the scheduler
763
+ latents = latents * self.scheduler.init_noise_sigma
764
+ return latents
765
+
766
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
767
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
768
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
769
+
770
+ passed_add_embed_dim = (
771
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
772
+ )
773
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
774
+
775
+ if expected_add_embed_dim != passed_add_embed_dim:
776
+ raise ValueError(
777
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
778
+ )
779
+
780
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
781
+ return add_time_ids
782
+
783
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
784
+ def upcast_vae(self):
785
+ dtype = self.vae.dtype
786
+ self.vae.to(dtype=torch.float32)
787
+ use_torch_2_0_or_xformers = isinstance(
788
+ self.vae.decoder.mid_block.attentions[0].processor,
789
+ (
790
+ AttnProcessor2_0,
791
+ XFormersAttnProcessor,
792
+ LoRAXFormersAttnProcessor,
793
+ LoRAAttnProcessor2_0,
794
+ ),
795
+ )
796
+ # if xformers or torch_2_0 is used attention block does not need
797
+ # to be in float32 which can save lots of memory
798
+ if use_torch_2_0_or_xformers:
799
+ self.vae.post_quant_conv.to(dtype)
800
+ self.vae.decoder.conv_in.to(dtype)
801
+ self.vae.decoder.mid_block.to(dtype)
802
+
803
+ @torch.no_grad()
804
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
805
+ def __call__(
806
+ self,
807
+ prompt: Union[str, List[str]] = None,
808
+ prompt_2: Optional[Union[str, List[str]]] = None,
809
+ video_length: Optional[int] = 8,
810
+ control_images: List[PIL.Image.Image] = None,
811
+ height: Optional[int] = None,
812
+ width: Optional[int] = None,
813
+ num_inference_steps: int = 50,
814
+ guidance_scale: float = 5.0,
815
+ negative_prompt: Optional[Union[str, List[str]]] = None,
816
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
817
+ num_images_per_prompt: Optional[int] = 1,
818
+ eta: float = 0.0,
819
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
820
+ latents: Optional[torch.FloatTensor] = None,
821
+ prompt_embeds: Optional[torch.FloatTensor] = None,
822
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
823
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
824
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
825
+ output_type: Optional[str] = "pil",
826
+ return_dict: bool = True,
827
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
828
+ callback_steps: int = 1,
829
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
830
+ guidance_rescale: float = 0.0,
831
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
832
+ guess_mode: bool = False,
833
+ control_guidance_start: Union[float, List[float]] = 0.0,
834
+ control_guidance_end: Union[float, List[float]] = 1.0,
835
+ original_size: Tuple[int, int] = None,
836
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
837
+ target_size: Tuple[int, int] = None,
838
+ negative_original_size: Optional[Tuple[int, int]] = None,
839
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
840
+ negative_target_size: Optional[Tuple[int, int]] = None,
841
+ ):
842
+ r"""
843
+ The call function to the pipeline for generation.
844
+
845
+ Args:
846
+ prompt (`str` or `List[str]`, *optional*):
847
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
848
+ prompt_2 (`str` or `List[str]`, *optional*):
849
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
850
+ used in both text-encoders.
851
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
852
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
853
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
854
+ specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
855
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
856
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
857
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
858
+ input to a single ControlNet.
859
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
860
+ The height in pixels of the generated image.
861
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
862
+ The width in pixels of the generated image.
863
+ num_inference_steps (`int`, *optional*, defaults to 50):
864
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
865
+ expense of slower inference.
866
+ guidance_scale (`float`, *optional*, defaults to 5.0):
867
+ A higher guidance scale value encourages the model to generate images closely linked to the text
868
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
869
+ negative_prompt (`str` or `List[str]`, *optional*):
870
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
871
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
872
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
873
+ The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
874
+ and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
875
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
876
+ The number of images to generate per prompt.
877
+ eta (`float`, *optional*, defaults to 0.0):
878
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
879
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
880
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
881
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
882
+ generation deterministic.
883
+ latents (`torch.FloatTensor`, *optional*):
884
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
885
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
886
+ tensor is generated by sampling using the supplied random `generator`.
887
+ prompt_embeds (`torch.FloatTensor`, *optional*):
888
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
889
+ provided, text embeddings are generated from the `prompt` input argument.
890
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
891
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
892
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
893
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
894
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
895
+ not provided, pooled text embeddings are generated from `prompt` input argument.
896
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
897
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
898
+ weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
899
+ argument.
900
+ output_type (`str`, *optional*, defaults to `"pil"`):
901
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
902
+ return_dict (`bool`, *optional*, defaults to `True`):
903
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
904
+ plain tuple.
905
+ callback (`Callable`, *optional*):
906
+ A function that calls every `callback_steps` steps during inference. The function is called with the
907
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
908
+ callback_steps (`int`, *optional*, defaults to 1):
909
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
910
+ every step.
911
+ cross_attention_kwargs (`dict`, *optional*):
912
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
913
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
914
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
915
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
916
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
917
+ the corresponding scale as a list.
918
+ guess_mode (`bool`, *optional*, defaults to `False`):
919
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
920
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
921
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
922
+ The percentage of total steps at which the ControlNet starts applying.
923
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
924
+ The percentage of total steps at which the ControlNet stops applying.
925
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
926
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
927
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
928
+ explained in section 2.2 of
929
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
930
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
931
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
932
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
933
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
934
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
935
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
936
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
937
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
938
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
939
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
940
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
941
+ micro-conditioning as explained in section 2.2 of
942
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
943
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
944
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
945
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
946
+ micro-conditioning as explained in section 2.2 of
947
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
948
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
949
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
950
+ To negatively condition the generation process based on a target image resolution. It should be as same
951
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
952
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
953
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
954
+
955
+ Examples:
956
+
957
+ Returns:
958
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
959
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
960
+ otherwise a `tuple` is returned containing the output images.
961
+ """
962
+
963
+
964
+ if video_length > 1 and num_images_per_prompt > 1:
965
+ print(f"Warning - setting num_images_per_prompt = 1 because video_length = {video_length}")
966
+ num_images_per_prompt = 1
967
+
968
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
969
+
970
+ # align format for control guidance
971
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
972
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
973
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
974
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
975
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
976
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
977
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
978
+ control_guidance_end
979
+ ]
980
+
981
+ # 1. Check inputs. Raise error if not correct
982
+ self.check_inputs(
983
+ prompt,
984
+ prompt_2,
985
+ control_images,
986
+ video_length,
987
+ callback_steps,
988
+ negative_prompt,
989
+ negative_prompt_2,
990
+ prompt_embeds,
991
+ negative_prompt_embeds,
992
+ pooled_prompt_embeds,
993
+ negative_pooled_prompt_embeds,
994
+ controlnet_conditioning_scale,
995
+ control_guidance_start,
996
+ control_guidance_end,
997
+ )
998
+
999
+ # 2. Define call parameters
1000
+ if prompt is not None and isinstance(prompt, str):
1001
+ batch_size = 1
1002
+ elif prompt is not None and isinstance(prompt, list):
1003
+ batch_size = len(prompt)
1004
+ else:
1005
+ batch_size = prompt_embeds.shape[0]
1006
+
1007
+ device = self._execution_device
1008
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1009
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1010
+ # corresponds to doing no classifier free guidance.
1011
+ do_classifier_free_guidance = guidance_scale > 1.0
1012
+
1013
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1014
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1015
+
1016
+ global_pool_conditions = (
1017
+ controlnet.config.global_pool_conditions
1018
+ if isinstance(controlnet, ControlNetModel)
1019
+ else controlnet.nets[0].config.global_pool_conditions
1020
+ )
1021
+ guess_mode = guess_mode or global_pool_conditions
1022
+
1023
+ # 3. Encode input prompt
1024
+ text_encoder_lora_scale = (
1025
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
1026
+ )
1027
+ (
1028
+ prompt_embeds,
1029
+ negative_prompt_embeds,
1030
+ pooled_prompt_embeds,
1031
+ negative_pooled_prompt_embeds,
1032
+ ) = self.encode_prompt(
1033
+ prompt,
1034
+ prompt_2,
1035
+ device,
1036
+ num_images_per_prompt,
1037
+ do_classifier_free_guidance,
1038
+ negative_prompt,
1039
+ negative_prompt_2,
1040
+ prompt_embeds=prompt_embeds,
1041
+ negative_prompt_embeds=negative_prompt_embeds,
1042
+ pooled_prompt_embeds=pooled_prompt_embeds,
1043
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1044
+ lora_scale=text_encoder_lora_scale,
1045
+ )
1046
+
1047
+
1048
+ # 4. Prepare image
1049
+ if isinstance(controlnet, ControlNetModel):
1050
+
1051
+ assert len(control_images) == video_length * batch_size
1052
+
1053
+ images = self.prepare_images(
1054
+ images=control_images,
1055
+ width=width,
1056
+ height=height,
1057
+ batch_size=batch_size * num_images_per_prompt,
1058
+ num_images_per_prompt=num_images_per_prompt,
1059
+ device=device,
1060
+ dtype=controlnet.dtype,
1061
+ do_classifier_free_guidance=do_classifier_free_guidance,
1062
+ guess_mode=guess_mode,
1063
+ )
1064
+
1065
+ height, width = images.shape[-2:]
1066
+ elif isinstance(controlnet, MultiControlNetModel):
1067
+
1068
+ raise Exception("not supported yet")
1069
+
1070
+ # images = []
1071
+ #
1072
+ # for image_ in control_images:
1073
+ # image_ = self.prepare_image(
1074
+ # image=image_,
1075
+ # width=width,
1076
+ # height=height,
1077
+ # batch_size=batch_size * num_images_per_prompt,
1078
+ # num_images_per_prompt=num_images_per_prompt,
1079
+ # device=device,
1080
+ # dtype=controlnet.dtype,
1081
+ # do_classifier_free_guidance=do_classifier_free_guidance,
1082
+ # guess_mode=guess_mode,
1083
+ # )
1084
+ #
1085
+ # images.append(image_)
1086
+ #
1087
+ # image = images
1088
+ # height, width = image[0].shape[-2:]
1089
+ else:
1090
+ assert False
1091
+
1092
+ # 5. Prepare timesteps
1093
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1094
+ timesteps = self.scheduler.timesteps
1095
+
1096
+ # 6. Prepare latent variables
1097
+ num_channels_latents = self.unet.config.in_channels
1098
+ latents = self.prepare_latents(
1099
+ batch_size * num_images_per_prompt,
1100
+ num_channels_latents,
1101
+ video_length,
1102
+ height,
1103
+ width,
1104
+ prompt_embeds.dtype,
1105
+ device,
1106
+ generator,
1107
+ latents,
1108
+ )
1109
+
1110
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1111
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1112
+
1113
+ # 7.1 Create tensor stating which controlnets to keep
1114
+ controlnet_keep = []
1115
+ for i in range(len(timesteps)):
1116
+ keeps = [
1117
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1118
+ for s, e in zip(control_guidance_start, control_guidance_end)
1119
+ ]
1120
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1121
+
1122
+ # 7.2 Prepare added time ids & embeddings
1123
+ # if isinstance(image, list):
1124
+ # original_size = original_size or image[0].shape[-2:]
1125
+ # else:
1126
+ original_size = original_size or images.shape[-2:]
1127
+ target_size = target_size or (height, width)
1128
+
1129
+ add_text_embeds = pooled_prompt_embeds
1130
+ add_time_ids = self._get_add_time_ids(
1131
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
1132
+ )
1133
+
1134
+ if negative_original_size is not None and negative_target_size is not None:
1135
+ negative_add_time_ids = self._get_add_time_ids(
1136
+ negative_original_size,
1137
+ negative_crops_coords_top_left,
1138
+ negative_target_size,
1139
+ dtype=prompt_embeds.dtype,
1140
+ )
1141
+ else:
1142
+ negative_add_time_ids = add_time_ids
1143
+
1144
+ if do_classifier_free_guidance:
1145
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1146
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1147
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1148
+
1149
+ prompt_embeds = prompt_embeds.to(device)
1150
+ add_text_embeds = add_text_embeds.to(device)
1151
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1152
+
1153
+ # 8. Denoising loop
1154
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1155
+
1156
+ images = rearrange(images, "b f c h w -> (b f) c h w")
1157
+
1158
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1159
+ for i, t in enumerate(timesteps):
1160
+ # expand the latents if we are doing classifier free guidance
1161
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1162
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1163
+
1164
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1165
+
1166
+ # controlnet(s) inference
1167
+ if guess_mode and do_classifier_free_guidance:
1168
+ # Infer ControlNet only for the conditional batch.
1169
+ control_model_input = latents
1170
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1171
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1172
+ controlnet_added_cond_kwargs = {
1173
+ "text_embeds": add_text_embeds.chunk(2)[1],
1174
+ "time_ids": add_time_ids.chunk(2)[1],
1175
+ }
1176
+ else:
1177
+ control_model_input = latent_model_input
1178
+ controlnet_prompt_embeds = prompt_embeds
1179
+ controlnet_added_cond_kwargs = added_cond_kwargs
1180
+
1181
+ if isinstance(controlnet_keep[i], list):
1182
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1183
+ else:
1184
+ controlnet_cond_scale = controlnet_conditioning_scale
1185
+ if isinstance(controlnet_cond_scale, list):
1186
+ controlnet_cond_scale = controlnet_cond_scale[0]
1187
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1188
+
1189
+
1190
+ # this will be non interlaced when arranged!
1191
+ control_model_input = rearrange(control_model_input, "b c f h w -> (b f) c h w")
1192
+ # if we chunked this by 2 - the top 8 frames will be positive for cfg
1193
+ # the bottom half will be negative for cfg...
1194
+
1195
+ if video_length > 1:
1196
+ # use repeat_interleave as we need to match the rearrangement above.
1197
+
1198
+ controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(video_length, dim=0)
1199
+ controlnet_added_cond_kwargs = {
1200
+ "text_embeds": controlnet_added_cond_kwargs['text_embeds'].repeat_interleave(video_length, dim=0),
1201
+ "time_ids": controlnet_added_cond_kwargs['time_ids'].repeat_interleave(video_length, dim=0)
1202
+ }
1203
+
1204
+ # if type(image) is list:
1205
+ # image = torch.cat(image, dim=0)
1206
+
1207
+ # todo - check if video_length > 1 this needs to produce num_frames * batch_size samples...
1208
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1209
+ control_model_input,
1210
+ t,
1211
+ encoder_hidden_states=controlnet_prompt_embeds,
1212
+ controlnet_cond=images,
1213
+ conditioning_scale=cond_scale,
1214
+ guess_mode=guess_mode,
1215
+ added_cond_kwargs=controlnet_added_cond_kwargs,
1216
+ return_dict=False,
1217
+ )
1218
+
1219
+ for j, sample in enumerate(down_block_res_samples):
1220
+ down_block_res_samples[j] = rearrange(sample, "(b f) c h w -> b c f h w", f=video_length)
1221
+
1222
+ mid_block_res_sample = rearrange(mid_block_res_sample, "(b f) c h w -> b c f h w", f=video_length)
1223
+
1224
+ if guess_mode and do_classifier_free_guidance:
1225
+ # Infered ControlNet only for the conditional batch.
1226
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1227
+ # add 0 to the unconditional batch to keep it unchanged.
1228
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1229
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1230
+
1231
+ # predict the noise residual
1232
+ noise_pred = self.unet(
1233
+ latent_model_input,
1234
+ t,
1235
+ encoder_hidden_states=prompt_embeds,
1236
+ cross_attention_kwargs=cross_attention_kwargs,
1237
+ down_block_additional_residuals=down_block_res_samples,
1238
+ mid_block_additional_residual=mid_block_res_sample,
1239
+ added_cond_kwargs=added_cond_kwargs,
1240
+ return_dict=False,
1241
+ enable_temporal_attentions=video_length > 1
1242
+ )[0]
1243
+
1244
+ # perform guidance
1245
+ if do_classifier_free_guidance:
1246
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1247
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1248
+
1249
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
1250
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1251
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
1252
+
1253
+ # compute the previous noisy sample x_t -> x_t-1
1254
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1255
+
1256
+ # call the callback, if provided
1257
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1258
+ progress_bar.update()
1259
+ if callback is not None and i % callback_steps == 0:
1260
+ callback(i, t, latents)
1261
+
1262
+ # make sure the VAE is in float32 mode, as it overflows in float16
1263
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
1264
+ self.upcast_vae()
1265
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1266
+
1267
+ # If we do sequential model offloading, let's offload unet and controlnet
1268
+ # manually for max memory savings
1269
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1270
+ self.unet.to("cpu")
1271
+ self.controlnet.to("cpu")
1272
+ torch.cuda.empty_cache()
1273
+
1274
+ # if not output_type == "latent":
1275
+ # # make sure the VAE is in float32 mode, as it overflows in float16
1276
+ # needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1277
+ #
1278
+ # if needs_upcasting:
1279
+ # self.upcast_vae()
1280
+ # latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1281
+ #
1282
+ # image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1283
+ #
1284
+ # # cast back to fp16 if needed
1285
+ # if needs_upcasting:
1286
+ # self.vae.to(dtype=torch.float16)
1287
+ # else:
1288
+ # image = latents
1289
+ # return StableDiffusionXLPipelineOutput(images=image)
1290
+
1291
+ video = self.decode_latents(latents)
1292
+
1293
+ # Convert to tensor
1294
+ if output_type == "tensor":
1295
+ video = torch.from_numpy(video)
1296
+
1297
+ if not return_dict:
1298
+ return video
1299
+
1300
+ return HotshotPipelineXLOutput(videos=video)
1301
+
1302
+ def decode_latents(self, latents):
1303
+ video_length = latents.shape[2]
1304
+ latents = 1 / self.vae.config.scaling_factor * latents
1305
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
1306
+ # video = self.vae.decode(latents).sample
1307
+ video = []
1308
+ for frame_idx in tqdm(range(latents.shape[0])):
1309
+ video.append(self.vae.decode(
1310
+ latents[frame_idx:frame_idx+1]).sample)
1311
+ video = torch.cat(video)
1312
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
1313
+ video = (video / 2.0 + 0.5).clamp(0, 1)
1314
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
1315
+ video = video.cpu().float().numpy()
1316
+ return video
1317
+
1318
+ # Overrride to properly handle the loading and unloading of the additional text encoder.
1319
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
1320
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
1321
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
1322
+ # it here explicitly to be able to tell that it's coming from an SDXL
1323
+ # pipeline.
1324
+ state_dict, network_alphas = self.lora_state_dict(
1325
+ pretrained_model_name_or_path_or_dict,
1326
+ unet_config=self.unet.config,
1327
+ **kwargs,
1328
+ )
1329
+ self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
1330
+
1331
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1332
+ if len(text_encoder_state_dict) > 0:
1333
+ self.load_lora_into_text_encoder(
1334
+ text_encoder_state_dict,
1335
+ network_alphas=network_alphas,
1336
+ text_encoder=self.text_encoder,
1337
+ prefix="text_encoder",
1338
+ lora_scale=self.lora_scale,
1339
+ )
1340
+
1341
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
1342
+ if len(text_encoder_2_state_dict) > 0:
1343
+ self.load_lora_into_text_encoder(
1344
+ text_encoder_2_state_dict,
1345
+ network_alphas=network_alphas,
1346
+ text_encoder=self.text_encoder_2,
1347
+ prefix="text_encoder_2",
1348
+ lora_scale=self.lora_scale,
1349
+ )
1350
+
1351
+ @classmethod
1352
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
1353
+ def save_lora_weights(
1354
+ self,
1355
+ save_directory: Union[str, os.PathLike],
1356
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1357
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1358
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1359
+ is_main_process: bool = True,
1360
+ weight_name: str = None,
1361
+ save_function: Callable = None,
1362
+ safe_serialization: bool = True,
1363
+ ):
1364
+ state_dict = {}
1365
+
1366
+ def pack_weights(layers, prefix):
1367
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
1368
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
1369
+ return layers_state_dict
1370
+
1371
+ state_dict.update(pack_weights(unet_lora_layers, "unet"))
1372
+
1373
+ if text_encoder_lora_layers and text_encoder_2_lora_layers:
1374
+ state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
1375
+ state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
1376
+
1377
+ self.write_lora_layers(
1378
+ state_dict=state_dict,
1379
+ save_directory=save_directory,
1380
+ is_main_process=is_main_process,
1381
+ weight_name=weight_name,
1382
+ save_function=save_function,
1383
+ safe_serialization=safe_serialization,
1384
+ )
1385
+
1386
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
1387
+ def _remove_text_encoder_monkey_patch(self):
1388
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
1389
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
hotshot_xl/pipelines/hotshot_xl_pipeline.py ADDED
@@ -0,0 +1,996 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modifications:
16
+ # Copyright 2023 Natural Synthetics Inc. All rights reserved.
17
+ # - Adapted the SDXL Pipeline to work temporally
18
+
19
+
20
+ import os
21
+ import inspect
22
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
26
+ from hotshot_xl import HotshotPipelineXLOutput
27
+
28
+ from diffusers.image_processor import VaeImageProcessor
29
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
30
+ from diffusers.models import AutoencoderKL
31
+ from hotshot_xl.models.unet import UNet3DConditionModel
32
+ from diffusers.models.attention_processor import (
33
+ AttnProcessor2_0,
34
+ LoRAAttnProcessor2_0,
35
+ LoRAXFormersAttnProcessor,
36
+ XFormersAttnProcessor,
37
+ )
38
+ from diffusers.schedulers import KarrasDiffusionSchedulers
39
+ from diffusers.utils import (
40
+ is_accelerate_available,
41
+ is_accelerate_version,
42
+ logging,
43
+ replace_example_docstring,
44
+ )
45
+ from diffusers.utils.torch_utils import randn_tensor
46
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
47
+ from tqdm import tqdm
48
+ from einops import repeat, rearrange
49
+ from diffusers.utils import deprecate, logging
50
+ import gc
51
+
52
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
53
+
54
+ EXAMPLE_DOC_STRING = """
55
+ Examples:
56
+ ```py
57
+ >>> import torch
58
+ >>> from hotshot_xl import HotshotPipelineXL
59
+
60
+ >>> pipe = HotshotXLPipeline.from_pretrained(
61
+ ... "hotshotco/Hotshot-XL"
62
+ ... )
63
+ >>> pipe = pipe.to("cuda")
64
+
65
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
66
+ >>> video = pipe(prompt,
67
+ ... width=672, height=384,
68
+ ... original_size=(1920, 1080),
69
+ ... target_size=(512, 512),
70
+ ... output_type="tensor"
71
+ ).video
72
+ ```
73
+ """
74
+
75
+
76
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
77
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
78
+ """
79
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
80
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
81
+ """
82
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
83
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
84
+ # rescale the results from guidance (fixes overexposure)
85
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
86
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
87
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
88
+ return noise_cfg
89
+
90
+
91
+
92
+
93
+ class HotshotXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
94
+ r"""
95
+ Pipeline for text-to-image generation using Stable Diffusion XL.
96
+
97
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
98
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
99
+
100
+ In addition the pipeline inherits the following loading methods:
101
+ - *LoRA*: [`HotshotPipelineXL.load_lora_weights`]
102
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
103
+
104
+ as well as the following saving methods:
105
+ - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
106
+
107
+ Args:
108
+ vae ([`AutoencoderKL`]):
109
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
110
+ text_encoder ([`CLIPTextModel`]):
111
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
112
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
113
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
114
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
115
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
116
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
117
+ specifically the
118
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
119
+ variant.
120
+ tokenizer (`CLIPTokenizer`):
121
+ Tokenizer of class
122
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
123
+ tokenizer_2 (`CLIPTokenizer`):
124
+ Second Tokenizer of class
125
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
126
+ unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
127
+ scheduler ([`SchedulerMixin`]):
128
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
129
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ vae: AutoencoderKL,
135
+ text_encoder: CLIPTextModel,
136
+ text_encoder_2: CLIPTextModelWithProjection,
137
+ tokenizer: CLIPTokenizer,
138
+ tokenizer_2: CLIPTokenizer,
139
+ unet: UNet3DConditionModel,
140
+ scheduler: KarrasDiffusionSchedulers,
141
+ force_zeros_for_empty_prompt: bool = True,
142
+ add_watermarker: Optional[bool] = None,
143
+ ):
144
+ super().__init__()
145
+
146
+ self.register_modules(
147
+ vae=vae,
148
+ text_encoder=text_encoder,
149
+ text_encoder_2=text_encoder_2,
150
+ tokenizer=tokenizer,
151
+ tokenizer_2=tokenizer_2,
152
+ unet=unet,
153
+ scheduler=scheduler,
154
+ )
155
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
156
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
157
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
158
+ self.default_sample_size = self.unet.config.sample_size
159
+ self.watermark = None
160
+
161
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
162
+ def enable_vae_slicing(self):
163
+ r"""
164
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
165
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
166
+ """
167
+ self.vae.enable_slicing()
168
+
169
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
170
+ def disable_vae_slicing(self):
171
+ r"""
172
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
173
+ computing decoding in one step.
174
+ """
175
+ self.vae.disable_slicing()
176
+
177
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
178
+ def enable_vae_tiling(self):
179
+ r"""
180
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
181
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
182
+ processing larger images.
183
+ """
184
+ self.vae.enable_tiling()
185
+
186
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
187
+ def disable_vae_tiling(self):
188
+ r"""
189
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
190
+ computing decoding in one step.
191
+ """
192
+ self.vae.disable_tiling()
193
+
194
+ def enable_model_cpu_offload(self, gpu_id=0):
195
+ r"""
196
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
197
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
198
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
199
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
200
+ """
201
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
202
+ from accelerate import cpu_offload_with_hook
203
+ else:
204
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
205
+
206
+ device = torch.device(f"cuda:{gpu_id}")
207
+
208
+ if self.device.type != "cpu":
209
+ self.to("cpu", silence_dtype_warnings=True)
210
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
211
+
212
+ model_sequence = (
213
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
214
+ )
215
+ model_sequence.extend([self.unet, self.vae])
216
+
217
+ hook = None
218
+ for cpu_offloaded_model in model_sequence:
219
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
220
+
221
+ # We'll offload the last model manually.
222
+ self.final_offload_hook = hook
223
+
224
+ def encode_prompt(
225
+ self,
226
+ prompt: str,
227
+ prompt_2: Optional[str] = None,
228
+ device: Optional[torch.device] = None,
229
+ num_images_per_prompt: int = 1,
230
+ do_classifier_free_guidance: bool = True,
231
+ negative_prompt: Optional[str] = None,
232
+ negative_prompt_2: Optional[str] = None,
233
+ prompt_embeds: Optional[torch.FloatTensor] = None,
234
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
235
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
236
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
237
+ lora_scale: Optional[float] = None,
238
+ ):
239
+ r"""
240
+ Encodes the prompt into text encoder hidden states.
241
+
242
+ Args:
243
+ prompt (`str` or `List[str]`, *optional*):
244
+ prompt to be encoded
245
+ prompt_2 (`str` or `List[str]`, *optional*):
246
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
247
+ used in both text-encoders
248
+ device: (`torch.device`):
249
+ torch device
250
+ num_images_per_prompt (`int`):
251
+ number of images that should be generated per prompt
252
+ do_classifier_free_guidance (`bool`):
253
+ whether to use classifier free guidance or not
254
+ negative_prompt (`str` or `List[str]`, *optional*):
255
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
256
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
257
+ less than `1`).
258
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
259
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
260
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
261
+ prompt_embeds (`torch.FloatTensor`, *optional*):
262
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
263
+ provided, text embeddings will be generated from `prompt` input argument.
264
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
265
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
266
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
267
+ argument.
268
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
269
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
270
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
271
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
272
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
273
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
274
+ input argument.
275
+ lora_scale (`float`, *optional*):
276
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
277
+ """
278
+ device = device or self._execution_device
279
+
280
+ # set lora scale so that monkey patched LoRA
281
+ # function of text encoder can correctly access it
282
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
283
+ self._lora_scale = lora_scale
284
+
285
+ if prompt is not None and isinstance(prompt, str):
286
+ batch_size = 1
287
+ elif prompt is not None and isinstance(prompt, list):
288
+ batch_size = len(prompt)
289
+ else:
290
+ batch_size = prompt_embeds.shape[0]
291
+
292
+ # Define tokenizers and text encoders
293
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
294
+ text_encoders = (
295
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
296
+ )
297
+
298
+ if prompt_embeds is None:
299
+ prompt_2 = prompt_2 or prompt
300
+ # textual inversion: procecss multi-vector tokens if necessary
301
+ prompt_embeds_list = []
302
+ prompts = [prompt, prompt_2]
303
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
304
+ if isinstance(self, TextualInversionLoaderMixin):
305
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
306
+
307
+ text_inputs = tokenizer(
308
+ prompt,
309
+ padding="max_length",
310
+ max_length=tokenizer.model_max_length,
311
+ truncation=True,
312
+ return_tensors="pt",
313
+ )
314
+
315
+ text_input_ids = text_inputs.input_ids
316
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
317
+
318
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
319
+ text_input_ids, untruncated_ids
320
+ ):
321
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
322
+ logger.warning(
323
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
324
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
325
+ )
326
+
327
+ prompt_embeds = text_encoder(
328
+ text_input_ids.to(device),
329
+ output_hidden_states=True,
330
+ )
331
+
332
+ # We are only ALWAYS interested in the pooled output of the final text encoder
333
+ pooled_prompt_embeds = prompt_embeds[0]
334
+ prompt_embeds = prompt_embeds.hidden_states[-2]
335
+
336
+ prompt_embeds_list.append(prompt_embeds)
337
+
338
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
339
+
340
+ # get unconditional embeddings for classifier free guidance
341
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
342
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
343
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
344
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
345
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
346
+ negative_prompt = negative_prompt or ""
347
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
348
+
349
+ uncond_tokens: List[str]
350
+ if prompt is not None and type(prompt) is not type(negative_prompt):
351
+ raise TypeError(
352
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
353
+ f" {type(prompt)}."
354
+ )
355
+ elif isinstance(negative_prompt, str):
356
+ uncond_tokens = [negative_prompt, negative_prompt_2]
357
+ elif batch_size != len(negative_prompt):
358
+ raise ValueError(
359
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
360
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
361
+ " the batch size of `prompt`."
362
+ )
363
+ else:
364
+ uncond_tokens = [negative_prompt, negative_prompt_2]
365
+
366
+ negative_prompt_embeds_list = []
367
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
368
+ if isinstance(self, TextualInversionLoaderMixin):
369
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
370
+
371
+ max_length = prompt_embeds.shape[1]
372
+ uncond_input = tokenizer(
373
+ negative_prompt,
374
+ padding="max_length",
375
+ max_length=max_length,
376
+ truncation=True,
377
+ return_tensors="pt",
378
+ )
379
+
380
+ negative_prompt_embeds = text_encoder(
381
+ uncond_input.input_ids.to(device),
382
+ output_hidden_states=True,
383
+ )
384
+ # We are only ALWAYS interested in the pooled output of the final text encoder
385
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
386
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
387
+
388
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
389
+
390
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
391
+
392
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
393
+ bs_embed, seq_len, _ = prompt_embeds.shape
394
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
395
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
396
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
397
+
398
+ if do_classifier_free_guidance:
399
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
400
+ seq_len = negative_prompt_embeds.shape[1]
401
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
402
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
403
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
404
+
405
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
406
+ bs_embed * num_images_per_prompt, -1
407
+ )
408
+ if do_classifier_free_guidance:
409
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
410
+ bs_embed * num_images_per_prompt, -1
411
+ )
412
+
413
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
414
+
415
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
416
+ def prepare_extra_step_kwargs(self, generator, eta):
417
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
418
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
419
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
420
+ # and should be between [0, 1]
421
+
422
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
423
+ extra_step_kwargs = {}
424
+ if accepts_eta:
425
+ extra_step_kwargs["eta"] = eta
426
+
427
+ # check if the scheduler accepts generator
428
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
429
+ if accepts_generator:
430
+ extra_step_kwargs["generator"] = generator
431
+ return extra_step_kwargs
432
+
433
+ def check_inputs(
434
+ self,
435
+ prompt,
436
+ prompt_2,
437
+ height,
438
+ width,
439
+ callback_steps,
440
+ negative_prompt=None,
441
+ negative_prompt_2=None,
442
+ prompt_embeds=None,
443
+ negative_prompt_embeds=None,
444
+ pooled_prompt_embeds=None,
445
+ negative_pooled_prompt_embeds=None,
446
+ ):
447
+ if height % 8 != 0 or width % 8 != 0:
448
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
449
+
450
+ if (callback_steps is None) or (
451
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
452
+ ):
453
+ raise ValueError(
454
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
455
+ f" {type(callback_steps)}."
456
+ )
457
+
458
+ if prompt is not None and prompt_embeds is not None:
459
+ raise ValueError(
460
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
461
+ " only forward one of the two."
462
+ )
463
+ elif prompt_2 is not None and prompt_embeds is not None:
464
+ raise ValueError(
465
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
466
+ " only forward one of the two."
467
+ )
468
+ elif prompt is None and prompt_embeds is None:
469
+ raise ValueError(
470
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
471
+ )
472
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
473
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
474
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
475
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
476
+
477
+ if negative_prompt is not None and negative_prompt_embeds is not None:
478
+ raise ValueError(
479
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
480
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
481
+ )
482
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
483
+ raise ValueError(
484
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
485
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
486
+ )
487
+
488
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
489
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
490
+ raise ValueError(
491
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
492
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
493
+ f" {negative_prompt_embeds.shape}."
494
+ )
495
+
496
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
497
+ raise ValueError(
498
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
499
+ )
500
+
501
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
502
+ raise ValueError(
503
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
504
+ )
505
+
506
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
507
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
508
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
509
+ if isinstance(generator, list) and len(generator) != batch_size:
510
+ raise ValueError(
511
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
512
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
513
+ )
514
+
515
+ if latents is None:
516
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
517
+ else:
518
+ latents = latents.to(device)
519
+
520
+ # scale the initial noise by the standard deviation required by the scheduler
521
+ latents = latents * self.scheduler.init_noise_sigma
522
+ return latents
523
+
524
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
525
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
526
+
527
+ passed_add_embed_dim = (
528
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
529
+ )
530
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
531
+
532
+ if expected_add_embed_dim != passed_add_embed_dim:
533
+ raise ValueError(
534
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
535
+ )
536
+
537
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
538
+ return add_time_ids
539
+
540
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
541
+ def upcast_vae(self):
542
+ dtype = self.vae.dtype
543
+ self.vae.to(dtype=torch.float32)
544
+ use_torch_2_0_or_xformers = isinstance(
545
+ self.vae.decoder.mid_block.attentions[0].processor,
546
+ (
547
+ AttnProcessor2_0,
548
+ XFormersAttnProcessor,
549
+ LoRAXFormersAttnProcessor,
550
+ LoRAAttnProcessor2_0,
551
+ ),
552
+ )
553
+ # if xformers or torch_2_0 is used attention block does not need
554
+ # to be in float32 which can save lots of memory
555
+ if use_torch_2_0_or_xformers:
556
+ self.vae.post_quant_conv.to(dtype)
557
+ self.vae.decoder.conv_in.to(dtype)
558
+ self.vae.decoder.mid_block.to(dtype)
559
+
560
+ @torch.no_grad()
561
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
562
+ def __call__(
563
+ self,
564
+ prompt: Union[str, List[str]] = None,
565
+ prompt_2: Optional[Union[str, List[str]]] = None,
566
+ video_length: Optional[int] = 8,
567
+ num_images_per_prompt: Optional[int] = 1,
568
+ height: Optional[int] = None,
569
+ width: Optional[int] = None,
570
+ num_inference_steps: int = 50,
571
+ denoising_end: Optional[float] = None,
572
+ guidance_scale: float = 5.0,
573
+ negative_prompt: Optional[Union[str, List[str]]] = None,
574
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
575
+ eta: float = 0.0,
576
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
577
+ latents: Optional[torch.FloatTensor] = None,
578
+ prompt_embeds: Optional[torch.FloatTensor] = None,
579
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
580
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
581
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
582
+ output_type: Optional[str] = "pil",
583
+ return_dict: bool = True,
584
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
585
+ callback_steps: int = 1,
586
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
587
+ guidance_rescale: float = 0.0,
588
+ original_size: Optional[Tuple[int, int]] = None,
589
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
590
+ target_size: Optional[Tuple[int, int]] = None,
591
+ low_vram_mode: Optional[bool] = False
592
+ ):
593
+ r"""
594
+ Function invoked when calling the pipeline for generation.
595
+
596
+ Args:
597
+ prompt (`str` or `List[str]`, *optional*):
598
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
599
+ instead.
600
+ prompt_2 (`str` or `List[str]`, *optional*):
601
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
602
+ used in both text-encoders
603
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
604
+ The height in pixels of the generated image.
605
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
606
+ The width in pixels of the generated image.
607
+ num_inference_steps (`int`, *optional*, defaults to 50):
608
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
609
+ expense of slower inference.
610
+ denoising_end (`float`, *optional*):
611
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
612
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
613
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
614
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
615
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
616
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
617
+ guidance_scale (`float`, *optional*, defaults to 5.0):
618
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
619
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
620
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
621
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
622
+ usually at the expense of lower image quality.
623
+ negative_prompt (`str` or `List[str]`, *optional*):
624
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
625
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
626
+ less than `1`).
627
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
628
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
629
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
630
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
631
+ The number of images to generate per prompt.
632
+ eta (`float`, *optional*, defaults to 0.0):
633
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
634
+ [`schedulers.DDIMScheduler`], will be ignored for others.
635
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
636
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
637
+ to make generation deterministic.
638
+ latents (`torch.FloatTensor`, *optional*):
639
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
640
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
641
+ tensor will ge generated by sampling using the supplied random `generator`.
642
+ prompt_embeds (`torch.FloatTensor`, *optional*):
643
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
644
+ provided, text embeddings will be generated from `prompt` input argument.
645
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
646
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
647
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
648
+ argument.
649
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
650
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
651
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
652
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
653
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
654
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
655
+ input argument.
656
+ output_type (`str`, *optional*, defaults to `"pil"`):
657
+ The output format of the generate image. Choose between
658
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
659
+ return_dict (`bool`, *optional*, defaults to `True`):
660
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
661
+ of a plain tuple.
662
+ callback (`Callable`, *optional*):
663
+ A function that will be called every `callback_steps` steps during inference. The function will be
664
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
665
+ callback_steps (`int`, *optional*, defaults to 1):
666
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
667
+ called at every step.
668
+ cross_attention_kwargs (`dict`, *optional*):
669
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
670
+ `self.processor` in
671
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
672
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
673
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
674
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
675
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
676
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
677
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
678
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
679
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
680
+ explained in section 2.2 of
681
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
682
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
683
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
684
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
685
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
686
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
687
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
688
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
689
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
690
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
691
+
692
+ Examples:
693
+
694
+ Returns:
695
+ [`~hotshot_xl.HotshotPipelineXLOutput`] or `tuple`:
696
+ [`~hotshot_xl.HotshotPipelineXLOutput`] if `return_dict` is True, otherwise a
697
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
698
+ """
699
+ self.low_vram_mode = low_vram_mode
700
+
701
+ if video_length > 1:
702
+ print(f"Warning - setting num_images_per_prompt = 1 because video_length = {video_length}")
703
+ num_images_per_prompt = 1
704
+
705
+ # 0. Default height and width to unet
706
+ height = height or self.default_sample_size * self.vae_scale_factor
707
+ width = width or self.default_sample_size * self.vae_scale_factor
708
+
709
+ original_size = original_size or (height, width)
710
+ target_size = target_size or (height, width)
711
+
712
+ # 1. Check inputs. Raise error if not correct
713
+ self.check_inputs(
714
+ prompt,
715
+ prompt_2,
716
+ height,
717
+ width,
718
+ callback_steps,
719
+ negative_prompt,
720
+ negative_prompt_2,
721
+ prompt_embeds,
722
+ negative_prompt_embeds,
723
+ pooled_prompt_embeds,
724
+ negative_pooled_prompt_embeds,
725
+ )
726
+
727
+ # 2. Define call parameters
728
+ if prompt is not None and isinstance(prompt, str):
729
+ batch_size = 1
730
+ elif prompt is not None and isinstance(prompt, list):
731
+ batch_size = len(prompt)
732
+ else:
733
+ batch_size = prompt_embeds.shape[0]
734
+
735
+ device = self._execution_device
736
+
737
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
738
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
739
+ # corresponds to doing no classifier free guidance.
740
+ do_classifier_free_guidance = guidance_scale > 1.0
741
+
742
+ if self.low_vram_mode:
743
+ self.text_encoder.to(device)
744
+ self.text_encoder_2.to(device)
745
+
746
+ # 3. Encode input prompt
747
+ text_encoder_lora_scale = (
748
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
749
+ )
750
+ (
751
+ prompt_embeds,
752
+ negative_prompt_embeds,
753
+ pooled_prompt_embeds,
754
+ negative_pooled_prompt_embeds,
755
+ ) = self.encode_prompt(
756
+ prompt=prompt,
757
+ prompt_2=prompt_2,
758
+ device=device,
759
+ num_images_per_prompt=num_images_per_prompt,
760
+ do_classifier_free_guidance=do_classifier_free_guidance,
761
+ negative_prompt=negative_prompt,
762
+ negative_prompt_2=negative_prompt_2,
763
+ prompt_embeds=prompt_embeds,
764
+ negative_prompt_embeds=negative_prompt_embeds,
765
+ pooled_prompt_embeds=pooled_prompt_embeds,
766
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
767
+ lora_scale=text_encoder_lora_scale,
768
+ )
769
+
770
+ if self.low_vram_mode:
771
+ self.text_encoder.to(torch.device("cpu"))
772
+ self.text_encoder_2.to(torch.device("cpu"))
773
+ self.vae.to(torch.device("cpu"))
774
+ torch.cuda.empty_cache()
775
+ torch.cuda.synchronize()
776
+ gc.collect()
777
+
778
+ # 4. Prepare timesteps
779
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
780
+
781
+ timesteps = self.scheduler.timesteps
782
+
783
+ # 5. Prepare latent variables
784
+ num_channels_latents = self.unet.config.in_channels
785
+ latents = self.prepare_latents(
786
+ batch_size * num_images_per_prompt,
787
+ num_channels_latents,
788
+ video_length,
789
+ height,
790
+ width,
791
+ prompt_embeds.dtype,
792
+ device,
793
+ generator,
794
+ latents,
795
+ )
796
+
797
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
798
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
799
+
800
+ # 7. Prepare added time ids & embeddings
801
+ add_text_embeds = pooled_prompt_embeds
802
+ add_time_ids = self._get_add_time_ids(
803
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
804
+ )
805
+
806
+ # todo - negative_original_size from latest diffusers for cfg
807
+
808
+ if do_classifier_free_guidance:
809
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
810
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
811
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
812
+
813
+ prompt_embeds = prompt_embeds.to(device)
814
+ add_text_embeds = add_text_embeds.to(device)
815
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
816
+
817
+ # 8. Denoising loop
818
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
819
+
820
+ # 7.1 Apply denoising_end
821
+ if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
822
+ discrete_timestep_cutoff = int(
823
+ round(
824
+ self.scheduler.config.num_train_timesteps
825
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
826
+ )
827
+ )
828
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
829
+ timesteps = timesteps[:num_inference_steps]
830
+
831
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
832
+ for i, t in enumerate(timesteps):
833
+ # expand the latents if we are doing classifier free guidance
834
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
835
+
836
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
837
+
838
+ # predict the noise residual
839
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
840
+ noise_pred = self.unet(
841
+ latent_model_input,
842
+ t,
843
+ encoder_hidden_states=prompt_embeds,
844
+ cross_attention_kwargs=cross_attention_kwargs,
845
+ added_cond_kwargs=added_cond_kwargs,
846
+ return_dict=False,
847
+ enable_temporal_attentions= video_length > 1
848
+ )[0]
849
+
850
+ # perform guidance
851
+ if do_classifier_free_guidance:
852
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
853
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
854
+
855
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
856
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
857
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
858
+
859
+ # compute the previous noisy sample x_t -> x_t-1
860
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
861
+
862
+ # call the callback, if provided
863
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
864
+ progress_bar.update()
865
+ if callback is not None and i % callback_steps == 0:
866
+ callback(i, t, latents)
867
+
868
+ # make sure the VAE is in float32 mode, as it overflows in float16
869
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
870
+ self.upcast_vae()
871
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
872
+
873
+ # if not output_type == "latent":
874
+ # image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
875
+ # else:
876
+ # image = latents
877
+ # return StableDiffusionXLPipelineOutput(images=image)
878
+
879
+ # apply watermark if available
880
+ # if self.watermark is not None:
881
+ # image = self.watermark.apply_watermark(image)
882
+
883
+ #image = self.image_processor.postprocess(image, output_type=output_type)
884
+
885
+ if self.low_vram_mode:
886
+ self.vae.to(device)
887
+ torch.cuda.empty_cache()
888
+ torch.cuda.synchronize()
889
+ gc.collect()
890
+
891
+ video = self.decode_latents(latents)
892
+
893
+ # Convert to tensor
894
+ if output_type == "tensor":
895
+ video = torch.from_numpy(video)
896
+
897
+ if not return_dict:
898
+ return video
899
+
900
+ return HotshotPipelineXLOutput(videos=video)
901
+
902
+ #
903
+ # # Offload last model to CPU
904
+ # if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
905
+ # self.final_offload_hook.offload()
906
+ #
907
+ # if not return_dict:
908
+ # return (image,)
909
+ #
910
+ # return StableDiffusionXLPipelineOutput(images=image)
911
+
912
+ # Overrride to properly handle the loading and unloading of the additional text encoder.
913
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
914
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
915
+ # it here explicitly to be able to tell that it's coming from an SDXL
916
+ # pipeline.
917
+ state_dict, network_alphas = self.lora_state_dict(
918
+ pretrained_model_name_or_path_or_dict,
919
+ unet_config=self.unet.config,
920
+ **kwargs,
921
+ )
922
+ self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
923
+
924
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
925
+ if len(text_encoder_state_dict) > 0:
926
+ self.load_lora_into_text_encoder(
927
+ text_encoder_state_dict,
928
+ network_alphas=network_alphas,
929
+ text_encoder=self.text_encoder,
930
+ prefix="text_encoder",
931
+ lora_scale=self.lora_scale,
932
+ )
933
+
934
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
935
+ if len(text_encoder_2_state_dict) > 0:
936
+ self.load_lora_into_text_encoder(
937
+ text_encoder_2_state_dict,
938
+ network_alphas=network_alphas,
939
+ text_encoder=self.text_encoder_2,
940
+ prefix="text_encoder_2",
941
+ lora_scale=self.lora_scale,
942
+ )
943
+
944
+ @classmethod
945
+ def save_lora_weights(
946
+ self,
947
+ save_directory: Union[str, os.PathLike],
948
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
949
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
950
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
951
+ is_main_process: bool = True,
952
+ weight_name: str = None,
953
+ save_function: Callable = None,
954
+ safe_serialization: bool = False,
955
+ ):
956
+ state_dict = {}
957
+
958
+ def pack_weights(layers, prefix):
959
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
960
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
961
+ return layers_state_dict
962
+
963
+ state_dict.update(pack_weights(unet_lora_layers, "unet"))
964
+
965
+ if text_encoder_lora_layers and text_encoder_2_lora_layers:
966
+ state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
967
+ state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
968
+
969
+ self.write_lora_layers(
970
+ state_dict=state_dict,
971
+ save_directory=save_directory,
972
+ is_main_process=is_main_process,
973
+ weight_name=weight_name,
974
+ save_function=save_function,
975
+ safe_serialization=safe_serialization,
976
+ )
977
+
978
+ def decode_latents(self, latents):
979
+ video_length = latents.shape[2]
980
+ latents = 1 / self.vae.config.scaling_factor * latents
981
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
982
+ # video = self.vae.decode(latents).sample
983
+ video = []
984
+ for frame_idx in tqdm(range(latents.shape[0])):
985
+ video.append(self.vae.decode(
986
+ latents[frame_idx:frame_idx+1]).sample)
987
+ video = torch.cat(video)
988
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
989
+ video = (video / 2.0 + 0.5).clamp(0, 1)
990
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
991
+ video = video.cpu().float().numpy()
992
+ return video
993
+
994
+ def _remove_text_encoder_monkey_patch(self):
995
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
996
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
hotshot_xl/utils.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Natural Synthetics Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Union
16
+ from io import BytesIO
17
+ import PIL
18
+ from PIL import ImageSequence, Image
19
+ import requests
20
+ import os
21
+ import numpy as np
22
+ import imageio
23
+
24
+
25
+ def get_image(img_path) -> PIL.Image.Image:
26
+ if img_path.startswith("http"):
27
+ return PIL.Image.open(requests.get(img_path, stream=True).raw)
28
+ if os.path.exists(img_path):
29
+ return Image.open(img_path)
30
+ raise Exception("File not found")
31
+
32
+ def images_to_gif_bytes(images: List, duration: int = 1000) -> bytes:
33
+ with BytesIO() as output_buffer:
34
+ # Save the first image
35
+ images[0].save(output_buffer,
36
+ format='GIF',
37
+ save_all=True,
38
+ append_images=images[1:],
39
+ duration=duration,
40
+ loop=0) # 0 means the GIF will loop indefinitely
41
+
42
+ # Get the byte array from the buffer
43
+ gif_bytes = output_buffer.getvalue()
44
+
45
+ return gif_bytes
46
+
47
+ def save_as_gif(images: List, file_path: str, duration: int = 1000):
48
+ with open(file_path, "wb") as f:
49
+ f.write(images_to_gif_bytes(images, duration))
50
+
51
+ def images_to_mp4_bytes(images: List[Image.Image], duration: int = 1000) -> bytes:
52
+ with BytesIO() as output_buffer:
53
+ with imageio.get_writer(output_buffer, format='mp4', fps=1/(duration/1000)) as writer:
54
+ for img in images:
55
+ writer.append_data(np.array(img))
56
+ mp4_bytes = output_buffer.getvalue()
57
+
58
+ return mp4_bytes
59
+
60
+ def save_as_mp4(images: List[Image.Image], file_path: str, duration: int = 1000):
61
+ with open(file_path, "wb") as f:
62
+ f.write(images_to_mp4_bytes(images, duration))
63
+
64
+ def scale_aspect_fill(img, new_width, new_height):
65
+ new_width = int(new_width)
66
+ new_height = int(new_height)
67
+
68
+ original_width, original_height = img.size
69
+ ratio_w = float(new_width) / original_width
70
+ ratio_h = float(new_height) / original_height
71
+
72
+ if ratio_w > ratio_h:
73
+ # It must be fixed by width
74
+ resize_width = new_width
75
+ resize_height = round(original_height * ratio_w)
76
+ else:
77
+ # Fixed by height
78
+ resize_width = round(original_width * ratio_h)
79
+ resize_height = new_height
80
+
81
+ img_resized = img.resize((resize_width, resize_height), Image.LANCZOS)
82
+
83
+ # Calculate cropping boundaries and do crop
84
+ left = (resize_width - new_width) / 2
85
+ top = (resize_height - new_height) / 2
86
+ right = (resize_width + new_width) / 2
87
+ bottom = (resize_height + new_height) / 2
88
+
89
+ img_cropped = img_resized.crop((left, top, right, bottom))
90
+
91
+ return img_cropped
92
+
93
+ def extract_gif_frames_from_midpoint(image: Union[str, PIL.Image.Image], fps: int=8, target_duration: int=1000) -> list:
94
+ # Load the GIF
95
+ image = get_image(image) if type(image) is str else image
96
+
97
+ frames = []
98
+
99
+ estimated_frame_time = None
100
+
101
+ # some gifs contain the duration - others don't
102
+ # so if there is a duration we will grab it otherwise we will fall back
103
+
104
+ for frame in ImageSequence.Iterator(image):
105
+
106
+ frames.append(frame.copy())
107
+ if 'duration' in frame.info:
108
+ frame_info_duration = frame.info['duration']
109
+ if frame_info_duration > 0:
110
+ estimated_frame_time = frame_info_duration
111
+
112
+ if estimated_frame_time is None:
113
+ if len(frames) <= 16:
114
+ # assume it's 8fps
115
+ estimated_frame_time = 1000 // 8
116
+ else:
117
+ # assume it's 15 fps
118
+ estimated_frame_time = 70
119
+
120
+ if len(frames) < fps:
121
+ raise ValueError(f"fps of {fps} is too small for this gif as it only has {len(frames)} frames.")
122
+
123
+ skip = len(frames) // fps
124
+ upper_bound_index = len(frames) - 1
125
+
126
+ best_indices = [x for x in range(0, len(frames), skip)][:fps]
127
+ offset = int(upper_bound_index - best_indices[-1]) // 2
128
+ best_indices = [x + offset for x in best_indices]
129
+ best_duration = (best_indices[-1] - best_indices[0]) * estimated_frame_time
130
+
131
+ while True:
132
+
133
+ skip -= 1
134
+
135
+ if skip == 0:
136
+ break
137
+
138
+ indices = [x for x in range(0, len(frames), skip)][:fps]
139
+
140
+ # center the indices, so we sample the middle of the gif...
141
+ offset = int(upper_bound_index - indices[-1]) // 2
142
+ if offset == 0:
143
+ # can't shift
144
+ break
145
+ indices = [x + offset for x in indices]
146
+
147
+ # is the new duration closer to the target than last guess?
148
+ duration = (indices[-1] - indices[0]) * estimated_frame_time
149
+ if abs(duration - target_duration) > abs(best_duration - target_duration):
150
+ break
151
+
152
+ best_indices = indices
153
+ best_duration = duration
154
+
155
+ return [frames[index] for index in best_indices]
156
+
157
+ def get_crop_coordinates(old_size: tuple, new_size: tuple) -> tuple:
158
+ """
159
+ Calculate the crop coordinates after scaling an image to fit a new size.
160
+
161
+ :param old_size: tuple of the form (width, height) representing the original size of the image.
162
+ :param new_size: tuple of the form (width, height) representing the desired size after scaling.
163
+ :return: tuple of the form (left, upper, right, lower) representing the normalized crop coordinates.
164
+ """
165
+ # Check if the input tuples have the right form (width, height)
166
+ if not (isinstance(old_size, tuple) and isinstance(new_size, tuple) and
167
+ len(old_size) == 2 and len(new_size) == 2):
168
+ raise ValueError("old_size and new_size should be tuples of the form (width, height)")
169
+
170
+ # Extract the width and height from the old and new sizes
171
+ old_width, old_height = old_size
172
+ new_width, new_height = new_size
173
+
174
+ # Calculate the ratios for width and height
175
+ ratio_w = float(new_width) / old_width
176
+ ratio_h = float(new_height) / old_height
177
+
178
+ # Determine which dimension is fixed (width or height)
179
+ if ratio_w > ratio_h:
180
+ # It must be fixed by width
181
+ resize_width = new_width
182
+ resize_height = round(old_height * ratio_w)
183
+ else:
184
+ # Fixed by height
185
+ resize_width = round(old_width * ratio_h)
186
+ resize_height = new_height
187
+
188
+ # Calculate cropping boundaries in the resized image space
189
+ left = (resize_width - new_width) / 2
190
+ upper = (resize_height - new_height) / 2
191
+ right = (resize_width + new_width) / 2
192
+ lower = (resize_height + new_height) / 2
193
+
194
+ # Normalize the cropping coordinates
195
+
196
+ # Return the normalized coordinates as a tuple
197
+ return (left, upper, right, lower)
198
+
199
+ aspect_ratio_to_1024_map = {
200
+ "0.42": [640, 1536],
201
+ "0.57": [768, 1344],
202
+ "0.68": [832, 1216],
203
+ "1.00": [1024, 1024],
204
+ "1.46": [1216, 832],
205
+ "1.75": [1344, 768],
206
+ "2.40": [1536, 640]
207
+ }
208
+
209
+ res_to_aspect_map = {
210
+ 1024: aspect_ratio_to_1024_map,
211
+ 512: {key: [value[0] // 2, value[1] // 2] for key, value in aspect_ratio_to_1024_map.items()},
212
+ }
213
+
214
+ def best_aspect_ratio(aspect_ratio: float, resolution: int):
215
+
216
+ map = res_to_aspect_map[resolution]
217
+
218
+ d = 99999999
219
+ res = None
220
+ for key, value in map.items():
221
+ ar = value[0] / value[1]
222
+ diff = abs(aspect_ratio - ar)
223
+ if diff < d:
224
+ d = diff
225
+ res = value
226
+
227
+ ar = res[0] / res[1]
228
+ return f"{ar:.2f}", res
inference.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Natural Synthetics Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import sys
16
+
17
+ sys.path.append("/")
18
+ import os
19
+ import argparse
20
+ import torch
21
+ from hotshot_xl.pipelines.hotshot_xl_pipeline import HotshotXLPipeline
22
+ from hotshot_xl.pipelines.hotshot_xl_controlnet_pipeline import HotshotXLControlNetPipeline
23
+ from hotshot_xl.models.unet import UNet3DConditionModel
24
+ import torchvision.transforms as transforms
25
+ from einops import rearrange
26
+ from hotshot_xl.utils import save_as_gif, save_as_mp4, extract_gif_frames_from_midpoint, scale_aspect_fill
27
+ from torch import autocast
28
+ from diffusers import ControlNetModel
29
+ from contextlib import contextmanager
30
+ from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
31
+ from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
32
+
33
+ SCHEDULERS = {
34
+ 'EulerAncestralDiscreteScheduler': EulerAncestralDiscreteScheduler,
35
+ 'EulerDiscreteScheduler': EulerDiscreteScheduler,
36
+ 'default': None,
37
+ # add more here
38
+ }
39
+
40
+ def parse_args():
41
+ parser = argparse.ArgumentParser(description="Hotshot-XL inference")
42
+ parser.add_argument("--pretrained_path", type=str, default="hotshotco/Hotshot-XL")
43
+ parser.add_argument("--xformers", action="store_true")
44
+ parser.add_argument("--spatial_unet_base", type=str)
45
+ parser.add_argument("--lora", type=str)
46
+ parser.add_argument("--output", type=str, required=True)
47
+ parser.add_argument("--steps", type=int, default=30)
48
+ parser.add_argument("--prompt", type=str,
49
+ default="a bulldog in the captains chair of a spaceship, hd, high quality")
50
+ parser.add_argument("--negative_prompt", type=str, default="blurry")
51
+ parser.add_argument("--seed", type=int, default=455)
52
+ parser.add_argument("--width", type=int, default=672)
53
+ parser.add_argument("--height", type=int, default=384)
54
+ parser.add_argument("--target_width", type=int, default=512)
55
+ parser.add_argument("--target_height", type=int, default=512)
56
+ parser.add_argument("--og_width", type=int, default=1920)
57
+ parser.add_argument("--og_height", type=int, default=1080)
58
+ parser.add_argument("--video_length", type=int, default=8)
59
+ parser.add_argument("--video_duration", type=int, default=1000)
60
+ parser.add_argument("--low_vram_mode", action="store_true")
61
+ parser.add_argument('--scheduler', type=str, default='EulerAncestralDiscreteScheduler',
62
+ help='Name of the scheduler to use')
63
+
64
+ parser.add_argument("--control_type", type=str, default=None, choices=["depth", "canny"])
65
+ parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
66
+ parser.add_argument("--control_guidance_start", type=float, default=0.0)
67
+ parser.add_argument("--control_guidance_end", type=float, default=1.0)
68
+ parser.add_argument("--gif", type=str, default=None)
69
+ parser.add_argument("--precision", type=str, default='f16', choices=[
70
+ 'f16', 'f32', 'bf16'
71
+ ])
72
+ parser.add_argument("--autocast", type=str, default=None, choices=[
73
+ 'f16', 'bf16'
74
+ ])
75
+
76
+ return parser.parse_args()
77
+
78
+
79
+ to_pil = transforms.ToPILImage()
80
+
81
+
82
+ def to_pil_images(video_frames: torch.Tensor, output_type='pil'):
83
+ video_frames = rearrange(video_frames, "b c f w h -> b f c w h")
84
+ bsz = video_frames.shape[0]
85
+ images = []
86
+ for i in range(bsz):
87
+ video = video_frames[i]
88
+ for j in range(video.shape[0]):
89
+ if output_type == "pil":
90
+ images.append(to_pil(video[j]))
91
+ else:
92
+ images.append(video[j])
93
+ return images
94
+
95
+ @contextmanager
96
+ def maybe_auto_cast(data_type):
97
+ if data_type:
98
+ with autocast("cuda", dtype=data_type):
99
+ yield
100
+ else:
101
+ yield
102
+
103
+
104
+ def main():
105
+ args = parse_args()
106
+
107
+ if args.control_type and not args.gif:
108
+ raise ValueError("Controlnet specified but you didn't specify a gif!")
109
+
110
+ if args.gif and not args.control_type:
111
+ print("warning: gif was specified but no control type was specified. gif will be ignored.")
112
+
113
+ output_dir = os.path.dirname(args.output)
114
+ if output_dir:
115
+ os.makedirs(output_dir, exist_ok=True)
116
+
117
+ device = torch.device("cuda")
118
+
119
+ control_net_model_pretrained_path = None
120
+ if args.control_type:
121
+ control_type_to_model_map = {
122
+ "canny": "diffusers/controlnet-canny-sdxl-1.0",
123
+ "depth": "diffusers/controlnet-depth-sdxl-1.0",
124
+ }
125
+ control_net_model_pretrained_path = control_type_to_model_map[args.control_type]
126
+
127
+ data_type = torch.float32
128
+
129
+ if args.precision == 'f16':
130
+ data_type = torch.half
131
+ elif args.precision == 'f32':
132
+ data_type = torch.float32
133
+ elif args.precision == 'bf16':
134
+ data_type = torch.bfloat16
135
+
136
+ pipe_line_args = {
137
+ "torch_dtype": data_type,
138
+ "use_safetensors": True
139
+ }
140
+
141
+ PipelineClass = HotshotXLPipeline
142
+
143
+ if control_net_model_pretrained_path:
144
+ PipelineClass = HotshotXLControlNetPipeline
145
+ pipe_line_args['controlnet'] = \
146
+ ControlNetModel.from_pretrained(control_net_model_pretrained_path, torch_dtype=data_type)
147
+
148
+ if args.spatial_unet_base:
149
+
150
+ unet_3d = UNet3DConditionModel.from_pretrained(args.pretrained_path, subfolder="unet", torch_dtype=data_type).to(device)
151
+
152
+ unet = UNet3DConditionModel.from_pretrained_spatial(args.spatial_unet_base).to(device, dtype=data_type)
153
+
154
+ temporal_layers = {}
155
+ unet_3d_sd = unet_3d.state_dict()
156
+
157
+ for k, v in unet_3d_sd.items():
158
+ if 'temporal' in k:
159
+ temporal_layers[k] = v
160
+
161
+ unet.load_state_dict(temporal_layers, strict=False)
162
+
163
+ pipe_line_args['unet'] = unet
164
+
165
+ del unet_3d_sd
166
+ del unet_3d
167
+ del temporal_layers
168
+
169
+ pipe = PipelineClass.from_pretrained(args.pretrained_path, **pipe_line_args).to(device)
170
+
171
+ if args.lora:
172
+ pipe.load_lora_weights(args.lora)
173
+
174
+ SchedulerClass = SCHEDULERS[args.scheduler]
175
+ if SchedulerClass is not None:
176
+ pipe.scheduler = SchedulerClass.from_config(pipe.scheduler.config)
177
+
178
+ if args.xformers:
179
+ pipe.enable_xformers_memory_efficient_attention()
180
+
181
+ generator = torch.Generator().manual_seed(args.seed) if args.seed else None
182
+
183
+ autocast_type = None
184
+ if args.autocast == 'f16':
185
+ autocast_type = torch.half
186
+ elif args.autocast == 'bf16':
187
+ autocast_type = torch.bfloat16
188
+
189
+ if type(pipe) is HotshotXLControlNetPipeline:
190
+ kwargs = {}
191
+ else:
192
+ kwargs = {
193
+ "low_vram_mode": args.low_vram_mode
194
+ }
195
+
196
+ if args.gif and type(pipe) is HotshotXLControlNetPipeline:
197
+ kwargs['control_images'] = [
198
+ scale_aspect_fill(img, args.width, args.height).convert("RGB") \
199
+ for img in
200
+ extract_gif_frames_from_midpoint(args.gif, fps=args.video_length, target_duration=args.video_duration)
201
+ ]
202
+ kwargs['controlnet_conditioning_scale'] = args.controlnet_conditioning_scale
203
+ kwargs['control_guidance_start'] = args.control_guidance_start
204
+ kwargs['control_guidance_end'] = args.control_guidance_end
205
+
206
+ with maybe_auto_cast(autocast_type):
207
+
208
+ images = pipe(args.prompt,
209
+ negative_prompt=args.negative_prompt,
210
+ width=args.width,
211
+ height=args.height,
212
+ original_size=(args.og_width, args.og_height),
213
+ target_size=(args.target_width, args.target_height),
214
+ num_inference_steps=args.steps,
215
+ video_length=args.video_length,
216
+ generator=generator,
217
+ output_type="tensor", **kwargs).videos
218
+
219
+ images = to_pil_images(images, output_type="pil")
220
+
221
+ if args.video_length > 1:
222
+ if args.output.split(".")[-1] == "gif":
223
+ save_as_gif(images, args.output, duration=args.video_duration // args.video_length)
224
+ else:
225
+ save_as_mp4(images, args.output, duration=args.video_duration // args.video_length)
226
+ else:
227
+ images[0].save(args.output, format='JPEG', quality=95)
228
+
229
+
230
+ if __name__ == "__main__":
231
+ main()
requirements.txt ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.23.0
2
+ appdirs==1.4.4
3
+ certifi==2023.7.22
4
+ charset-normalizer==3.3.0
5
+ click==8.1.7
6
+ cmake==3.27.6
7
+ decorator==4.4.2
8
+ diffusers==0.21.4
9
+ docker-pycreds==0.4.0
10
+ einops==0.7.0
11
+ filelock==3.12.4
12
+ fsspec==2023.9.2
13
+ gitdb==4.0.10
14
+ GitPython==3.1.37
15
+ huggingface-hub==0.16.4
16
+ idna==3.4
17
+ imageio==2.31.5
18
+ imageio-ffmpeg==0.4.9
19
+ importlib-metadata==6.8.0
20
+ Jinja2==3.1.2
21
+ lit==17.0.2
22
+ MarkupSafe==2.1.3
23
+ moviepy==1.0.3
24
+ mpmath==1.3.0
25
+ networkx==3.1
26
+ numpy==1.26.0
27
+ nvidia-cublas-cu11==11.10.3.66
28
+ nvidia-cuda-cupti-cu11==11.7.101
29
+ nvidia-cuda-nvrtc-cu11==11.7.99
30
+ nvidia-cuda-runtime-cu11==11.7.99
31
+ nvidia-cudnn-cu11==8.5.0.96
32
+ nvidia-cufft-cu11==10.9.0.58
33
+ nvidia-curand-cu11==10.2.10.91
34
+ nvidia-cusolver-cu11==11.4.0.1
35
+ nvidia-cusparse-cu11==11.7.4.91
36
+ nvidia-nccl-cu11==2.14.3
37
+ nvidia-nvtx-cu11==11.7.91
38
+ packaging==23.2
39
+ pathtools==0.1.2
40
+ Pillow==10.0.1
41
+ proglog==0.1.10
42
+ protobuf==4.24.3
43
+ psutil==5.9.5
44
+ PyYAML==6.0.1
45
+ regex==2023.10.3
46
+ requests==2.31.0
47
+ safetensors==0.3.3
48
+ sentry-sdk==1.31.0
49
+ setproctitle==1.3.3
50
+ six==1.16.0
51
+ smmap==5.0.1
52
+ sympy==1.12
53
+ tokenizers==0.14.0
54
+ torch==2.0.1
55
+ torchvision==0.15.2
56
+ tqdm==4.66.1
57
+ transformers==4.34.0
58
+ triton==2.0.0
59
+ typing_extensions==4.8.0
60
+ urllib3==2.0.6
61
+ wandb==0.15.11
62
+ zipp==3.17.0
setup.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='hotshot_xl',
5
+ version='1.0',
6
+ packages=find_packages(include=['hotshot_xl*',]),
7
+ author="Natural Synthetics Inc",
8
+ install_requires=[
9
+ "torch>=2.0.1",
10
+ "torchvision>=0.15.2",
11
+ "diffusers>=0.21.4",
12
+ "transformers>=4.33.3",
13
+ "einops"
14
+ ],
15
+ )