Spaces:
Running
on
Zero
Running
on
Zero
init commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +160 -0
- LICENSE +201 -0
- app.py +50 -0
- checkpoints/stablemakeup/pytorch_model.bin +3 -0
- checkpoints/stablemakeup/pytorch_model_1.bin +3 -0
- checkpoints/stablemakeup/pytorch_model_2.bin +3 -0
- detail_encoder/.DS_Store +0 -0
- detail_encoder/__init__.py +0 -0
- detail_encoder/_clip.py +1349 -0
- detail_encoder/attention_processor.py +687 -0
- detail_encoder/encoder_plus.py +113 -0
- detail_encoder/resampler.py +112 -0
- diffusers/.DS_Store +0 -0
- diffusers/__init__.py +734 -0
- diffusers/commands/__init__.py +27 -0
- diffusers/commands/diffusers_cli.py +43 -0
- diffusers/commands/env.py +84 -0
- diffusers/commands/fp16_safetensors.py +133 -0
- diffusers/configuration_utils.py +694 -0
- diffusers/dependency_versions_check.py +35 -0
- diffusers/dependency_versions_table.py +46 -0
- diffusers/experimental/README.md +5 -0
- diffusers/experimental/__init__.py +1 -0
- diffusers/experimental/rl/__init__.py +1 -0
- diffusers/experimental/rl/value_guided_sampling.py +154 -0
- diffusers/image_processor.py +476 -0
- diffusers/loaders.py +0 -0
- diffusers/models/README.md +3 -0
- diffusers/models/__init__.py +77 -0
- diffusers/models/activations.py +120 -0
- diffusers/models/adapter.py +584 -0
- diffusers/models/attention.py +396 -0
- diffusers/models/attention_flax.py +486 -0
- diffusers/models/attention_processor.py +2020 -0
- diffusers/models/autoencoder_asym_kl.py +181 -0
- diffusers/models/autoencoder_kl.py +465 -0
- diffusers/models/autoencoder_tiny.py +349 -0
- diffusers/models/consistency_decoder_vae.py +430 -0
- diffusers/models/controlnet.py +844 -0
- diffusers/models/controlnet_flax.py +394 -0
- diffusers/models/dual_transformer_2d.py +155 -0
- diffusers/models/embeddings.py +792 -0
- diffusers/models/embeddings_flax.py +95 -0
- diffusers/models/lora.py +304 -0
- diffusers/models/modeling_flax_pytorch_utils.py +134 -0
- diffusers/models/modeling_flax_utils.py +560 -0
- diffusers/models/modeling_pytorch_flax_utils.py +161 -0
- diffusers/models/modeling_utils.py +1158 -0
- diffusers/models/normalization.py +148 -0
- diffusers/models/prior_transformer.py +382 -0
.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
app.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import gradio as gr
|
3 |
+
from inference_utils import inference
|
4 |
+
|
5 |
+
|
6 |
+
@spaces.GPU
|
7 |
+
def send_to_model(id_image, makeup_image, guidance_scale):
|
8 |
+
if guidance_scale is None:
|
9 |
+
# when creating example caches.
|
10 |
+
guidance_scale = 1.6
|
11 |
+
return inference(id_image, makeup_image, guidance_scale, size=512)
|
12 |
+
|
13 |
+
if __name__ == "__main__":
|
14 |
+
with gr.Blocks() as demo:
|
15 |
+
gr.HTML(
|
16 |
+
"""
|
17 |
+
<h1 style="text-align: center; font-size: 32px; font-family: 'Times New Roman', Times, serif;">
|
18 |
+
Stable-Makeup: When Real-World Makeup Transfer Meets Diffusion Model
|
19 |
+
</h1>
|
20 |
+
<p style="text-align: center; font-size: 20px; font-family: 'Times New Roman', Times, serif;">
|
21 |
+
<a style="text-align: center; display:inline-block"
|
22 |
+
href="https://xiaojiu-z.github.io/Stable-Makeup.github.io/">
|
23 |
+
<img src="https://huggingface.co/datasets/huggingface/badges/raw/main/paper-page-sm.svg#center"
|
24 |
+
alt="Paper Page">
|
25 |
+
</a>
|
26 |
+
<a style="text-align: center; display:inline-block" href="https://huggingface.co/spaces/sky24h/Stable-Makeup-unofficial?duplicate=true">
|
27 |
+
<img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg#center" alt="Duplicate Space">
|
28 |
+
</a>
|
29 |
+
</p>
|
30 |
+
"""
|
31 |
+
)
|
32 |
+
gr.Interface(
|
33 |
+
fn=send_to_model,
|
34 |
+
inputs=[
|
35 |
+
gr.Image(type="pil", label="id_image", height=512, width=512),
|
36 |
+
gr.Image(type="pil", label="makeup_image", height=512, width=512),
|
37 |
+
gr.Slider(minimum=1.01, maximum=3, value=1.6, step=0.05, label="guidance_scale", info="1.05-1.15 is suggested for light makeup and 2 for heavy makeup."),
|
38 |
+
],
|
39 |
+
outputs="image",
|
40 |
+
allow_flagging="never",
|
41 |
+
description="This is an unofficial demo for the paper 'Stable-Makeup: When Real-World Makeup Transfer Meets Diffusion Model'.",
|
42 |
+
examples=[
|
43 |
+
["./test_imgs/id/1.jpg", "./test_imgs/makeup/1.jpg"],
|
44 |
+
["./test_imgs/id/2.jpg", "./test_imgs/makeup/2.jpg"],
|
45 |
+
["./test_imgs/id/3.jpg", "./test_imgs/makeup/3.jpg"],
|
46 |
+
["./test_imgs/id/4.jpg", "./test_imgs/makeup/4.png"],
|
47 |
+
],
|
48 |
+
cache_examples=True,
|
49 |
+
)
|
50 |
+
demo.queue(max_size=10).launch()
|
checkpoints/stablemakeup/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:da8272fdbb74cb70714272e3b4ff381958944b4f308df16977dfe8893dfc7f64
|
3 |
+
size 1373905877
|
checkpoints/stablemakeup/pytorch_model_1.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:05573314b4a16e592456e4c5dc3932fe705e1e35d8609ea886cb98ac2deadf47
|
3 |
+
size 1445256905
|
checkpoints/stablemakeup/pytorch_model_2.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:143d489e9e1f3c5014ae97f735b2c37bbc7c487b7a50f1f6062afc593ab9da40
|
3 |
+
size 1445256905
|
detail_encoder/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
detail_encoder/__init__.py
ADDED
File without changes
|
detail_encoder/_clip.py
ADDED
@@ -0,0 +1,1349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch CLIP model."""
|
16 |
+
|
17 |
+
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from typing import Any, Optional, Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.utils.checkpoint
|
23 |
+
from torch import nn
|
24 |
+
|
25 |
+
from transformers.activations import ACT2FN
|
26 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
27 |
+
from transformers.modeling_utils import PreTrainedModel
|
28 |
+
from transformers.utils import (
|
29 |
+
ModelOutput,
|
30 |
+
add_start_docstrings,
|
31 |
+
add_start_docstrings_to_model_forward,
|
32 |
+
logging,
|
33 |
+
replace_return_docstrings,
|
34 |
+
)
|
35 |
+
from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
36 |
+
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__)
|
39 |
+
|
40 |
+
_CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32"
|
41 |
+
|
42 |
+
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
43 |
+
"openai/clip-vit-base-patch32",
|
44 |
+
# See all CLIP models at https://huggingface.co/models?filter=clip
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
49 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
50 |
+
"""
|
51 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
52 |
+
"""
|
53 |
+
bsz, src_len = mask.size()
|
54 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
55 |
+
|
56 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
57 |
+
|
58 |
+
inverted_mask = 1.0 - expanded_mask
|
59 |
+
|
60 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
61 |
+
|
62 |
+
|
63 |
+
# contrastive loss function, adapted from
|
64 |
+
# https://sachinruk.github.io/blog/2021-03-07-clip.html
|
65 |
+
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
66 |
+
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
|
67 |
+
|
68 |
+
|
69 |
+
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
|
70 |
+
caption_loss = contrastive_loss(similarity)
|
71 |
+
image_loss = contrastive_loss(similarity.t())
|
72 |
+
return (caption_loss + image_loss) / 2.0
|
73 |
+
|
74 |
+
|
75 |
+
@dataclass
|
76 |
+
class CLIPVisionModelOutput(ModelOutput):
|
77 |
+
"""
|
78 |
+
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
82 |
+
The image embeddings obtained by applying the projection layer to the pooler_output.
|
83 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
84 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
85 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
86 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
87 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
88 |
+
|
89 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
90 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
91 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
92 |
+
sequence_length)`.
|
93 |
+
|
94 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
95 |
+
heads.
|
96 |
+
"""
|
97 |
+
|
98 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
99 |
+
last_hidden_state: torch.FloatTensor = None
|
100 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
101 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
102 |
+
|
103 |
+
|
104 |
+
@dataclass
|
105 |
+
class CLIPTextModelOutput(ModelOutput):
|
106 |
+
"""
|
107 |
+
Base class for text model's outputs that also contains a pooling of the last hidden states.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
111 |
+
The text embeddings obtained by applying the projection layer to the pooler_output.
|
112 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
113 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
114 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
115 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
116 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
117 |
+
|
118 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
119 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
120 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
121 |
+
sequence_length)`.
|
122 |
+
|
123 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
124 |
+
heads.
|
125 |
+
"""
|
126 |
+
|
127 |
+
text_embeds: Optional[torch.FloatTensor] = None
|
128 |
+
last_hidden_state: torch.FloatTensor = None
|
129 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
130 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
131 |
+
|
132 |
+
|
133 |
+
@dataclass
|
134 |
+
class CLIPOutput(ModelOutput):
|
135 |
+
"""
|
136 |
+
Args:
|
137 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
138 |
+
Contrastive loss for image-text similarity.
|
139 |
+
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
140 |
+
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
141 |
+
similarity scores.
|
142 |
+
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
143 |
+
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
144 |
+
similarity scores.
|
145 |
+
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
146 |
+
The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
|
147 |
+
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
148 |
+
The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
149 |
+
text_model_output(`BaseModelOutputWithPooling`):
|
150 |
+
The output of the [`CLIPTextModel`].
|
151 |
+
vision_model_output(`BaseModelOutputWithPooling`):
|
152 |
+
The output of the [`CLIPVisionModel`].
|
153 |
+
"""
|
154 |
+
|
155 |
+
loss: Optional[torch.FloatTensor] = None
|
156 |
+
logits_per_image: torch.FloatTensor = None
|
157 |
+
logits_per_text: torch.FloatTensor = None
|
158 |
+
text_embeds: torch.FloatTensor = None
|
159 |
+
image_embeds: torch.FloatTensor = None
|
160 |
+
text_model_output: BaseModelOutputWithPooling = None
|
161 |
+
vision_model_output: BaseModelOutputWithPooling = None
|
162 |
+
|
163 |
+
def to_tuple(self) -> Tuple[Any]:
|
164 |
+
return tuple(
|
165 |
+
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
166 |
+
for k in self.keys()
|
167 |
+
)
|
168 |
+
|
169 |
+
|
170 |
+
class CLIPVisionEmbeddings(nn.Module):
|
171 |
+
def __init__(self, config: CLIPVisionConfig):
|
172 |
+
super().__init__()
|
173 |
+
self.config = config
|
174 |
+
self.embed_dim = config.hidden_size
|
175 |
+
self.image_size = config.image_size
|
176 |
+
self.patch_size = config.patch_size
|
177 |
+
|
178 |
+
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
179 |
+
|
180 |
+
self.patch_embedding = nn.Conv2d(
|
181 |
+
in_channels=config.num_channels,
|
182 |
+
out_channels=self.embed_dim,
|
183 |
+
kernel_size=self.patch_size,
|
184 |
+
stride=self.patch_size,
|
185 |
+
bias=False,
|
186 |
+
)
|
187 |
+
|
188 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
189 |
+
self.num_positions = self.num_patches + 1
|
190 |
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
191 |
+
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
192 |
+
|
193 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
194 |
+
batch_size = pixel_values.shape[0]
|
195 |
+
target_dtype = self.patch_embedding.weight.dtype
|
196 |
+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
197 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
198 |
+
|
199 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
200 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
201 |
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
202 |
+
return embeddings
|
203 |
+
|
204 |
+
|
205 |
+
class CLIPTextEmbeddings(nn.Module):
|
206 |
+
def __init__(self, config: CLIPTextConfig):
|
207 |
+
super().__init__()
|
208 |
+
embed_dim = config.hidden_size
|
209 |
+
|
210 |
+
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
211 |
+
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
212 |
+
|
213 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
214 |
+
self.register_buffer(
|
215 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
216 |
+
)
|
217 |
+
|
218 |
+
def forward(
|
219 |
+
self,
|
220 |
+
input_ids: Optional[torch.LongTensor] = None,
|
221 |
+
position_ids: Optional[torch.LongTensor] = None,
|
222 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
223 |
+
) -> torch.Tensor:
|
224 |
+
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
225 |
+
|
226 |
+
if position_ids is None:
|
227 |
+
position_ids = self.position_ids[:, :seq_length]
|
228 |
+
|
229 |
+
if inputs_embeds is None:
|
230 |
+
inputs_embeds = self.token_embedding(input_ids)
|
231 |
+
|
232 |
+
position_embeddings = self.position_embedding(position_ids)
|
233 |
+
embeddings = inputs_embeds + position_embeddings
|
234 |
+
|
235 |
+
return embeddings
|
236 |
+
|
237 |
+
|
238 |
+
class CLIPAttention(nn.Module):
|
239 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
240 |
+
|
241 |
+
def __init__(self, config):
|
242 |
+
super().__init__()
|
243 |
+
self.config = config
|
244 |
+
self.embed_dim = config.hidden_size
|
245 |
+
self.num_heads = config.num_attention_heads
|
246 |
+
self.head_dim = self.embed_dim // self.num_heads
|
247 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
248 |
+
raise ValueError(
|
249 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
250 |
+
f" {self.num_heads})."
|
251 |
+
)
|
252 |
+
self.scale = self.head_dim**-0.5
|
253 |
+
self.dropout = config.attention_dropout
|
254 |
+
|
255 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
256 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
257 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
258 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
259 |
+
|
260 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
261 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
262 |
+
|
263 |
+
def forward(
|
264 |
+
self,
|
265 |
+
hidden_states: torch.Tensor,
|
266 |
+
attention_mask: Optional[torch.Tensor] = None,
|
267 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
268 |
+
output_attentions: Optional[bool] = False,
|
269 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
270 |
+
"""Input shape: Batch x Time x Channel"""
|
271 |
+
|
272 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
273 |
+
|
274 |
+
# get query proj
|
275 |
+
query_states = self.q_proj(hidden_states) * self.scale
|
276 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
277 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
278 |
+
|
279 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
280 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
281 |
+
key_states = key_states.view(*proj_shape)
|
282 |
+
value_states = value_states.view(*proj_shape)
|
283 |
+
|
284 |
+
src_len = key_states.size(1)
|
285 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
286 |
+
|
287 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
288 |
+
raise ValueError(
|
289 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
290 |
+
f" {attn_weights.size()}"
|
291 |
+
)
|
292 |
+
|
293 |
+
# apply the causal_attention_mask first
|
294 |
+
if causal_attention_mask is not None:
|
295 |
+
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
296 |
+
raise ValueError(
|
297 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
298 |
+
f" {causal_attention_mask.size()}"
|
299 |
+
)
|
300 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
301 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
302 |
+
|
303 |
+
if attention_mask is not None:
|
304 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
305 |
+
raise ValueError(
|
306 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
307 |
+
)
|
308 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
309 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
310 |
+
|
311 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
312 |
+
|
313 |
+
if output_attentions:
|
314 |
+
# this operation is a bit akward, but it's required to
|
315 |
+
# make sure that attn_weights keeps its gradient.
|
316 |
+
# In order to do so, attn_weights have to reshaped
|
317 |
+
# twice and have to be reused in the following
|
318 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
319 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
320 |
+
else:
|
321 |
+
attn_weights_reshaped = None
|
322 |
+
|
323 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
324 |
+
|
325 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
326 |
+
|
327 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
328 |
+
raise ValueError(
|
329 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
330 |
+
f" {attn_output.size()}"
|
331 |
+
)
|
332 |
+
|
333 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
334 |
+
attn_output = attn_output.transpose(1, 2)
|
335 |
+
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
336 |
+
|
337 |
+
attn_output = self.out_proj(attn_output)
|
338 |
+
|
339 |
+
return attn_output, attn_weights_reshaped
|
340 |
+
|
341 |
+
|
342 |
+
class CLIPMLP(nn.Module):
|
343 |
+
def __init__(self, config):
|
344 |
+
super().__init__()
|
345 |
+
self.config = config
|
346 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
347 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
348 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
349 |
+
|
350 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
351 |
+
hidden_states = self.fc1(hidden_states)
|
352 |
+
hidden_states = self.activation_fn(hidden_states)
|
353 |
+
hidden_states = self.fc2(hidden_states)
|
354 |
+
return hidden_states
|
355 |
+
|
356 |
+
|
357 |
+
class CLIPEncoderLayer(nn.Module):
|
358 |
+
def __init__(self, config: CLIPConfig):
|
359 |
+
super().__init__()
|
360 |
+
self.embed_dim = config.hidden_size
|
361 |
+
self.self_attn = CLIPAttention(config)
|
362 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
363 |
+
self.mlp = CLIPMLP(config)
|
364 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
365 |
+
|
366 |
+
def forward(
|
367 |
+
self,
|
368 |
+
hidden_states: torch.Tensor,
|
369 |
+
attention_mask: torch.Tensor,
|
370 |
+
causal_attention_mask: torch.Tensor,
|
371 |
+
output_attentions: Optional[bool] = False,
|
372 |
+
) -> Tuple[torch.FloatTensor]:
|
373 |
+
"""
|
374 |
+
Args:
|
375 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
376 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
377 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
378 |
+
`(config.encoder_attention_heads,)`.
|
379 |
+
output_attentions (`bool`, *optional*):
|
380 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
381 |
+
returned tensors for more detail.
|
382 |
+
"""
|
383 |
+
residual = hidden_states
|
384 |
+
|
385 |
+
hidden_states = self.layer_norm1(hidden_states)
|
386 |
+
hidden_states, attn_weights = self.self_attn(
|
387 |
+
hidden_states=hidden_states,
|
388 |
+
attention_mask=attention_mask,
|
389 |
+
causal_attention_mask=causal_attention_mask,
|
390 |
+
output_attentions=output_attentions,
|
391 |
+
)
|
392 |
+
hidden_states = residual + hidden_states
|
393 |
+
|
394 |
+
residual = hidden_states
|
395 |
+
hidden_states = self.layer_norm2(hidden_states)
|
396 |
+
hidden_states = self.mlp(hidden_states)
|
397 |
+
hidden_states = residual + hidden_states
|
398 |
+
|
399 |
+
outputs = (hidden_states,)
|
400 |
+
|
401 |
+
if output_attentions:
|
402 |
+
outputs += (attn_weights,)
|
403 |
+
|
404 |
+
return outputs
|
405 |
+
|
406 |
+
|
407 |
+
class CLIPPreTrainedModel(PreTrainedModel):
|
408 |
+
"""
|
409 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
410 |
+
models.
|
411 |
+
"""
|
412 |
+
|
413 |
+
config_class = CLIPConfig
|
414 |
+
base_model_prefix = "clip"
|
415 |
+
supports_gradient_checkpointing = True
|
416 |
+
|
417 |
+
def _init_weights(self, module):
|
418 |
+
"""Initialize the weights"""
|
419 |
+
factor = self.config.initializer_factor
|
420 |
+
if isinstance(module, CLIPTextEmbeddings):
|
421 |
+
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
422 |
+
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
423 |
+
elif isinstance(module, CLIPVisionEmbeddings):
|
424 |
+
factor = self.config.initializer_factor
|
425 |
+
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
426 |
+
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
427 |
+
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
|
428 |
+
elif isinstance(module, CLIPAttention):
|
429 |
+
factor = self.config.initializer_factor
|
430 |
+
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
431 |
+
out_proj_std = (module.embed_dim**-0.5) * factor
|
432 |
+
nn.init.normal_(module.q_proj.weight, std=in_proj_std)
|
433 |
+
nn.init.normal_(module.k_proj.weight, std=in_proj_std)
|
434 |
+
nn.init.normal_(module.v_proj.weight, std=in_proj_std)
|
435 |
+
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
|
436 |
+
elif isinstance(module, CLIPMLP):
|
437 |
+
factor = self.config.initializer_factor
|
438 |
+
in_proj_std = (
|
439 |
+
(module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
440 |
+
)
|
441 |
+
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
|
442 |
+
nn.init.normal_(module.fc1.weight, std=fc_std)
|
443 |
+
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
444 |
+
elif isinstance(module, CLIPModel):
|
445 |
+
nn.init.normal_(
|
446 |
+
module.text_projection.weight,
|
447 |
+
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
448 |
+
)
|
449 |
+
nn.init.normal_(
|
450 |
+
module.visual_projection.weight,
|
451 |
+
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
452 |
+
)
|
453 |
+
elif isinstance(module, CLIPVisionModelWithProjection):
|
454 |
+
nn.init.normal_(
|
455 |
+
module.visual_projection.weight,
|
456 |
+
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
|
457 |
+
)
|
458 |
+
elif isinstance(module, CLIPTextModelWithProjection):
|
459 |
+
nn.init.normal_(
|
460 |
+
module.text_projection.weight,
|
461 |
+
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
|
462 |
+
)
|
463 |
+
|
464 |
+
if isinstance(module, nn.LayerNorm):
|
465 |
+
module.bias.data.zero_()
|
466 |
+
module.weight.data.fill_(1.0)
|
467 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
468 |
+
module.bias.data.zero_()
|
469 |
+
|
470 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
471 |
+
if isinstance(module, CLIPEncoder):
|
472 |
+
module.gradient_checkpointing = value
|
473 |
+
|
474 |
+
|
475 |
+
CLIP_START_DOCSTRING = r"""
|
476 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
477 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
478 |
+
etc.)
|
479 |
+
|
480 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
481 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
482 |
+
and behavior.
|
483 |
+
|
484 |
+
Parameters:
|
485 |
+
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
|
486 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
487 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
488 |
+
"""
|
489 |
+
|
490 |
+
CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
491 |
+
Args:
|
492 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
493 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
494 |
+
it.
|
495 |
+
|
496 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
497 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
498 |
+
|
499 |
+
[What are input IDs?](../glossary#input-ids)
|
500 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
501 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
502 |
+
|
503 |
+
- 1 for tokens that are **not masked**,
|
504 |
+
- 0 for tokens that are **masked**.
|
505 |
+
|
506 |
+
[What are attention masks?](../glossary#attention-mask)
|
507 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
508 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
509 |
+
config.max_position_embeddings - 1]`.
|
510 |
+
|
511 |
+
[What are position IDs?](../glossary#position-ids)
|
512 |
+
output_attentions (`bool`, *optional*):
|
513 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
514 |
+
tensors for more detail.
|
515 |
+
output_hidden_states (`bool`, *optional*):
|
516 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
517 |
+
more detail.
|
518 |
+
return_dict (`bool`, *optional*):
|
519 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
520 |
+
"""
|
521 |
+
|
522 |
+
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
523 |
+
Args:
|
524 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
525 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
526 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
527 |
+
output_attentions (`bool`, *optional*):
|
528 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
529 |
+
tensors for more detail.
|
530 |
+
output_hidden_states (`bool`, *optional*):
|
531 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
532 |
+
more detail.
|
533 |
+
return_dict (`bool`, *optional*):
|
534 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
535 |
+
"""
|
536 |
+
|
537 |
+
CLIP_INPUTS_DOCSTRING = r"""
|
538 |
+
Args:
|
539 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
540 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
541 |
+
it.
|
542 |
+
|
543 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
544 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
545 |
+
|
546 |
+
[What are input IDs?](../glossary#input-ids)
|
547 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
548 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
549 |
+
|
550 |
+
- 1 for tokens that are **not masked**,
|
551 |
+
- 0 for tokens that are **masked**.
|
552 |
+
|
553 |
+
[What are attention masks?](../glossary#attention-mask)
|
554 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
555 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
556 |
+
config.max_position_embeddings - 1]`.
|
557 |
+
|
558 |
+
[What are position IDs?](../glossary#position-ids)
|
559 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
560 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
561 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
562 |
+
return_loss (`bool`, *optional*):
|
563 |
+
Whether or not to return the contrastive loss.
|
564 |
+
output_attentions (`bool`, *optional*):
|
565 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
566 |
+
tensors for more detail.
|
567 |
+
output_hidden_states (`bool`, *optional*):
|
568 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
569 |
+
more detail.
|
570 |
+
return_dict (`bool`, *optional*):
|
571 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
572 |
+
"""
|
573 |
+
|
574 |
+
|
575 |
+
class CLIPEncoder(nn.Module):
|
576 |
+
"""
|
577 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
578 |
+
[`CLIPEncoderLayer`].
|
579 |
+
|
580 |
+
Args:
|
581 |
+
config: CLIPConfig
|
582 |
+
"""
|
583 |
+
|
584 |
+
def __init__(self, config: CLIPConfig):
|
585 |
+
super().__init__()
|
586 |
+
self.config = config
|
587 |
+
self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
588 |
+
self.gradient_checkpointing = False
|
589 |
+
|
590 |
+
def forward(
|
591 |
+
self,
|
592 |
+
inputs_embeds,
|
593 |
+
attention_mask: Optional[torch.Tensor] = None,
|
594 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
595 |
+
output_attentions: Optional[bool] = None,
|
596 |
+
output_hidden_states: Optional[bool] = None,
|
597 |
+
return_dict: Optional[bool] = None,
|
598 |
+
) -> Union[Tuple, BaseModelOutput]:
|
599 |
+
r"""
|
600 |
+
Args:
|
601 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
602 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
603 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
604 |
+
than the model's internal embedding lookup matrix.
|
605 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
606 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
607 |
+
|
608 |
+
- 1 for tokens that are **not masked**,
|
609 |
+
- 0 for tokens that are **masked**.
|
610 |
+
|
611 |
+
[What are attention masks?](../glossary#attention-mask)
|
612 |
+
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
613 |
+
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
614 |
+
|
615 |
+
- 1 for tokens that are **not masked**,
|
616 |
+
- 0 for tokens that are **masked**.
|
617 |
+
|
618 |
+
[What are attention masks?](../glossary#attention-mask)
|
619 |
+
output_attentions (`bool`, *optional*):
|
620 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
621 |
+
returned tensors for more detail.
|
622 |
+
output_hidden_states (`bool`, *optional*):
|
623 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
624 |
+
for more detail.
|
625 |
+
return_dict (`bool`, *optional*):
|
626 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
627 |
+
"""
|
628 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
629 |
+
output_hidden_states = (
|
630 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
631 |
+
)
|
632 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
633 |
+
|
634 |
+
encoder_states = () if output_hidden_states else None
|
635 |
+
all_attentions = () if output_attentions else None
|
636 |
+
|
637 |
+
hidden_states = inputs_embeds
|
638 |
+
for idx, encoder_layer in enumerate(self.layers):
|
639 |
+
if output_hidden_states:
|
640 |
+
encoder_states = encoder_states + (hidden_states,)
|
641 |
+
if self.gradient_checkpointing and self.training:
|
642 |
+
|
643 |
+
def create_custom_forward(module):
|
644 |
+
def custom_forward(*inputs):
|
645 |
+
return module(*inputs, output_attentions)
|
646 |
+
|
647 |
+
return custom_forward
|
648 |
+
|
649 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
650 |
+
create_custom_forward(encoder_layer),
|
651 |
+
hidden_states,
|
652 |
+
attention_mask,
|
653 |
+
causal_attention_mask,
|
654 |
+
)
|
655 |
+
else:
|
656 |
+
layer_outputs = encoder_layer(
|
657 |
+
hidden_states,
|
658 |
+
attention_mask,
|
659 |
+
causal_attention_mask,
|
660 |
+
output_attentions=output_attentions,
|
661 |
+
)
|
662 |
+
|
663 |
+
hidden_states = layer_outputs[0]
|
664 |
+
|
665 |
+
if output_attentions:
|
666 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
667 |
+
|
668 |
+
if output_hidden_states:
|
669 |
+
encoder_states = encoder_states + (hidden_states,)
|
670 |
+
|
671 |
+
if not return_dict:
|
672 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
673 |
+
return BaseModelOutput(
|
674 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
675 |
+
)
|
676 |
+
|
677 |
+
|
678 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
679 |
+
def _make_causal_mask(
|
680 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
681 |
+
):
|
682 |
+
"""
|
683 |
+
Make causal mask used for bi-directional self-attention.
|
684 |
+
"""
|
685 |
+
bsz, tgt_len = input_ids_shape
|
686 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
687 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
688 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
689 |
+
mask = mask.to(dtype)
|
690 |
+
|
691 |
+
if past_key_values_length > 0:
|
692 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
693 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
694 |
+
|
695 |
+
|
696 |
+
class CLIPTextTransformer(nn.Module):
|
697 |
+
def __init__(self, config: CLIPTextConfig):
|
698 |
+
super().__init__()
|
699 |
+
self.config = config
|
700 |
+
embed_dim = config.hidden_size
|
701 |
+
self.embeddings = CLIPTextEmbeddings(config)
|
702 |
+
self.encoder = CLIPEncoder(config)
|
703 |
+
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
704 |
+
|
705 |
+
# For `pooled_output` computation
|
706 |
+
self.eos_token_id = config.eos_token_id
|
707 |
+
|
708 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
709 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
710 |
+
def forward(
|
711 |
+
self,
|
712 |
+
input_ids: Optional[torch.Tensor] = None,
|
713 |
+
attention_mask: Optional[torch.Tensor] = None,
|
714 |
+
position_ids: Optional[torch.Tensor] = None,
|
715 |
+
output_attentions: Optional[bool] = None,
|
716 |
+
output_hidden_states: Optional[bool] = None,
|
717 |
+
return_dict: Optional[bool] = None,
|
718 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
719 |
+
r"""
|
720 |
+
Returns:
|
721 |
+
|
722 |
+
"""
|
723 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
724 |
+
output_hidden_states = (
|
725 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
726 |
+
)
|
727 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
728 |
+
|
729 |
+
if input_ids is None:
|
730 |
+
raise ValueError("You have to specify input_ids")
|
731 |
+
|
732 |
+
input_shape = input_ids.size()
|
733 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
734 |
+
|
735 |
+
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
736 |
+
|
737 |
+
# CLIP's text model uses causal mask, prepare it here.
|
738 |
+
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
739 |
+
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
|
740 |
+
# expand attention_mask
|
741 |
+
if attention_mask is not None:
|
742 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
743 |
+
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
744 |
+
|
745 |
+
encoder_outputs = self.encoder(
|
746 |
+
inputs_embeds=hidden_states,
|
747 |
+
attention_mask=attention_mask,
|
748 |
+
causal_attention_mask=causal_attention_mask,
|
749 |
+
output_attentions=output_attentions,
|
750 |
+
output_hidden_states=output_hidden_states,
|
751 |
+
return_dict=return_dict,
|
752 |
+
)
|
753 |
+
|
754 |
+
last_hidden_state = encoder_outputs[0]
|
755 |
+
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
756 |
+
|
757 |
+
if self.eos_token_id == 2:
|
758 |
+
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
759 |
+
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
760 |
+
# ------------------------------------------------------------
|
761 |
+
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
762 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
763 |
+
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
764 |
+
pooled_output = last_hidden_state[
|
765 |
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
766 |
+
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
767 |
+
]
|
768 |
+
else:
|
769 |
+
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
770 |
+
pooled_output = last_hidden_state[
|
771 |
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
772 |
+
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
|
773 |
+
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
|
774 |
+
.int()
|
775 |
+
.argmax(dim=-1),
|
776 |
+
]
|
777 |
+
|
778 |
+
if not return_dict:
|
779 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
780 |
+
|
781 |
+
return BaseModelOutputWithPooling(
|
782 |
+
last_hidden_state=last_hidden_state,
|
783 |
+
pooler_output=pooled_output,
|
784 |
+
hidden_states=encoder_outputs.hidden_states,
|
785 |
+
attentions=encoder_outputs.attentions,
|
786 |
+
)
|
787 |
+
|
788 |
+
|
789 |
+
@add_start_docstrings(
|
790 |
+
"""The text model from CLIP without any head or projection on top.""",
|
791 |
+
CLIP_START_DOCSTRING,
|
792 |
+
)
|
793 |
+
class CLIPTextModel(CLIPPreTrainedModel):
|
794 |
+
config_class = CLIPTextConfig
|
795 |
+
|
796 |
+
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
|
797 |
+
|
798 |
+
def __init__(self, config: CLIPTextConfig):
|
799 |
+
super().__init__(config)
|
800 |
+
self.text_model = CLIPTextTransformer(config)
|
801 |
+
# Initialize weights and apply final processing
|
802 |
+
self.post_init()
|
803 |
+
|
804 |
+
def get_input_embeddings(self) -> nn.Module:
|
805 |
+
return self.text_model.embeddings.token_embedding
|
806 |
+
|
807 |
+
def set_input_embeddings(self, value):
|
808 |
+
self.text_model.embeddings.token_embedding = value
|
809 |
+
|
810 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
811 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
812 |
+
def forward(
|
813 |
+
self,
|
814 |
+
input_ids: Optional[torch.Tensor] = None,
|
815 |
+
attention_mask: Optional[torch.Tensor] = None,
|
816 |
+
position_ids: Optional[torch.Tensor] = None,
|
817 |
+
output_attentions: Optional[bool] = None,
|
818 |
+
output_hidden_states: Optional[bool] = None,
|
819 |
+
return_dict: Optional[bool] = None,
|
820 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
821 |
+
r"""
|
822 |
+
Returns:
|
823 |
+
|
824 |
+
Examples:
|
825 |
+
|
826 |
+
```python
|
827 |
+
>>> from transformers import AutoTokenizer, CLIPTextModel
|
828 |
+
|
829 |
+
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
830 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
831 |
+
|
832 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
833 |
+
|
834 |
+
>>> outputs = model(**inputs)
|
835 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
836 |
+
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
837 |
+
```"""
|
838 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
839 |
+
|
840 |
+
return self.text_model(
|
841 |
+
input_ids=input_ids,
|
842 |
+
attention_mask=attention_mask,
|
843 |
+
position_ids=position_ids,
|
844 |
+
output_attentions=output_attentions,
|
845 |
+
output_hidden_states=output_hidden_states,
|
846 |
+
return_dict=return_dict,
|
847 |
+
)
|
848 |
+
|
849 |
+
|
850 |
+
class CLIPVisionTransformer(nn.Module):
|
851 |
+
def __init__(self, config: CLIPVisionConfig):
|
852 |
+
super().__init__()
|
853 |
+
self.config = config
|
854 |
+
embed_dim = config.hidden_size
|
855 |
+
|
856 |
+
self.embeddings = CLIPVisionEmbeddings(config)
|
857 |
+
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
858 |
+
self.encoder = CLIPEncoder(config)
|
859 |
+
self.post_layernorm = None
|
860 |
+
|
861 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
862 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
|
863 |
+
def forward(
|
864 |
+
self,
|
865 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
866 |
+
output_attentions: Optional[bool] = None,
|
867 |
+
output_hidden_states: Optional[bool] = None,
|
868 |
+
return_dict: Optional[bool] = None,
|
869 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
870 |
+
r"""
|
871 |
+
Returns:
|
872 |
+
|
873 |
+
"""
|
874 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
875 |
+
output_hidden_states = (
|
876 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
877 |
+
)
|
878 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
879 |
+
|
880 |
+
if pixel_values is None:
|
881 |
+
raise ValueError("You have to specify pixel_values")
|
882 |
+
|
883 |
+
hidden_states = self.embeddings(pixel_values)
|
884 |
+
hidden_states = self.pre_layrnorm(hidden_states)
|
885 |
+
|
886 |
+
encoder_outputs = self.encoder(
|
887 |
+
inputs_embeds=hidden_states,
|
888 |
+
output_attentions=output_attentions,
|
889 |
+
output_hidden_states=output_hidden_states,
|
890 |
+
return_dict=return_dict,
|
891 |
+
)
|
892 |
+
|
893 |
+
last_hidden_state = encoder_outputs[0]
|
894 |
+
# pooled_output = last_hidden_state[:, 0, :]
|
895 |
+
# pooled_output = self.post_layernorm(pooled_output)
|
896 |
+
pooled_output = None
|
897 |
+
|
898 |
+
if not return_dict:
|
899 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
900 |
+
|
901 |
+
return BaseModelOutputWithPooling(
|
902 |
+
last_hidden_state=last_hidden_state,
|
903 |
+
pooler_output=pooled_output,
|
904 |
+
hidden_states=encoder_outputs.hidden_states,
|
905 |
+
attentions=encoder_outputs.attentions,
|
906 |
+
)
|
907 |
+
|
908 |
+
|
909 |
+
@add_start_docstrings(
|
910 |
+
"""The vision model from CLIP without any head or projection on top.""",
|
911 |
+
CLIP_START_DOCSTRING,
|
912 |
+
)
|
913 |
+
class CLIPVisionModel(CLIPPreTrainedModel):
|
914 |
+
config_class = CLIPVisionConfig
|
915 |
+
main_input_name = "pixel_values"
|
916 |
+
|
917 |
+
def __init__(self, config: CLIPVisionConfig):
|
918 |
+
super().__init__(config)
|
919 |
+
self.vision_model = CLIPVisionTransformer(config)
|
920 |
+
# Initialize weights and apply final processing
|
921 |
+
self.post_init()
|
922 |
+
|
923 |
+
def get_input_embeddings(self) -> nn.Module:
|
924 |
+
return self.vision_model.embeddings.patch_embedding
|
925 |
+
|
926 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
927 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
|
928 |
+
def forward(
|
929 |
+
self,
|
930 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
931 |
+
output_attentions: Optional[bool] = None,
|
932 |
+
output_hidden_states: Optional[bool] = None,
|
933 |
+
return_dict: Optional[bool] = None,
|
934 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
935 |
+
r"""
|
936 |
+
Returns:
|
937 |
+
|
938 |
+
Examples:
|
939 |
+
|
940 |
+
```python
|
941 |
+
>>> from PIL import Image
|
942 |
+
>>> import requests
|
943 |
+
>>> from transformers import AutoProcessor, CLIPVisionModel
|
944 |
+
|
945 |
+
>>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
|
946 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
947 |
+
|
948 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
949 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
950 |
+
|
951 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
952 |
+
|
953 |
+
>>> outputs = model(**inputs)
|
954 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
955 |
+
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
956 |
+
```"""
|
957 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
958 |
+
|
959 |
+
return self.vision_model(
|
960 |
+
pixel_values=pixel_values,
|
961 |
+
output_attentions=output_attentions,
|
962 |
+
output_hidden_states=output_hidden_states,
|
963 |
+
return_dict=return_dict,
|
964 |
+
)
|
965 |
+
|
966 |
+
|
967 |
+
@add_start_docstrings(CLIP_START_DOCSTRING)
|
968 |
+
class CLIPModel(CLIPPreTrainedModel):
|
969 |
+
config_class = CLIPConfig
|
970 |
+
|
971 |
+
def __init__(self, config: CLIPConfig):
|
972 |
+
super().__init__(config)
|
973 |
+
|
974 |
+
if not isinstance(config.text_config, CLIPTextConfig):
|
975 |
+
raise ValueError(
|
976 |
+
"config.text_config is expected to be of type CLIPTextConfig but is of type"
|
977 |
+
f" {type(config.text_config)}."
|
978 |
+
)
|
979 |
+
|
980 |
+
if not isinstance(config.vision_config, CLIPVisionConfig):
|
981 |
+
raise ValueError(
|
982 |
+
"config.vision_config is expected to be of type CLIPVisionConfig but is of type"
|
983 |
+
f" {type(config.vision_config)}."
|
984 |
+
)
|
985 |
+
|
986 |
+
text_config = config.text_config
|
987 |
+
vision_config = config.vision_config
|
988 |
+
|
989 |
+
self.projection_dim = config.projection_dim
|
990 |
+
self.text_embed_dim = text_config.hidden_size
|
991 |
+
self.vision_embed_dim = vision_config.hidden_size
|
992 |
+
|
993 |
+
self.text_model = CLIPTextTransformer(text_config)
|
994 |
+
self.vision_model = CLIPVisionTransformer(vision_config)
|
995 |
+
|
996 |
+
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
|
997 |
+
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
998 |
+
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
|
999 |
+
|
1000 |
+
# Initialize weights and apply final processing
|
1001 |
+
self.post_init()
|
1002 |
+
|
1003 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
1004 |
+
def get_text_features(
|
1005 |
+
self,
|
1006 |
+
input_ids: Optional[torch.Tensor] = None,
|
1007 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1008 |
+
position_ids: Optional[torch.Tensor] = None,
|
1009 |
+
output_attentions: Optional[bool] = None,
|
1010 |
+
output_hidden_states: Optional[bool] = None,
|
1011 |
+
return_dict: Optional[bool] = None,
|
1012 |
+
) -> torch.FloatTensor:
|
1013 |
+
r"""
|
1014 |
+
Returns:
|
1015 |
+
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
1016 |
+
applying the projection layer to the pooled output of [`CLIPTextModel`].
|
1017 |
+
|
1018 |
+
Examples:
|
1019 |
+
|
1020 |
+
```python
|
1021 |
+
>>> from transformers import AutoTokenizer, CLIPModel
|
1022 |
+
|
1023 |
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
1024 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
1025 |
+
|
1026 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
1027 |
+
>>> text_features = model.get_text_features(**inputs)
|
1028 |
+
```"""
|
1029 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
1030 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1031 |
+
output_hidden_states = (
|
1032 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1033 |
+
)
|
1034 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1035 |
+
|
1036 |
+
text_outputs = self.text_model(
|
1037 |
+
input_ids=input_ids,
|
1038 |
+
attention_mask=attention_mask,
|
1039 |
+
position_ids=position_ids,
|
1040 |
+
output_attentions=output_attentions,
|
1041 |
+
output_hidden_states=output_hidden_states,
|
1042 |
+
return_dict=return_dict,
|
1043 |
+
)
|
1044 |
+
|
1045 |
+
pooled_output = text_outputs[1]
|
1046 |
+
text_features = self.text_projection(pooled_output)
|
1047 |
+
|
1048 |
+
return text_features
|
1049 |
+
|
1050 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
1051 |
+
def get_image_features(
|
1052 |
+
self,
|
1053 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
1054 |
+
output_attentions: Optional[bool] = None,
|
1055 |
+
output_hidden_states: Optional[bool] = None,
|
1056 |
+
return_dict: Optional[bool] = None,
|
1057 |
+
) -> torch.FloatTensor:
|
1058 |
+
r"""
|
1059 |
+
Returns:
|
1060 |
+
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
1061 |
+
applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
1062 |
+
|
1063 |
+
Examples:
|
1064 |
+
|
1065 |
+
```python
|
1066 |
+
>>> from PIL import Image
|
1067 |
+
>>> import requests
|
1068 |
+
>>> from transformers import AutoProcessor, CLIPModel
|
1069 |
+
|
1070 |
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
1071 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
1072 |
+
|
1073 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1074 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1075 |
+
|
1076 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
1077 |
+
|
1078 |
+
>>> image_features = model.get_image_features(**inputs)
|
1079 |
+
```"""
|
1080 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
1081 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1082 |
+
output_hidden_states = (
|
1083 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1084 |
+
)
|
1085 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1086 |
+
|
1087 |
+
vision_outputs = self.vision_model(
|
1088 |
+
pixel_values=pixel_values,
|
1089 |
+
output_attentions=output_attentions,
|
1090 |
+
output_hidden_states=output_hidden_states,
|
1091 |
+
return_dict=return_dict,
|
1092 |
+
)
|
1093 |
+
|
1094 |
+
pooled_output = vision_outputs[1] # pooled_output
|
1095 |
+
image_features = self.visual_projection(pooled_output)
|
1096 |
+
|
1097 |
+
return image_features
|
1098 |
+
|
1099 |
+
@add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
|
1100 |
+
@replace_return_docstrings(output_type=CLIPOutput, config_class=CLIPConfig)
|
1101 |
+
def forward(
|
1102 |
+
self,
|
1103 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1104 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
1105 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1106 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1107 |
+
return_loss: Optional[bool] = None,
|
1108 |
+
output_attentions: Optional[bool] = None,
|
1109 |
+
output_hidden_states: Optional[bool] = None,
|
1110 |
+
return_dict: Optional[bool] = None,
|
1111 |
+
) -> Union[Tuple, CLIPOutput]:
|
1112 |
+
r"""
|
1113 |
+
Returns:
|
1114 |
+
|
1115 |
+
Examples:
|
1116 |
+
|
1117 |
+
```python
|
1118 |
+
>>> from PIL import Image
|
1119 |
+
>>> import requests
|
1120 |
+
>>> from transformers import AutoProcessor, CLIPModel
|
1121 |
+
|
1122 |
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
1123 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
1124 |
+
|
1125 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1126 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1127 |
+
|
1128 |
+
>>> inputs = processor(
|
1129 |
+
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
|
1130 |
+
... )
|
1131 |
+
|
1132 |
+
>>> outputs = model(**inputs)
|
1133 |
+
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
1134 |
+
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
1135 |
+
```"""
|
1136 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
1137 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1138 |
+
output_hidden_states = (
|
1139 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1140 |
+
)
|
1141 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1142 |
+
|
1143 |
+
vision_outputs = self.vision_model(
|
1144 |
+
pixel_values=pixel_values,
|
1145 |
+
output_attentions=output_attentions,
|
1146 |
+
output_hidden_states=output_hidden_states,
|
1147 |
+
return_dict=return_dict,
|
1148 |
+
)
|
1149 |
+
|
1150 |
+
text_outputs = self.text_model(
|
1151 |
+
input_ids=input_ids,
|
1152 |
+
attention_mask=attention_mask,
|
1153 |
+
position_ids=position_ids,
|
1154 |
+
output_attentions=output_attentions,
|
1155 |
+
output_hidden_states=output_hidden_states,
|
1156 |
+
return_dict=return_dict,
|
1157 |
+
)
|
1158 |
+
|
1159 |
+
image_embeds = vision_outputs[1]
|
1160 |
+
image_embeds = self.visual_projection(image_embeds)
|
1161 |
+
|
1162 |
+
text_embeds = text_outputs[1]
|
1163 |
+
text_embeds = self.text_projection(text_embeds)
|
1164 |
+
|
1165 |
+
# normalized features
|
1166 |
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
1167 |
+
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
1168 |
+
|
1169 |
+
# cosine similarity as logits
|
1170 |
+
logit_scale = self.logit_scale.exp()
|
1171 |
+
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
1172 |
+
logits_per_image = logits_per_text.t()
|
1173 |
+
|
1174 |
+
loss = None
|
1175 |
+
if return_loss:
|
1176 |
+
loss = clip_loss(logits_per_text)
|
1177 |
+
|
1178 |
+
if not return_dict:
|
1179 |
+
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
1180 |
+
return ((loss,) + output) if loss is not None else output
|
1181 |
+
|
1182 |
+
return CLIPOutput(
|
1183 |
+
loss=loss,
|
1184 |
+
logits_per_image=logits_per_image,
|
1185 |
+
logits_per_text=logits_per_text,
|
1186 |
+
text_embeds=text_embeds,
|
1187 |
+
image_embeds=image_embeds,
|
1188 |
+
text_model_output=text_outputs,
|
1189 |
+
vision_model_output=vision_outputs,
|
1190 |
+
)
|
1191 |
+
|
1192 |
+
|
1193 |
+
@add_start_docstrings(
|
1194 |
+
"""
|
1195 |
+
CLIP Text Model with a projection layer on top (a linear layer on top of the pooled output).
|
1196 |
+
""",
|
1197 |
+
CLIP_START_DOCSTRING,
|
1198 |
+
)
|
1199 |
+
class CLIPTextModelWithProjection(CLIPPreTrainedModel):
|
1200 |
+
config_class = CLIPTextConfig
|
1201 |
+
|
1202 |
+
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
|
1203 |
+
|
1204 |
+
def __init__(self, config: CLIPTextConfig):
|
1205 |
+
super().__init__(config)
|
1206 |
+
|
1207 |
+
self.text_model = CLIPTextTransformer(config)
|
1208 |
+
|
1209 |
+
self.text_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
|
1210 |
+
|
1211 |
+
# Initialize weights and apply final processing
|
1212 |
+
self.post_init()
|
1213 |
+
|
1214 |
+
def get_input_embeddings(self) -> nn.Module:
|
1215 |
+
return self.text_model.embeddings.token_embedding
|
1216 |
+
|
1217 |
+
def set_input_embeddings(self, value):
|
1218 |
+
self.text_model.embeddings.token_embedding = value
|
1219 |
+
|
1220 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
1221 |
+
@replace_return_docstrings(output_type=CLIPTextModelOutput, config_class=CLIPTextConfig)
|
1222 |
+
def forward(
|
1223 |
+
self,
|
1224 |
+
input_ids: Optional[torch.Tensor] = None,
|
1225 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1226 |
+
position_ids: Optional[torch.Tensor] = None,
|
1227 |
+
output_attentions: Optional[bool] = None,
|
1228 |
+
output_hidden_states: Optional[bool] = None,
|
1229 |
+
return_dict: Optional[bool] = None,
|
1230 |
+
) -> Union[Tuple, CLIPTextModelOutput]:
|
1231 |
+
r"""
|
1232 |
+
Returns:
|
1233 |
+
|
1234 |
+
Examples:
|
1235 |
+
|
1236 |
+
```python
|
1237 |
+
>>> from transformers import AutoTokenizer, CLIPTextModelWithProjection
|
1238 |
+
|
1239 |
+
>>> model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
|
1240 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
1241 |
+
|
1242 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
1243 |
+
|
1244 |
+
>>> outputs = model(**inputs)
|
1245 |
+
>>> text_embeds = outputs.text_embeds
|
1246 |
+
```"""
|
1247 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1248 |
+
|
1249 |
+
text_outputs = self.text_model(
|
1250 |
+
input_ids=input_ids,
|
1251 |
+
attention_mask=attention_mask,
|
1252 |
+
position_ids=position_ids,
|
1253 |
+
output_attentions=output_attentions,
|
1254 |
+
output_hidden_states=output_hidden_states,
|
1255 |
+
return_dict=return_dict,
|
1256 |
+
)
|
1257 |
+
|
1258 |
+
pooled_output = text_outputs[1]
|
1259 |
+
|
1260 |
+
text_embeds = self.text_projection(pooled_output)
|
1261 |
+
|
1262 |
+
if not return_dict:
|
1263 |
+
outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
|
1264 |
+
return tuple(output for output in outputs if output is not None)
|
1265 |
+
|
1266 |
+
return CLIPTextModelOutput(
|
1267 |
+
text_embeds=text_embeds,
|
1268 |
+
last_hidden_state=text_outputs.last_hidden_state,
|
1269 |
+
hidden_states=text_outputs.hidden_states,
|
1270 |
+
attentions=text_outputs.attentions,
|
1271 |
+
)
|
1272 |
+
|
1273 |
+
|
1274 |
+
@add_start_docstrings(
|
1275 |
+
"""
|
1276 |
+
CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output).
|
1277 |
+
""",
|
1278 |
+
CLIP_START_DOCSTRING,
|
1279 |
+
)
|
1280 |
+
class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
|
1281 |
+
config_class = CLIPVisionConfig
|
1282 |
+
main_input_name = "pixel_values"
|
1283 |
+
|
1284 |
+
def __init__(self, config: CLIPVisionConfig):
|
1285 |
+
super().__init__(config)
|
1286 |
+
|
1287 |
+
self.vision_model = CLIPVisionTransformer(config)
|
1288 |
+
|
1289 |
+
self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
|
1290 |
+
|
1291 |
+
# Initialize weights and apply final processing
|
1292 |
+
self.post_init()
|
1293 |
+
|
1294 |
+
def get_input_embeddings(self) -> nn.Module:
|
1295 |
+
return self.vision_model.embeddings.patch_embedding
|
1296 |
+
|
1297 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
1298 |
+
@replace_return_docstrings(output_type=CLIPVisionModelOutput, config_class=CLIPVisionConfig)
|
1299 |
+
def forward(
|
1300 |
+
self,
|
1301 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
1302 |
+
output_attentions: Optional[bool] = None,
|
1303 |
+
output_hidden_states: Optional[bool] = None,
|
1304 |
+
return_dict: Optional[bool] = None,
|
1305 |
+
) -> Union[Tuple, CLIPVisionModelOutput]:
|
1306 |
+
r"""
|
1307 |
+
Returns:
|
1308 |
+
|
1309 |
+
Examples:
|
1310 |
+
|
1311 |
+
```python
|
1312 |
+
>>> from PIL import Image
|
1313 |
+
>>> import requests
|
1314 |
+
>>> from transformers import AutoProcessor, CLIPVisionModelWithProjection
|
1315 |
+
|
1316 |
+
>>> model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
|
1317 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
1318 |
+
|
1319 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1320 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1321 |
+
|
1322 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
1323 |
+
|
1324 |
+
>>> outputs = model(**inputs)
|
1325 |
+
>>> image_embeds = outputs.image_embeds
|
1326 |
+
```"""
|
1327 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1328 |
+
|
1329 |
+
vision_outputs = self.vision_model(
|
1330 |
+
pixel_values=pixel_values,
|
1331 |
+
output_attentions=output_attentions,
|
1332 |
+
output_hidden_states=output_hidden_states,
|
1333 |
+
return_dict=return_dict,
|
1334 |
+
)
|
1335 |
+
|
1336 |
+
pooled_output = vision_outputs[1] # pooled_output
|
1337 |
+
|
1338 |
+
image_embeds = self.visual_projection(pooled_output)
|
1339 |
+
|
1340 |
+
if not return_dict:
|
1341 |
+
outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:]
|
1342 |
+
return tuple(output for output in outputs if output is not None)
|
1343 |
+
|
1344 |
+
return CLIPVisionModelOutput(
|
1345 |
+
image_embeds=image_embeds,
|
1346 |
+
last_hidden_state=vision_outputs.last_hidden_state,
|
1347 |
+
hidden_states=vision_outputs.hidden_states,
|
1348 |
+
attentions=vision_outputs.attentions,
|
1349 |
+
)
|
detail_encoder/attention_processor.py
ADDED
@@ -0,0 +1,687 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from diffusers.utils.import_utils import is_xformers_available
|
6 |
+
from torchvision import transforms
|
7 |
+
if is_xformers_available():
|
8 |
+
import xformers
|
9 |
+
import xformers.ops
|
10 |
+
else:
|
11 |
+
xformers = None
|
12 |
+
|
13 |
+
class SSRAttnProcessor(nn.Module):
|
14 |
+
r"""
|
15 |
+
Attention processor for SSR-Adapater.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1):
|
19 |
+
super().__init__()
|
20 |
+
self.hidden_size = hidden_size
|
21 |
+
self.cross_attention_dim = cross_attention_dim
|
22 |
+
self.scale = scale
|
23 |
+
# self.to_q_SSR = nn.Linear(hidden_size, hidden_size, bias=False)
|
24 |
+
self.to_k_SSR = nn.Linear(cross_attention_dim, hidden_size, bias=False)
|
25 |
+
self.to_v_SSR = nn.Linear(cross_attention_dim, hidden_size, bias=False)
|
26 |
+
|
27 |
+
def __call__(
|
28 |
+
self,
|
29 |
+
attn,
|
30 |
+
hidden_states,
|
31 |
+
encoder_hidden_states=None,
|
32 |
+
attention_mask=None,
|
33 |
+
temb=None,
|
34 |
+
):
|
35 |
+
residual = hidden_states
|
36 |
+
|
37 |
+
if attn.spatial_norm is not None:
|
38 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
39 |
+
|
40 |
+
input_ndim = hidden_states.ndim
|
41 |
+
|
42 |
+
if input_ndim == 4:
|
43 |
+
batch_size, channel, height, width = hidden_states.shape
|
44 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
45 |
+
|
46 |
+
batch_size, sequence_length, _ = (
|
47 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
48 |
+
)
|
49 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
50 |
+
|
51 |
+
if attn.group_norm is not None:
|
52 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
53 |
+
|
54 |
+
# query = self.to_q_SSR(hidden_states)
|
55 |
+
query = attn.to_q(hidden_states)
|
56 |
+
query = attn.head_to_batch_dim(query)
|
57 |
+
|
58 |
+
if encoder_hidden_states is None:
|
59 |
+
encoder_hidden_states = hidden_states
|
60 |
+
elif attn.norm_cross:
|
61 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
62 |
+
|
63 |
+
_hidden_states = encoder_hidden_states
|
64 |
+
_key = self.to_k_SSR(_hidden_states)
|
65 |
+
_value = self.to_v_SSR(_hidden_states)
|
66 |
+
_key = attn.head_to_batch_dim(_key)
|
67 |
+
_value = attn.head_to_batch_dim(_value)
|
68 |
+
_attention_probs = attn.get_attention_scores(query, _key, None)
|
69 |
+
_hidden_states = torch.bmm(_attention_probs, _value)
|
70 |
+
_hidden_states = attn.batch_to_head_dim(_hidden_states)
|
71 |
+
hidden_states = self.scale * _hidden_states
|
72 |
+
|
73 |
+
# linear proj
|
74 |
+
hidden_states = attn.to_out[0](hidden_states)
|
75 |
+
# dropout
|
76 |
+
hidden_states = attn.to_out[1](hidden_states)
|
77 |
+
|
78 |
+
if input_ndim == 4:
|
79 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
80 |
+
|
81 |
+
if attn.residual_connection:
|
82 |
+
hidden_states = hidden_states + residual
|
83 |
+
|
84 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
85 |
+
|
86 |
+
return hidden_states
|
87 |
+
|
88 |
+
|
89 |
+
class SSRAttnProcessor2_0(torch.nn.Module):
|
90 |
+
r"""
|
91 |
+
Attention processor for SSR-Adapater for PyTorch 2.0.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
98 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
99 |
+
self.hidden_size = hidden_size
|
100 |
+
self.cross_attention_dim = cross_attention_dim
|
101 |
+
self.scale = scale
|
102 |
+
# self.to_q_SSR = nn.Linear(hidden_size, hidden_size, bias=False)
|
103 |
+
self.to_k_SSR = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
104 |
+
self.to_v_SSR = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
105 |
+
|
106 |
+
def __call__(
|
107 |
+
self,
|
108 |
+
attn,
|
109 |
+
hidden_states,
|
110 |
+
encoder_hidden_states=None,
|
111 |
+
attention_mask=None,
|
112 |
+
temb=None,
|
113 |
+
):
|
114 |
+
residual = hidden_states
|
115 |
+
|
116 |
+
if attn.spatial_norm is not None:
|
117 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
118 |
+
|
119 |
+
input_ndim = hidden_states.ndim
|
120 |
+
|
121 |
+
if input_ndim == 4:
|
122 |
+
batch_size, channel, height, width = hidden_states.shape
|
123 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
124 |
+
|
125 |
+
batch_size, sequence_length, _ = (
|
126 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
127 |
+
)
|
128 |
+
|
129 |
+
if attention_mask is not None:
|
130 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
131 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
132 |
+
# (batch, heads, source_length, target_length)
|
133 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
134 |
+
|
135 |
+
if attn.group_norm is not None:
|
136 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
137 |
+
|
138 |
+
# query = self.to_q_SSR(hidden_states)
|
139 |
+
query = attn.to_q(hidden_states)
|
140 |
+
|
141 |
+
if encoder_hidden_states is None:
|
142 |
+
encoder_hidden_states = hidden_states
|
143 |
+
elif attn.norm_cross:
|
144 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
145 |
+
|
146 |
+
# split hidden states
|
147 |
+
_hidden_states = encoder_hidden_states
|
148 |
+
|
149 |
+
_key = self.to_k_SSR(_hidden_states)
|
150 |
+
_value = self.to_v_SSR(_hidden_states)
|
151 |
+
inner_dim = _key.shape[-1]
|
152 |
+
head_dim = inner_dim // attn.heads
|
153 |
+
|
154 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
155 |
+
|
156 |
+
_key = _key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
157 |
+
_value = _value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
158 |
+
|
159 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
160 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
161 |
+
_hidden_states = F.scaled_dot_product_attention(
|
162 |
+
query, _key, _value, attn_mask=None, dropout_p=0.0, is_causal=False
|
163 |
+
)
|
164 |
+
|
165 |
+
_hidden_states = _hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
166 |
+
_hidden_states = _hidden_states.to(query.dtype)
|
167 |
+
|
168 |
+
hidden_states = self.scale * _hidden_states
|
169 |
+
|
170 |
+
# linear proj
|
171 |
+
hidden_states = attn.to_out[0](hidden_states)
|
172 |
+
# dropout
|
173 |
+
hidden_states = attn.to_out[1](hidden_states)
|
174 |
+
|
175 |
+
if input_ndim == 4:
|
176 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
177 |
+
|
178 |
+
if attn.residual_connection:
|
179 |
+
hidden_states = hidden_states + residual
|
180 |
+
|
181 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
182 |
+
|
183 |
+
return hidden_states
|
184 |
+
|
185 |
+
|
186 |
+
class AttnProcessor2_0(torch.nn.Module):
|
187 |
+
r"""
|
188 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
189 |
+
"""
|
190 |
+
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
hidden_size=None,
|
194 |
+
cross_attention_dim=None,
|
195 |
+
):
|
196 |
+
super().__init__()
|
197 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
198 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
199 |
+
|
200 |
+
def __call__(
|
201 |
+
self,
|
202 |
+
attn,
|
203 |
+
hidden_states,
|
204 |
+
encoder_hidden_states=None,
|
205 |
+
attention_mask=None,
|
206 |
+
temb=None,
|
207 |
+
):
|
208 |
+
residual = hidden_states
|
209 |
+
|
210 |
+
if attn.spatial_norm is not None:
|
211 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
212 |
+
|
213 |
+
input_ndim = hidden_states.ndim
|
214 |
+
|
215 |
+
if input_ndim == 4:
|
216 |
+
batch_size, channel, height, width = hidden_states.shape
|
217 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
218 |
+
|
219 |
+
batch_size, sequence_length, _ = (
|
220 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
221 |
+
)
|
222 |
+
|
223 |
+
if attention_mask is not None:
|
224 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
225 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
226 |
+
# (batch, heads, source_length, target_length)
|
227 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
228 |
+
|
229 |
+
if attn.group_norm is not None:
|
230 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
231 |
+
|
232 |
+
query = attn.to_q(hidden_states)
|
233 |
+
|
234 |
+
if encoder_hidden_states is None:
|
235 |
+
encoder_hidden_states = hidden_states
|
236 |
+
elif attn.norm_cross:
|
237 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
238 |
+
|
239 |
+
key = attn.to_k(encoder_hidden_states)
|
240 |
+
value = attn.to_v(encoder_hidden_states)
|
241 |
+
|
242 |
+
inner_dim = key.shape[-1]
|
243 |
+
head_dim = inner_dim // attn.heads
|
244 |
+
|
245 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
246 |
+
|
247 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
248 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
249 |
+
|
250 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
251 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
252 |
+
hidden_states = F.scaled_dot_product_attention(
|
253 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
254 |
+
)
|
255 |
+
|
256 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
257 |
+
hidden_states = hidden_states.to(query.dtype)
|
258 |
+
|
259 |
+
# linear proj
|
260 |
+
hidden_states = attn.to_out[0](hidden_states)
|
261 |
+
# dropout
|
262 |
+
hidden_states = attn.to_out[1](hidden_states)
|
263 |
+
|
264 |
+
if input_ndim == 4:
|
265 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
266 |
+
|
267 |
+
if attn.residual_connection:
|
268 |
+
hidden_states = hidden_states + residual
|
269 |
+
|
270 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
271 |
+
|
272 |
+
return hidden_states
|
273 |
+
|
274 |
+
class AttnProcessor(nn.Module):
|
275 |
+
r"""
|
276 |
+
Default processor for performing attention-related computations.
|
277 |
+
"""
|
278 |
+
def __init__(
|
279 |
+
self,
|
280 |
+
hidden_size=None,
|
281 |
+
cross_attention_dim=None,
|
282 |
+
):
|
283 |
+
super().__init__()
|
284 |
+
|
285 |
+
def __call__(
|
286 |
+
self,
|
287 |
+
attn,
|
288 |
+
hidden_states,
|
289 |
+
encoder_hidden_states=None,
|
290 |
+
attention_mask=None,
|
291 |
+
temb=None,
|
292 |
+
):
|
293 |
+
residual = hidden_states
|
294 |
+
|
295 |
+
if attn.spatial_norm is not None:
|
296 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
297 |
+
|
298 |
+
input_ndim = hidden_states.ndim
|
299 |
+
|
300 |
+
if input_ndim == 4:
|
301 |
+
batch_size, channel, height, width = hidden_states.shape
|
302 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
303 |
+
|
304 |
+
batch_size, sequence_length, _ = (
|
305 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
306 |
+
)
|
307 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
308 |
+
|
309 |
+
if attn.group_norm is not None:
|
310 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
311 |
+
|
312 |
+
query = attn.to_q(hidden_states)
|
313 |
+
|
314 |
+
if encoder_hidden_states is None:
|
315 |
+
encoder_hidden_states = hidden_states
|
316 |
+
elif attn.norm_cross:
|
317 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
318 |
+
|
319 |
+
key = attn.to_k(encoder_hidden_states)
|
320 |
+
value = attn.to_v(encoder_hidden_states)
|
321 |
+
|
322 |
+
query = attn.head_to_batch_dim(query)
|
323 |
+
key = attn.head_to_batch_dim(key)
|
324 |
+
value = attn.head_to_batch_dim(value)
|
325 |
+
|
326 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
327 |
+
hidden_states = torch.bmm(attention_probs, value)
|
328 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
329 |
+
|
330 |
+
# linear proj
|
331 |
+
hidden_states = attn.to_out[0](hidden_states)
|
332 |
+
# dropout
|
333 |
+
hidden_states = attn.to_out[1](hidden_states)
|
334 |
+
|
335 |
+
if input_ndim == 4:
|
336 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
337 |
+
|
338 |
+
if attn.residual_connection:
|
339 |
+
hidden_states = hidden_states + residual
|
340 |
+
|
341 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
342 |
+
|
343 |
+
return hidden_states
|
344 |
+
|
345 |
+
|
346 |
+
class ConvAttnProcessor:
|
347 |
+
def __call__(
|
348 |
+
self,
|
349 |
+
attn,
|
350 |
+
hidden_states,
|
351 |
+
encoder_hidden_states=None,
|
352 |
+
attention_mask=None,
|
353 |
+
):
|
354 |
+
## map to 2D
|
355 |
+
if len(hidden_states.shape) == 4:
|
356 |
+
shape = hidden_states.shape
|
357 |
+
hidden_states = torch.reshape(hidden_states, (shape[0], shape[1], shape[2] * shape[3]))
|
358 |
+
hidden_states = hidden_states.permute(0, 2, 1)
|
359 |
+
if encoder_hidden_states is not None:
|
360 |
+
if len(encoder_hidden_states.shape) == 4:
|
361 |
+
kv_shape = encoder_hidden_states.shape
|
362 |
+
encoder_hidden_states = torch.reshape(
|
363 |
+
encoder_hidden_states, (kv_shape[0], kv_shape[1], kv_shape[2] * kv_shape[3])
|
364 |
+
)
|
365 |
+
encoder_hidden_states = encoder_hidden_states.permute(0, 2, 1)
|
366 |
+
|
367 |
+
# the same to standard attn
|
368 |
+
batch_size, sequence_length, _ = (
|
369 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
370 |
+
)
|
371 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
372 |
+
query = attn.to_q(hidden_states)
|
373 |
+
|
374 |
+
if encoder_hidden_states is None:
|
375 |
+
encoder_hidden_states = hidden_states
|
376 |
+
elif attn.norm_cross:
|
377 |
+
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
|
378 |
+
|
379 |
+
key = attn.to_k(encoder_hidden_states)
|
380 |
+
value = attn.to_v(encoder_hidden_states)
|
381 |
+
|
382 |
+
query = attn.head_to_batch_dim(query)
|
383 |
+
key = attn.head_to_batch_dim(key)
|
384 |
+
value = attn.head_to_batch_dim(value)
|
385 |
+
|
386 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
387 |
+
hidden_states = torch.bmm(attention_probs, value)
|
388 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
389 |
+
|
390 |
+
# linear proj
|
391 |
+
hidden_states = attn.to_out[0](hidden_states)
|
392 |
+
# dropout
|
393 |
+
hidden_states = attn.to_out[1](hidden_states)
|
394 |
+
|
395 |
+
# map back to 4D
|
396 |
+
if len(hidden_states.shape) == 3:
|
397 |
+
hidden_states = hidden_states.permute(0, 2, 1)
|
398 |
+
hidden_states = torch.reshape(hidden_states, (shape[0], shape[1], shape[2], shape[3]))
|
399 |
+
|
400 |
+
return hidden_states
|
401 |
+
|
402 |
+
|
403 |
+
class SSRAttnProcessor_text(nn.Module):
|
404 |
+
r"""
|
405 |
+
Attention processor for SSR-Adapater.
|
406 |
+
"""
|
407 |
+
|
408 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1):
|
409 |
+
super().__init__()
|
410 |
+
self.text_context_len = 77
|
411 |
+
self.hidden_size = hidden_size
|
412 |
+
self.cross_attention_dim = cross_attention_dim
|
413 |
+
self.scale = scale
|
414 |
+
self.to_k_SSR = nn.Linear(cross_attention_dim, hidden_size, bias=False)
|
415 |
+
self.to_v_SSR = nn.Linear(cross_attention_dim, hidden_size, bias=False)
|
416 |
+
|
417 |
+
def __call__(
|
418 |
+
self,
|
419 |
+
attn,
|
420 |
+
hidden_states,
|
421 |
+
encoder_hidden_states=None,
|
422 |
+
attention_mask=None,
|
423 |
+
temb=None,
|
424 |
+
):
|
425 |
+
residual = hidden_states
|
426 |
+
|
427 |
+
if attn.spatial_norm is not None:
|
428 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
429 |
+
|
430 |
+
input_ndim = hidden_states.ndim
|
431 |
+
|
432 |
+
if input_ndim == 4:
|
433 |
+
batch_size, channel, height, width = hidden_states.shape
|
434 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
435 |
+
|
436 |
+
batch_size, sequence_length, _ = (
|
437 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
438 |
+
)
|
439 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
440 |
+
|
441 |
+
if attn.group_norm is not None:
|
442 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
443 |
+
|
444 |
+
query = attn.to_q(hidden_states)
|
445 |
+
query = attn.head_to_batch_dim(query)
|
446 |
+
|
447 |
+
if encoder_hidden_states is None:
|
448 |
+
encoder_hidden_states = hidden_states
|
449 |
+
elif attn.norm_cross:
|
450 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
451 |
+
|
452 |
+
# split hidden states
|
453 |
+
encoder_hidden_states, _hidden_states = encoder_hidden_states[:, :self.text_context_len,
|
454 |
+
:], encoder_hidden_states[:, self.text_context_len:, :]
|
455 |
+
encoder_hidden_states = encoder_hidden_states[:, :, :768]
|
456 |
+
# for text
|
457 |
+
key = attn.to_k(encoder_hidden_states)
|
458 |
+
value = attn.to_v(encoder_hidden_states)
|
459 |
+
|
460 |
+
key = attn.head_to_batch_dim(key)
|
461 |
+
value = attn.head_to_batch_dim(value)
|
462 |
+
|
463 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
464 |
+
hidden_states = torch.bmm(attention_probs, value)
|
465 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
466 |
+
|
467 |
+
# for image
|
468 |
+
_key = self.to_k_SSR(_hidden_states)
|
469 |
+
_value = self.to_v_SSR(_hidden_states)
|
470 |
+
_key = attn.head_to_batch_dim(_key)
|
471 |
+
_value = attn.head_to_batch_dim(_value)
|
472 |
+
_attention_probs = attn.get_attention_scores(query, _key, None)
|
473 |
+
_hidden_states = torch.bmm(_attention_probs, _value)
|
474 |
+
_hidden_states = attn.batch_to_head_dim(_hidden_states)
|
475 |
+
hidden_states = self.scale * _hidden_states + hidden_states
|
476 |
+
|
477 |
+
# linear proj
|
478 |
+
hidden_states = attn.to_out[0](hidden_states)
|
479 |
+
# dropout
|
480 |
+
hidden_states = attn.to_out[1](hidden_states)
|
481 |
+
|
482 |
+
if input_ndim == 4:
|
483 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
484 |
+
|
485 |
+
if attn.residual_connection:
|
486 |
+
hidden_states = hidden_states + residual
|
487 |
+
|
488 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
489 |
+
|
490 |
+
return hidden_states
|
491 |
+
|
492 |
+
|
493 |
+
class SSRAttnProcessor2_0_text(torch.nn.Module):
|
494 |
+
r"""
|
495 |
+
Attention processor for SSR-Adapater for PyTorch 2.0.
|
496 |
+
"""
|
497 |
+
|
498 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
|
499 |
+
super().__init__()
|
500 |
+
|
501 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
502 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
503 |
+
self.text_context_len = 77
|
504 |
+
self.hidden_size = hidden_size
|
505 |
+
self.cross_attention_dim = cross_attention_dim
|
506 |
+
self.scale = scale
|
507 |
+
self.to_k_SSR = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
508 |
+
self.to_v_SSR = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
509 |
+
|
510 |
+
def __call__(
|
511 |
+
self,
|
512 |
+
attn,
|
513 |
+
hidden_states,
|
514 |
+
encoder_hidden_states=None,
|
515 |
+
attention_mask=None,
|
516 |
+
temb=None,
|
517 |
+
):
|
518 |
+
residual = hidden_states
|
519 |
+
|
520 |
+
if attn.spatial_norm is not None:
|
521 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
522 |
+
|
523 |
+
input_ndim = hidden_states.ndim
|
524 |
+
|
525 |
+
if input_ndim == 4:
|
526 |
+
batch_size, channel, height, width = hidden_states.shape
|
527 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
528 |
+
|
529 |
+
batch_size, sequence_length, _ = (
|
530 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
531 |
+
)
|
532 |
+
|
533 |
+
if attention_mask is not None:
|
534 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
535 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
536 |
+
# (batch, heads, source_length, target_length)
|
537 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
538 |
+
|
539 |
+
if attn.group_norm is not None:
|
540 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
541 |
+
|
542 |
+
query = attn.to_q(hidden_states)
|
543 |
+
|
544 |
+
if encoder_hidden_states is None:
|
545 |
+
encoder_hidden_states = hidden_states
|
546 |
+
elif attn.norm_cross:
|
547 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
548 |
+
|
549 |
+
# split hidden states
|
550 |
+
encoder_hidden_states, _hidden_states = encoder_hidden_states[:, :self.text_context_len,
|
551 |
+
:], encoder_hidden_states[:, self.text_context_len:, :]
|
552 |
+
|
553 |
+
encoder_hidden_states = encoder_hidden_states[:, :, :768]
|
554 |
+
# for text
|
555 |
+
key = attn.to_k(encoder_hidden_states)
|
556 |
+
value = attn.to_v(encoder_hidden_states)
|
557 |
+
inner_dim = key.shape[-1]
|
558 |
+
head_dim = inner_dim // attn.heads
|
559 |
+
|
560 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
561 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
562 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
563 |
+
|
564 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
565 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
566 |
+
hidden_states = F.scaled_dot_product_attention(
|
567 |
+
query, key, value, attn_mask=attention_mask, dropout_p = 0.0, is_causal = False
|
568 |
+
)
|
569 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
570 |
+
hidden_states = hidden_states.to(query.dtype)
|
571 |
+
|
572 |
+
# for image
|
573 |
+
_key = self.to_k_SSR(_hidden_states)
|
574 |
+
_value = self.to_v_SSR(_hidden_states)
|
575 |
+
inner_dim = _key.shape[-1]
|
576 |
+
head_dim = inner_dim // attn.heads
|
577 |
+
|
578 |
+
_key = _key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
579 |
+
_value = _value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
580 |
+
|
581 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
582 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
583 |
+
_hidden_states = F.scaled_dot_product_attention(
|
584 |
+
query, _key, _value, attn_mask=None, dropout_p=0.0, is_causal=False
|
585 |
+
)
|
586 |
+
|
587 |
+
_hidden_states = _hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
588 |
+
_hidden_states = _hidden_states.to(query.dtype)
|
589 |
+
|
590 |
+
hidden_states = self.scale * _hidden_states + hidden_states
|
591 |
+
|
592 |
+
# linear proj
|
593 |
+
hidden_states = attn.to_out[0](hidden_states)
|
594 |
+
# dropout
|
595 |
+
hidden_states = attn.to_out[1](hidden_states)
|
596 |
+
|
597 |
+
if input_ndim == 4:
|
598 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
599 |
+
|
600 |
+
if attn.residual_connection:
|
601 |
+
hidden_states = hidden_states + residual
|
602 |
+
|
603 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
604 |
+
|
605 |
+
return hidden_states
|
606 |
+
|
607 |
+
|
608 |
+
class SSRAttnProcessor_visual(nn.Module):
|
609 |
+
r"""
|
610 |
+
Attention processor for attn visualization.
|
611 |
+
"""
|
612 |
+
|
613 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1, attnstore=None, place_in_unet=None):
|
614 |
+
super().__init__()
|
615 |
+
self.hidden_size = hidden_size
|
616 |
+
self.cross_attention_dim = cross_attention_dim
|
617 |
+
self.scale = scale
|
618 |
+
self.to_k_SSR = nn.Linear(cross_attention_dim, hidden_size, bias=False)
|
619 |
+
self.to_v_SSR = nn.Linear(cross_attention_dim, hidden_size, bias=False)
|
620 |
+
self.attnstore = attnstore
|
621 |
+
self.place_in_unet = place_in_unet
|
622 |
+
|
623 |
+
def __call__(
|
624 |
+
self,
|
625 |
+
attn,
|
626 |
+
hidden_states,
|
627 |
+
encoder_hidden_states=None,
|
628 |
+
attention_mask=None,
|
629 |
+
temb=None,
|
630 |
+
):
|
631 |
+
residual = hidden_states
|
632 |
+
|
633 |
+
if attn.spatial_norm is not None:
|
634 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
635 |
+
|
636 |
+
input_ndim = hidden_states.ndim
|
637 |
+
|
638 |
+
if input_ndim == 4:
|
639 |
+
batch_size, channel, height, width = hidden_states.shape
|
640 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
641 |
+
|
642 |
+
batch_size, sequence_length, _ = (
|
643 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
644 |
+
)
|
645 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
646 |
+
|
647 |
+
if attn.group_norm is not None:
|
648 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
649 |
+
|
650 |
+
# query = self.to_q_SSR(hidden_states)
|
651 |
+
query = attn.to_q(hidden_states)
|
652 |
+
query = attn.head_to_batch_dim(query)
|
653 |
+
|
654 |
+
if encoder_hidden_states is None:
|
655 |
+
encoder_hidden_states = hidden_states
|
656 |
+
elif attn.norm_cross:
|
657 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
658 |
+
|
659 |
+
_hidden_states = encoder_hidden_states
|
660 |
+
_key = self.to_k_SSR(_hidden_states)
|
661 |
+
_value = self.to_v_SSR(_hidden_states)
|
662 |
+
_key = attn.head_to_batch_dim(_key)
|
663 |
+
_value = attn.head_to_batch_dim(_value)
|
664 |
+
_attention_probs = attn.get_attention_scores(query, _key, None)
|
665 |
+
|
666 |
+
# store attention maps
|
667 |
+
is_cross = encoder_hidden_states is not None
|
668 |
+
self.attnstore(_attention_probs, is_cross, self.place_in_unet)
|
669 |
+
|
670 |
+
_hidden_states = torch.bmm(_attention_probs, _value)
|
671 |
+
_hidden_states = attn.batch_to_head_dim(_hidden_states)
|
672 |
+
hidden_states = self.scale * _hidden_states
|
673 |
+
|
674 |
+
# linear proj
|
675 |
+
hidden_states = attn.to_out[0](hidden_states)
|
676 |
+
# dropout
|
677 |
+
hidden_states = attn.to_out[1](hidden_states)
|
678 |
+
|
679 |
+
if input_ndim == 4:
|
680 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
681 |
+
|
682 |
+
if attn.residual_connection:
|
683 |
+
hidden_states = hidden_states + residual
|
684 |
+
|
685 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
686 |
+
|
687 |
+
return hidden_states
|
detail_encoder/encoder_plus.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import torch
|
3 |
+
from torchvision import transforms
|
4 |
+
from transformers import CLIPImageProcessor
|
5 |
+
from transformers import CLIPVisionModel as OriginalCLIPVisionModel
|
6 |
+
from ._clip import CLIPVisionModel
|
7 |
+
from PIL import Image
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.nn as nn
|
10 |
+
import os
|
11 |
+
|
12 |
+
def is_torch2_available():
|
13 |
+
return hasattr(F, "scaled_dot_product_attention")
|
14 |
+
if is_torch2_available():
|
15 |
+
from .attention_processor import SSRAttnProcessor2_0 as SSRAttnProcessor, AttnProcessor2_0 as AttnProcessor
|
16 |
+
else:
|
17 |
+
from .attention_processor import SSRAttnProcessor, AttnProcessor
|
18 |
+
from .resampler import Resampler
|
19 |
+
|
20 |
+
class detail_encoder(torch.nn.Module):
|
21 |
+
"""from SSR-encoder"""
|
22 |
+
def __init__(self, unet, image_encoder_path, device="cuda", dtype=torch.float32):
|
23 |
+
super().__init__()
|
24 |
+
self.device = device
|
25 |
+
self.dtype = dtype
|
26 |
+
|
27 |
+
# load image encoder
|
28 |
+
clip_encoder = OriginalCLIPVisionModel.from_pretrained(image_encoder_path)
|
29 |
+
self.image_encoder = CLIPVisionModel(clip_encoder.config)
|
30 |
+
state_dict = clip_encoder.state_dict()
|
31 |
+
self.image_encoder.load_state_dict(state_dict, strict=False)
|
32 |
+
self.image_encoder.to(self.device, self.dtype)
|
33 |
+
del clip_encoder
|
34 |
+
self.clip_image_processor = CLIPImageProcessor()
|
35 |
+
|
36 |
+
# load SSR layers
|
37 |
+
attn_procs = {}
|
38 |
+
for name in unet.attn_processors.keys():
|
39 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
40 |
+
if name.startswith("mid_block"):
|
41 |
+
hidden_size = unet.config.block_out_channels[-1]
|
42 |
+
elif name.startswith("up_blocks"):
|
43 |
+
block_id = int(name[len("up_blocks.")])
|
44 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
45 |
+
elif name.startswith("down_blocks"):
|
46 |
+
block_id = int(name[len("down_blocks.")])
|
47 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
48 |
+
if cross_attention_dim is None:
|
49 |
+
attn_procs[name] = AttnProcessor()
|
50 |
+
else:
|
51 |
+
attn_procs[name] = SSRAttnProcessor(hidden_size=hidden_size, cross_attention_dim=1024, scale=1).to(self.device, dtype=self.dtype)
|
52 |
+
unet.set_attn_processor(attn_procs)
|
53 |
+
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
54 |
+
self.SSR_layers = adapter_modules
|
55 |
+
self.SSR_layers.to(self.device, dtype=self.dtype)
|
56 |
+
self.resampler = self.init_proj()
|
57 |
+
|
58 |
+
def init_proj(self):
|
59 |
+
resampler = Resampler().to(self.device, dtype=self.dtype)
|
60 |
+
return resampler
|
61 |
+
|
62 |
+
def forward(self, img):
|
63 |
+
image_embeds = self.image_encoder(img, output_hidden_states=True)['hidden_states'][2::2]
|
64 |
+
image_embeds = torch.cat(image_embeds, dim=1)
|
65 |
+
image_embeds = self.resampler(image_embeds)
|
66 |
+
return image_embeds
|
67 |
+
|
68 |
+
@torch.inference_mode()
|
69 |
+
def get_image_embeds(self, pil_image):
|
70 |
+
if isinstance(pil_image, Image.Image):
|
71 |
+
pil_image = [pil_image]
|
72 |
+
clip_image = []
|
73 |
+
for pil in pil_image:
|
74 |
+
tensor_image = self.clip_image_processor(images=pil, return_tensors="pt").pixel_values.to(self.device, dtype=self.dtype)
|
75 |
+
clip_image.append(tensor_image)
|
76 |
+
clip_image = torch.cat(clip_image, dim=0)
|
77 |
+
|
78 |
+
# cond
|
79 |
+
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True)['hidden_states'][2::2] # 1 257*12 1024
|
80 |
+
clip_image_embeds = torch.cat(clip_image_embeds, dim=1)
|
81 |
+
uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True)['hidden_states'][2::2]
|
82 |
+
uncond_clip_image_embeds = torch.cat(uncond_clip_image_embeds, dim=1)
|
83 |
+
clip_image_embeds = self.resampler(clip_image_embeds)
|
84 |
+
uncond_clip_image_embeds = self.resampler(uncond_clip_image_embeds)
|
85 |
+
return clip_image_embeds, uncond_clip_image_embeds
|
86 |
+
|
87 |
+
def generate(
|
88 |
+
self,
|
89 |
+
id_image,
|
90 |
+
makeup_image,
|
91 |
+
seed=None,
|
92 |
+
guidance_scale=2,
|
93 |
+
num_inference_steps=30,
|
94 |
+
pipe=None,
|
95 |
+
**kwargs,
|
96 |
+
):
|
97 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(makeup_image)
|
98 |
+
|
99 |
+
prompt_embeds = image_prompt_embeds
|
100 |
+
negative_prompt_embeds = uncond_image_prompt_embeds
|
101 |
+
|
102 |
+
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
103 |
+
image = pipe(
|
104 |
+
image=id_image,
|
105 |
+
prompt_embeds=prompt_embeds,
|
106 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
107 |
+
guidance_scale=guidance_scale,
|
108 |
+
num_inference_steps=num_inference_steps,
|
109 |
+
generator=generator,
|
110 |
+
**kwargs,
|
111 |
+
).images[0]
|
112 |
+
|
113 |
+
return image
|
detail_encoder/resampler.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn, einsum
|
5 |
+
from inspect import isfunction
|
6 |
+
|
7 |
+
|
8 |
+
def exists(val):
|
9 |
+
return val is not None
|
10 |
+
|
11 |
+
def uniq(arr):
|
12 |
+
return{el: True for el in arr}.keys()
|
13 |
+
|
14 |
+
|
15 |
+
def default(val, d):
|
16 |
+
if exists(val):
|
17 |
+
return val
|
18 |
+
return d() if isfunction(d) else d
|
19 |
+
|
20 |
+
|
21 |
+
def max_neg_value(t):
|
22 |
+
return -torch.finfo(t.dtype).max
|
23 |
+
|
24 |
+
|
25 |
+
def init_(tensor):
|
26 |
+
dim = tensor.shape[-1]
|
27 |
+
std = 1 / math.sqrt(dim)
|
28 |
+
tensor.uniform_(-std, std)
|
29 |
+
return tensor
|
30 |
+
|
31 |
+
|
32 |
+
# feedforward
|
33 |
+
class GEGLU(nn.Module):
|
34 |
+
def __init__(self, dim_in, dim_out):
|
35 |
+
super().__init__()
|
36 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
40 |
+
return x * F.gelu(gate)
|
41 |
+
|
42 |
+
|
43 |
+
class FeedForward(nn.Module):
|
44 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=True, dropout=0.):
|
45 |
+
super().__init__()
|
46 |
+
inner_dim = int(dim * mult)
|
47 |
+
dim_out = default(dim_out, dim)
|
48 |
+
project_in = nn.Sequential(
|
49 |
+
nn.Linear(dim, inner_dim),
|
50 |
+
nn.GELU()
|
51 |
+
) if not glu else GEGLU(dim, inner_dim)
|
52 |
+
|
53 |
+
self.net = nn.Sequential(
|
54 |
+
project_in,
|
55 |
+
nn.Dropout(dropout),
|
56 |
+
nn.Linear(inner_dim, dim_out)
|
57 |
+
)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
return self.net(x)
|
61 |
+
|
62 |
+
|
63 |
+
class SelfAttention(nn.Module):
|
64 |
+
def __init__(self, query_dim, heads=8, dim_head=64, dropout=0.):
|
65 |
+
super().__init__()
|
66 |
+
inner_dim = dim_head * heads
|
67 |
+
self.scale = dim_head ** -0.5
|
68 |
+
self.heads = heads
|
69 |
+
|
70 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
71 |
+
self.to_k = nn.Linear(query_dim, inner_dim, bias=False)
|
72 |
+
self.to_v = nn.Linear(query_dim, inner_dim, bias=False)
|
73 |
+
|
74 |
+
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) )
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
q = self.to_q(x) # B*N*(H*C)
|
78 |
+
k = self.to_k(x) # B*N*(H*C)
|
79 |
+
v = self.to_v(x) # B*N*(H*C)
|
80 |
+
|
81 |
+
B, N, HC = q.shape
|
82 |
+
H = self.heads
|
83 |
+
C = HC // H
|
84 |
+
|
85 |
+
q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
|
86 |
+
k = k.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
|
87 |
+
v = v.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
|
88 |
+
|
89 |
+
sim = torch.einsum('b i c, b j c -> b i j', q, k) * self.scale # (B*H)*N*N
|
90 |
+
attn = sim.softmax(dim=-1) # (B*H)*N*N
|
91 |
+
|
92 |
+
out = torch.einsum('b i j, b j c -> b i c', attn, v) # (B*H)*N*C
|
93 |
+
out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) # B*N*(H*C)
|
94 |
+
|
95 |
+
return self.to_out(out)
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
class Resampler(nn.Module):
|
100 |
+
def __init__(self, query_dim=1024, n_heads=8, d_head=64):
|
101 |
+
super().__init__()
|
102 |
+
|
103 |
+
self.attn = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
|
104 |
+
self.ff = FeedForward(query_dim, glu=True)
|
105 |
+
|
106 |
+
self.norm1 = nn.LayerNorm(query_dim)
|
107 |
+
self.norm2 = nn.LayerNorm(query_dim)
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
x = x + self.attn(self.norm1(x))
|
111 |
+
x = x + self.ff(self.norm2(x))
|
112 |
+
return x
|
diffusers/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
diffusers/__init__.py
ADDED
@@ -0,0 +1,734 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "0.23.1"
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING
|
4 |
+
|
5 |
+
from .utils import (
|
6 |
+
DIFFUSERS_SLOW_IMPORT,
|
7 |
+
OptionalDependencyNotAvailable,
|
8 |
+
_LazyModule,
|
9 |
+
is_flax_available,
|
10 |
+
is_k_diffusion_available,
|
11 |
+
is_librosa_available,
|
12 |
+
is_note_seq_available,
|
13 |
+
is_onnx_available,
|
14 |
+
is_scipy_available,
|
15 |
+
is_torch_available,
|
16 |
+
is_torchsde_available,
|
17 |
+
is_transformers_available,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
# Lazy Import based on
|
22 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/__init__.py
|
23 |
+
|
24 |
+
# When adding a new object to this init, please add it to `_import_structure`. The `_import_structure` is a dictionary submodule to list of object names,
|
25 |
+
# and is used to defer the actual importing for when the objects are requested.
|
26 |
+
# This way `import diffusers` provides the names in the namespace without actually importing anything (and especially none of the backends).
|
27 |
+
|
28 |
+
_import_structure = {
|
29 |
+
"configuration_utils": ["ConfigMixin"],
|
30 |
+
"models": [],
|
31 |
+
"pipelines": [],
|
32 |
+
"schedulers": [],
|
33 |
+
"utils": [
|
34 |
+
"OptionalDependencyNotAvailable",
|
35 |
+
"is_flax_available",
|
36 |
+
"is_inflect_available",
|
37 |
+
"is_invisible_watermark_available",
|
38 |
+
"is_k_diffusion_available",
|
39 |
+
"is_k_diffusion_version",
|
40 |
+
"is_librosa_available",
|
41 |
+
"is_note_seq_available",
|
42 |
+
"is_onnx_available",
|
43 |
+
"is_scipy_available",
|
44 |
+
"is_torch_available",
|
45 |
+
"is_torchsde_available",
|
46 |
+
"is_transformers_available",
|
47 |
+
"is_transformers_version",
|
48 |
+
"is_unidecode_available",
|
49 |
+
"logging",
|
50 |
+
],
|
51 |
+
}
|
52 |
+
|
53 |
+
try:
|
54 |
+
if not is_onnx_available():
|
55 |
+
raise OptionalDependencyNotAvailable()
|
56 |
+
except OptionalDependencyNotAvailable:
|
57 |
+
from .utils import dummy_onnx_objects # noqa F403
|
58 |
+
|
59 |
+
_import_structure["utils.dummy_onnx_objects"] = [
|
60 |
+
name for name in dir(dummy_onnx_objects) if not name.startswith("_")
|
61 |
+
]
|
62 |
+
|
63 |
+
else:
|
64 |
+
_import_structure["pipelines"].extend(["OnnxRuntimeModel"])
|
65 |
+
|
66 |
+
try:
|
67 |
+
if not is_torch_available():
|
68 |
+
raise OptionalDependencyNotAvailable()
|
69 |
+
except OptionalDependencyNotAvailable:
|
70 |
+
from .utils import dummy_pt_objects # noqa F403
|
71 |
+
|
72 |
+
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
|
73 |
+
|
74 |
+
else:
|
75 |
+
_import_structure["models"].extend(
|
76 |
+
[
|
77 |
+
"AsymmetricAutoencoderKL",
|
78 |
+
"AutoencoderKL",
|
79 |
+
"AutoencoderTiny",
|
80 |
+
"ConsistencyDecoderVAE",
|
81 |
+
"ControlNetModel",
|
82 |
+
"ModelMixin",
|
83 |
+
"MotionAdapter",
|
84 |
+
"MultiAdapter",
|
85 |
+
"PriorTransformer",
|
86 |
+
"T2IAdapter",
|
87 |
+
"T5FilmDecoder",
|
88 |
+
"Transformer2DModel",
|
89 |
+
"UNet1DModel",
|
90 |
+
"UNet2DConditionModel",
|
91 |
+
"UNet2DModel",
|
92 |
+
"UNet3DConditionModel",
|
93 |
+
"UNetMotionModel",
|
94 |
+
"VQModel",
|
95 |
+
]
|
96 |
+
)
|
97 |
+
_import_structure["optimization"] = [
|
98 |
+
"get_constant_schedule",
|
99 |
+
"get_constant_schedule_with_warmup",
|
100 |
+
"get_cosine_schedule_with_warmup",
|
101 |
+
"get_cosine_with_hard_restarts_schedule_with_warmup",
|
102 |
+
"get_linear_schedule_with_warmup",
|
103 |
+
"get_polynomial_decay_schedule_with_warmup",
|
104 |
+
"get_scheduler",
|
105 |
+
]
|
106 |
+
|
107 |
+
_import_structure["pipelines"].extend(
|
108 |
+
[
|
109 |
+
"AudioPipelineOutput",
|
110 |
+
"AutoPipelineForImage2Image",
|
111 |
+
"AutoPipelineForInpainting",
|
112 |
+
"AutoPipelineForText2Image",
|
113 |
+
"ConsistencyModelPipeline",
|
114 |
+
"DanceDiffusionPipeline",
|
115 |
+
"DDIMPipeline",
|
116 |
+
"DDPMPipeline",
|
117 |
+
"DiffusionPipeline",
|
118 |
+
"DiTPipeline",
|
119 |
+
"ImagePipelineOutput",
|
120 |
+
"KarrasVePipeline",
|
121 |
+
"LDMPipeline",
|
122 |
+
"LDMSuperResolutionPipeline",
|
123 |
+
"PNDMPipeline",
|
124 |
+
"RePaintPipeline",
|
125 |
+
"ScoreSdeVePipeline",
|
126 |
+
]
|
127 |
+
)
|
128 |
+
_import_structure["schedulers"].extend(
|
129 |
+
[
|
130 |
+
"CMStochasticIterativeScheduler",
|
131 |
+
"DDIMInverseScheduler",
|
132 |
+
"DDIMParallelScheduler",
|
133 |
+
"DDIMScheduler",
|
134 |
+
"DDPMParallelScheduler",
|
135 |
+
"DDPMScheduler",
|
136 |
+
"DDPMWuerstchenScheduler",
|
137 |
+
"DEISMultistepScheduler",
|
138 |
+
"DPMSolverMultistepInverseScheduler",
|
139 |
+
"DPMSolverMultistepScheduler",
|
140 |
+
"DPMSolverSinglestepScheduler",
|
141 |
+
"EulerAncestralDiscreteScheduler",
|
142 |
+
"EulerDiscreteScheduler",
|
143 |
+
"HeunDiscreteScheduler",
|
144 |
+
"IPNDMScheduler",
|
145 |
+
"KarrasVeScheduler",
|
146 |
+
"KDPM2AncestralDiscreteScheduler",
|
147 |
+
"KDPM2DiscreteScheduler",
|
148 |
+
"LCMScheduler",
|
149 |
+
"PNDMScheduler",
|
150 |
+
"RePaintScheduler",
|
151 |
+
"SchedulerMixin",
|
152 |
+
"ScoreSdeVeScheduler",
|
153 |
+
"UnCLIPScheduler",
|
154 |
+
"UniPCMultistepScheduler",
|
155 |
+
"VQDiffusionScheduler",
|
156 |
+
]
|
157 |
+
)
|
158 |
+
_import_structure["training_utils"] = ["EMAModel"]
|
159 |
+
|
160 |
+
try:
|
161 |
+
if not (is_torch_available() and is_scipy_available()):
|
162 |
+
raise OptionalDependencyNotAvailable()
|
163 |
+
except OptionalDependencyNotAvailable:
|
164 |
+
from .utils import dummy_torch_and_scipy_objects # noqa F403
|
165 |
+
|
166 |
+
_import_structure["utils.dummy_torch_and_scipy_objects"] = [
|
167 |
+
name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_")
|
168 |
+
]
|
169 |
+
|
170 |
+
else:
|
171 |
+
_import_structure["schedulers"].extend(["LMSDiscreteScheduler"])
|
172 |
+
|
173 |
+
try:
|
174 |
+
if not (is_torch_available() and is_torchsde_available()):
|
175 |
+
raise OptionalDependencyNotAvailable()
|
176 |
+
except OptionalDependencyNotAvailable:
|
177 |
+
from .utils import dummy_torch_and_torchsde_objects # noqa F403
|
178 |
+
|
179 |
+
_import_structure["utils.dummy_torch_and_torchsde_objects"] = [
|
180 |
+
name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_")
|
181 |
+
]
|
182 |
+
|
183 |
+
else:
|
184 |
+
_import_structure["schedulers"].extend(["DPMSolverSDEScheduler"])
|
185 |
+
|
186 |
+
try:
|
187 |
+
if not (is_torch_available() and is_transformers_available()):
|
188 |
+
raise OptionalDependencyNotAvailable()
|
189 |
+
except OptionalDependencyNotAvailable:
|
190 |
+
from .utils import dummy_torch_and_transformers_objects # noqa F403
|
191 |
+
|
192 |
+
_import_structure["utils.dummy_torch_and_transformers_objects"] = [
|
193 |
+
name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_")
|
194 |
+
]
|
195 |
+
|
196 |
+
else:
|
197 |
+
_import_structure["pipelines"].extend(
|
198 |
+
[
|
199 |
+
"AltDiffusionImg2ImgPipeline",
|
200 |
+
"AltDiffusionPipeline",
|
201 |
+
"AnimateDiffPipeline",
|
202 |
+
"AudioLDM2Pipeline",
|
203 |
+
"AudioLDM2ProjectionModel",
|
204 |
+
"AudioLDM2UNet2DConditionModel",
|
205 |
+
"AudioLDMPipeline",
|
206 |
+
"BlipDiffusionControlNetPipeline",
|
207 |
+
"BlipDiffusionPipeline",
|
208 |
+
"CLIPImageProjection",
|
209 |
+
"CycleDiffusionPipeline",
|
210 |
+
"IFImg2ImgPipeline",
|
211 |
+
"IFImg2ImgSuperResolutionPipeline",
|
212 |
+
"IFInpaintingPipeline",
|
213 |
+
"IFInpaintingSuperResolutionPipeline",
|
214 |
+
"IFPipeline",
|
215 |
+
"IFSuperResolutionPipeline",
|
216 |
+
"ImageTextPipelineOutput",
|
217 |
+
"KandinskyCombinedPipeline",
|
218 |
+
"KandinskyImg2ImgCombinedPipeline",
|
219 |
+
"KandinskyImg2ImgPipeline",
|
220 |
+
"KandinskyInpaintCombinedPipeline",
|
221 |
+
"KandinskyInpaintPipeline",
|
222 |
+
"KandinskyPipeline",
|
223 |
+
"KandinskyPriorPipeline",
|
224 |
+
"KandinskyV22CombinedPipeline",
|
225 |
+
"KandinskyV22ControlnetImg2ImgPipeline",
|
226 |
+
"KandinskyV22ControlnetPipeline",
|
227 |
+
"KandinskyV22Img2ImgCombinedPipeline",
|
228 |
+
"KandinskyV22Img2ImgPipeline",
|
229 |
+
"KandinskyV22InpaintCombinedPipeline",
|
230 |
+
"KandinskyV22InpaintPipeline",
|
231 |
+
"KandinskyV22Pipeline",
|
232 |
+
"KandinskyV22PriorEmb2EmbPipeline",
|
233 |
+
"KandinskyV22PriorPipeline",
|
234 |
+
"LatentConsistencyModelImg2ImgPipeline",
|
235 |
+
"LatentConsistencyModelPipeline",
|
236 |
+
"LDMTextToImagePipeline",
|
237 |
+
"MusicLDMPipeline",
|
238 |
+
"PaintByExamplePipeline",
|
239 |
+
"PixArtAlphaPipeline",
|
240 |
+
"SemanticStableDiffusionPipeline",
|
241 |
+
"ShapEImg2ImgPipeline",
|
242 |
+
"ShapEPipeline",
|
243 |
+
"StableDiffusionAdapterPipeline",
|
244 |
+
"StableDiffusionAttendAndExcitePipeline",
|
245 |
+
"StableDiffusionControlNetImg2ImgPipeline",
|
246 |
+
"StableDiffusionControlNetInpaintPipeline",
|
247 |
+
"StableDiffusionControlNetPipeline",
|
248 |
+
"StableDiffusionDepth2ImgPipeline",
|
249 |
+
"StableDiffusionDiffEditPipeline",
|
250 |
+
"StableDiffusionGLIGENPipeline",
|
251 |
+
"StableDiffusionGLIGENTextImagePipeline",
|
252 |
+
"StableDiffusionImageVariationPipeline",
|
253 |
+
"StableDiffusionImg2ImgPipeline",
|
254 |
+
"StableDiffusionInpaintPipeline",
|
255 |
+
"StableDiffusionInpaintPipelineLegacy",
|
256 |
+
"StableDiffusionInstructPix2PixPipeline",
|
257 |
+
"StableDiffusionLatentUpscalePipeline",
|
258 |
+
"StableDiffusionLDM3DPipeline",
|
259 |
+
"StableDiffusionModelEditingPipeline",
|
260 |
+
"StableDiffusionPanoramaPipeline",
|
261 |
+
"StableDiffusionParadigmsPipeline",
|
262 |
+
"StableDiffusionPipeline",
|
263 |
+
"StableDiffusionPipelineSafe",
|
264 |
+
"StableDiffusionPix2PixZeroPipeline",
|
265 |
+
"StableDiffusionSAGPipeline",
|
266 |
+
"StableDiffusionUpscalePipeline",
|
267 |
+
"StableDiffusionXLAdapterPipeline",
|
268 |
+
"StableDiffusionXLControlNetImg2ImgPipeline",
|
269 |
+
"StableDiffusionXLControlNetInpaintPipeline",
|
270 |
+
"StableDiffusionXLControlNetPipeline",
|
271 |
+
"StableDiffusionXLImg2ImgPipeline",
|
272 |
+
"StableDiffusionXLInpaintPipeline",
|
273 |
+
"StableDiffusionXLInstructPix2PixPipeline",
|
274 |
+
"StableDiffusionXLPipeline",
|
275 |
+
"StableUnCLIPImg2ImgPipeline",
|
276 |
+
"StableUnCLIPPipeline",
|
277 |
+
"TextToVideoSDPipeline",
|
278 |
+
"TextToVideoZeroPipeline",
|
279 |
+
"UnCLIPImageVariationPipeline",
|
280 |
+
"UnCLIPPipeline",
|
281 |
+
"UniDiffuserModel",
|
282 |
+
"UniDiffuserPipeline",
|
283 |
+
"UniDiffuserTextDecoder",
|
284 |
+
"VersatileDiffusionDualGuidedPipeline",
|
285 |
+
"VersatileDiffusionImageVariationPipeline",
|
286 |
+
"VersatileDiffusionPipeline",
|
287 |
+
"VersatileDiffusionTextToImagePipeline",
|
288 |
+
"VideoToVideoSDPipeline",
|
289 |
+
"VQDiffusionPipeline",
|
290 |
+
"WuerstchenCombinedPipeline",
|
291 |
+
"WuerstchenDecoderPipeline",
|
292 |
+
"WuerstchenPriorPipeline",
|
293 |
+
]
|
294 |
+
)
|
295 |
+
|
296 |
+
try:
|
297 |
+
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
298 |
+
raise OptionalDependencyNotAvailable()
|
299 |
+
except OptionalDependencyNotAvailable:
|
300 |
+
from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
|
301 |
+
|
302 |
+
_import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [
|
303 |
+
name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_")
|
304 |
+
]
|
305 |
+
|
306 |
+
else:
|
307 |
+
_import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline"])
|
308 |
+
|
309 |
+
try:
|
310 |
+
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
311 |
+
raise OptionalDependencyNotAvailable()
|
312 |
+
except OptionalDependencyNotAvailable:
|
313 |
+
from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
|
314 |
+
|
315 |
+
_import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [
|
316 |
+
name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_")
|
317 |
+
]
|
318 |
+
|
319 |
+
else:
|
320 |
+
_import_structure["pipelines"].extend(
|
321 |
+
[
|
322 |
+
"OnnxStableDiffusionImg2ImgPipeline",
|
323 |
+
"OnnxStableDiffusionInpaintPipeline",
|
324 |
+
"OnnxStableDiffusionInpaintPipelineLegacy",
|
325 |
+
"OnnxStableDiffusionPipeline",
|
326 |
+
"OnnxStableDiffusionUpscalePipeline",
|
327 |
+
"StableDiffusionOnnxPipeline",
|
328 |
+
]
|
329 |
+
)
|
330 |
+
|
331 |
+
try:
|
332 |
+
if not (is_torch_available() and is_librosa_available()):
|
333 |
+
raise OptionalDependencyNotAvailable()
|
334 |
+
except OptionalDependencyNotAvailable:
|
335 |
+
from .utils import dummy_torch_and_librosa_objects # noqa F403
|
336 |
+
|
337 |
+
_import_structure["utils.dummy_torch_and_librosa_objects"] = [
|
338 |
+
name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_")
|
339 |
+
]
|
340 |
+
|
341 |
+
else:
|
342 |
+
_import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"])
|
343 |
+
|
344 |
+
try:
|
345 |
+
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
346 |
+
raise OptionalDependencyNotAvailable()
|
347 |
+
except OptionalDependencyNotAvailable:
|
348 |
+
from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
|
349 |
+
|
350 |
+
_import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [
|
351 |
+
name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_")
|
352 |
+
]
|
353 |
+
|
354 |
+
|
355 |
+
else:
|
356 |
+
_import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"])
|
357 |
+
|
358 |
+
try:
|
359 |
+
if not is_flax_available():
|
360 |
+
raise OptionalDependencyNotAvailable()
|
361 |
+
except OptionalDependencyNotAvailable:
|
362 |
+
from .utils import dummy_flax_objects # noqa F403
|
363 |
+
|
364 |
+
_import_structure["utils.dummy_flax_objects"] = [
|
365 |
+
name for name in dir(dummy_flax_objects) if not name.startswith("_")
|
366 |
+
]
|
367 |
+
|
368 |
+
|
369 |
+
else:
|
370 |
+
_import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
|
371 |
+
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
|
372 |
+
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
373 |
+
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
|
374 |
+
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
|
375 |
+
_import_structure["schedulers"].extend(
|
376 |
+
[
|
377 |
+
"FlaxDDIMScheduler",
|
378 |
+
"FlaxDDPMScheduler",
|
379 |
+
"FlaxDPMSolverMultistepScheduler",
|
380 |
+
"FlaxEulerDiscreteScheduler",
|
381 |
+
"FlaxKarrasVeScheduler",
|
382 |
+
"FlaxLMSDiscreteScheduler",
|
383 |
+
"FlaxPNDMScheduler",
|
384 |
+
"FlaxSchedulerMixin",
|
385 |
+
"FlaxScoreSdeVeScheduler",
|
386 |
+
]
|
387 |
+
)
|
388 |
+
|
389 |
+
|
390 |
+
try:
|
391 |
+
if not (is_flax_available() and is_transformers_available()):
|
392 |
+
raise OptionalDependencyNotAvailable()
|
393 |
+
except OptionalDependencyNotAvailable:
|
394 |
+
from .utils import dummy_flax_and_transformers_objects # noqa F403
|
395 |
+
|
396 |
+
_import_structure["utils.dummy_flax_and_transformers_objects"] = [
|
397 |
+
name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_")
|
398 |
+
]
|
399 |
+
|
400 |
+
|
401 |
+
else:
|
402 |
+
_import_structure["pipelines"].extend(
|
403 |
+
[
|
404 |
+
"FlaxStableDiffusionControlNetPipeline",
|
405 |
+
"FlaxStableDiffusionImg2ImgPipeline",
|
406 |
+
"FlaxStableDiffusionInpaintPipeline",
|
407 |
+
"FlaxStableDiffusionPipeline",
|
408 |
+
"FlaxStableDiffusionXLPipeline",
|
409 |
+
]
|
410 |
+
)
|
411 |
+
|
412 |
+
try:
|
413 |
+
if not (is_note_seq_available()):
|
414 |
+
raise OptionalDependencyNotAvailable()
|
415 |
+
except OptionalDependencyNotAvailable:
|
416 |
+
from .utils import dummy_note_seq_objects # noqa F403
|
417 |
+
|
418 |
+
_import_structure["utils.dummy_note_seq_objects"] = [
|
419 |
+
name for name in dir(dummy_note_seq_objects) if not name.startswith("_")
|
420 |
+
]
|
421 |
+
|
422 |
+
|
423 |
+
else:
|
424 |
+
_import_structure["pipelines"].extend(["MidiProcessor"])
|
425 |
+
|
426 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
427 |
+
from .configuration_utils import ConfigMixin
|
428 |
+
|
429 |
+
try:
|
430 |
+
if not is_onnx_available():
|
431 |
+
raise OptionalDependencyNotAvailable()
|
432 |
+
except OptionalDependencyNotAvailable:
|
433 |
+
from .utils.dummy_onnx_objects import * # noqa F403
|
434 |
+
else:
|
435 |
+
from .pipelines import OnnxRuntimeModel
|
436 |
+
|
437 |
+
try:
|
438 |
+
if not is_torch_available():
|
439 |
+
raise OptionalDependencyNotAvailable()
|
440 |
+
except OptionalDependencyNotAvailable:
|
441 |
+
from .utils.dummy_pt_objects import * # noqa F403
|
442 |
+
else:
|
443 |
+
from .models import (
|
444 |
+
AsymmetricAutoencoderKL,
|
445 |
+
AutoencoderKL,
|
446 |
+
AutoencoderTiny,
|
447 |
+
ConsistencyDecoderVAE,
|
448 |
+
ControlNetModel,
|
449 |
+
ModelMixin,
|
450 |
+
MotionAdapter,
|
451 |
+
MultiAdapter,
|
452 |
+
PriorTransformer,
|
453 |
+
T2IAdapter,
|
454 |
+
T5FilmDecoder,
|
455 |
+
Transformer2DModel,
|
456 |
+
UNet1DModel,
|
457 |
+
UNet2DConditionModel,
|
458 |
+
UNet2DModel,
|
459 |
+
UNet3DConditionModel,
|
460 |
+
UNetMotionModel,
|
461 |
+
VQModel,
|
462 |
+
)
|
463 |
+
from .optimization import (
|
464 |
+
get_constant_schedule,
|
465 |
+
get_constant_schedule_with_warmup,
|
466 |
+
get_cosine_schedule_with_warmup,
|
467 |
+
get_cosine_with_hard_restarts_schedule_with_warmup,
|
468 |
+
get_linear_schedule_with_warmup,
|
469 |
+
get_polynomial_decay_schedule_with_warmup,
|
470 |
+
get_scheduler,
|
471 |
+
)
|
472 |
+
from .pipelines import (
|
473 |
+
AudioPipelineOutput,
|
474 |
+
AutoPipelineForImage2Image,
|
475 |
+
AutoPipelineForInpainting,
|
476 |
+
AutoPipelineForText2Image,
|
477 |
+
BlipDiffusionControlNetPipeline,
|
478 |
+
BlipDiffusionPipeline,
|
479 |
+
CLIPImageProjection,
|
480 |
+
ConsistencyModelPipeline,
|
481 |
+
DanceDiffusionPipeline,
|
482 |
+
DDIMPipeline,
|
483 |
+
DDPMPipeline,
|
484 |
+
DiffusionPipeline,
|
485 |
+
DiTPipeline,
|
486 |
+
ImagePipelineOutput,
|
487 |
+
KarrasVePipeline,
|
488 |
+
LDMPipeline,
|
489 |
+
LDMSuperResolutionPipeline,
|
490 |
+
PNDMPipeline,
|
491 |
+
RePaintPipeline,
|
492 |
+
ScoreSdeVePipeline,
|
493 |
+
)
|
494 |
+
from .schedulers import (
|
495 |
+
CMStochasticIterativeScheduler,
|
496 |
+
DDIMInverseScheduler,
|
497 |
+
DDIMParallelScheduler,
|
498 |
+
DDIMScheduler,
|
499 |
+
DDPMParallelScheduler,
|
500 |
+
DDPMScheduler,
|
501 |
+
DDPMWuerstchenScheduler,
|
502 |
+
DEISMultistepScheduler,
|
503 |
+
DPMSolverMultistepInverseScheduler,
|
504 |
+
DPMSolverMultistepScheduler,
|
505 |
+
DPMSolverSinglestepScheduler,
|
506 |
+
EulerAncestralDiscreteScheduler,
|
507 |
+
EulerDiscreteScheduler,
|
508 |
+
HeunDiscreteScheduler,
|
509 |
+
IPNDMScheduler,
|
510 |
+
KarrasVeScheduler,
|
511 |
+
KDPM2AncestralDiscreteScheduler,
|
512 |
+
KDPM2DiscreteScheduler,
|
513 |
+
LCMScheduler,
|
514 |
+
PNDMScheduler,
|
515 |
+
RePaintScheduler,
|
516 |
+
SchedulerMixin,
|
517 |
+
ScoreSdeVeScheduler,
|
518 |
+
UnCLIPScheduler,
|
519 |
+
UniPCMultistepScheduler,
|
520 |
+
VQDiffusionScheduler,
|
521 |
+
)
|
522 |
+
from .training_utils import EMAModel
|
523 |
+
|
524 |
+
try:
|
525 |
+
if not (is_torch_available() and is_scipy_available()):
|
526 |
+
raise OptionalDependencyNotAvailable()
|
527 |
+
except OptionalDependencyNotAvailable:
|
528 |
+
from .utils.dummy_torch_and_scipy_objects import * # noqa F403
|
529 |
+
else:
|
530 |
+
from .schedulers import LMSDiscreteScheduler
|
531 |
+
|
532 |
+
try:
|
533 |
+
if not (is_torch_available() and is_torchsde_available()):
|
534 |
+
raise OptionalDependencyNotAvailable()
|
535 |
+
except OptionalDependencyNotAvailable:
|
536 |
+
from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
|
537 |
+
else:
|
538 |
+
from .schedulers import DPMSolverSDEScheduler
|
539 |
+
|
540 |
+
try:
|
541 |
+
if not (is_torch_available() and is_transformers_available()):
|
542 |
+
raise OptionalDependencyNotAvailable()
|
543 |
+
except OptionalDependencyNotAvailable:
|
544 |
+
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
545 |
+
else:
|
546 |
+
from .pipelines import (
|
547 |
+
AltDiffusionImg2ImgPipeline,
|
548 |
+
AltDiffusionPipeline,
|
549 |
+
AnimateDiffPipeline,
|
550 |
+
AudioLDM2Pipeline,
|
551 |
+
AudioLDM2ProjectionModel,
|
552 |
+
AudioLDM2UNet2DConditionModel,
|
553 |
+
AudioLDMPipeline,
|
554 |
+
CLIPImageProjection,
|
555 |
+
CycleDiffusionPipeline,
|
556 |
+
IFImg2ImgPipeline,
|
557 |
+
IFImg2ImgSuperResolutionPipeline,
|
558 |
+
IFInpaintingPipeline,
|
559 |
+
IFInpaintingSuperResolutionPipeline,
|
560 |
+
IFPipeline,
|
561 |
+
IFSuperResolutionPipeline,
|
562 |
+
ImageTextPipelineOutput,
|
563 |
+
KandinskyCombinedPipeline,
|
564 |
+
KandinskyImg2ImgCombinedPipeline,
|
565 |
+
KandinskyImg2ImgPipeline,
|
566 |
+
KandinskyInpaintCombinedPipeline,
|
567 |
+
KandinskyInpaintPipeline,
|
568 |
+
KandinskyPipeline,
|
569 |
+
KandinskyPriorPipeline,
|
570 |
+
KandinskyV22CombinedPipeline,
|
571 |
+
KandinskyV22ControlnetImg2ImgPipeline,
|
572 |
+
KandinskyV22ControlnetPipeline,
|
573 |
+
KandinskyV22Img2ImgCombinedPipeline,
|
574 |
+
KandinskyV22Img2ImgPipeline,
|
575 |
+
KandinskyV22InpaintCombinedPipeline,
|
576 |
+
KandinskyV22InpaintPipeline,
|
577 |
+
KandinskyV22Pipeline,
|
578 |
+
KandinskyV22PriorEmb2EmbPipeline,
|
579 |
+
KandinskyV22PriorPipeline,
|
580 |
+
LatentConsistencyModelImg2ImgPipeline,
|
581 |
+
LatentConsistencyModelPipeline,
|
582 |
+
LDMTextToImagePipeline,
|
583 |
+
MusicLDMPipeline,
|
584 |
+
PaintByExamplePipeline,
|
585 |
+
PixArtAlphaPipeline,
|
586 |
+
SemanticStableDiffusionPipeline,
|
587 |
+
ShapEImg2ImgPipeline,
|
588 |
+
ShapEPipeline,
|
589 |
+
StableDiffusionAdapterPipeline,
|
590 |
+
StableDiffusionAttendAndExcitePipeline,
|
591 |
+
StableDiffusionControlNetImg2ImgPipeline,
|
592 |
+
StableDiffusionControlNetInpaintPipeline,
|
593 |
+
StableDiffusionControlNetPipeline,
|
594 |
+
StableDiffusionDepth2ImgPipeline,
|
595 |
+
StableDiffusionDiffEditPipeline,
|
596 |
+
StableDiffusionGLIGENPipeline,
|
597 |
+
StableDiffusionGLIGENTextImagePipeline,
|
598 |
+
StableDiffusionImageVariationPipeline,
|
599 |
+
StableDiffusionImg2ImgPipeline,
|
600 |
+
StableDiffusionInpaintPipeline,
|
601 |
+
StableDiffusionInpaintPipelineLegacy,
|
602 |
+
StableDiffusionInstructPix2PixPipeline,
|
603 |
+
StableDiffusionLatentUpscalePipeline,
|
604 |
+
StableDiffusionLDM3DPipeline,
|
605 |
+
StableDiffusionModelEditingPipeline,
|
606 |
+
StableDiffusionPanoramaPipeline,
|
607 |
+
StableDiffusionParadigmsPipeline,
|
608 |
+
StableDiffusionPipeline,
|
609 |
+
StableDiffusionPipelineSafe,
|
610 |
+
StableDiffusionPix2PixZeroPipeline,
|
611 |
+
StableDiffusionSAGPipeline,
|
612 |
+
StableDiffusionUpscalePipeline,
|
613 |
+
StableDiffusionXLAdapterPipeline,
|
614 |
+
StableDiffusionXLControlNetImg2ImgPipeline,
|
615 |
+
StableDiffusionXLControlNetInpaintPipeline,
|
616 |
+
StableDiffusionXLControlNetPipeline,
|
617 |
+
StableDiffusionXLImg2ImgPipeline,
|
618 |
+
StableDiffusionXLInpaintPipeline,
|
619 |
+
StableDiffusionXLInstructPix2PixPipeline,
|
620 |
+
StableDiffusionXLPipeline,
|
621 |
+
StableUnCLIPImg2ImgPipeline,
|
622 |
+
StableUnCLIPPipeline,
|
623 |
+
TextToVideoSDPipeline,
|
624 |
+
TextToVideoZeroPipeline,
|
625 |
+
UnCLIPImageVariationPipeline,
|
626 |
+
UnCLIPPipeline,
|
627 |
+
UniDiffuserModel,
|
628 |
+
UniDiffuserPipeline,
|
629 |
+
UniDiffuserTextDecoder,
|
630 |
+
VersatileDiffusionDualGuidedPipeline,
|
631 |
+
VersatileDiffusionImageVariationPipeline,
|
632 |
+
VersatileDiffusionPipeline,
|
633 |
+
VersatileDiffusionTextToImagePipeline,
|
634 |
+
VideoToVideoSDPipeline,
|
635 |
+
VQDiffusionPipeline,
|
636 |
+
WuerstchenCombinedPipeline,
|
637 |
+
WuerstchenDecoderPipeline,
|
638 |
+
WuerstchenPriorPipeline,
|
639 |
+
)
|
640 |
+
|
641 |
+
try:
|
642 |
+
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
643 |
+
raise OptionalDependencyNotAvailable()
|
644 |
+
except OptionalDependencyNotAvailable:
|
645 |
+
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
|
646 |
+
else:
|
647 |
+
from .pipelines import StableDiffusionKDiffusionPipeline
|
648 |
+
|
649 |
+
try:
|
650 |
+
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
651 |
+
raise OptionalDependencyNotAvailable()
|
652 |
+
except OptionalDependencyNotAvailable:
|
653 |
+
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
|
654 |
+
else:
|
655 |
+
from .pipelines import (
|
656 |
+
OnnxStableDiffusionImg2ImgPipeline,
|
657 |
+
OnnxStableDiffusionInpaintPipeline,
|
658 |
+
OnnxStableDiffusionInpaintPipelineLegacy,
|
659 |
+
OnnxStableDiffusionPipeline,
|
660 |
+
OnnxStableDiffusionUpscalePipeline,
|
661 |
+
StableDiffusionOnnxPipeline,
|
662 |
+
)
|
663 |
+
|
664 |
+
try:
|
665 |
+
if not (is_torch_available() and is_librosa_available()):
|
666 |
+
raise OptionalDependencyNotAvailable()
|
667 |
+
except OptionalDependencyNotAvailable:
|
668 |
+
from .utils.dummy_torch_and_librosa_objects import * # noqa F403
|
669 |
+
else:
|
670 |
+
from .pipelines import AudioDiffusionPipeline, Mel
|
671 |
+
|
672 |
+
try:
|
673 |
+
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
674 |
+
raise OptionalDependencyNotAvailable()
|
675 |
+
except OptionalDependencyNotAvailable:
|
676 |
+
from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
|
677 |
+
else:
|
678 |
+
from .pipelines import SpectrogramDiffusionPipeline
|
679 |
+
|
680 |
+
try:
|
681 |
+
if not is_flax_available():
|
682 |
+
raise OptionalDependencyNotAvailable()
|
683 |
+
except OptionalDependencyNotAvailable:
|
684 |
+
from .utils.dummy_flax_objects import * # noqa F403
|
685 |
+
else:
|
686 |
+
from .models.controlnet_flax import FlaxControlNetModel
|
687 |
+
from .models.modeling_flax_utils import FlaxModelMixin
|
688 |
+
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
689 |
+
from .models.vae_flax import FlaxAutoencoderKL
|
690 |
+
from .pipelines import FlaxDiffusionPipeline
|
691 |
+
from .schedulers import (
|
692 |
+
FlaxDDIMScheduler,
|
693 |
+
FlaxDDPMScheduler,
|
694 |
+
FlaxDPMSolverMultistepScheduler,
|
695 |
+
FlaxEulerDiscreteScheduler,
|
696 |
+
FlaxKarrasVeScheduler,
|
697 |
+
FlaxLMSDiscreteScheduler,
|
698 |
+
FlaxPNDMScheduler,
|
699 |
+
FlaxSchedulerMixin,
|
700 |
+
FlaxScoreSdeVeScheduler,
|
701 |
+
)
|
702 |
+
|
703 |
+
try:
|
704 |
+
if not (is_flax_available() and is_transformers_available()):
|
705 |
+
raise OptionalDependencyNotAvailable()
|
706 |
+
except OptionalDependencyNotAvailable:
|
707 |
+
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
|
708 |
+
else:
|
709 |
+
from .pipelines import (
|
710 |
+
FlaxStableDiffusionControlNetPipeline,
|
711 |
+
FlaxStableDiffusionImg2ImgPipeline,
|
712 |
+
FlaxStableDiffusionInpaintPipeline,
|
713 |
+
FlaxStableDiffusionPipeline,
|
714 |
+
FlaxStableDiffusionXLPipeline,
|
715 |
+
)
|
716 |
+
|
717 |
+
try:
|
718 |
+
if not (is_note_seq_available()):
|
719 |
+
raise OptionalDependencyNotAvailable()
|
720 |
+
except OptionalDependencyNotAvailable:
|
721 |
+
from .utils.dummy_note_seq_objects import * # noqa F403
|
722 |
+
else:
|
723 |
+
from .pipelines import MidiProcessor
|
724 |
+
|
725 |
+
else:
|
726 |
+
import sys
|
727 |
+
|
728 |
+
sys.modules[__name__] = _LazyModule(
|
729 |
+
__name__,
|
730 |
+
globals()["__file__"],
|
731 |
+
_import_structure,
|
732 |
+
module_spec=__spec__,
|
733 |
+
extra_objects={"__version__": __version__},
|
734 |
+
)
|
diffusers/commands/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from abc import ABC, abstractmethod
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
|
18 |
+
|
19 |
+
class BaseDiffusersCLICommand(ABC):
|
20 |
+
@staticmethod
|
21 |
+
@abstractmethod
|
22 |
+
def register_subcommand(parser: ArgumentParser):
|
23 |
+
raise NotImplementedError()
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def run(self):
|
27 |
+
raise NotImplementedError()
|
diffusers/commands/diffusers_cli.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
|
18 |
+
from .env import EnvironmentCommand
|
19 |
+
from .fp16_safetensors import FP16SafetensorsCommand
|
20 |
+
|
21 |
+
|
22 |
+
def main():
|
23 |
+
parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
|
24 |
+
commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
|
25 |
+
|
26 |
+
# Register commands
|
27 |
+
EnvironmentCommand.register_subcommand(commands_parser)
|
28 |
+
FP16SafetensorsCommand.register_subcommand(commands_parser)
|
29 |
+
|
30 |
+
# Let's go
|
31 |
+
args = parser.parse_args()
|
32 |
+
|
33 |
+
if not hasattr(args, "func"):
|
34 |
+
parser.print_help()
|
35 |
+
exit(1)
|
36 |
+
|
37 |
+
# Run
|
38 |
+
service = args.func(args)
|
39 |
+
service.run()
|
40 |
+
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
main()
|
diffusers/commands/env.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import platform
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
|
18 |
+
import huggingface_hub
|
19 |
+
|
20 |
+
from .. import __version__ as version
|
21 |
+
from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available
|
22 |
+
from . import BaseDiffusersCLICommand
|
23 |
+
|
24 |
+
|
25 |
+
def info_command_factory(_):
|
26 |
+
return EnvironmentCommand()
|
27 |
+
|
28 |
+
|
29 |
+
class EnvironmentCommand(BaseDiffusersCLICommand):
|
30 |
+
@staticmethod
|
31 |
+
def register_subcommand(parser: ArgumentParser):
|
32 |
+
download_parser = parser.add_parser("env")
|
33 |
+
download_parser.set_defaults(func=info_command_factory)
|
34 |
+
|
35 |
+
def run(self):
|
36 |
+
hub_version = huggingface_hub.__version__
|
37 |
+
|
38 |
+
pt_version = "not installed"
|
39 |
+
pt_cuda_available = "NA"
|
40 |
+
if is_torch_available():
|
41 |
+
import torch
|
42 |
+
|
43 |
+
pt_version = torch.__version__
|
44 |
+
pt_cuda_available = torch.cuda.is_available()
|
45 |
+
|
46 |
+
transformers_version = "not installed"
|
47 |
+
if is_transformers_available():
|
48 |
+
import transformers
|
49 |
+
|
50 |
+
transformers_version = transformers.__version__
|
51 |
+
|
52 |
+
accelerate_version = "not installed"
|
53 |
+
if is_accelerate_available():
|
54 |
+
import accelerate
|
55 |
+
|
56 |
+
accelerate_version = accelerate.__version__
|
57 |
+
|
58 |
+
xformers_version = "not installed"
|
59 |
+
if is_xformers_available():
|
60 |
+
import xformers
|
61 |
+
|
62 |
+
xformers_version = xformers.__version__
|
63 |
+
|
64 |
+
info = {
|
65 |
+
"`diffusers` version": version,
|
66 |
+
"Platform": platform.platform(),
|
67 |
+
"Python version": platform.python_version(),
|
68 |
+
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
|
69 |
+
"Huggingface_hub version": hub_version,
|
70 |
+
"Transformers version": transformers_version,
|
71 |
+
"Accelerate version": accelerate_version,
|
72 |
+
"xFormers version": xformers_version,
|
73 |
+
"Using GPU in script?": "<fill in>",
|
74 |
+
"Using distributed or parallel set-up in script?": "<fill in>",
|
75 |
+
}
|
76 |
+
|
77 |
+
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
|
78 |
+
print(self.format_dict(info))
|
79 |
+
|
80 |
+
return info
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def format_dict(d):
|
84 |
+
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
|
diffusers/commands/fp16_safetensors.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""
|
16 |
+
Usage example:
|
17 |
+
diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors
|
18 |
+
"""
|
19 |
+
|
20 |
+
import glob
|
21 |
+
import json
|
22 |
+
from argparse import ArgumentParser, Namespace
|
23 |
+
from importlib import import_module
|
24 |
+
|
25 |
+
import huggingface_hub
|
26 |
+
import torch
|
27 |
+
from huggingface_hub import hf_hub_download
|
28 |
+
from packaging import version
|
29 |
+
|
30 |
+
from ..utils import logging
|
31 |
+
from . import BaseDiffusersCLICommand
|
32 |
+
|
33 |
+
|
34 |
+
def conversion_command_factory(args: Namespace):
|
35 |
+
return FP16SafetensorsCommand(
|
36 |
+
args.ckpt_id,
|
37 |
+
args.fp16,
|
38 |
+
args.use_safetensors,
|
39 |
+
args.use_auth_token,
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
class FP16SafetensorsCommand(BaseDiffusersCLICommand):
|
44 |
+
@staticmethod
|
45 |
+
def register_subcommand(parser: ArgumentParser):
|
46 |
+
conversion_parser = parser.add_parser("fp16_safetensors")
|
47 |
+
conversion_parser.add_argument(
|
48 |
+
"--ckpt_id",
|
49 |
+
type=str,
|
50 |
+
help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.",
|
51 |
+
)
|
52 |
+
conversion_parser.add_argument(
|
53 |
+
"--fp16", action="store_true", help="If serializing the variables in FP16 precision."
|
54 |
+
)
|
55 |
+
conversion_parser.add_argument(
|
56 |
+
"--use_safetensors", action="store_true", help="If serializing in the safetensors format."
|
57 |
+
)
|
58 |
+
conversion_parser.add_argument(
|
59 |
+
"--use_auth_token",
|
60 |
+
action="store_true",
|
61 |
+
help="When working with checkpoints having private visibility. When used `huggingface-cli login` needs to be run beforehand.",
|
62 |
+
)
|
63 |
+
conversion_parser.set_defaults(func=conversion_command_factory)
|
64 |
+
|
65 |
+
def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool, use_auth_token: bool):
|
66 |
+
self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
|
67 |
+
self.ckpt_id = ckpt_id
|
68 |
+
self.local_ckpt_dir = f"/tmp/{ckpt_id}"
|
69 |
+
self.fp16 = fp16
|
70 |
+
|
71 |
+
self.use_safetensors = use_safetensors
|
72 |
+
|
73 |
+
if not self.use_safetensors and not self.fp16:
|
74 |
+
raise NotImplementedError(
|
75 |
+
"When `use_safetensors` and `fp16` both are False, then this command is of no use."
|
76 |
+
)
|
77 |
+
|
78 |
+
self.use_auth_token = use_auth_token
|
79 |
+
|
80 |
+
def run(self):
|
81 |
+
if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
|
82 |
+
raise ImportError(
|
83 |
+
"The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
|
84 |
+
" installation."
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
from huggingface_hub import create_commit
|
88 |
+
from huggingface_hub._commit_api import CommitOperationAdd
|
89 |
+
|
90 |
+
model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json", token=self.use_auth_token)
|
91 |
+
with open(model_index, "r") as f:
|
92 |
+
pipeline_class_name = json.load(f)["_class_name"]
|
93 |
+
pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
|
94 |
+
self.logger.info(f"Pipeline class imported: {pipeline_class_name}.")
|
95 |
+
|
96 |
+
# Load the appropriate pipeline. We could have use `DiffusionPipeline`
|
97 |
+
# here, but just to avoid any rough edge cases.
|
98 |
+
pipeline = pipeline_class.from_pretrained(
|
99 |
+
self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32, use_auth_token=self.use_auth_token
|
100 |
+
)
|
101 |
+
pipeline.save_pretrained(
|
102 |
+
self.local_ckpt_dir,
|
103 |
+
safe_serialization=True if self.use_safetensors else False,
|
104 |
+
variant="fp16" if self.fp16 else None,
|
105 |
+
)
|
106 |
+
self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.")
|
107 |
+
|
108 |
+
# Fetch all the paths.
|
109 |
+
if self.fp16:
|
110 |
+
modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*")
|
111 |
+
elif self.use_safetensors:
|
112 |
+
modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors")
|
113 |
+
|
114 |
+
# Prepare for the PR.
|
115 |
+
commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}."
|
116 |
+
operations = []
|
117 |
+
for path in modified_paths:
|
118 |
+
operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path))
|
119 |
+
|
120 |
+
# Open the PR.
|
121 |
+
commit_description = (
|
122 |
+
"Variables converted by the [`diffusers`' `fp16_safetensors`"
|
123 |
+
" CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)."
|
124 |
+
)
|
125 |
+
hub_pr_url = create_commit(
|
126 |
+
repo_id=self.ckpt_id,
|
127 |
+
operations=operations,
|
128 |
+
commit_message=commit_message,
|
129 |
+
commit_description=commit_description,
|
130 |
+
repo_type="model",
|
131 |
+
create_pr=True,
|
132 |
+
).pr_url
|
133 |
+
self.logger.info(f"PR created here: {hub_pr_url}.")
|
diffusers/configuration_utils.py
ADDED
@@ -0,0 +1,694 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" ConfigMixin base class and utilities."""
|
17 |
+
import dataclasses
|
18 |
+
import functools
|
19 |
+
import importlib
|
20 |
+
import inspect
|
21 |
+
import json
|
22 |
+
import os
|
23 |
+
import re
|
24 |
+
from collections import OrderedDict
|
25 |
+
from pathlib import PosixPath
|
26 |
+
from typing import Any, Dict, Tuple, Union
|
27 |
+
|
28 |
+
import numpy as np
|
29 |
+
from huggingface_hub import create_repo, hf_hub_download
|
30 |
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
31 |
+
from requests import HTTPError
|
32 |
+
|
33 |
+
from . import __version__
|
34 |
+
from .utils import (
|
35 |
+
DIFFUSERS_CACHE,
|
36 |
+
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
37 |
+
DummyObject,
|
38 |
+
deprecate,
|
39 |
+
extract_commit_hash,
|
40 |
+
http_user_agent,
|
41 |
+
logging,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
logger = logging.get_logger(__name__)
|
46 |
+
|
47 |
+
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
48 |
+
|
49 |
+
|
50 |
+
class FrozenDict(OrderedDict):
|
51 |
+
def __init__(self, *args, **kwargs):
|
52 |
+
super().__init__(*args, **kwargs)
|
53 |
+
|
54 |
+
for key, value in self.items():
|
55 |
+
setattr(self, key, value)
|
56 |
+
|
57 |
+
self.__frozen = True
|
58 |
+
|
59 |
+
def __delitem__(self, *args, **kwargs):
|
60 |
+
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
61 |
+
|
62 |
+
def setdefault(self, *args, **kwargs):
|
63 |
+
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
64 |
+
|
65 |
+
def pop(self, *args, **kwargs):
|
66 |
+
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
67 |
+
|
68 |
+
def update(self, *args, **kwargs):
|
69 |
+
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
70 |
+
|
71 |
+
def __setattr__(self, name, value):
|
72 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
73 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
74 |
+
super().__setattr__(name, value)
|
75 |
+
|
76 |
+
def __setitem__(self, name, value):
|
77 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
78 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
79 |
+
super().__setitem__(name, value)
|
80 |
+
|
81 |
+
|
82 |
+
class ConfigMixin:
|
83 |
+
r"""
|
84 |
+
Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
|
85 |
+
provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
|
86 |
+
saving classes that inherit from [`ConfigMixin`].
|
87 |
+
|
88 |
+
Class attributes:
|
89 |
+
- **config_name** (`str`) -- A filename under which the config should stored when calling
|
90 |
+
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
91 |
+
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
92 |
+
overridden by subclass).
|
93 |
+
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
|
94 |
+
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
|
95 |
+
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
|
96 |
+
subclass).
|
97 |
+
"""
|
98 |
+
config_name = None
|
99 |
+
ignore_for_config = []
|
100 |
+
has_compatibles = False
|
101 |
+
|
102 |
+
_deprecated_kwargs = []
|
103 |
+
|
104 |
+
def register_to_config(self, **kwargs):
|
105 |
+
if self.config_name is None:
|
106 |
+
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
107 |
+
# Special case for `kwargs` used in deprecation warning added to schedulers
|
108 |
+
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
|
109 |
+
# or solve in a more general way.
|
110 |
+
kwargs.pop("kwargs", None)
|
111 |
+
|
112 |
+
if not hasattr(self, "_internal_dict"):
|
113 |
+
internal_dict = kwargs
|
114 |
+
else:
|
115 |
+
previous_dict = dict(self._internal_dict)
|
116 |
+
internal_dict = {**self._internal_dict, **kwargs}
|
117 |
+
logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
|
118 |
+
|
119 |
+
self._internal_dict = FrozenDict(internal_dict)
|
120 |
+
|
121 |
+
def __getattr__(self, name: str) -> Any:
|
122 |
+
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
123 |
+
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
|
124 |
+
|
125 |
+
Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
|
126 |
+
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
127 |
+
"""
|
128 |
+
|
129 |
+
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
130 |
+
is_attribute = name in self.__dict__
|
131 |
+
|
132 |
+
if is_in_config and not is_attribute:
|
133 |
+
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
|
134 |
+
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
|
135 |
+
return self._internal_dict[name]
|
136 |
+
|
137 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
138 |
+
|
139 |
+
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
140 |
+
"""
|
141 |
+
Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
|
142 |
+
[`~ConfigMixin.from_config`] class method.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
save_directory (`str` or `os.PathLike`):
|
146 |
+
Directory where the configuration JSON file is saved (will be created if it does not exist).
|
147 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
148 |
+
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
149 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
150 |
+
namespace).
|
151 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
152 |
+
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
153 |
+
"""
|
154 |
+
if os.path.isfile(save_directory):
|
155 |
+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
156 |
+
|
157 |
+
os.makedirs(save_directory, exist_ok=True)
|
158 |
+
|
159 |
+
# If we save using the predefined names, we can load using `from_config`
|
160 |
+
output_config_file = os.path.join(save_directory, self.config_name)
|
161 |
+
|
162 |
+
self.to_json_file(output_config_file)
|
163 |
+
logger.info(f"Configuration saved in {output_config_file}")
|
164 |
+
|
165 |
+
if push_to_hub:
|
166 |
+
commit_message = kwargs.pop("commit_message", None)
|
167 |
+
private = kwargs.pop("private", False)
|
168 |
+
create_pr = kwargs.pop("create_pr", False)
|
169 |
+
token = kwargs.pop("token", None)
|
170 |
+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
171 |
+
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
172 |
+
|
173 |
+
self._upload_folder(
|
174 |
+
save_directory,
|
175 |
+
repo_id,
|
176 |
+
token=token,
|
177 |
+
commit_message=commit_message,
|
178 |
+
create_pr=create_pr,
|
179 |
+
)
|
180 |
+
|
181 |
+
@classmethod
|
182 |
+
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
183 |
+
r"""
|
184 |
+
Instantiate a Python class from a config dictionary.
|
185 |
+
|
186 |
+
Parameters:
|
187 |
+
config (`Dict[str, Any]`):
|
188 |
+
A config dictionary from which the Python class is instantiated. Make sure to only load configuration
|
189 |
+
files of compatible classes.
|
190 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
191 |
+
Whether kwargs that are not consumed by the Python class should be returned or not.
|
192 |
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
193 |
+
Can be used to update the configuration object (after it is loaded) and initiate the Python class.
|
194 |
+
`**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
|
195 |
+
overwrite the same named arguments in `config`.
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
[`ModelMixin`] or [`SchedulerMixin`]:
|
199 |
+
A model or scheduler object instantiated from a config dictionary.
|
200 |
+
|
201 |
+
Examples:
|
202 |
+
|
203 |
+
```python
|
204 |
+
>>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
|
205 |
+
|
206 |
+
>>> # Download scheduler from huggingface.co and cache.
|
207 |
+
>>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
|
208 |
+
|
209 |
+
>>> # Instantiate DDIM scheduler class with same config as DDPM
|
210 |
+
>>> scheduler = DDIMScheduler.from_config(scheduler.config)
|
211 |
+
|
212 |
+
>>> # Instantiate PNDM scheduler class with same config as DDPM
|
213 |
+
>>> scheduler = PNDMScheduler.from_config(scheduler.config)
|
214 |
+
```
|
215 |
+
"""
|
216 |
+
# <===== TO BE REMOVED WITH DEPRECATION
|
217 |
+
# TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
|
218 |
+
if "pretrained_model_name_or_path" in kwargs:
|
219 |
+
config = kwargs.pop("pretrained_model_name_or_path")
|
220 |
+
|
221 |
+
if config is None:
|
222 |
+
raise ValueError("Please make sure to provide a config as the first positional argument.")
|
223 |
+
# ======>
|
224 |
+
|
225 |
+
if not isinstance(config, dict):
|
226 |
+
deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
|
227 |
+
if "Scheduler" in cls.__name__:
|
228 |
+
deprecation_message += (
|
229 |
+
f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
|
230 |
+
" Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
|
231 |
+
" be removed in v1.0.0."
|
232 |
+
)
|
233 |
+
elif "Model" in cls.__name__:
|
234 |
+
deprecation_message += (
|
235 |
+
f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
|
236 |
+
f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
|
237 |
+
" instead. This functionality will be removed in v1.0.0."
|
238 |
+
)
|
239 |
+
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
|
240 |
+
config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
|
241 |
+
|
242 |
+
init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
|
243 |
+
|
244 |
+
# Allow dtype to be specified on initialization
|
245 |
+
if "dtype" in unused_kwargs:
|
246 |
+
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
247 |
+
|
248 |
+
# add possible deprecated kwargs
|
249 |
+
for deprecated_kwarg in cls._deprecated_kwargs:
|
250 |
+
if deprecated_kwarg in unused_kwargs:
|
251 |
+
init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
|
252 |
+
|
253 |
+
# Return model and optionally state and/or unused_kwargs
|
254 |
+
model = cls(**init_dict)
|
255 |
+
|
256 |
+
# make sure to also save config parameters that might be used for compatible classes
|
257 |
+
model.register_to_config(**hidden_dict)
|
258 |
+
|
259 |
+
# add hidden kwargs of compatible classes to unused_kwargs
|
260 |
+
unused_kwargs = {**unused_kwargs, **hidden_dict}
|
261 |
+
|
262 |
+
if return_unused_kwargs:
|
263 |
+
return (model, unused_kwargs)
|
264 |
+
else:
|
265 |
+
return model
|
266 |
+
|
267 |
+
@classmethod
|
268 |
+
def get_config_dict(cls, *args, **kwargs):
|
269 |
+
deprecation_message = (
|
270 |
+
f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
|
271 |
+
" removed in version v1.0.0"
|
272 |
+
)
|
273 |
+
deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
|
274 |
+
return cls.load_config(*args, **kwargs)
|
275 |
+
|
276 |
+
@classmethod
|
277 |
+
def load_config(
|
278 |
+
cls,
|
279 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
280 |
+
return_unused_kwargs=False,
|
281 |
+
return_commit_hash=False,
|
282 |
+
**kwargs,
|
283 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
284 |
+
r"""
|
285 |
+
Load a model or scheduler configuration.
|
286 |
+
|
287 |
+
Parameters:
|
288 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
289 |
+
Can be either:
|
290 |
+
|
291 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
292 |
+
the Hub.
|
293 |
+
- A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
|
294 |
+
[`~ConfigMixin.save_config`].
|
295 |
+
|
296 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
297 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
298 |
+
is not used.
|
299 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
300 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
301 |
+
cached versions if they exist.
|
302 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
303 |
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
304 |
+
incompletely downloaded files are deleted.
|
305 |
+
proxies (`Dict[str, str]`, *optional*):
|
306 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
307 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
308 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
309 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
310 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
311 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
312 |
+
won't be downloaded from the Hub.
|
313 |
+
use_auth_token (`str` or *bool*, *optional*):
|
314 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
315 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
316 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
317 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
318 |
+
allowed by Git.
|
319 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
320 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
321 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False):
|
322 |
+
Whether unused keyword arguments of the config are returned.
|
323 |
+
return_commit_hash (`bool`, *optional*, defaults to `False):
|
324 |
+
Whether the `commit_hash` of the loaded configuration are returned.
|
325 |
+
|
326 |
+
Returns:
|
327 |
+
`dict`:
|
328 |
+
A dictionary of all the parameters stored in a JSON configuration file.
|
329 |
+
|
330 |
+
"""
|
331 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
332 |
+
force_download = kwargs.pop("force_download", False)
|
333 |
+
resume_download = kwargs.pop("resume_download", False)
|
334 |
+
proxies = kwargs.pop("proxies", None)
|
335 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
336 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
337 |
+
revision = kwargs.pop("revision", None)
|
338 |
+
_ = kwargs.pop("mirror", None)
|
339 |
+
subfolder = kwargs.pop("subfolder", None)
|
340 |
+
user_agent = kwargs.pop("user_agent", {})
|
341 |
+
|
342 |
+
user_agent = {**user_agent, "file_type": "config"}
|
343 |
+
user_agent = http_user_agent(user_agent)
|
344 |
+
|
345 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
346 |
+
|
347 |
+
if cls.config_name is None:
|
348 |
+
raise ValueError(
|
349 |
+
"`self.config_name` is not defined. Note that one should not load a config from "
|
350 |
+
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
351 |
+
)
|
352 |
+
|
353 |
+
if os.path.isfile(pretrained_model_name_or_path):
|
354 |
+
config_file = pretrained_model_name_or_path
|
355 |
+
elif os.path.isdir(pretrained_model_name_or_path):
|
356 |
+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
357 |
+
# Load from a PyTorch checkpoint
|
358 |
+
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
359 |
+
elif subfolder is not None and os.path.isfile(
|
360 |
+
os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
361 |
+
):
|
362 |
+
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
363 |
+
else:
|
364 |
+
raise EnvironmentError(
|
365 |
+
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
|
366 |
+
)
|
367 |
+
else:
|
368 |
+
try:
|
369 |
+
# Load from URL or cache if already cached
|
370 |
+
config_file = hf_hub_download(
|
371 |
+
pretrained_model_name_or_path,
|
372 |
+
filename=cls.config_name,
|
373 |
+
cache_dir=cache_dir,
|
374 |
+
force_download=force_download,
|
375 |
+
proxies=proxies,
|
376 |
+
resume_download=resume_download,
|
377 |
+
local_files_only=local_files_only,
|
378 |
+
use_auth_token=use_auth_token,
|
379 |
+
user_agent=user_agent,
|
380 |
+
subfolder=subfolder,
|
381 |
+
revision=revision,
|
382 |
+
)
|
383 |
+
except RepositoryNotFoundError:
|
384 |
+
raise EnvironmentError(
|
385 |
+
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
|
386 |
+
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
|
387 |
+
" token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
|
388 |
+
" login`."
|
389 |
+
)
|
390 |
+
except RevisionNotFoundError:
|
391 |
+
raise EnvironmentError(
|
392 |
+
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
|
393 |
+
" this model name. Check the model page at"
|
394 |
+
f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
395 |
+
)
|
396 |
+
except EntryNotFoundError:
|
397 |
+
raise EnvironmentError(
|
398 |
+
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
|
399 |
+
)
|
400 |
+
except HTTPError as err:
|
401 |
+
raise EnvironmentError(
|
402 |
+
"There was a specific connection error when trying to load"
|
403 |
+
f" {pretrained_model_name_or_path}:\n{err}"
|
404 |
+
)
|
405 |
+
except ValueError:
|
406 |
+
raise EnvironmentError(
|
407 |
+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
408 |
+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
409 |
+
f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
|
410 |
+
" run the library in offline mode at"
|
411 |
+
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
412 |
+
)
|
413 |
+
except EnvironmentError:
|
414 |
+
raise EnvironmentError(
|
415 |
+
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
416 |
+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
417 |
+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
418 |
+
f"containing a {cls.config_name} file"
|
419 |
+
)
|
420 |
+
|
421 |
+
try:
|
422 |
+
# Load config dict
|
423 |
+
config_dict = cls._dict_from_json_file(config_file)
|
424 |
+
|
425 |
+
commit_hash = extract_commit_hash(config_file)
|
426 |
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
427 |
+
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
428 |
+
|
429 |
+
if not (return_unused_kwargs or return_commit_hash):
|
430 |
+
return config_dict
|
431 |
+
|
432 |
+
outputs = (config_dict,)
|
433 |
+
|
434 |
+
if return_unused_kwargs:
|
435 |
+
outputs += (kwargs,)
|
436 |
+
|
437 |
+
if return_commit_hash:
|
438 |
+
outputs += (commit_hash,)
|
439 |
+
|
440 |
+
return outputs
|
441 |
+
|
442 |
+
@staticmethod
|
443 |
+
def _get_init_keys(cls):
|
444 |
+
return set(dict(inspect.signature(cls.__init__).parameters).keys())
|
445 |
+
|
446 |
+
@classmethod
|
447 |
+
def extract_init_dict(cls, config_dict, **kwargs):
|
448 |
+
# Skip keys that were not present in the original config, so default __init__ values were used
|
449 |
+
used_defaults = config_dict.get("_use_default_values", [])
|
450 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
|
451 |
+
|
452 |
+
# 0. Copy origin config dict
|
453 |
+
original_dict = dict(config_dict.items())
|
454 |
+
|
455 |
+
# 1. Retrieve expected config attributes from __init__ signature
|
456 |
+
expected_keys = cls._get_init_keys(cls)
|
457 |
+
expected_keys.remove("self")
|
458 |
+
# remove general kwargs if present in dict
|
459 |
+
if "kwargs" in expected_keys:
|
460 |
+
expected_keys.remove("kwargs")
|
461 |
+
# remove flax internal keys
|
462 |
+
if hasattr(cls, "_flax_internal_args"):
|
463 |
+
for arg in cls._flax_internal_args:
|
464 |
+
expected_keys.remove(arg)
|
465 |
+
|
466 |
+
# 2. Remove attributes that cannot be expected from expected config attributes
|
467 |
+
# remove keys to be ignored
|
468 |
+
if len(cls.ignore_for_config) > 0:
|
469 |
+
expected_keys = expected_keys - set(cls.ignore_for_config)
|
470 |
+
|
471 |
+
# load diffusers library to import compatible and original scheduler
|
472 |
+
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
473 |
+
|
474 |
+
if cls.has_compatibles:
|
475 |
+
compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
|
476 |
+
else:
|
477 |
+
compatible_classes = []
|
478 |
+
|
479 |
+
expected_keys_comp_cls = set()
|
480 |
+
for c in compatible_classes:
|
481 |
+
expected_keys_c = cls._get_init_keys(c)
|
482 |
+
expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
|
483 |
+
expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
|
484 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
|
485 |
+
|
486 |
+
# remove attributes from orig class that cannot be expected
|
487 |
+
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
|
488 |
+
if (
|
489 |
+
isinstance(orig_cls_name, str)
|
490 |
+
and orig_cls_name != cls.__name__
|
491 |
+
and hasattr(diffusers_library, orig_cls_name)
|
492 |
+
):
|
493 |
+
orig_cls = getattr(diffusers_library, orig_cls_name)
|
494 |
+
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
|
495 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
|
496 |
+
elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)):
|
497 |
+
raise ValueError(
|
498 |
+
"Make sure that the `_class_name` is of type string or list of string (for custom pipelines)."
|
499 |
+
)
|
500 |
+
|
501 |
+
# remove private attributes
|
502 |
+
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
|
503 |
+
|
504 |
+
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
|
505 |
+
init_dict = {}
|
506 |
+
for key in expected_keys:
|
507 |
+
# if config param is passed to kwarg and is present in config dict
|
508 |
+
# it should overwrite existing config dict key
|
509 |
+
if key in kwargs and key in config_dict:
|
510 |
+
config_dict[key] = kwargs.pop(key)
|
511 |
+
|
512 |
+
if key in kwargs:
|
513 |
+
# overwrite key
|
514 |
+
init_dict[key] = kwargs.pop(key)
|
515 |
+
elif key in config_dict:
|
516 |
+
# use value from config dict
|
517 |
+
init_dict[key] = config_dict.pop(key)
|
518 |
+
|
519 |
+
# 4. Give nice warning if unexpected values have been passed
|
520 |
+
if len(config_dict) > 0:
|
521 |
+
logger.warning(
|
522 |
+
f"The config attributes {config_dict} were passed to {cls.__name__}, "
|
523 |
+
"but are not expected and will be ignored. Please verify your "
|
524 |
+
f"{cls.config_name} configuration file."
|
525 |
+
)
|
526 |
+
|
527 |
+
# 5. Give nice info if config attributes are initiliazed to default because they have not been passed
|
528 |
+
passed_keys = set(init_dict.keys())
|
529 |
+
if len(expected_keys - passed_keys) > 0:
|
530 |
+
logger.info(
|
531 |
+
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
|
532 |
+
)
|
533 |
+
|
534 |
+
# 6. Define unused keyword arguments
|
535 |
+
unused_kwargs = {**config_dict, **kwargs}
|
536 |
+
|
537 |
+
# 7. Define "hidden" config parameters that were saved for compatible classes
|
538 |
+
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
|
539 |
+
|
540 |
+
return init_dict, unused_kwargs, hidden_config_dict
|
541 |
+
|
542 |
+
@classmethod
|
543 |
+
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
544 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
545 |
+
text = reader.read()
|
546 |
+
return json.loads(text)
|
547 |
+
|
548 |
+
def __repr__(self):
|
549 |
+
return f"{self.__class__.__name__} {self.to_json_string()}"
|
550 |
+
|
551 |
+
@property
|
552 |
+
def config(self) -> Dict[str, Any]:
|
553 |
+
"""
|
554 |
+
Returns the config of the class as a frozen dictionary
|
555 |
+
|
556 |
+
Returns:
|
557 |
+
`Dict[str, Any]`: Config of the class.
|
558 |
+
"""
|
559 |
+
return self._internal_dict
|
560 |
+
|
561 |
+
def to_json_string(self) -> str:
|
562 |
+
"""
|
563 |
+
Serializes the configuration instance to a JSON string.
|
564 |
+
|
565 |
+
Returns:
|
566 |
+
`str`:
|
567 |
+
String containing all the attributes that make up the configuration instance in JSON format.
|
568 |
+
"""
|
569 |
+
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
|
570 |
+
config_dict["_class_name"] = self.__class__.__name__
|
571 |
+
config_dict["_diffusers_version"] = __version__
|
572 |
+
|
573 |
+
def to_json_saveable(value):
|
574 |
+
if isinstance(value, np.ndarray):
|
575 |
+
value = value.tolist()
|
576 |
+
elif isinstance(value, PosixPath):
|
577 |
+
value = str(value)
|
578 |
+
return value
|
579 |
+
|
580 |
+
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
581 |
+
# Don't save "_ignore_files" or "_use_default_values"
|
582 |
+
config_dict.pop("_ignore_files", None)
|
583 |
+
config_dict.pop("_use_default_values", None)
|
584 |
+
|
585 |
+
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
586 |
+
|
587 |
+
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
588 |
+
"""
|
589 |
+
Save the configuration instance's parameters to a JSON file.
|
590 |
+
|
591 |
+
Args:
|
592 |
+
json_file_path (`str` or `os.PathLike`):
|
593 |
+
Path to the JSON file to save a configuration instance's parameters.
|
594 |
+
"""
|
595 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
596 |
+
writer.write(self.to_json_string())
|
597 |
+
|
598 |
+
|
599 |
+
def register_to_config(init):
|
600 |
+
r"""
|
601 |
+
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
|
602 |
+
automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
|
603 |
+
shouldn't be registered in the config, use the `ignore_for_config` class variable
|
604 |
+
|
605 |
+
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
|
606 |
+
"""
|
607 |
+
|
608 |
+
@functools.wraps(init)
|
609 |
+
def inner_init(self, *args, **kwargs):
|
610 |
+
# Ignore private kwargs in the init.
|
611 |
+
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
|
612 |
+
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
|
613 |
+
if not isinstance(self, ConfigMixin):
|
614 |
+
raise RuntimeError(
|
615 |
+
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
616 |
+
"not inherit from `ConfigMixin`."
|
617 |
+
)
|
618 |
+
|
619 |
+
ignore = getattr(self, "ignore_for_config", [])
|
620 |
+
# Get positional arguments aligned with kwargs
|
621 |
+
new_kwargs = {}
|
622 |
+
signature = inspect.signature(init)
|
623 |
+
parameters = {
|
624 |
+
name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
|
625 |
+
}
|
626 |
+
for arg, name in zip(args, parameters.keys()):
|
627 |
+
new_kwargs[name] = arg
|
628 |
+
|
629 |
+
# Then add all kwargs
|
630 |
+
new_kwargs.update(
|
631 |
+
{
|
632 |
+
k: init_kwargs.get(k, default)
|
633 |
+
for k, default in parameters.items()
|
634 |
+
if k not in ignore and k not in new_kwargs
|
635 |
+
}
|
636 |
+
)
|
637 |
+
|
638 |
+
# Take note of the parameters that were not present in the loaded config
|
639 |
+
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
|
640 |
+
new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
|
641 |
+
|
642 |
+
new_kwargs = {**config_init_kwargs, **new_kwargs}
|
643 |
+
getattr(self, "register_to_config")(**new_kwargs)
|
644 |
+
init(self, *args, **init_kwargs)
|
645 |
+
|
646 |
+
return inner_init
|
647 |
+
|
648 |
+
|
649 |
+
def flax_register_to_config(cls):
|
650 |
+
original_init = cls.__init__
|
651 |
+
|
652 |
+
@functools.wraps(original_init)
|
653 |
+
def init(self, *args, **kwargs):
|
654 |
+
if not isinstance(self, ConfigMixin):
|
655 |
+
raise RuntimeError(
|
656 |
+
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
657 |
+
"not inherit from `ConfigMixin`."
|
658 |
+
)
|
659 |
+
|
660 |
+
# Ignore private kwargs in the init. Retrieve all passed attributes
|
661 |
+
init_kwargs = dict(kwargs.items())
|
662 |
+
|
663 |
+
# Retrieve default values
|
664 |
+
fields = dataclasses.fields(self)
|
665 |
+
default_kwargs = {}
|
666 |
+
for field in fields:
|
667 |
+
# ignore flax specific attributes
|
668 |
+
if field.name in self._flax_internal_args:
|
669 |
+
continue
|
670 |
+
if type(field.default) == dataclasses._MISSING_TYPE:
|
671 |
+
default_kwargs[field.name] = None
|
672 |
+
else:
|
673 |
+
default_kwargs[field.name] = getattr(self, field.name)
|
674 |
+
|
675 |
+
# Make sure init_kwargs override default kwargs
|
676 |
+
new_kwargs = {**default_kwargs, **init_kwargs}
|
677 |
+
# dtype should be part of `init_kwargs`, but not `new_kwargs`
|
678 |
+
if "dtype" in new_kwargs:
|
679 |
+
new_kwargs.pop("dtype")
|
680 |
+
|
681 |
+
# Get positional arguments aligned with kwargs
|
682 |
+
for i, arg in enumerate(args):
|
683 |
+
name = fields[i].name
|
684 |
+
new_kwargs[name] = arg
|
685 |
+
|
686 |
+
# Take note of the parameters that were not present in the loaded config
|
687 |
+
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
|
688 |
+
new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
|
689 |
+
|
690 |
+
getattr(self, "register_to_config")(**new_kwargs)
|
691 |
+
original_init(self, *args, **kwargs)
|
692 |
+
|
693 |
+
cls.__init__ = init
|
694 |
+
return cls
|
diffusers/dependency_versions_check.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import sys
|
15 |
+
|
16 |
+
from .dependency_versions_table import deps
|
17 |
+
from .utils.versions import require_version, require_version_core
|
18 |
+
|
19 |
+
|
20 |
+
# define which module versions we always want to check at run time
|
21 |
+
# (usually the ones defined in `install_requires` in setup.py)
|
22 |
+
#
|
23 |
+
# order specific notes:
|
24 |
+
# - tqdm must be checked before tokenizers
|
25 |
+
|
26 |
+
pkgs_to_check_at_runtime = "python requests filelock numpy".split()
|
27 |
+
for pkg in pkgs_to_check_at_runtime:
|
28 |
+
if pkg in deps:
|
29 |
+
require_version_core(deps[pkg])
|
30 |
+
else:
|
31 |
+
raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
|
32 |
+
|
33 |
+
|
34 |
+
def dep_version_check(pkg, hint=None):
|
35 |
+
require_version(deps[pkg], hint)
|
diffusers/dependency_versions_table.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# THIS FILE HAS BEEN AUTOGENERATED. To update:
|
2 |
+
# 1. modify the `_deps` dict in setup.py
|
3 |
+
# 2. run `make deps_table_update``
|
4 |
+
deps = {
|
5 |
+
"Pillow": "Pillow",
|
6 |
+
"accelerate": "accelerate>=0.11.0",
|
7 |
+
"compel": "compel==0.1.8",
|
8 |
+
"black": "black~=23.1",
|
9 |
+
"datasets": "datasets",
|
10 |
+
"filelock": "filelock",
|
11 |
+
"flax": "flax>=0.4.1",
|
12 |
+
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
13 |
+
"huggingface-hub": "huggingface-hub>=0.13.2",
|
14 |
+
"requests-mock": "requests-mock==1.10.0",
|
15 |
+
"importlib_metadata": "importlib_metadata",
|
16 |
+
"invisible-watermark": "invisible-watermark>=0.2.0",
|
17 |
+
"isort": "isort>=5.5.4",
|
18 |
+
"jax": "jax>=0.4.1",
|
19 |
+
"jaxlib": "jaxlib>=0.4.1",
|
20 |
+
"Jinja2": "Jinja2",
|
21 |
+
"k-diffusion": "k-diffusion>=0.0.12",
|
22 |
+
"torchsde": "torchsde",
|
23 |
+
"note_seq": "note_seq",
|
24 |
+
"librosa": "librosa",
|
25 |
+
"numpy": "numpy",
|
26 |
+
"omegaconf": "omegaconf",
|
27 |
+
"parameterized": "parameterized",
|
28 |
+
"peft": "peft<=0.6.2",
|
29 |
+
"protobuf": "protobuf>=3.20.3,<4",
|
30 |
+
"pytest": "pytest",
|
31 |
+
"pytest-timeout": "pytest-timeout",
|
32 |
+
"pytest-xdist": "pytest-xdist",
|
33 |
+
"python": "python>=3.8.0",
|
34 |
+
"ruff": "ruff==0.0.280",
|
35 |
+
"safetensors": "safetensors>=0.3.1",
|
36 |
+
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
37 |
+
"scipy": "scipy",
|
38 |
+
"onnx": "onnx",
|
39 |
+
"regex": "regex!=2019.12.17",
|
40 |
+
"requests": "requests",
|
41 |
+
"tensorboard": "tensorboard",
|
42 |
+
"torch": "torch>=1.4",
|
43 |
+
"torchvision": "torchvision",
|
44 |
+
"transformers": "transformers>=4.25.1",
|
45 |
+
"urllib3": "urllib3<=2.0.0",
|
46 |
+
}
|
diffusers/experimental/README.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🧨 Diffusers Experimental
|
2 |
+
|
3 |
+
We are adding experimental code to support novel applications and usages of the Diffusers library.
|
4 |
+
Currently, the following experiments are supported:
|
5 |
+
* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model.
|
diffusers/experimental/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .rl import ValueGuidedRLPipeline
|
diffusers/experimental/rl/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .value_guided_sampling import ValueGuidedRLPipeline
|
diffusers/experimental/rl/value_guided_sampling.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import tqdm
|
18 |
+
|
19 |
+
from ...models.unet_1d import UNet1DModel
|
20 |
+
from ...pipelines import DiffusionPipeline
|
21 |
+
from ...utils.dummy_pt_objects import DDPMScheduler
|
22 |
+
from ...utils.torch_utils import randn_tensor
|
23 |
+
|
24 |
+
|
25 |
+
class ValueGuidedRLPipeline(DiffusionPipeline):
|
26 |
+
r"""
|
27 |
+
Pipeline for value-guided sampling from a diffusion model trained to predict sequences of states.
|
28 |
+
|
29 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
30 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
31 |
+
|
32 |
+
Parameters:
|
33 |
+
value_function ([`UNet1DModel`]):
|
34 |
+
A specialized UNet for fine-tuning trajectories base on reward.
|
35 |
+
unet ([`UNet1DModel`]):
|
36 |
+
UNet architecture to denoise the encoded trajectories.
|
37 |
+
scheduler ([`SchedulerMixin`]):
|
38 |
+
A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
|
39 |
+
application is [`DDPMScheduler`].
|
40 |
+
env ():
|
41 |
+
An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
value_function: UNet1DModel,
|
47 |
+
unet: UNet1DModel,
|
48 |
+
scheduler: DDPMScheduler,
|
49 |
+
env,
|
50 |
+
):
|
51 |
+
super().__init__()
|
52 |
+
self.value_function = value_function
|
53 |
+
self.unet = unet
|
54 |
+
self.scheduler = scheduler
|
55 |
+
self.env = env
|
56 |
+
self.data = env.get_dataset()
|
57 |
+
self.means = {}
|
58 |
+
for key in self.data.keys():
|
59 |
+
try:
|
60 |
+
self.means[key] = self.data[key].mean()
|
61 |
+
except: # noqa: E722
|
62 |
+
pass
|
63 |
+
self.stds = {}
|
64 |
+
for key in self.data.keys():
|
65 |
+
try:
|
66 |
+
self.stds[key] = self.data[key].std()
|
67 |
+
except: # noqa: E722
|
68 |
+
pass
|
69 |
+
self.state_dim = env.observation_space.shape[0]
|
70 |
+
self.action_dim = env.action_space.shape[0]
|
71 |
+
|
72 |
+
def normalize(self, x_in, key):
|
73 |
+
return (x_in - self.means[key]) / self.stds[key]
|
74 |
+
|
75 |
+
def de_normalize(self, x_in, key):
|
76 |
+
return x_in * self.stds[key] + self.means[key]
|
77 |
+
|
78 |
+
def to_torch(self, x_in):
|
79 |
+
if isinstance(x_in, dict):
|
80 |
+
return {k: self.to_torch(v) for k, v in x_in.items()}
|
81 |
+
elif torch.is_tensor(x_in):
|
82 |
+
return x_in.to(self.unet.device)
|
83 |
+
return torch.tensor(x_in, device=self.unet.device)
|
84 |
+
|
85 |
+
def reset_x0(self, x_in, cond, act_dim):
|
86 |
+
for key, val in cond.items():
|
87 |
+
x_in[:, key, act_dim:] = val.clone()
|
88 |
+
return x_in
|
89 |
+
|
90 |
+
def run_diffusion(self, x, conditions, n_guide_steps, scale):
|
91 |
+
batch_size = x.shape[0]
|
92 |
+
y = None
|
93 |
+
for i in tqdm.tqdm(self.scheduler.timesteps):
|
94 |
+
# create batch of timesteps to pass into model
|
95 |
+
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
|
96 |
+
for _ in range(n_guide_steps):
|
97 |
+
with torch.enable_grad():
|
98 |
+
x.requires_grad_()
|
99 |
+
|
100 |
+
# permute to match dimension for pre-trained models
|
101 |
+
y = self.value_function(x.permute(0, 2, 1), timesteps).sample
|
102 |
+
grad = torch.autograd.grad([y.sum()], [x])[0]
|
103 |
+
|
104 |
+
posterior_variance = self.scheduler._get_variance(i)
|
105 |
+
model_std = torch.exp(0.5 * posterior_variance)
|
106 |
+
grad = model_std * grad
|
107 |
+
|
108 |
+
grad[timesteps < 2] = 0
|
109 |
+
x = x.detach()
|
110 |
+
x = x + scale * grad
|
111 |
+
x = self.reset_x0(x, conditions, self.action_dim)
|
112 |
+
|
113 |
+
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
|
114 |
+
|
115 |
+
# TODO: verify deprecation of this kwarg
|
116 |
+
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
|
117 |
+
|
118 |
+
# apply conditions to the trajectory (set the initial state)
|
119 |
+
x = self.reset_x0(x, conditions, self.action_dim)
|
120 |
+
x = self.to_torch(x)
|
121 |
+
return x, y
|
122 |
+
|
123 |
+
def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
|
124 |
+
# normalize the observations and create batch dimension
|
125 |
+
obs = self.normalize(obs, "observations")
|
126 |
+
obs = obs[None].repeat(batch_size, axis=0)
|
127 |
+
|
128 |
+
conditions = {0: self.to_torch(obs)}
|
129 |
+
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
|
130 |
+
|
131 |
+
# generate initial noise and apply our conditions (to make the trajectories start at current state)
|
132 |
+
x1 = randn_tensor(shape, device=self.unet.device)
|
133 |
+
x = self.reset_x0(x1, conditions, self.action_dim)
|
134 |
+
x = self.to_torch(x)
|
135 |
+
|
136 |
+
# run the diffusion process
|
137 |
+
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
|
138 |
+
|
139 |
+
# sort output trajectories by value
|
140 |
+
sorted_idx = y.argsort(0, descending=True).squeeze()
|
141 |
+
sorted_values = x[sorted_idx]
|
142 |
+
actions = sorted_values[:, :, : self.action_dim]
|
143 |
+
actions = actions.detach().cpu().numpy()
|
144 |
+
denorm_actions = self.de_normalize(actions, key="actions")
|
145 |
+
|
146 |
+
# select the action with the highest value
|
147 |
+
if y is not None:
|
148 |
+
selected_index = 0
|
149 |
+
else:
|
150 |
+
# if we didn't run value guiding, select a random action
|
151 |
+
selected_index = np.random.randint(0, batch_size)
|
152 |
+
|
153 |
+
denorm_actions = denorm_actions[selected_index, 0]
|
154 |
+
return denorm_actions
|
diffusers/image_processor.py
ADDED
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import warnings
|
16 |
+
from typing import List, Optional, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import PIL.Image
|
20 |
+
import torch
|
21 |
+
from PIL import Image
|
22 |
+
|
23 |
+
from .configuration_utils import ConfigMixin, register_to_config
|
24 |
+
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
25 |
+
|
26 |
+
|
27 |
+
PipelineImageInput = Union[
|
28 |
+
PIL.Image.Image,
|
29 |
+
np.ndarray,
|
30 |
+
torch.FloatTensor,
|
31 |
+
List[PIL.Image.Image],
|
32 |
+
List[np.ndarray],
|
33 |
+
List[torch.FloatTensor],
|
34 |
+
]
|
35 |
+
|
36 |
+
|
37 |
+
class VaeImageProcessor(ConfigMixin):
|
38 |
+
"""
|
39 |
+
Image processor for VAE.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
43 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
44 |
+
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
45 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
46 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
47 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
48 |
+
Resampling filter to use when resizing the image.
|
49 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
50 |
+
Whether to normalize the image to [-1,1].
|
51 |
+
do_binarize (`bool`, *optional*, defaults to `False`):
|
52 |
+
Whether to binarize the image to 0/1.
|
53 |
+
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
|
54 |
+
Whether to convert the images to RGB format.
|
55 |
+
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
|
56 |
+
Whether to convert the images to grayscale format.
|
57 |
+
"""
|
58 |
+
|
59 |
+
config_name = CONFIG_NAME
|
60 |
+
|
61 |
+
@register_to_config
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
do_resize: bool = True,
|
65 |
+
vae_scale_factor: int = 8,
|
66 |
+
resample: str = "lanczos",
|
67 |
+
do_normalize: bool = True,
|
68 |
+
do_binarize: bool = False,
|
69 |
+
do_convert_rgb: bool = False,
|
70 |
+
do_convert_grayscale: bool = False,
|
71 |
+
):
|
72 |
+
super().__init__()
|
73 |
+
if do_convert_rgb and do_convert_grayscale:
|
74 |
+
raise ValueError(
|
75 |
+
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
|
76 |
+
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
|
77 |
+
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
|
78 |
+
)
|
79 |
+
self.config.do_convert_rgb = False
|
80 |
+
|
81 |
+
@staticmethod
|
82 |
+
def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
|
83 |
+
"""
|
84 |
+
Convert a numpy image or a batch of images to a PIL image.
|
85 |
+
"""
|
86 |
+
if images.ndim == 3:
|
87 |
+
images = images[None, ...]
|
88 |
+
images = (images * 255).round().astype("uint8")
|
89 |
+
if images.shape[-1] == 1:
|
90 |
+
# special case for grayscale (single channel) images
|
91 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
92 |
+
else:
|
93 |
+
pil_images = [Image.fromarray(image) for image in images]
|
94 |
+
|
95 |
+
return pil_images
|
96 |
+
|
97 |
+
@staticmethod
|
98 |
+
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
99 |
+
"""
|
100 |
+
Convert a PIL image or a list of PIL images to NumPy arrays.
|
101 |
+
"""
|
102 |
+
if not isinstance(images, list):
|
103 |
+
images = [images]
|
104 |
+
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
|
105 |
+
images = np.stack(images, axis=0)
|
106 |
+
|
107 |
+
return images
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
|
111 |
+
"""
|
112 |
+
Convert a NumPy image to a PyTorch tensor.
|
113 |
+
"""
|
114 |
+
if images.ndim == 3:
|
115 |
+
images = images[..., None]
|
116 |
+
|
117 |
+
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
|
118 |
+
return images
|
119 |
+
|
120 |
+
@staticmethod
|
121 |
+
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
|
122 |
+
"""
|
123 |
+
Convert a PyTorch tensor to a NumPy image.
|
124 |
+
"""
|
125 |
+
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
126 |
+
return images
|
127 |
+
|
128 |
+
@staticmethod
|
129 |
+
def normalize(images):
|
130 |
+
"""
|
131 |
+
Normalize an image array to [-1,1].
|
132 |
+
"""
|
133 |
+
return 2.0 * images - 1.0
|
134 |
+
|
135 |
+
@staticmethod
|
136 |
+
def denormalize(images):
|
137 |
+
"""
|
138 |
+
Denormalize an image array to [0,1].
|
139 |
+
"""
|
140 |
+
return (images / 2 + 0.5).clamp(0, 1)
|
141 |
+
|
142 |
+
@staticmethod
|
143 |
+
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
144 |
+
"""
|
145 |
+
Converts a PIL image to RGB format.
|
146 |
+
"""
|
147 |
+
image = image.convert("RGB")
|
148 |
+
|
149 |
+
return image
|
150 |
+
|
151 |
+
@staticmethod
|
152 |
+
def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
|
153 |
+
"""
|
154 |
+
Converts a PIL image to grayscale format.
|
155 |
+
"""
|
156 |
+
image = image.convert("L")
|
157 |
+
|
158 |
+
return image
|
159 |
+
|
160 |
+
def get_default_height_width(
|
161 |
+
self,
|
162 |
+
image: [PIL.Image.Image, np.ndarray, torch.Tensor],
|
163 |
+
height: Optional[int] = None,
|
164 |
+
width: Optional[int] = None,
|
165 |
+
):
|
166 |
+
"""
|
167 |
+
This function return the height and width that are downscaled to the next integer multiple of
|
168 |
+
`vae_scale_factor`.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
172 |
+
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
|
173 |
+
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
|
174 |
+
have shape `[batch, channel, height, width]`.
|
175 |
+
height (`int`, *optional*, defaults to `None`):
|
176 |
+
The height in preprocessed image. If `None`, will use the height of `image` input.
|
177 |
+
width (`int`, *optional*`, defaults to `None`):
|
178 |
+
The width in preprocessed. If `None`, will use the width of the `image` input.
|
179 |
+
"""
|
180 |
+
|
181 |
+
if height is None:
|
182 |
+
if isinstance(image, PIL.Image.Image):
|
183 |
+
height = image.height
|
184 |
+
elif isinstance(image, torch.Tensor):
|
185 |
+
height = image.shape[2]
|
186 |
+
else:
|
187 |
+
height = image.shape[1]
|
188 |
+
|
189 |
+
if width is None:
|
190 |
+
if isinstance(image, PIL.Image.Image):
|
191 |
+
width = image.width
|
192 |
+
elif isinstance(image, torch.Tensor):
|
193 |
+
width = image.shape[3]
|
194 |
+
else:
|
195 |
+
width = image.shape[2]
|
196 |
+
|
197 |
+
width, height = (
|
198 |
+
x - x % self.config.vae_scale_factor for x in (width, height)
|
199 |
+
) # resize to integer multiple of vae_scale_factor
|
200 |
+
|
201 |
+
return height, width
|
202 |
+
|
203 |
+
def resize(
|
204 |
+
self,
|
205 |
+
image: [PIL.Image.Image, np.ndarray, torch.Tensor],
|
206 |
+
height: Optional[int] = None,
|
207 |
+
width: Optional[int] = None,
|
208 |
+
) -> [PIL.Image.Image, np.ndarray, torch.Tensor]:
|
209 |
+
"""
|
210 |
+
Resize image.
|
211 |
+
"""
|
212 |
+
if isinstance(image, PIL.Image.Image):
|
213 |
+
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
|
214 |
+
elif isinstance(image, torch.Tensor):
|
215 |
+
image = torch.nn.functional.interpolate(
|
216 |
+
image,
|
217 |
+
size=(height, width),
|
218 |
+
)
|
219 |
+
elif isinstance(image, np.ndarray):
|
220 |
+
image = self.numpy_to_pt(image)
|
221 |
+
image = torch.nn.functional.interpolate(
|
222 |
+
image,
|
223 |
+
size=(height, width),
|
224 |
+
)
|
225 |
+
image = self.pt_to_numpy(image)
|
226 |
+
return image
|
227 |
+
|
228 |
+
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
229 |
+
"""
|
230 |
+
create a mask
|
231 |
+
"""
|
232 |
+
image[image < 0.5] = 0
|
233 |
+
image[image >= 0.5] = 1
|
234 |
+
return image
|
235 |
+
|
236 |
+
def preprocess(
|
237 |
+
self,
|
238 |
+
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
239 |
+
height: Optional[int] = None,
|
240 |
+
width: Optional[int] = None,
|
241 |
+
) -> torch.Tensor:
|
242 |
+
"""
|
243 |
+
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
|
244 |
+
"""
|
245 |
+
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
246 |
+
|
247 |
+
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
|
248 |
+
if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
|
249 |
+
if isinstance(image, torch.Tensor):
|
250 |
+
# if image is a pytorch tensor could have 2 possible shapes:
|
251 |
+
# 1. batch x height x width: we should insert the channel dimension at position 1
|
252 |
+
# 2. channnel x height x width: we should insert batch dimension at position 0,
|
253 |
+
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1
|
254 |
+
# for simplicity, we insert a dimension of size 1 at position 1 for both cases
|
255 |
+
image = image.unsqueeze(1)
|
256 |
+
else:
|
257 |
+
# if it is a numpy array, it could have 2 possible shapes:
|
258 |
+
# 1. batch x height x width: insert channel dimension on last position
|
259 |
+
# 2. height x width x channel: insert batch dimension on first position
|
260 |
+
if image.shape[-1] == 1:
|
261 |
+
image = np.expand_dims(image, axis=0)
|
262 |
+
else:
|
263 |
+
image = np.expand_dims(image, axis=-1)
|
264 |
+
|
265 |
+
if isinstance(image, supported_formats):
|
266 |
+
image = [image]
|
267 |
+
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
|
268 |
+
raise ValueError(
|
269 |
+
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
|
270 |
+
)
|
271 |
+
|
272 |
+
if isinstance(image[0], PIL.Image.Image):
|
273 |
+
if self.config.do_convert_rgb:
|
274 |
+
image = [self.convert_to_rgb(i) for i in image]
|
275 |
+
elif self.config.do_convert_grayscale:
|
276 |
+
image = [self.convert_to_grayscale(i) for i in image]
|
277 |
+
if self.config.do_resize:
|
278 |
+
height, width = self.get_default_height_width(image[0], height, width)
|
279 |
+
image = [self.resize(i, height, width) for i in image]
|
280 |
+
image = self.pil_to_numpy(image) # to np
|
281 |
+
image = self.numpy_to_pt(image) # to pt
|
282 |
+
|
283 |
+
elif isinstance(image[0], np.ndarray):
|
284 |
+
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
285 |
+
|
286 |
+
image = self.numpy_to_pt(image)
|
287 |
+
|
288 |
+
height, width = self.get_default_height_width(image, height, width)
|
289 |
+
if self.config.do_resize:
|
290 |
+
image = self.resize(image, height, width)
|
291 |
+
|
292 |
+
elif isinstance(image[0], torch.Tensor):
|
293 |
+
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
294 |
+
|
295 |
+
if self.config.do_convert_grayscale and image.ndim == 3:
|
296 |
+
image = image.unsqueeze(1)
|
297 |
+
|
298 |
+
channel = image.shape[1]
|
299 |
+
# don't need any preprocess if the image is latents
|
300 |
+
if channel == 4:
|
301 |
+
return image
|
302 |
+
|
303 |
+
height, width = self.get_default_height_width(image, height, width)
|
304 |
+
if self.config.do_resize:
|
305 |
+
image = self.resize(image, height, width)
|
306 |
+
|
307 |
+
# expected range [0,1], normalize to [-1,1]
|
308 |
+
do_normalize = self.config.do_normalize
|
309 |
+
if image.min() < 0 and do_normalize:
|
310 |
+
warnings.warn(
|
311 |
+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
312 |
+
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
313 |
+
FutureWarning,
|
314 |
+
)
|
315 |
+
do_normalize = False
|
316 |
+
|
317 |
+
if do_normalize:
|
318 |
+
image = self.normalize(image)
|
319 |
+
|
320 |
+
if self.config.do_binarize:
|
321 |
+
image = self.binarize(image)
|
322 |
+
|
323 |
+
return image
|
324 |
+
|
325 |
+
def postprocess(
|
326 |
+
self,
|
327 |
+
image: torch.FloatTensor,
|
328 |
+
output_type: str = "pil",
|
329 |
+
do_denormalize: Optional[List[bool]] = None,
|
330 |
+
):
|
331 |
+
if not isinstance(image, torch.Tensor):
|
332 |
+
raise ValueError(
|
333 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
334 |
+
)
|
335 |
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
336 |
+
deprecation_message = (
|
337 |
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
338 |
+
"`pil`, `np`, `pt`, `latent`"
|
339 |
+
)
|
340 |
+
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
341 |
+
output_type = "np"
|
342 |
+
|
343 |
+
if output_type == "latent":
|
344 |
+
return image
|
345 |
+
|
346 |
+
if do_denormalize is None:
|
347 |
+
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
348 |
+
|
349 |
+
image = torch.stack(
|
350 |
+
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
|
351 |
+
)
|
352 |
+
|
353 |
+
if output_type == "pt":
|
354 |
+
return image
|
355 |
+
|
356 |
+
image = self.pt_to_numpy(image)
|
357 |
+
|
358 |
+
if output_type == "np":
|
359 |
+
return image
|
360 |
+
|
361 |
+
if output_type == "pil":
|
362 |
+
return self.numpy_to_pil(image)
|
363 |
+
|
364 |
+
|
365 |
+
class VaeImageProcessorLDM3D(VaeImageProcessor):
|
366 |
+
"""
|
367 |
+
Image processor for VAE LDM3D.
|
368 |
+
|
369 |
+
Args:
|
370 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
371 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
|
372 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
373 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
374 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
375 |
+
Resampling filter to use when resizing the image.
|
376 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
377 |
+
Whether to normalize the image to [-1,1].
|
378 |
+
"""
|
379 |
+
|
380 |
+
config_name = CONFIG_NAME
|
381 |
+
|
382 |
+
@register_to_config
|
383 |
+
def __init__(
|
384 |
+
self,
|
385 |
+
do_resize: bool = True,
|
386 |
+
vae_scale_factor: int = 8,
|
387 |
+
resample: str = "lanczos",
|
388 |
+
do_normalize: bool = True,
|
389 |
+
):
|
390 |
+
super().__init__()
|
391 |
+
|
392 |
+
@staticmethod
|
393 |
+
def numpy_to_pil(images):
|
394 |
+
"""
|
395 |
+
Convert a NumPy image or a batch of images to a PIL image.
|
396 |
+
"""
|
397 |
+
if images.ndim == 3:
|
398 |
+
images = images[None, ...]
|
399 |
+
images = (images * 255).round().astype("uint8")
|
400 |
+
if images.shape[-1] == 1:
|
401 |
+
# special case for grayscale (single channel) images
|
402 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
403 |
+
else:
|
404 |
+
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
|
405 |
+
|
406 |
+
return pil_images
|
407 |
+
|
408 |
+
@staticmethod
|
409 |
+
def rgblike_to_depthmap(image):
|
410 |
+
"""
|
411 |
+
Args:
|
412 |
+
image: RGB-like depth image
|
413 |
+
|
414 |
+
Returns: depth map
|
415 |
+
|
416 |
+
"""
|
417 |
+
return image[:, :, 1] * 2**8 + image[:, :, 2]
|
418 |
+
|
419 |
+
def numpy_to_depth(self, images):
|
420 |
+
"""
|
421 |
+
Convert a NumPy depth image or a batch of images to a PIL image.
|
422 |
+
"""
|
423 |
+
if images.ndim == 3:
|
424 |
+
images = images[None, ...]
|
425 |
+
images_depth = images[:, :, :, 3:]
|
426 |
+
if images.shape[-1] == 6:
|
427 |
+
images_depth = (images_depth * 255).round().astype("uint8")
|
428 |
+
pil_images = [
|
429 |
+
Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
|
430 |
+
]
|
431 |
+
elif images.shape[-1] == 4:
|
432 |
+
images_depth = (images_depth * 65535.0).astype(np.uint16)
|
433 |
+
pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
|
434 |
+
else:
|
435 |
+
raise Exception("Not supported")
|
436 |
+
|
437 |
+
return pil_images
|
438 |
+
|
439 |
+
def postprocess(
|
440 |
+
self,
|
441 |
+
image: torch.FloatTensor,
|
442 |
+
output_type: str = "pil",
|
443 |
+
do_denormalize: Optional[List[bool]] = None,
|
444 |
+
):
|
445 |
+
if not isinstance(image, torch.Tensor):
|
446 |
+
raise ValueError(
|
447 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
448 |
+
)
|
449 |
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
450 |
+
deprecation_message = (
|
451 |
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
452 |
+
"`pil`, `np`, `pt`, `latent`"
|
453 |
+
)
|
454 |
+
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
455 |
+
output_type = "np"
|
456 |
+
|
457 |
+
if do_denormalize is None:
|
458 |
+
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
459 |
+
|
460 |
+
image = torch.stack(
|
461 |
+
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
|
462 |
+
)
|
463 |
+
|
464 |
+
image = self.pt_to_numpy(image)
|
465 |
+
|
466 |
+
if output_type == "np":
|
467 |
+
if image.shape[-1] == 6:
|
468 |
+
image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
|
469 |
+
else:
|
470 |
+
image_depth = image[:, :, :, 3:]
|
471 |
+
return image[:, :, :, :3], image_depth
|
472 |
+
|
473 |
+
if output_type == "pil":
|
474 |
+
return self.numpy_to_pil(image), self.numpy_to_depth(image)
|
475 |
+
else:
|
476 |
+
raise Exception(f"This type {output_type} is not supported")
|
diffusers/loaders.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
diffusers/models/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Models
|
2 |
+
|
3 |
+
For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models/overview).
|
diffusers/models/__init__.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import TYPE_CHECKING
|
16 |
+
|
17 |
+
from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
|
18 |
+
|
19 |
+
|
20 |
+
_import_structure = {}
|
21 |
+
|
22 |
+
if is_torch_available():
|
23 |
+
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
24 |
+
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
25 |
+
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
|
26 |
+
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
|
27 |
+
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
28 |
+
_import_structure["controlnet"] = ["ControlNetModel"]
|
29 |
+
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
30 |
+
_import_structure["modeling_utils"] = ["ModelMixin"]
|
31 |
+
_import_structure["prior_transformer"] = ["PriorTransformer"]
|
32 |
+
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
|
33 |
+
_import_structure["transformer_2d"] = ["Transformer2DModel"]
|
34 |
+
_import_structure["transformer_temporal"] = ["TransformerTemporalModel"]
|
35 |
+
_import_structure["unet_1d"] = ["UNet1DModel"]
|
36 |
+
_import_structure["unet_2d"] = ["UNet2DModel"]
|
37 |
+
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
|
38 |
+
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
|
39 |
+
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
|
40 |
+
_import_structure["vq_model"] = ["VQModel"]
|
41 |
+
|
42 |
+
if is_flax_available():
|
43 |
+
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
|
44 |
+
_import_structure["unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
45 |
+
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
|
46 |
+
|
47 |
+
|
48 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
49 |
+
if is_torch_available():
|
50 |
+
from .adapter import MultiAdapter, T2IAdapter
|
51 |
+
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
52 |
+
from .autoencoder_kl import AutoencoderKL
|
53 |
+
from .autoencoder_tiny import AutoencoderTiny
|
54 |
+
from .consistency_decoder_vae import ConsistencyDecoderVAE
|
55 |
+
from .controlnet import ControlNetModel
|
56 |
+
from .dual_transformer_2d import DualTransformer2DModel
|
57 |
+
from .modeling_utils import ModelMixin
|
58 |
+
from .prior_transformer import PriorTransformer
|
59 |
+
from .t5_film_transformer import T5FilmDecoder
|
60 |
+
from .transformer_2d import Transformer2DModel
|
61 |
+
from .transformer_temporal import TransformerTemporalModel
|
62 |
+
from .unet_1d import UNet1DModel
|
63 |
+
from .unet_2d import UNet2DModel
|
64 |
+
from .unet_2d_condition import UNet2DConditionModel
|
65 |
+
from .unet_3d_condition import UNet3DConditionModel
|
66 |
+
from .unet_motion_model import MotionAdapter, UNetMotionModel
|
67 |
+
from .vq_model import VQModel
|
68 |
+
|
69 |
+
if is_flax_available():
|
70 |
+
from .controlnet_flax import FlaxControlNetModel
|
71 |
+
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
72 |
+
from .vae_flax import FlaxAutoencoderKL
|
73 |
+
|
74 |
+
else:
|
75 |
+
import sys
|
76 |
+
|
77 |
+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
diffusers/models/activations.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from ..utils import USE_PEFT_BACKEND
|
21 |
+
from .lora import LoRACompatibleLinear
|
22 |
+
|
23 |
+
|
24 |
+
ACTIVATION_FUNCTIONS = {
|
25 |
+
"swish": nn.SiLU(),
|
26 |
+
"silu": nn.SiLU(),
|
27 |
+
"mish": nn.Mish(),
|
28 |
+
"gelu": nn.GELU(),
|
29 |
+
"relu": nn.ReLU(),
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
def get_activation(act_fn: str) -> nn.Module:
|
34 |
+
"""Helper function to get activation function from string.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
act_fn (str): Name of activation function.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
nn.Module: Activation function.
|
41 |
+
"""
|
42 |
+
|
43 |
+
act_fn = act_fn.lower()
|
44 |
+
if act_fn in ACTIVATION_FUNCTIONS:
|
45 |
+
return ACTIVATION_FUNCTIONS[act_fn]
|
46 |
+
else:
|
47 |
+
raise ValueError(f"Unsupported activation function: {act_fn}")
|
48 |
+
|
49 |
+
|
50 |
+
class GELU(nn.Module):
|
51 |
+
r"""
|
52 |
+
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
53 |
+
|
54 |
+
Parameters:
|
55 |
+
dim_in (`int`): The number of channels in the input.
|
56 |
+
dim_out (`int`): The number of channels in the output.
|
57 |
+
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
|
61 |
+
super().__init__()
|
62 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
63 |
+
self.approximate = approximate
|
64 |
+
|
65 |
+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
66 |
+
if gate.device.type != "mps":
|
67 |
+
return F.gelu(gate, approximate=self.approximate)
|
68 |
+
# mps: gelu is not implemented for float16
|
69 |
+
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
|
70 |
+
|
71 |
+
def forward(self, hidden_states):
|
72 |
+
hidden_states = self.proj(hidden_states)
|
73 |
+
hidden_states = self.gelu(hidden_states)
|
74 |
+
return hidden_states
|
75 |
+
|
76 |
+
|
77 |
+
class GEGLU(nn.Module):
|
78 |
+
r"""
|
79 |
+
A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.
|
80 |
+
|
81 |
+
Parameters:
|
82 |
+
dim_in (`int`): The number of channels in the input.
|
83 |
+
dim_out (`int`): The number of channels in the output.
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(self, dim_in: int, dim_out: int):
|
87 |
+
super().__init__()
|
88 |
+
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
|
89 |
+
|
90 |
+
self.proj = linear_cls(dim_in, dim_out * 2)
|
91 |
+
|
92 |
+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
93 |
+
if gate.device.type != "mps":
|
94 |
+
return F.gelu(gate)
|
95 |
+
# mps: gelu is not implemented for float16
|
96 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
97 |
+
|
98 |
+
def forward(self, hidden_states, scale: float = 1.0):
|
99 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
100 |
+
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
|
101 |
+
return hidden_states * self.gelu(gate)
|
102 |
+
|
103 |
+
|
104 |
+
class ApproximateGELU(nn.Module):
|
105 |
+
r"""
|
106 |
+
The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
|
107 |
+
[paper](https://arxiv.org/abs/1606.08415).
|
108 |
+
|
109 |
+
Parameters:
|
110 |
+
dim_in (`int`): The number of channels in the input.
|
111 |
+
dim_out (`int`): The number of channels in the output.
|
112 |
+
"""
|
113 |
+
|
114 |
+
def __init__(self, dim_in: int, dim_out: int):
|
115 |
+
super().__init__()
|
116 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
117 |
+
|
118 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
119 |
+
x = self.proj(x)
|
120 |
+
return x * torch.sigmoid(1.702 * x)
|
diffusers/models/adapter.py
ADDED
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
from typing import Callable, List, Optional, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
21 |
+
from ..utils import logging
|
22 |
+
from .modeling_utils import ModelMixin
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
class MultiAdapter(ModelMixin):
|
29 |
+
r"""
|
30 |
+
MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
|
31 |
+
user-assigned weighting.
|
32 |
+
|
33 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
34 |
+
implements for all the model (such as downloading or saving, etc.)
|
35 |
+
|
36 |
+
Parameters:
|
37 |
+
adapters (`List[T2IAdapter]`, *optional*, defaults to None):
|
38 |
+
A list of `T2IAdapter` model instances.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, adapters: List["T2IAdapter"]):
|
42 |
+
super(MultiAdapter, self).__init__()
|
43 |
+
|
44 |
+
self.num_adapter = len(adapters)
|
45 |
+
self.adapters = nn.ModuleList(adapters)
|
46 |
+
|
47 |
+
if len(adapters) == 0:
|
48 |
+
raise ValueError("Expecting at least one adapter")
|
49 |
+
|
50 |
+
if len(adapters) == 1:
|
51 |
+
raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`")
|
52 |
+
|
53 |
+
# The outputs from each adapter are added together with a weight.
|
54 |
+
# This means that the change in dimensions from downsampling must
|
55 |
+
# be the same for all adapters. Inductively, it also means the
|
56 |
+
# downscale_factor and total_downscale_factor must be the same for all
|
57 |
+
# adapters.
|
58 |
+
first_adapter_total_downscale_factor = adapters[0].total_downscale_factor
|
59 |
+
first_adapter_downscale_factor = adapters[0].downscale_factor
|
60 |
+
for idx in range(1, len(adapters)):
|
61 |
+
if (
|
62 |
+
adapters[idx].total_downscale_factor != first_adapter_total_downscale_factor
|
63 |
+
or adapters[idx].downscale_factor != first_adapter_downscale_factor
|
64 |
+
):
|
65 |
+
raise ValueError(
|
66 |
+
f"Expecting all adapters to have the same downscaling behavior, but got:\n"
|
67 |
+
f"adapters[0].total_downscale_factor={first_adapter_total_downscale_factor}\n"
|
68 |
+
f"adapters[0].downscale_factor={first_adapter_downscale_factor}\n"
|
69 |
+
f"adapter[`{idx}`].total_downscale_factor={adapters[idx].total_downscale_factor}\n"
|
70 |
+
f"adapter[`{idx}`].downscale_factor={adapters[idx].downscale_factor}"
|
71 |
+
)
|
72 |
+
|
73 |
+
self.total_downscale_factor = first_adapter_total_downscale_factor
|
74 |
+
self.downscale_factor = first_adapter_downscale_factor
|
75 |
+
|
76 |
+
def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
|
77 |
+
r"""
|
78 |
+
Args:
|
79 |
+
xs (`torch.Tensor`):
|
80 |
+
(batch, channel, height, width) input images for multiple adapter models concated along dimension 1,
|
81 |
+
`channel` should equal to `num_adapter` * "number of channel of image".
|
82 |
+
adapter_weights (`List[float]`, *optional*, defaults to None):
|
83 |
+
List of floats representing the weight which will be multiply to each adapter's output before adding
|
84 |
+
them together.
|
85 |
+
"""
|
86 |
+
if adapter_weights is None:
|
87 |
+
adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter)
|
88 |
+
else:
|
89 |
+
adapter_weights = torch.tensor(adapter_weights)
|
90 |
+
|
91 |
+
accume_state = None
|
92 |
+
for x, w, adapter in zip(xs, adapter_weights, self.adapters):
|
93 |
+
features = adapter(x)
|
94 |
+
if accume_state is None:
|
95 |
+
accume_state = features
|
96 |
+
for i in range(len(accume_state)):
|
97 |
+
accume_state[i] = w * accume_state[i]
|
98 |
+
else:
|
99 |
+
for i in range(len(features)):
|
100 |
+
accume_state[i] += w * features[i]
|
101 |
+
return accume_state
|
102 |
+
|
103 |
+
def save_pretrained(
|
104 |
+
self,
|
105 |
+
save_directory: Union[str, os.PathLike],
|
106 |
+
is_main_process: bool = True,
|
107 |
+
save_function: Callable = None,
|
108 |
+
safe_serialization: bool = True,
|
109 |
+
variant: Optional[str] = None,
|
110 |
+
):
|
111 |
+
"""
|
112 |
+
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
113 |
+
`[`~models.adapter.MultiAdapter.from_pretrained`]` class method.
|
114 |
+
|
115 |
+
Arguments:
|
116 |
+
save_directory (`str` or `os.PathLike`):
|
117 |
+
Directory to which to save. Will be created if it doesn't exist.
|
118 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
119 |
+
Whether the process calling this is the main process or not. Useful when in distributed training like
|
120 |
+
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
121 |
+
the main process to avoid race conditions.
|
122 |
+
save_function (`Callable`):
|
123 |
+
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
124 |
+
need to replace `torch.save` by another method. Can be configured with the environment variable
|
125 |
+
`DIFFUSERS_SAVE_MODE`.
|
126 |
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
127 |
+
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
128 |
+
variant (`str`, *optional*):
|
129 |
+
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
130 |
+
"""
|
131 |
+
idx = 0
|
132 |
+
model_path_to_save = save_directory
|
133 |
+
for adapter in self.adapters:
|
134 |
+
adapter.save_pretrained(
|
135 |
+
model_path_to_save,
|
136 |
+
is_main_process=is_main_process,
|
137 |
+
save_function=save_function,
|
138 |
+
safe_serialization=safe_serialization,
|
139 |
+
variant=variant,
|
140 |
+
)
|
141 |
+
|
142 |
+
idx += 1
|
143 |
+
model_path_to_save = model_path_to_save + f"_{idx}"
|
144 |
+
|
145 |
+
@classmethod
|
146 |
+
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
|
147 |
+
r"""
|
148 |
+
Instantiate a pretrained MultiAdapter model from multiple pre-trained adapter models.
|
149 |
+
|
150 |
+
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
151 |
+
the model, you should first set it back in training mode with `model.train()`.
|
152 |
+
|
153 |
+
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
154 |
+
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
155 |
+
task.
|
156 |
+
|
157 |
+
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
158 |
+
weights are discarded.
|
159 |
+
|
160 |
+
Parameters:
|
161 |
+
pretrained_model_path (`os.PathLike`):
|
162 |
+
A path to a *directory* containing model weights saved using
|
163 |
+
[`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`.
|
164 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
165 |
+
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
166 |
+
will be automatically derived from the model's weights.
|
167 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
168 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
169 |
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
170 |
+
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
171 |
+
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
172 |
+
same device.
|
173 |
+
|
174 |
+
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
175 |
+
more information about each option see [designing a device
|
176 |
+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
177 |
+
max_memory (`Dict`, *optional*):
|
178 |
+
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
|
179 |
+
GPU and the available CPU RAM if unset.
|
180 |
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
181 |
+
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
182 |
+
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
183 |
+
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
184 |
+
setting this argument to `True` will raise an error.
|
185 |
+
variant (`str`, *optional*):
|
186 |
+
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
187 |
+
ignored when using `from_flax`.
|
188 |
+
use_safetensors (`bool`, *optional*, defaults to `None`):
|
189 |
+
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
|
190 |
+
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
|
191 |
+
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
|
192 |
+
"""
|
193 |
+
idx = 0
|
194 |
+
adapters = []
|
195 |
+
|
196 |
+
# load adapter and append to list until no adapter directory exists anymore
|
197 |
+
# first adapter has to be saved under `./mydirectory/adapter` to be compliant with `DiffusionPipeline.from_pretrained`
|
198 |
+
# second, third, ... adapters have to be saved under `./mydirectory/adapter_1`, `./mydirectory/adapter_2`, ...
|
199 |
+
model_path_to_load = pretrained_model_path
|
200 |
+
while os.path.isdir(model_path_to_load):
|
201 |
+
adapter = T2IAdapter.from_pretrained(model_path_to_load, **kwargs)
|
202 |
+
adapters.append(adapter)
|
203 |
+
|
204 |
+
idx += 1
|
205 |
+
model_path_to_load = pretrained_model_path + f"_{idx}"
|
206 |
+
|
207 |
+
logger.info(f"{len(adapters)} adapters loaded from {pretrained_model_path}.")
|
208 |
+
|
209 |
+
if len(adapters) == 0:
|
210 |
+
raise ValueError(
|
211 |
+
f"No T2IAdapters found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
|
212 |
+
)
|
213 |
+
|
214 |
+
return cls(adapters)
|
215 |
+
|
216 |
+
|
217 |
+
class T2IAdapter(ModelMixin, ConfigMixin):
|
218 |
+
r"""
|
219 |
+
A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model
|
220 |
+
generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's
|
221 |
+
architecture follows the original implementation of
|
222 |
+
[Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97)
|
223 |
+
and
|
224 |
+
[AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).
|
225 |
+
|
226 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
227 |
+
implements for all the model (such as downloading or saving, etc.)
|
228 |
+
|
229 |
+
Parameters:
|
230 |
+
in_channels (`int`, *optional*, defaults to 3):
|
231 |
+
Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale
|
232 |
+
image as *control image*.
|
233 |
+
channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
234 |
+
The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
|
235 |
+
also determine the number of downsample blocks in the Adapter.
|
236 |
+
num_res_blocks (`int`, *optional*, defaults to 2):
|
237 |
+
Number of ResNet blocks in each downsample block.
|
238 |
+
downscale_factor (`int`, *optional*, defaults to 8):
|
239 |
+
A factor that determines the total downscale factor of the Adapter.
|
240 |
+
adapter_type (`str`, *optional*, defaults to `full_adapter`):
|
241 |
+
The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`.
|
242 |
+
"""
|
243 |
+
|
244 |
+
@register_to_config
|
245 |
+
def __init__(
|
246 |
+
self,
|
247 |
+
in_channels: int = 3,
|
248 |
+
channels: List[int] = [320, 640, 1280, 1280],
|
249 |
+
num_res_blocks: int = 2,
|
250 |
+
downscale_factor: int = 8,
|
251 |
+
adapter_type: str = "full_adapter",
|
252 |
+
):
|
253 |
+
super().__init__()
|
254 |
+
|
255 |
+
if adapter_type == "full_adapter":
|
256 |
+
self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor)
|
257 |
+
elif adapter_type == "full_adapter_xl":
|
258 |
+
self.adapter = FullAdapterXL(in_channels, channels, num_res_blocks, downscale_factor)
|
259 |
+
elif adapter_type == "light_adapter":
|
260 |
+
self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
|
261 |
+
else:
|
262 |
+
raise ValueError(
|
263 |
+
f"Unsupported adapter_type: '{adapter_type}'. Choose either 'full_adapter' or "
|
264 |
+
"'full_adapter_xl' or 'light_adapter'."
|
265 |
+
)
|
266 |
+
|
267 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
268 |
+
r"""
|
269 |
+
This function processes the input tensor `x` through the adapter model and returns a list of feature tensors,
|
270 |
+
each representing information extracted at a different scale from the input. The length of the list is
|
271 |
+
determined by the number of downsample blocks in the Adapter, as specified by the `channels` and
|
272 |
+
`num_res_blocks` parameters during initialization.
|
273 |
+
"""
|
274 |
+
return self.adapter(x)
|
275 |
+
|
276 |
+
@property
|
277 |
+
def total_downscale_factor(self):
|
278 |
+
return self.adapter.total_downscale_factor
|
279 |
+
|
280 |
+
@property
|
281 |
+
def downscale_factor(self):
|
282 |
+
"""The downscale factor applied in the T2I-Adapter's initial pixel unshuffle operation. If an input image's dimensions are
|
283 |
+
not evenly divisible by the downscale_factor then an exception will be raised.
|
284 |
+
"""
|
285 |
+
return self.adapter.unshuffle.downscale_factor
|
286 |
+
|
287 |
+
|
288 |
+
# full adapter
|
289 |
+
|
290 |
+
|
291 |
+
class FullAdapter(nn.Module):
|
292 |
+
r"""
|
293 |
+
See [`T2IAdapter`] for more information.
|
294 |
+
"""
|
295 |
+
|
296 |
+
def __init__(
|
297 |
+
self,
|
298 |
+
in_channels: int = 3,
|
299 |
+
channels: List[int] = [320, 640, 1280, 1280],
|
300 |
+
num_res_blocks: int = 2,
|
301 |
+
downscale_factor: int = 8,
|
302 |
+
):
|
303 |
+
super().__init__()
|
304 |
+
|
305 |
+
in_channels = in_channels * downscale_factor**2
|
306 |
+
|
307 |
+
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
|
308 |
+
self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
|
309 |
+
|
310 |
+
self.body = nn.ModuleList(
|
311 |
+
[
|
312 |
+
AdapterBlock(channels[0], channels[0], num_res_blocks),
|
313 |
+
*[
|
314 |
+
AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True)
|
315 |
+
for i in range(1, len(channels))
|
316 |
+
],
|
317 |
+
]
|
318 |
+
)
|
319 |
+
|
320 |
+
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
|
321 |
+
|
322 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
323 |
+
r"""
|
324 |
+
This method processes the input tensor `x` through the FullAdapter model and performs operations including
|
325 |
+
pixel unshuffling, convolution, and a stack of AdapterBlocks. It returns a list of feature tensors, each
|
326 |
+
capturing information at a different stage of processing within the FullAdapter model. The number of feature
|
327 |
+
tensors in the list is determined by the number of downsample blocks specified during initialization.
|
328 |
+
"""
|
329 |
+
x = self.unshuffle(x)
|
330 |
+
x = self.conv_in(x)
|
331 |
+
|
332 |
+
features = []
|
333 |
+
|
334 |
+
for block in self.body:
|
335 |
+
x = block(x)
|
336 |
+
features.append(x)
|
337 |
+
|
338 |
+
return features
|
339 |
+
|
340 |
+
|
341 |
+
class FullAdapterXL(nn.Module):
|
342 |
+
r"""
|
343 |
+
See [`T2IAdapter`] for more information.
|
344 |
+
"""
|
345 |
+
|
346 |
+
def __init__(
|
347 |
+
self,
|
348 |
+
in_channels: int = 3,
|
349 |
+
channels: List[int] = [320, 640, 1280, 1280],
|
350 |
+
num_res_blocks: int = 2,
|
351 |
+
downscale_factor: int = 16,
|
352 |
+
):
|
353 |
+
super().__init__()
|
354 |
+
|
355 |
+
in_channels = in_channels * downscale_factor**2
|
356 |
+
|
357 |
+
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
|
358 |
+
self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
|
359 |
+
|
360 |
+
self.body = []
|
361 |
+
# blocks to extract XL features with dimensions of [320, 64, 64], [640, 64, 64], [1280, 32, 32], [1280, 32, 32]
|
362 |
+
for i in range(len(channels)):
|
363 |
+
if i == 1:
|
364 |
+
self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks))
|
365 |
+
elif i == 2:
|
366 |
+
self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True))
|
367 |
+
else:
|
368 |
+
self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks))
|
369 |
+
|
370 |
+
self.body = nn.ModuleList(self.body)
|
371 |
+
# XL has only one downsampling AdapterBlock.
|
372 |
+
self.total_downscale_factor = downscale_factor * 2
|
373 |
+
|
374 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
375 |
+
r"""
|
376 |
+
This method takes the tensor x as input and processes it through FullAdapterXL model. It consists of operations
|
377 |
+
including unshuffling pixels, applying convolution layer and appending each block into list of feature tensors.
|
378 |
+
"""
|
379 |
+
x = self.unshuffle(x)
|
380 |
+
x = self.conv_in(x)
|
381 |
+
|
382 |
+
features = []
|
383 |
+
|
384 |
+
for block in self.body:
|
385 |
+
x = block(x)
|
386 |
+
features.append(x)
|
387 |
+
|
388 |
+
return features
|
389 |
+
|
390 |
+
|
391 |
+
class AdapterBlock(nn.Module):
|
392 |
+
r"""
|
393 |
+
An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and
|
394 |
+
`FullAdapterXL` models.
|
395 |
+
|
396 |
+
Parameters:
|
397 |
+
in_channels (`int`):
|
398 |
+
Number of channels of AdapterBlock's input.
|
399 |
+
out_channels (`int`):
|
400 |
+
Number of channels of AdapterBlock's output.
|
401 |
+
num_res_blocks (`int`):
|
402 |
+
Number of ResNet blocks in the AdapterBlock.
|
403 |
+
down (`bool`, *optional*, defaults to `False`):
|
404 |
+
Whether to perform downsampling on AdapterBlock's input.
|
405 |
+
"""
|
406 |
+
|
407 |
+
def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
|
408 |
+
super().__init__()
|
409 |
+
|
410 |
+
self.downsample = None
|
411 |
+
if down:
|
412 |
+
self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
|
413 |
+
|
414 |
+
self.in_conv = None
|
415 |
+
if in_channels != out_channels:
|
416 |
+
self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
417 |
+
|
418 |
+
self.resnets = nn.Sequential(
|
419 |
+
*[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)],
|
420 |
+
)
|
421 |
+
|
422 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
423 |
+
r"""
|
424 |
+
This method takes tensor x as input and performs operations downsampling and convolutional layers if the
|
425 |
+
self.downsample and self.in_conv properties of AdapterBlock model are specified. Then it applies a series of
|
426 |
+
residual blocks to the input tensor.
|
427 |
+
"""
|
428 |
+
if self.downsample is not None:
|
429 |
+
x = self.downsample(x)
|
430 |
+
|
431 |
+
if self.in_conv is not None:
|
432 |
+
x = self.in_conv(x)
|
433 |
+
|
434 |
+
x = self.resnets(x)
|
435 |
+
|
436 |
+
return x
|
437 |
+
|
438 |
+
|
439 |
+
class AdapterResnetBlock(nn.Module):
|
440 |
+
r"""
|
441 |
+
An `AdapterResnetBlock` is a helper model that implements a ResNet-like block.
|
442 |
+
|
443 |
+
Parameters:
|
444 |
+
channels (`int`):
|
445 |
+
Number of channels of AdapterResnetBlock's input and output.
|
446 |
+
"""
|
447 |
+
|
448 |
+
def __init__(self, channels: int):
|
449 |
+
super().__init__()
|
450 |
+
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
451 |
+
self.act = nn.ReLU()
|
452 |
+
self.block2 = nn.Conv2d(channels, channels, kernel_size=1)
|
453 |
+
|
454 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
455 |
+
r"""
|
456 |
+
This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional
|
457 |
+
layer on the input tensor. It returns addition with the input tensor.
|
458 |
+
"""
|
459 |
+
|
460 |
+
h = self.act(self.block1(x))
|
461 |
+
h = self.block2(h)
|
462 |
+
|
463 |
+
return h + x
|
464 |
+
|
465 |
+
|
466 |
+
# light adapter
|
467 |
+
|
468 |
+
|
469 |
+
class LightAdapter(nn.Module):
|
470 |
+
r"""
|
471 |
+
See [`T2IAdapter`] for more information.
|
472 |
+
"""
|
473 |
+
|
474 |
+
def __init__(
|
475 |
+
self,
|
476 |
+
in_channels: int = 3,
|
477 |
+
channels: List[int] = [320, 640, 1280],
|
478 |
+
num_res_blocks: int = 4,
|
479 |
+
downscale_factor: int = 8,
|
480 |
+
):
|
481 |
+
super().__init__()
|
482 |
+
|
483 |
+
in_channels = in_channels * downscale_factor**2
|
484 |
+
|
485 |
+
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
|
486 |
+
|
487 |
+
self.body = nn.ModuleList(
|
488 |
+
[
|
489 |
+
LightAdapterBlock(in_channels, channels[0], num_res_blocks),
|
490 |
+
*[
|
491 |
+
LightAdapterBlock(channels[i], channels[i + 1], num_res_blocks, down=True)
|
492 |
+
for i in range(len(channels) - 1)
|
493 |
+
],
|
494 |
+
LightAdapterBlock(channels[-1], channels[-1], num_res_blocks, down=True),
|
495 |
+
]
|
496 |
+
)
|
497 |
+
|
498 |
+
self.total_downscale_factor = downscale_factor * (2 ** len(channels))
|
499 |
+
|
500 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
501 |
+
r"""
|
502 |
+
This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each
|
503 |
+
feature tensor corresponds to a different level of processing within the LightAdapter.
|
504 |
+
"""
|
505 |
+
x = self.unshuffle(x)
|
506 |
+
|
507 |
+
features = []
|
508 |
+
|
509 |
+
for block in self.body:
|
510 |
+
x = block(x)
|
511 |
+
features.append(x)
|
512 |
+
|
513 |
+
return features
|
514 |
+
|
515 |
+
|
516 |
+
class LightAdapterBlock(nn.Module):
|
517 |
+
r"""
|
518 |
+
A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the
|
519 |
+
`LightAdapter` model.
|
520 |
+
|
521 |
+
Parameters:
|
522 |
+
in_channels (`int`):
|
523 |
+
Number of channels of LightAdapterBlock's input.
|
524 |
+
out_channels (`int`):
|
525 |
+
Number of channels of LightAdapterBlock's output.
|
526 |
+
num_res_blocks (`int`):
|
527 |
+
Number of LightAdapterResnetBlocks in the LightAdapterBlock.
|
528 |
+
down (`bool`, *optional*, defaults to `False`):
|
529 |
+
Whether to perform downsampling on LightAdapterBlock's input.
|
530 |
+
"""
|
531 |
+
|
532 |
+
def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
|
533 |
+
super().__init__()
|
534 |
+
mid_channels = out_channels // 4
|
535 |
+
|
536 |
+
self.downsample = None
|
537 |
+
if down:
|
538 |
+
self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
|
539 |
+
|
540 |
+
self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
|
541 |
+
self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
|
542 |
+
self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
|
543 |
+
|
544 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
545 |
+
r"""
|
546 |
+
This method takes tensor x as input and performs downsampling if required. Then it applies in convolution
|
547 |
+
layer, a sequence of residual blocks, and out convolutional layer.
|
548 |
+
"""
|
549 |
+
if self.downsample is not None:
|
550 |
+
x = self.downsample(x)
|
551 |
+
|
552 |
+
x = self.in_conv(x)
|
553 |
+
x = self.resnets(x)
|
554 |
+
x = self.out_conv(x)
|
555 |
+
|
556 |
+
return x
|
557 |
+
|
558 |
+
|
559 |
+
class LightAdapterResnetBlock(nn.Module):
|
560 |
+
"""
|
561 |
+
A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different
|
562 |
+
architecture than `AdapterResnetBlock`.
|
563 |
+
|
564 |
+
Parameters:
|
565 |
+
channels (`int`):
|
566 |
+
Number of channels of LightAdapterResnetBlock's input and output.
|
567 |
+
"""
|
568 |
+
|
569 |
+
def __init__(self, channels: int):
|
570 |
+
super().__init__()
|
571 |
+
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
572 |
+
self.act = nn.ReLU()
|
573 |
+
self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
574 |
+
|
575 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
576 |
+
r"""
|
577 |
+
This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and
|
578 |
+
another convolutional layer and adds it to input tensor.
|
579 |
+
"""
|
580 |
+
|
581 |
+
h = self.act(self.block1(x))
|
582 |
+
h = self.block2(h)
|
583 |
+
|
584 |
+
return h + x
|
diffusers/models/attention.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Any, Dict, Optional
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
|
19 |
+
from ..utils import USE_PEFT_BACKEND
|
20 |
+
from ..utils.torch_utils import maybe_allow_in_graph
|
21 |
+
from .activations import GEGLU, GELU, ApproximateGELU
|
22 |
+
from .attention_processor import Attention
|
23 |
+
from .embeddings import SinusoidalPositionalEmbedding
|
24 |
+
from .lora import LoRACompatibleLinear
|
25 |
+
from .normalization import AdaLayerNorm, AdaLayerNormZero
|
26 |
+
|
27 |
+
|
28 |
+
@maybe_allow_in_graph
|
29 |
+
class GatedSelfAttentionDense(nn.Module):
|
30 |
+
r"""
|
31 |
+
A gated self-attention dense layer that combines visual features and object features.
|
32 |
+
|
33 |
+
Parameters:
|
34 |
+
query_dim (`int`): The number of channels in the query.
|
35 |
+
context_dim (`int`): The number of channels in the context.
|
36 |
+
n_heads (`int`): The number of heads to use for attention.
|
37 |
+
d_head (`int`): The number of channels in each head.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
# we need a linear projection since we need cat visual feature and obj feature
|
44 |
+
self.linear = nn.Linear(context_dim, query_dim)
|
45 |
+
|
46 |
+
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
|
47 |
+
self.ff = FeedForward(query_dim, activation_fn="geglu")
|
48 |
+
|
49 |
+
self.norm1 = nn.LayerNorm(query_dim)
|
50 |
+
self.norm2 = nn.LayerNorm(query_dim)
|
51 |
+
|
52 |
+
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
|
53 |
+
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
|
54 |
+
|
55 |
+
self.enabled = True
|
56 |
+
|
57 |
+
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
|
58 |
+
if not self.enabled:
|
59 |
+
return x
|
60 |
+
|
61 |
+
n_visual = x.shape[1]
|
62 |
+
objs = self.linear(objs)
|
63 |
+
|
64 |
+
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
|
65 |
+
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
|
66 |
+
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
@maybe_allow_in_graph
|
71 |
+
class BasicTransformerBlock(nn.Module):
|
72 |
+
r"""
|
73 |
+
A basic Transformer block.
|
74 |
+
|
75 |
+
Parameters:
|
76 |
+
dim (`int`): The number of channels in the input and output.
|
77 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
78 |
+
attention_head_dim (`int`): The number of channels in each head.
|
79 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
80 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
81 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
82 |
+
num_embeds_ada_norm (:
|
83 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
84 |
+
attention_bias (:
|
85 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
86 |
+
only_cross_attention (`bool`, *optional*):
|
87 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
88 |
+
double_self_attention (`bool`, *optional*):
|
89 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
90 |
+
upcast_attention (`bool`, *optional*):
|
91 |
+
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
92 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
93 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
94 |
+
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
95 |
+
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
96 |
+
final_dropout (`bool` *optional*, defaults to False):
|
97 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
98 |
+
attention_type (`str`, *optional*, defaults to `"default"`):
|
99 |
+
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
100 |
+
positional_embeddings (`str`, *optional*, defaults to `None`):
|
101 |
+
The type of positional embeddings to apply to.
|
102 |
+
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
103 |
+
The maximum number of positional embeddings to apply.
|
104 |
+
"""
|
105 |
+
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
dim: int,
|
109 |
+
num_attention_heads: int,
|
110 |
+
attention_head_dim: int,
|
111 |
+
dropout=0.0,
|
112 |
+
cross_attention_dim: Optional[int] = None,
|
113 |
+
activation_fn: str = "geglu",
|
114 |
+
num_embeds_ada_norm: Optional[int] = None,
|
115 |
+
attention_bias: bool = False,
|
116 |
+
only_cross_attention: bool = False,
|
117 |
+
double_self_attention: bool = False,
|
118 |
+
upcast_attention: bool = False,
|
119 |
+
norm_elementwise_affine: bool = True,
|
120 |
+
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
|
121 |
+
norm_eps: float = 1e-5,
|
122 |
+
final_dropout: bool = False,
|
123 |
+
attention_type: str = "default",
|
124 |
+
positional_embeddings: Optional[str] = None,
|
125 |
+
num_positional_embeddings: Optional[int] = None,
|
126 |
+
):
|
127 |
+
super().__init__()
|
128 |
+
self.only_cross_attention = only_cross_attention
|
129 |
+
|
130 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
131 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
132 |
+
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
133 |
+
self.use_layer_norm = norm_type == "layer_norm"
|
134 |
+
|
135 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
136 |
+
raise ValueError(
|
137 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
138 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
139 |
+
)
|
140 |
+
|
141 |
+
if positional_embeddings and (num_positional_embeddings is None):
|
142 |
+
raise ValueError(
|
143 |
+
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
144 |
+
)
|
145 |
+
|
146 |
+
if positional_embeddings == "sinusoidal":
|
147 |
+
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
148 |
+
else:
|
149 |
+
self.pos_embed = None
|
150 |
+
|
151 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
152 |
+
# 1. Self-Attn
|
153 |
+
if self.use_ada_layer_norm:
|
154 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
155 |
+
elif self.use_ada_layer_norm_zero:
|
156 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
157 |
+
else:
|
158 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
159 |
+
|
160 |
+
self.attn1 = Attention(
|
161 |
+
query_dim=dim,
|
162 |
+
heads=num_attention_heads,
|
163 |
+
dim_head=attention_head_dim,
|
164 |
+
dropout=dropout,
|
165 |
+
bias=attention_bias,
|
166 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
167 |
+
upcast_attention=upcast_attention,
|
168 |
+
)
|
169 |
+
|
170 |
+
# 2. Cross-Attn
|
171 |
+
if cross_attention_dim is not None or double_self_attention:
|
172 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
173 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
174 |
+
# the second cross attention block.
|
175 |
+
self.norm2 = (
|
176 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
177 |
+
if self.use_ada_layer_norm
|
178 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
179 |
+
)
|
180 |
+
self.attn2 = Attention(
|
181 |
+
query_dim=dim,
|
182 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
183 |
+
heads=num_attention_heads,
|
184 |
+
dim_head=attention_head_dim,
|
185 |
+
dropout=dropout,
|
186 |
+
bias=attention_bias,
|
187 |
+
upcast_attention=upcast_attention,
|
188 |
+
) # is self-attn if encoder_hidden_states is none
|
189 |
+
else:
|
190 |
+
self.norm2 = None
|
191 |
+
self.attn2 = None
|
192 |
+
|
193 |
+
# 3. Feed-forward
|
194 |
+
if not self.use_ada_layer_norm_single:
|
195 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
196 |
+
|
197 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
198 |
+
|
199 |
+
# 4. Fuser
|
200 |
+
if attention_type == "gated" or attention_type == "gated-text-image":
|
201 |
+
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
202 |
+
|
203 |
+
# 5. Scale-shift for PixArt-Alpha.
|
204 |
+
if self.use_ada_layer_norm_single:
|
205 |
+
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
206 |
+
|
207 |
+
# let chunk size default to None
|
208 |
+
self._chunk_size = None
|
209 |
+
self._chunk_dim = 0
|
210 |
+
|
211 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
212 |
+
# Sets chunk feed-forward
|
213 |
+
self._chunk_size = chunk_size
|
214 |
+
self._chunk_dim = dim
|
215 |
+
|
216 |
+
def forward(
|
217 |
+
self,
|
218 |
+
hidden_states: torch.FloatTensor,
|
219 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
220 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
221 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
222 |
+
timestep: Optional[torch.LongTensor] = None,
|
223 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
224 |
+
class_labels: Optional[torch.LongTensor] = None,
|
225 |
+
) -> torch.FloatTensor:
|
226 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
227 |
+
# 0. Self-Attention
|
228 |
+
batch_size = hidden_states.shape[0]
|
229 |
+
|
230 |
+
if self.use_ada_layer_norm:
|
231 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
232 |
+
elif self.use_ada_layer_norm_zero:
|
233 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
234 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
235 |
+
)
|
236 |
+
elif self.use_layer_norm:
|
237 |
+
norm_hidden_states = self.norm1(hidden_states)
|
238 |
+
elif self.use_ada_layer_norm_single:
|
239 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
240 |
+
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
241 |
+
).chunk(6, dim=1)
|
242 |
+
norm_hidden_states = self.norm1(hidden_states)
|
243 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
244 |
+
norm_hidden_states = norm_hidden_states.squeeze(1)
|
245 |
+
else:
|
246 |
+
raise ValueError("Incorrect norm used")
|
247 |
+
|
248 |
+
if self.pos_embed is not None:
|
249 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
250 |
+
|
251 |
+
# 1. Retrieve lora scale.
|
252 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
253 |
+
|
254 |
+
# 2. Prepare GLIGEN inputs
|
255 |
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
256 |
+
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
257 |
+
|
258 |
+
attn_output = self.attn1(
|
259 |
+
norm_hidden_states, # 32 4096 320
|
260 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, # 32 77 768
|
261 |
+
attention_mask=attention_mask,
|
262 |
+
**cross_attention_kwargs,
|
263 |
+
)
|
264 |
+
if self.use_ada_layer_norm_zero:
|
265 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
266 |
+
elif self.use_ada_layer_norm_single:
|
267 |
+
attn_output = gate_msa * attn_output
|
268 |
+
|
269 |
+
hidden_states = attn_output + hidden_states
|
270 |
+
if hidden_states.ndim == 4:
|
271 |
+
hidden_states = hidden_states.squeeze(1)
|
272 |
+
|
273 |
+
# 2.5 GLIGEN Control
|
274 |
+
if gligen_kwargs is not None:
|
275 |
+
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
276 |
+
|
277 |
+
# 3. Cross-Attention
|
278 |
+
if self.attn2 is not None:
|
279 |
+
if self.use_ada_layer_norm:
|
280 |
+
norm_hidden_states = self.norm2(hidden_states, timestep)
|
281 |
+
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
|
282 |
+
norm_hidden_states = self.norm2(hidden_states)
|
283 |
+
elif self.use_ada_layer_norm_single:
|
284 |
+
# For PixArt norm2 isn't applied here:
|
285 |
+
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
286 |
+
norm_hidden_states = hidden_states
|
287 |
+
else:
|
288 |
+
raise ValueError("Incorrect norm")
|
289 |
+
|
290 |
+
if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
|
291 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
292 |
+
|
293 |
+
attn_output = self.attn2(
|
294 |
+
norm_hidden_states,
|
295 |
+
encoder_hidden_states=encoder_hidden_states,
|
296 |
+
attention_mask=encoder_attention_mask,
|
297 |
+
**cross_attention_kwargs,
|
298 |
+
)
|
299 |
+
hidden_states = attn_output + hidden_states
|
300 |
+
|
301 |
+
# 4. Feed-forward
|
302 |
+
if not self.use_ada_layer_norm_single:
|
303 |
+
norm_hidden_states = self.norm3(hidden_states)
|
304 |
+
|
305 |
+
if self.use_ada_layer_norm_zero:
|
306 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
307 |
+
|
308 |
+
if self.use_ada_layer_norm_single:
|
309 |
+
norm_hidden_states = self.norm2(hidden_states)
|
310 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
311 |
+
|
312 |
+
if self._chunk_size is not None:
|
313 |
+
# "feed_forward_chunk_size" can be used to save memory
|
314 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
315 |
+
raise ValueError(
|
316 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
317 |
+
)
|
318 |
+
|
319 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
320 |
+
ff_output = torch.cat(
|
321 |
+
[
|
322 |
+
self.ff(hid_slice, scale=lora_scale)
|
323 |
+
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
|
324 |
+
],
|
325 |
+
dim=self._chunk_dim,
|
326 |
+
)
|
327 |
+
else:
|
328 |
+
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
|
329 |
+
|
330 |
+
if self.use_ada_layer_norm_zero:
|
331 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
332 |
+
elif self.use_ada_layer_norm_single:
|
333 |
+
ff_output = gate_mlp * ff_output
|
334 |
+
|
335 |
+
hidden_states = ff_output + hidden_states
|
336 |
+
if hidden_states.ndim == 4:
|
337 |
+
hidden_states = hidden_states.squeeze(1)
|
338 |
+
|
339 |
+
return hidden_states
|
340 |
+
|
341 |
+
|
342 |
+
class FeedForward(nn.Module):
|
343 |
+
r"""
|
344 |
+
A feed-forward layer.
|
345 |
+
|
346 |
+
Parameters:
|
347 |
+
dim (`int`): The number of channels in the input.
|
348 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
349 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
350 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
351 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
352 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
353 |
+
"""
|
354 |
+
|
355 |
+
def __init__(
|
356 |
+
self,
|
357 |
+
dim: int,
|
358 |
+
dim_out: Optional[int] = None,
|
359 |
+
mult: int = 4,
|
360 |
+
dropout: float = 0.0,
|
361 |
+
activation_fn: str = "geglu",
|
362 |
+
final_dropout: bool = False,
|
363 |
+
):
|
364 |
+
super().__init__()
|
365 |
+
inner_dim = int(dim * mult)
|
366 |
+
dim_out = dim_out if dim_out is not None else dim
|
367 |
+
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
|
368 |
+
|
369 |
+
if activation_fn == "gelu":
|
370 |
+
act_fn = GELU(dim, inner_dim)
|
371 |
+
if activation_fn == "gelu-approximate":
|
372 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
373 |
+
elif activation_fn == "geglu":
|
374 |
+
act_fn = GEGLU(dim, inner_dim)
|
375 |
+
elif activation_fn == "geglu-approximate":
|
376 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
377 |
+
|
378 |
+
self.net = nn.ModuleList([])
|
379 |
+
# project in
|
380 |
+
self.net.append(act_fn)
|
381 |
+
# project dropout
|
382 |
+
self.net.append(nn.Dropout(dropout))
|
383 |
+
# project out
|
384 |
+
self.net.append(linear_cls(inner_dim, dim_out))
|
385 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
386 |
+
if final_dropout:
|
387 |
+
self.net.append(nn.Dropout(dropout))
|
388 |
+
|
389 |
+
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
390 |
+
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
|
391 |
+
for module in self.net:
|
392 |
+
if isinstance(module, compatible_cls):
|
393 |
+
hidden_states = module(hidden_states, scale)
|
394 |
+
else:
|
395 |
+
hidden_states = module(hidden_states)
|
396 |
+
return hidden_states
|
diffusers/models/attention_flax.py
ADDED
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import functools
|
16 |
+
import math
|
17 |
+
|
18 |
+
import flax.linen as nn
|
19 |
+
import jax
|
20 |
+
import jax.numpy as jnp
|
21 |
+
|
22 |
+
|
23 |
+
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
|
24 |
+
"""Multi-head dot product attention with a limited number of queries."""
|
25 |
+
num_kv, num_heads, k_features = key.shape[-3:]
|
26 |
+
v_features = value.shape[-1]
|
27 |
+
key_chunk_size = min(key_chunk_size, num_kv)
|
28 |
+
query = query / jnp.sqrt(k_features)
|
29 |
+
|
30 |
+
@functools.partial(jax.checkpoint, prevent_cse=False)
|
31 |
+
def summarize_chunk(query, key, value):
|
32 |
+
attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
|
33 |
+
|
34 |
+
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
|
35 |
+
max_score = jax.lax.stop_gradient(max_score)
|
36 |
+
exp_weights = jnp.exp(attn_weights - max_score)
|
37 |
+
|
38 |
+
exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
|
39 |
+
max_score = jnp.einsum("...qhk->...qh", max_score)
|
40 |
+
|
41 |
+
return (exp_values, exp_weights.sum(axis=-1), max_score)
|
42 |
+
|
43 |
+
def chunk_scanner(chunk_idx):
|
44 |
+
# julienne key array
|
45 |
+
key_chunk = jax.lax.dynamic_slice(
|
46 |
+
operand=key,
|
47 |
+
start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
|
48 |
+
slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
|
49 |
+
)
|
50 |
+
|
51 |
+
# julienne value array
|
52 |
+
value_chunk = jax.lax.dynamic_slice(
|
53 |
+
operand=value,
|
54 |
+
start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
|
55 |
+
slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
|
56 |
+
)
|
57 |
+
|
58 |
+
return summarize_chunk(query, key_chunk, value_chunk)
|
59 |
+
|
60 |
+
chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
|
61 |
+
|
62 |
+
global_max = jnp.max(chunk_max, axis=0, keepdims=True)
|
63 |
+
max_diffs = jnp.exp(chunk_max - global_max)
|
64 |
+
|
65 |
+
chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
|
66 |
+
chunk_weights *= max_diffs
|
67 |
+
|
68 |
+
all_values = chunk_values.sum(axis=0)
|
69 |
+
all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
|
70 |
+
|
71 |
+
return all_values / all_weights
|
72 |
+
|
73 |
+
|
74 |
+
def jax_memory_efficient_attention(
|
75 |
+
query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
|
76 |
+
):
|
77 |
+
r"""
|
78 |
+
Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
|
79 |
+
https://github.com/AminRezaei0x443/memory-efficient-attention
|
80 |
+
|
81 |
+
Args:
|
82 |
+
query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
|
83 |
+
key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
|
84 |
+
value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
|
85 |
+
precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
|
86 |
+
numerical precision for computation
|
87 |
+
query_chunk_size (`int`, *optional*, defaults to 1024):
|
88 |
+
chunk size to divide query array value must divide query_length equally without remainder
|
89 |
+
key_chunk_size (`int`, *optional*, defaults to 4096):
|
90 |
+
chunk size to divide key and value array value must divide key_value_length equally without remainder
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
(`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
|
94 |
+
"""
|
95 |
+
num_q, num_heads, q_features = query.shape[-3:]
|
96 |
+
|
97 |
+
def chunk_scanner(chunk_idx, _):
|
98 |
+
# julienne query array
|
99 |
+
query_chunk = jax.lax.dynamic_slice(
|
100 |
+
operand=query,
|
101 |
+
start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
|
102 |
+
slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
|
103 |
+
)
|
104 |
+
|
105 |
+
return (
|
106 |
+
chunk_idx + query_chunk_size, # unused ignore it
|
107 |
+
_query_chunk_attention(
|
108 |
+
query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
|
109 |
+
),
|
110 |
+
)
|
111 |
+
|
112 |
+
_, res = jax.lax.scan(
|
113 |
+
f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter
|
114 |
+
)
|
115 |
+
|
116 |
+
return jnp.concatenate(res, axis=-3) # fuse the chunked result back
|
117 |
+
|
118 |
+
|
119 |
+
class FlaxAttention(nn.Module):
|
120 |
+
r"""
|
121 |
+
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
|
122 |
+
|
123 |
+
Parameters:
|
124 |
+
query_dim (:obj:`int`):
|
125 |
+
Input hidden states dimension
|
126 |
+
heads (:obj:`int`, *optional*, defaults to 8):
|
127 |
+
Number of heads
|
128 |
+
dim_head (:obj:`int`, *optional*, defaults to 64):
|
129 |
+
Hidden states dimension inside each head
|
130 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
131 |
+
Dropout rate
|
132 |
+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
133 |
+
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
134 |
+
split_head_dim (`bool`, *optional*, defaults to `False`):
|
135 |
+
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
136 |
+
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
137 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
138 |
+
Parameters `dtype`
|
139 |
+
|
140 |
+
"""
|
141 |
+
query_dim: int
|
142 |
+
heads: int = 8
|
143 |
+
dim_head: int = 64
|
144 |
+
dropout: float = 0.0
|
145 |
+
use_memory_efficient_attention: bool = False
|
146 |
+
split_head_dim: bool = False
|
147 |
+
dtype: jnp.dtype = jnp.float32
|
148 |
+
|
149 |
+
def setup(self):
|
150 |
+
inner_dim = self.dim_head * self.heads
|
151 |
+
self.scale = self.dim_head**-0.5
|
152 |
+
|
153 |
+
# Weights were exported with old names {to_q, to_k, to_v, to_out}
|
154 |
+
self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
|
155 |
+
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
|
156 |
+
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
|
157 |
+
|
158 |
+
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
|
159 |
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
160 |
+
|
161 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
162 |
+
batch_size, seq_len, dim = tensor.shape
|
163 |
+
head_size = self.heads
|
164 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
165 |
+
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
|
166 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
167 |
+
return tensor
|
168 |
+
|
169 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
170 |
+
batch_size, seq_len, dim = tensor.shape
|
171 |
+
head_size = self.heads
|
172 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
173 |
+
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
|
174 |
+
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
|
175 |
+
return tensor
|
176 |
+
|
177 |
+
def __call__(self, hidden_states, context=None, deterministic=True):
|
178 |
+
context = hidden_states if context is None else context
|
179 |
+
|
180 |
+
query_proj = self.query(hidden_states)
|
181 |
+
key_proj = self.key(context)
|
182 |
+
value_proj = self.value(context)
|
183 |
+
|
184 |
+
if self.split_head_dim:
|
185 |
+
b = hidden_states.shape[0]
|
186 |
+
query_states = jnp.reshape(query_proj, (b, -1, self.heads, self.dim_head))
|
187 |
+
key_states = jnp.reshape(key_proj, (b, -1, self.heads, self.dim_head))
|
188 |
+
value_states = jnp.reshape(value_proj, (b, -1, self.heads, self.dim_head))
|
189 |
+
else:
|
190 |
+
query_states = self.reshape_heads_to_batch_dim(query_proj)
|
191 |
+
key_states = self.reshape_heads_to_batch_dim(key_proj)
|
192 |
+
value_states = self.reshape_heads_to_batch_dim(value_proj)
|
193 |
+
|
194 |
+
if self.use_memory_efficient_attention:
|
195 |
+
query_states = query_states.transpose(1, 0, 2)
|
196 |
+
key_states = key_states.transpose(1, 0, 2)
|
197 |
+
value_states = value_states.transpose(1, 0, 2)
|
198 |
+
|
199 |
+
# this if statement create a chunk size for each layer of the unet
|
200 |
+
# the chunk size is equal to the query_length dimension of the deepest layer of the unet
|
201 |
+
|
202 |
+
flatten_latent_dim = query_states.shape[-3]
|
203 |
+
if flatten_latent_dim % 64 == 0:
|
204 |
+
query_chunk_size = int(flatten_latent_dim / 64)
|
205 |
+
elif flatten_latent_dim % 16 == 0:
|
206 |
+
query_chunk_size = int(flatten_latent_dim / 16)
|
207 |
+
elif flatten_latent_dim % 4 == 0:
|
208 |
+
query_chunk_size = int(flatten_latent_dim / 4)
|
209 |
+
else:
|
210 |
+
query_chunk_size = int(flatten_latent_dim)
|
211 |
+
|
212 |
+
hidden_states = jax_memory_efficient_attention(
|
213 |
+
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
|
214 |
+
)
|
215 |
+
|
216 |
+
hidden_states = hidden_states.transpose(1, 0, 2)
|
217 |
+
else:
|
218 |
+
# compute attentions
|
219 |
+
if self.split_head_dim:
|
220 |
+
attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states)
|
221 |
+
else:
|
222 |
+
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
|
223 |
+
|
224 |
+
attention_scores = attention_scores * self.scale
|
225 |
+
attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2)
|
226 |
+
|
227 |
+
# attend to values
|
228 |
+
if self.split_head_dim:
|
229 |
+
hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states)
|
230 |
+
b = hidden_states.shape[0]
|
231 |
+
hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head))
|
232 |
+
else:
|
233 |
+
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
|
234 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
235 |
+
|
236 |
+
hidden_states = self.proj_attn(hidden_states)
|
237 |
+
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
238 |
+
|
239 |
+
|
240 |
+
class FlaxBasicTransformerBlock(nn.Module):
|
241 |
+
r"""
|
242 |
+
A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
|
243 |
+
https://arxiv.org/abs/1706.03762
|
244 |
+
|
245 |
+
|
246 |
+
Parameters:
|
247 |
+
dim (:obj:`int`):
|
248 |
+
Inner hidden states dimension
|
249 |
+
n_heads (:obj:`int`):
|
250 |
+
Number of heads
|
251 |
+
d_head (:obj:`int`):
|
252 |
+
Hidden states dimension inside each head
|
253 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
254 |
+
Dropout rate
|
255 |
+
only_cross_attention (`bool`, defaults to `False`):
|
256 |
+
Whether to only apply cross attention.
|
257 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
258 |
+
Parameters `dtype`
|
259 |
+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
260 |
+
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
261 |
+
split_head_dim (`bool`, *optional*, defaults to `False`):
|
262 |
+
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
263 |
+
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
264 |
+
"""
|
265 |
+
dim: int
|
266 |
+
n_heads: int
|
267 |
+
d_head: int
|
268 |
+
dropout: float = 0.0
|
269 |
+
only_cross_attention: bool = False
|
270 |
+
dtype: jnp.dtype = jnp.float32
|
271 |
+
use_memory_efficient_attention: bool = False
|
272 |
+
split_head_dim: bool = False
|
273 |
+
|
274 |
+
def setup(self):
|
275 |
+
# self attention (or cross_attention if only_cross_attention is True)
|
276 |
+
self.attn1 = FlaxAttention(
|
277 |
+
self.dim,
|
278 |
+
self.n_heads,
|
279 |
+
self.d_head,
|
280 |
+
self.dropout,
|
281 |
+
self.use_memory_efficient_attention,
|
282 |
+
self.split_head_dim,
|
283 |
+
dtype=self.dtype,
|
284 |
+
)
|
285 |
+
# cross attention
|
286 |
+
self.attn2 = FlaxAttention(
|
287 |
+
self.dim,
|
288 |
+
self.n_heads,
|
289 |
+
self.d_head,
|
290 |
+
self.dropout,
|
291 |
+
self.use_memory_efficient_attention,
|
292 |
+
self.split_head_dim,
|
293 |
+
dtype=self.dtype,
|
294 |
+
)
|
295 |
+
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
|
296 |
+
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
297 |
+
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
298 |
+
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
299 |
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
300 |
+
|
301 |
+
def __call__(self, hidden_states, context, deterministic=True):
|
302 |
+
# self attention
|
303 |
+
residual = hidden_states
|
304 |
+
if self.only_cross_attention:
|
305 |
+
hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
|
306 |
+
else:
|
307 |
+
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
|
308 |
+
hidden_states = hidden_states + residual
|
309 |
+
|
310 |
+
# cross attention
|
311 |
+
residual = hidden_states
|
312 |
+
hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
|
313 |
+
hidden_states = hidden_states + residual
|
314 |
+
|
315 |
+
# feed forward
|
316 |
+
residual = hidden_states
|
317 |
+
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
|
318 |
+
hidden_states = hidden_states + residual
|
319 |
+
|
320 |
+
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
321 |
+
|
322 |
+
|
323 |
+
class FlaxTransformer2DModel(nn.Module):
|
324 |
+
r"""
|
325 |
+
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
|
326 |
+
https://arxiv.org/pdf/1506.02025.pdf
|
327 |
+
|
328 |
+
|
329 |
+
Parameters:
|
330 |
+
in_channels (:obj:`int`):
|
331 |
+
Input number of channels
|
332 |
+
n_heads (:obj:`int`):
|
333 |
+
Number of heads
|
334 |
+
d_head (:obj:`int`):
|
335 |
+
Hidden states dimension inside each head
|
336 |
+
depth (:obj:`int`, *optional*, defaults to 1):
|
337 |
+
Number of transformers block
|
338 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
339 |
+
Dropout rate
|
340 |
+
use_linear_projection (`bool`, defaults to `False`): tbd
|
341 |
+
only_cross_attention (`bool`, defaults to `False`): tbd
|
342 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
343 |
+
Parameters `dtype`
|
344 |
+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
345 |
+
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
346 |
+
split_head_dim (`bool`, *optional*, defaults to `False`):
|
347 |
+
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
348 |
+
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
349 |
+
"""
|
350 |
+
in_channels: int
|
351 |
+
n_heads: int
|
352 |
+
d_head: int
|
353 |
+
depth: int = 1
|
354 |
+
dropout: float = 0.0
|
355 |
+
use_linear_projection: bool = False
|
356 |
+
only_cross_attention: bool = False
|
357 |
+
dtype: jnp.dtype = jnp.float32
|
358 |
+
use_memory_efficient_attention: bool = False
|
359 |
+
split_head_dim: bool = False
|
360 |
+
|
361 |
+
def setup(self):
|
362 |
+
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
363 |
+
|
364 |
+
inner_dim = self.n_heads * self.d_head
|
365 |
+
if self.use_linear_projection:
|
366 |
+
self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
|
367 |
+
else:
|
368 |
+
self.proj_in = nn.Conv(
|
369 |
+
inner_dim,
|
370 |
+
kernel_size=(1, 1),
|
371 |
+
strides=(1, 1),
|
372 |
+
padding="VALID",
|
373 |
+
dtype=self.dtype,
|
374 |
+
)
|
375 |
+
|
376 |
+
self.transformer_blocks = [
|
377 |
+
FlaxBasicTransformerBlock(
|
378 |
+
inner_dim,
|
379 |
+
self.n_heads,
|
380 |
+
self.d_head,
|
381 |
+
dropout=self.dropout,
|
382 |
+
only_cross_attention=self.only_cross_attention,
|
383 |
+
dtype=self.dtype,
|
384 |
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
385 |
+
split_head_dim=self.split_head_dim,
|
386 |
+
)
|
387 |
+
for _ in range(self.depth)
|
388 |
+
]
|
389 |
+
|
390 |
+
if self.use_linear_projection:
|
391 |
+
self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
|
392 |
+
else:
|
393 |
+
self.proj_out = nn.Conv(
|
394 |
+
inner_dim,
|
395 |
+
kernel_size=(1, 1),
|
396 |
+
strides=(1, 1),
|
397 |
+
padding="VALID",
|
398 |
+
dtype=self.dtype,
|
399 |
+
)
|
400 |
+
|
401 |
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
402 |
+
|
403 |
+
def __call__(self, hidden_states, context, deterministic=True):
|
404 |
+
batch, height, width, channels = hidden_states.shape
|
405 |
+
residual = hidden_states
|
406 |
+
hidden_states = self.norm(hidden_states)
|
407 |
+
if self.use_linear_projection:
|
408 |
+
hidden_states = hidden_states.reshape(batch, height * width, channels)
|
409 |
+
hidden_states = self.proj_in(hidden_states)
|
410 |
+
else:
|
411 |
+
hidden_states = self.proj_in(hidden_states)
|
412 |
+
hidden_states = hidden_states.reshape(batch, height * width, channels)
|
413 |
+
|
414 |
+
for transformer_block in self.transformer_blocks:
|
415 |
+
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
|
416 |
+
|
417 |
+
if self.use_linear_projection:
|
418 |
+
hidden_states = self.proj_out(hidden_states)
|
419 |
+
hidden_states = hidden_states.reshape(batch, height, width, channels)
|
420 |
+
else:
|
421 |
+
hidden_states = hidden_states.reshape(batch, height, width, channels)
|
422 |
+
hidden_states = self.proj_out(hidden_states)
|
423 |
+
|
424 |
+
hidden_states = hidden_states + residual
|
425 |
+
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
426 |
+
|
427 |
+
|
428 |
+
class FlaxFeedForward(nn.Module):
|
429 |
+
r"""
|
430 |
+
Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
|
431 |
+
[`FeedForward`] class, with the following simplifications:
|
432 |
+
- The activation function is currently hardcoded to a gated linear unit from:
|
433 |
+
https://arxiv.org/abs/2002.05202
|
434 |
+
- `dim_out` is equal to `dim`.
|
435 |
+
- The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
|
436 |
+
|
437 |
+
Parameters:
|
438 |
+
dim (:obj:`int`):
|
439 |
+
Inner hidden states dimension
|
440 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
441 |
+
Dropout rate
|
442 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
443 |
+
Parameters `dtype`
|
444 |
+
"""
|
445 |
+
dim: int
|
446 |
+
dropout: float = 0.0
|
447 |
+
dtype: jnp.dtype = jnp.float32
|
448 |
+
|
449 |
+
def setup(self):
|
450 |
+
# The second linear layer needs to be called
|
451 |
+
# net_2 for now to match the index of the Sequential layer
|
452 |
+
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
|
453 |
+
self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
|
454 |
+
|
455 |
+
def __call__(self, hidden_states, deterministic=True):
|
456 |
+
hidden_states = self.net_0(hidden_states, deterministic=deterministic)
|
457 |
+
hidden_states = self.net_2(hidden_states)
|
458 |
+
return hidden_states
|
459 |
+
|
460 |
+
|
461 |
+
class FlaxGEGLU(nn.Module):
|
462 |
+
r"""
|
463 |
+
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
|
464 |
+
https://arxiv.org/abs/2002.05202.
|
465 |
+
|
466 |
+
Parameters:
|
467 |
+
dim (:obj:`int`):
|
468 |
+
Input hidden states dimension
|
469 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
470 |
+
Dropout rate
|
471 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
472 |
+
Parameters `dtype`
|
473 |
+
"""
|
474 |
+
dim: int
|
475 |
+
dropout: float = 0.0
|
476 |
+
dtype: jnp.dtype = jnp.float32
|
477 |
+
|
478 |
+
def setup(self):
|
479 |
+
inner_dim = self.dim * 4
|
480 |
+
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
|
481 |
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
482 |
+
|
483 |
+
def __call__(self, hidden_states, deterministic=True):
|
484 |
+
hidden_states = self.proj(hidden_states)
|
485 |
+
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
|
486 |
+
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
|
diffusers/models/attention_processor.py
ADDED
@@ -0,0 +1,2020 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from importlib import import_module
|
15 |
+
from typing import Callable, Optional, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from ..utils import USE_PEFT_BACKEND, deprecate, logging
|
22 |
+
from ..utils.import_utils import is_xformers_available
|
23 |
+
from ..utils.torch_utils import maybe_allow_in_graph
|
24 |
+
from .lora import LoRACompatibleLinear, LoRALinearLayer
|
25 |
+
|
26 |
+
|
27 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
28 |
+
|
29 |
+
|
30 |
+
if is_xformers_available():
|
31 |
+
import xformers
|
32 |
+
import xformers.ops
|
33 |
+
else:
|
34 |
+
xformers = None
|
35 |
+
|
36 |
+
|
37 |
+
@maybe_allow_in_graph
|
38 |
+
class Attention(nn.Module):
|
39 |
+
r"""
|
40 |
+
A cross attention layer.
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
query_dim (`int`):
|
44 |
+
The number of channels in the query.
|
45 |
+
cross_attention_dim (`int`, *optional*):
|
46 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
47 |
+
heads (`int`, *optional*, defaults to 8):
|
48 |
+
The number of heads to use for multi-head attention.
|
49 |
+
dim_head (`int`, *optional*, defaults to 64):
|
50 |
+
The number of channels in each head.
|
51 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
52 |
+
The dropout probability to use.
|
53 |
+
bias (`bool`, *optional*, defaults to False):
|
54 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
55 |
+
upcast_attention (`bool`, *optional*, defaults to False):
|
56 |
+
Set to `True` to upcast the attention computation to `float32`.
|
57 |
+
upcast_softmax (`bool`, *optional*, defaults to False):
|
58 |
+
Set to `True` to upcast the softmax computation to `float32`.
|
59 |
+
cross_attention_norm (`str`, *optional*, defaults to `None`):
|
60 |
+
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
|
61 |
+
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
|
62 |
+
The number of groups to use for the group norm in the cross attention.
|
63 |
+
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
64 |
+
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
65 |
+
norm_num_groups (`int`, *optional*, defaults to `None`):
|
66 |
+
The number of groups to use for the group norm in the attention.
|
67 |
+
spatial_norm_dim (`int`, *optional*, defaults to `None`):
|
68 |
+
The number of channels to use for the spatial normalization.
|
69 |
+
out_bias (`bool`, *optional*, defaults to `True`):
|
70 |
+
Set to `True` to use a bias in the output linear layer.
|
71 |
+
scale_qk (`bool`, *optional*, defaults to `True`):
|
72 |
+
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
|
73 |
+
only_cross_attention (`bool`, *optional*, defaults to `False`):
|
74 |
+
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
|
75 |
+
`added_kv_proj_dim` is not `None`.
|
76 |
+
eps (`float`, *optional*, defaults to 1e-5):
|
77 |
+
An additional value added to the denominator in group normalization that is used for numerical stability.
|
78 |
+
rescale_output_factor (`float`, *optional*, defaults to 1.0):
|
79 |
+
A factor to rescale the output by dividing it with this value.
|
80 |
+
residual_connection (`bool`, *optional*, defaults to `False`):
|
81 |
+
Set to `True` to add the residual connection to the output.
|
82 |
+
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
|
83 |
+
Set to `True` if the attention block is loaded from a deprecated state dict.
|
84 |
+
processor (`AttnProcessor`, *optional*, defaults to `None`):
|
85 |
+
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
|
86 |
+
`AttnProcessor` otherwise.
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
query_dim: int,
|
92 |
+
cross_attention_dim: Optional[int] = None,
|
93 |
+
heads: int = 8,
|
94 |
+
dim_head: int = 64,
|
95 |
+
dropout: float = 0.0,
|
96 |
+
bias: bool = False,
|
97 |
+
upcast_attention: bool = False,
|
98 |
+
upcast_softmax: bool = False,
|
99 |
+
cross_attention_norm: Optional[str] = None,
|
100 |
+
cross_attention_norm_num_groups: int = 32,
|
101 |
+
added_kv_proj_dim: Optional[int] = None,
|
102 |
+
norm_num_groups: Optional[int] = None,
|
103 |
+
spatial_norm_dim: Optional[int] = None,
|
104 |
+
out_bias: bool = True,
|
105 |
+
scale_qk: bool = True,
|
106 |
+
only_cross_attention: bool = False,
|
107 |
+
eps: float = 1e-5,
|
108 |
+
rescale_output_factor: float = 1.0,
|
109 |
+
residual_connection: bool = False,
|
110 |
+
_from_deprecated_attn_block: bool = False,
|
111 |
+
processor: Optional["AttnProcessor"] = None,
|
112 |
+
):
|
113 |
+
super().__init__()
|
114 |
+
self.inner_dim = dim_head * heads
|
115 |
+
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
116 |
+
self.upcast_attention = upcast_attention
|
117 |
+
self.upcast_softmax = upcast_softmax
|
118 |
+
self.rescale_output_factor = rescale_output_factor
|
119 |
+
self.residual_connection = residual_connection
|
120 |
+
self.dropout = dropout
|
121 |
+
|
122 |
+
# we make use of this private variable to know whether this class is loaded
|
123 |
+
# with an deprecated state dict so that we can convert it on the fly
|
124 |
+
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
125 |
+
|
126 |
+
self.scale_qk = scale_qk
|
127 |
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
128 |
+
|
129 |
+
self.heads = heads
|
130 |
+
# for slice_size > 0 the attention score computation
|
131 |
+
# is split across the batch axis to save memory
|
132 |
+
# You can set slice_size with `set_attention_slice`
|
133 |
+
self.sliceable_head_dim = heads
|
134 |
+
|
135 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
136 |
+
self.only_cross_attention = only_cross_attention
|
137 |
+
|
138 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
139 |
+
raise ValueError(
|
140 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
141 |
+
)
|
142 |
+
|
143 |
+
if norm_num_groups is not None:
|
144 |
+
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
145 |
+
else:
|
146 |
+
self.group_norm = None
|
147 |
+
|
148 |
+
if spatial_norm_dim is not None:
|
149 |
+
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
|
150 |
+
else:
|
151 |
+
self.spatial_norm = None
|
152 |
+
|
153 |
+
if cross_attention_norm is None:
|
154 |
+
self.norm_cross = None
|
155 |
+
elif cross_attention_norm == "layer_norm":
|
156 |
+
self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
|
157 |
+
elif cross_attention_norm == "group_norm":
|
158 |
+
if self.added_kv_proj_dim is not None:
|
159 |
+
# The given `encoder_hidden_states` are initially of shape
|
160 |
+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
161 |
+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
162 |
+
# before the projection, so we need to use `added_kv_proj_dim` as
|
163 |
+
# the number of channels for the group norm.
|
164 |
+
norm_cross_num_channels = added_kv_proj_dim
|
165 |
+
else:
|
166 |
+
norm_cross_num_channels = self.cross_attention_dim
|
167 |
+
|
168 |
+
self.norm_cross = nn.GroupNorm(
|
169 |
+
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
|
170 |
+
)
|
171 |
+
else:
|
172 |
+
raise ValueError(
|
173 |
+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
174 |
+
)
|
175 |
+
|
176 |
+
if USE_PEFT_BACKEND:
|
177 |
+
linear_cls = nn.Linear
|
178 |
+
else:
|
179 |
+
linear_cls = LoRACompatibleLinear
|
180 |
+
|
181 |
+
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
182 |
+
|
183 |
+
if not self.only_cross_attention:
|
184 |
+
# only relevant for the `AddedKVProcessor` classes
|
185 |
+
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
|
186 |
+
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
|
187 |
+
else:
|
188 |
+
self.to_k = None
|
189 |
+
self.to_v = None
|
190 |
+
|
191 |
+
if self.added_kv_proj_dim is not None:
|
192 |
+
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
193 |
+
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
194 |
+
|
195 |
+
self.to_out = nn.ModuleList([])
|
196 |
+
self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
|
197 |
+
self.to_out.append(nn.Dropout(dropout))
|
198 |
+
|
199 |
+
# set attention processor
|
200 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
201 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
202 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
203 |
+
if processor is None:
|
204 |
+
processor = (
|
205 |
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
206 |
+
)
|
207 |
+
self.set_processor(processor)
|
208 |
+
|
209 |
+
def set_use_memory_efficient_attention_xformers(
|
210 |
+
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
211 |
+
) -> None:
|
212 |
+
r"""
|
213 |
+
Set whether to use memory efficient attention from `xformers` or not.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
use_memory_efficient_attention_xformers (`bool`):
|
217 |
+
Whether to use memory efficient attention from `xformers` or not.
|
218 |
+
attention_op (`Callable`, *optional*):
|
219 |
+
The attention operation to use. Defaults to `None` which uses the default attention operation from
|
220 |
+
`xformers`.
|
221 |
+
"""
|
222 |
+
is_lora = hasattr(self, "processor") and isinstance(
|
223 |
+
self.processor,
|
224 |
+
LORA_ATTENTION_PROCESSORS,
|
225 |
+
)
|
226 |
+
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
227 |
+
self.processor,
|
228 |
+
(CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
|
229 |
+
)
|
230 |
+
is_added_kv_processor = hasattr(self, "processor") and isinstance(
|
231 |
+
self.processor,
|
232 |
+
(
|
233 |
+
AttnAddedKVProcessor,
|
234 |
+
AttnAddedKVProcessor2_0,
|
235 |
+
SlicedAttnAddedKVProcessor,
|
236 |
+
XFormersAttnAddedKVProcessor,
|
237 |
+
LoRAAttnAddedKVProcessor,
|
238 |
+
),
|
239 |
+
)
|
240 |
+
|
241 |
+
if use_memory_efficient_attention_xformers:
|
242 |
+
if is_added_kv_processor and (is_lora or is_custom_diffusion):
|
243 |
+
raise NotImplementedError(
|
244 |
+
f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
|
245 |
+
)
|
246 |
+
if not is_xformers_available():
|
247 |
+
raise ModuleNotFoundError(
|
248 |
+
(
|
249 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
250 |
+
" xformers"
|
251 |
+
),
|
252 |
+
name="xformers",
|
253 |
+
)
|
254 |
+
elif not torch.cuda.is_available():
|
255 |
+
raise ValueError(
|
256 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
257 |
+
" only available for GPU "
|
258 |
+
)
|
259 |
+
else:
|
260 |
+
try:
|
261 |
+
# Make sure we can run the memory efficient attention
|
262 |
+
_ = xformers.ops.memory_efficient_attention(
|
263 |
+
torch.randn((1, 2, 40), device="cuda"),
|
264 |
+
torch.randn((1, 2, 40), device="cuda"),
|
265 |
+
torch.randn((1, 2, 40), device="cuda"),
|
266 |
+
)
|
267 |
+
except Exception as e:
|
268 |
+
raise e
|
269 |
+
|
270 |
+
if is_lora:
|
271 |
+
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
|
272 |
+
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
|
273 |
+
processor = LoRAXFormersAttnProcessor(
|
274 |
+
hidden_size=self.processor.hidden_size,
|
275 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
276 |
+
rank=self.processor.rank,
|
277 |
+
attention_op=attention_op,
|
278 |
+
)
|
279 |
+
processor.load_state_dict(self.processor.state_dict())
|
280 |
+
processor.to(self.processor.to_q_lora.up.weight.device)
|
281 |
+
elif is_custom_diffusion:
|
282 |
+
processor = CustomDiffusionXFormersAttnProcessor(
|
283 |
+
train_kv=self.processor.train_kv,
|
284 |
+
train_q_out=self.processor.train_q_out,
|
285 |
+
hidden_size=self.processor.hidden_size,
|
286 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
287 |
+
attention_op=attention_op,
|
288 |
+
)
|
289 |
+
processor.load_state_dict(self.processor.state_dict())
|
290 |
+
if hasattr(self.processor, "to_k_custom_diffusion"):
|
291 |
+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
292 |
+
elif is_added_kv_processor:
|
293 |
+
# TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
|
294 |
+
# which uses this type of cross attention ONLY because the attention mask of format
|
295 |
+
# [0, ..., -10.000, ..., 0, ...,] is not supported
|
296 |
+
# throw warning
|
297 |
+
logger.info(
|
298 |
+
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
|
299 |
+
)
|
300 |
+
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
|
301 |
+
else:
|
302 |
+
processor = XFormersAttnProcessor(attention_op=attention_op)
|
303 |
+
else:
|
304 |
+
if is_lora:
|
305 |
+
attn_processor_class = (
|
306 |
+
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
307 |
+
)
|
308 |
+
processor = attn_processor_class(
|
309 |
+
hidden_size=self.processor.hidden_size,
|
310 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
311 |
+
rank=self.processor.rank,
|
312 |
+
)
|
313 |
+
processor.load_state_dict(self.processor.state_dict())
|
314 |
+
processor.to(self.processor.to_q_lora.up.weight.device)
|
315 |
+
elif is_custom_diffusion:
|
316 |
+
attn_processor_class = (
|
317 |
+
CustomDiffusionAttnProcessor2_0
|
318 |
+
if hasattr(F, "scaled_dot_product_attention")
|
319 |
+
else CustomDiffusionAttnProcessor
|
320 |
+
)
|
321 |
+
processor = attn_processor_class(
|
322 |
+
train_kv=self.processor.train_kv,
|
323 |
+
train_q_out=self.processor.train_q_out,
|
324 |
+
hidden_size=self.processor.hidden_size,
|
325 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
326 |
+
)
|
327 |
+
processor.load_state_dict(self.processor.state_dict())
|
328 |
+
if hasattr(self.processor, "to_k_custom_diffusion"):
|
329 |
+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
330 |
+
else:
|
331 |
+
# set attention processor
|
332 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
333 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
334 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
335 |
+
processor = (
|
336 |
+
AttnProcessor2_0()
|
337 |
+
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
338 |
+
else AttnProcessor()
|
339 |
+
)
|
340 |
+
|
341 |
+
self.set_processor(processor)
|
342 |
+
|
343 |
+
def set_attention_slice(self, slice_size: int) -> None:
|
344 |
+
r"""
|
345 |
+
Set the slice size for attention computation.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
slice_size (`int`):
|
349 |
+
The slice size for attention computation.
|
350 |
+
"""
|
351 |
+
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
352 |
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
353 |
+
|
354 |
+
if slice_size is not None and self.added_kv_proj_dim is not None:
|
355 |
+
processor = SlicedAttnAddedKVProcessor(slice_size)
|
356 |
+
elif slice_size is not None:
|
357 |
+
processor = SlicedAttnProcessor(slice_size)
|
358 |
+
elif self.added_kv_proj_dim is not None:
|
359 |
+
processor = AttnAddedKVProcessor()
|
360 |
+
else:
|
361 |
+
# set attention processor
|
362 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
363 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
364 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
365 |
+
processor = (
|
366 |
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
367 |
+
)
|
368 |
+
|
369 |
+
self.set_processor(processor)
|
370 |
+
|
371 |
+
def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None:
|
372 |
+
r"""
|
373 |
+
Set the attention processor to use.
|
374 |
+
|
375 |
+
Args:
|
376 |
+
processor (`AttnProcessor`):
|
377 |
+
The attention processor to use.
|
378 |
+
_remove_lora (`bool`, *optional*, defaults to `False`):
|
379 |
+
Set to `True` to remove LoRA layers from the model.
|
380 |
+
"""
|
381 |
+
if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
|
382 |
+
deprecate(
|
383 |
+
"set_processor to offload LoRA",
|
384 |
+
"0.26.0",
|
385 |
+
"In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
|
386 |
+
)
|
387 |
+
# TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
|
388 |
+
# We need to remove all LoRA layers
|
389 |
+
# Don't forget to remove ALL `_remove_lora` from the codebase
|
390 |
+
for module in self.modules():
|
391 |
+
if hasattr(module, "set_lora_layer"):
|
392 |
+
module.set_lora_layer(None)
|
393 |
+
|
394 |
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
395 |
+
# pop `processor` from `self._modules`
|
396 |
+
if (
|
397 |
+
hasattr(self, "processor")
|
398 |
+
and isinstance(self.processor, torch.nn.Module)
|
399 |
+
and not isinstance(processor, torch.nn.Module)
|
400 |
+
):
|
401 |
+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
402 |
+
self._modules.pop("processor")
|
403 |
+
|
404 |
+
self.processor = processor
|
405 |
+
|
406 |
+
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
|
407 |
+
r"""
|
408 |
+
Get the attention processor in use.
|
409 |
+
|
410 |
+
Args:
|
411 |
+
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
|
412 |
+
Set to `True` to return the deprecated LoRA attention processor.
|
413 |
+
|
414 |
+
Returns:
|
415 |
+
"AttentionProcessor": The attention processor in use.
|
416 |
+
"""
|
417 |
+
if not return_deprecated_lora:
|
418 |
+
return self.processor
|
419 |
+
|
420 |
+
# TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
|
421 |
+
# serialization format for LoRA Attention Processors. It should be deleted once the integration
|
422 |
+
# with PEFT is completed.
|
423 |
+
is_lora_activated = {
|
424 |
+
name: module.lora_layer is not None
|
425 |
+
for name, module in self.named_modules()
|
426 |
+
if hasattr(module, "lora_layer")
|
427 |
+
}
|
428 |
+
|
429 |
+
# 1. if no layer has a LoRA activated we can return the processor as usual
|
430 |
+
if not any(is_lora_activated.values()):
|
431 |
+
return self.processor
|
432 |
+
|
433 |
+
# If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
|
434 |
+
is_lora_activated.pop("add_k_proj", None)
|
435 |
+
is_lora_activated.pop("add_v_proj", None)
|
436 |
+
# 2. else it is not posssible that only some layers have LoRA activated
|
437 |
+
if not all(is_lora_activated.values()):
|
438 |
+
raise ValueError(
|
439 |
+
f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
|
440 |
+
)
|
441 |
+
|
442 |
+
# 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
|
443 |
+
non_lora_processor_cls_name = self.processor.__class__.__name__
|
444 |
+
lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
|
445 |
+
|
446 |
+
hidden_size = self.inner_dim
|
447 |
+
|
448 |
+
# now create a LoRA attention processor from the LoRA layers
|
449 |
+
if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
|
450 |
+
kwargs = {
|
451 |
+
"cross_attention_dim": self.cross_attention_dim,
|
452 |
+
"rank": self.to_q.lora_layer.rank,
|
453 |
+
"network_alpha": self.to_q.lora_layer.network_alpha,
|
454 |
+
"q_rank": self.to_q.lora_layer.rank,
|
455 |
+
"q_hidden_size": self.to_q.lora_layer.out_features,
|
456 |
+
"k_rank": self.to_k.lora_layer.rank,
|
457 |
+
"k_hidden_size": self.to_k.lora_layer.out_features,
|
458 |
+
"v_rank": self.to_v.lora_layer.rank,
|
459 |
+
"v_hidden_size": self.to_v.lora_layer.out_features,
|
460 |
+
"out_rank": self.to_out[0].lora_layer.rank,
|
461 |
+
"out_hidden_size": self.to_out[0].lora_layer.out_features,
|
462 |
+
}
|
463 |
+
|
464 |
+
if hasattr(self.processor, "attention_op"):
|
465 |
+
kwargs["attention_op"] = self.processor.attention_op
|
466 |
+
|
467 |
+
lora_processor = lora_processor_cls(hidden_size, **kwargs)
|
468 |
+
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
469 |
+
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
470 |
+
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
471 |
+
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
|
472 |
+
elif lora_processor_cls == LoRAAttnAddedKVProcessor:
|
473 |
+
lora_processor = lora_processor_cls(
|
474 |
+
hidden_size,
|
475 |
+
cross_attention_dim=self.add_k_proj.weight.shape[0],
|
476 |
+
rank=self.to_q.lora_layer.rank,
|
477 |
+
network_alpha=self.to_q.lora_layer.network_alpha,
|
478 |
+
)
|
479 |
+
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
480 |
+
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
481 |
+
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
482 |
+
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
|
483 |
+
|
484 |
+
# only save if used
|
485 |
+
if self.add_k_proj.lora_layer is not None:
|
486 |
+
lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
|
487 |
+
lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
|
488 |
+
else:
|
489 |
+
lora_processor.add_k_proj_lora = None
|
490 |
+
lora_processor.add_v_proj_lora = None
|
491 |
+
else:
|
492 |
+
raise ValueError(f"{lora_processor_cls} does not exist.")
|
493 |
+
|
494 |
+
return lora_processor
|
495 |
+
|
496 |
+
def forward(
|
497 |
+
self,
|
498 |
+
hidden_states: torch.FloatTensor,
|
499 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
500 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
501 |
+
**cross_attention_kwargs,
|
502 |
+
) -> torch.Tensor:
|
503 |
+
r"""
|
504 |
+
The forward method of the `Attention` class.
|
505 |
+
|
506 |
+
Args:
|
507 |
+
hidden_states (`torch.Tensor`):
|
508 |
+
The hidden states of the query.
|
509 |
+
encoder_hidden_states (`torch.Tensor`, *optional*):
|
510 |
+
The hidden states of the encoder.
|
511 |
+
attention_mask (`torch.Tensor`, *optional*):
|
512 |
+
The attention mask to use. If `None`, no mask is applied.
|
513 |
+
**cross_attention_kwargs:
|
514 |
+
Additional keyword arguments to pass along to the cross attention.
|
515 |
+
|
516 |
+
Returns:
|
517 |
+
`torch.Tensor`: The output of the attention layer.
|
518 |
+
"""
|
519 |
+
# The `Attention` class can call different attention processors / attention functions
|
520 |
+
# here we simply pass along all tensors to the selected processor class
|
521 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
522 |
+
return self.processor(
|
523 |
+
self,
|
524 |
+
hidden_states,
|
525 |
+
encoder_hidden_states=encoder_hidden_states,
|
526 |
+
attention_mask=attention_mask,
|
527 |
+
**cross_attention_kwargs,
|
528 |
+
)
|
529 |
+
|
530 |
+
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
|
531 |
+
r"""
|
532 |
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
|
533 |
+
is the number of heads initialized while constructing the `Attention` class.
|
534 |
+
|
535 |
+
Args:
|
536 |
+
tensor (`torch.Tensor`): The tensor to reshape.
|
537 |
+
|
538 |
+
Returns:
|
539 |
+
`torch.Tensor`: The reshaped tensor.
|
540 |
+
"""
|
541 |
+
head_size = self.heads
|
542 |
+
batch_size, seq_len, dim = tensor.shape
|
543 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
544 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
545 |
+
return tensor
|
546 |
+
|
547 |
+
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
548 |
+
r"""
|
549 |
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
|
550 |
+
the number of heads initialized while constructing the `Attention` class.
|
551 |
+
|
552 |
+
Args:
|
553 |
+
tensor (`torch.Tensor`): The tensor to reshape.
|
554 |
+
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
|
555 |
+
reshaped to `[batch_size * heads, seq_len, dim // heads]`.
|
556 |
+
|
557 |
+
Returns:
|
558 |
+
`torch.Tensor`: The reshaped tensor.
|
559 |
+
"""
|
560 |
+
head_size = self.heads
|
561 |
+
batch_size, seq_len, dim = tensor.shape
|
562 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
563 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
564 |
+
|
565 |
+
if out_dim == 3:
|
566 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
567 |
+
|
568 |
+
return tensor
|
569 |
+
|
570 |
+
def get_attention_scores(
|
571 |
+
self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
|
572 |
+
) -> torch.Tensor:
|
573 |
+
r"""
|
574 |
+
Compute the attention scores.
|
575 |
+
|
576 |
+
Args:
|
577 |
+
query (`torch.Tensor`): The query tensor.
|
578 |
+
key (`torch.Tensor`): The key tensor.
|
579 |
+
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
|
580 |
+
|
581 |
+
Returns:
|
582 |
+
`torch.Tensor`: The attention probabilities/scores.
|
583 |
+
"""
|
584 |
+
dtype = query.dtype
|
585 |
+
if self.upcast_attention:
|
586 |
+
query = query.float()
|
587 |
+
key = key.float()
|
588 |
+
|
589 |
+
if attention_mask is None:
|
590 |
+
baddbmm_input = torch.empty(
|
591 |
+
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
592 |
+
)
|
593 |
+
beta = 0
|
594 |
+
else:
|
595 |
+
baddbmm_input = attention_mask
|
596 |
+
beta = 1
|
597 |
+
|
598 |
+
attention_scores = torch.baddbmm(
|
599 |
+
baddbmm_input,
|
600 |
+
query,
|
601 |
+
key.transpose(-1, -2),
|
602 |
+
beta=beta,
|
603 |
+
alpha=self.scale,
|
604 |
+
)
|
605 |
+
del baddbmm_input
|
606 |
+
|
607 |
+
if self.upcast_softmax:
|
608 |
+
attention_scores = attention_scores.float()
|
609 |
+
|
610 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
611 |
+
del attention_scores
|
612 |
+
|
613 |
+
attention_probs = attention_probs.to(dtype)
|
614 |
+
|
615 |
+
return attention_probs
|
616 |
+
|
617 |
+
def prepare_attention_mask(
|
618 |
+
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
|
619 |
+
) -> torch.Tensor:
|
620 |
+
r"""
|
621 |
+
Prepare the attention mask for the attention computation.
|
622 |
+
|
623 |
+
Args:
|
624 |
+
attention_mask (`torch.Tensor`):
|
625 |
+
The attention mask to prepare.
|
626 |
+
target_length (`int`):
|
627 |
+
The target length of the attention mask. This is the length of the attention mask after padding.
|
628 |
+
batch_size (`int`):
|
629 |
+
The batch size, which is used to repeat the attention mask.
|
630 |
+
out_dim (`int`, *optional*, defaults to `3`):
|
631 |
+
The output dimension of the attention mask. Can be either `3` or `4`.
|
632 |
+
|
633 |
+
Returns:
|
634 |
+
`torch.Tensor`: The prepared attention mask.
|
635 |
+
"""
|
636 |
+
head_size = self.heads
|
637 |
+
if attention_mask is None:
|
638 |
+
return attention_mask
|
639 |
+
|
640 |
+
current_length: int = attention_mask.shape[-1]
|
641 |
+
if current_length != target_length:
|
642 |
+
if attention_mask.device.type == "mps":
|
643 |
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
644 |
+
# Instead, we can manually construct the padding tensor.
|
645 |
+
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
646 |
+
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
647 |
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
648 |
+
else:
|
649 |
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
650 |
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
651 |
+
# remaining_length: int = target_length - current_length
|
652 |
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
653 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
654 |
+
|
655 |
+
if out_dim == 3:
|
656 |
+
if attention_mask.shape[0] < batch_size * head_size:
|
657 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
658 |
+
elif out_dim == 4:
|
659 |
+
attention_mask = attention_mask.unsqueeze(1)
|
660 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
661 |
+
|
662 |
+
return attention_mask
|
663 |
+
|
664 |
+
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
665 |
+
r"""
|
666 |
+
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
|
667 |
+
`Attention` class.
|
668 |
+
|
669 |
+
Args:
|
670 |
+
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
|
671 |
+
|
672 |
+
Returns:
|
673 |
+
`torch.Tensor`: The normalized encoder hidden states.
|
674 |
+
"""
|
675 |
+
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
676 |
+
|
677 |
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
678 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
679 |
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
680 |
+
# Group norm norms along the channels dimension and expects
|
681 |
+
# input to be in the shape of (N, C, *). In this case, we want
|
682 |
+
# to norm along the hidden dimension, so we need to move
|
683 |
+
# (batch_size, sequence_length, hidden_size) ->
|
684 |
+
# (batch_size, hidden_size, sequence_length)
|
685 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
686 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
687 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
688 |
+
else:
|
689 |
+
assert False
|
690 |
+
|
691 |
+
return encoder_hidden_states
|
692 |
+
|
693 |
+
|
694 |
+
class AttnProcessor:
|
695 |
+
r"""
|
696 |
+
Default processor for performing attention-related computations.
|
697 |
+
"""
|
698 |
+
|
699 |
+
def __call__(
|
700 |
+
self,
|
701 |
+
attn: Attention,
|
702 |
+
hidden_states: torch.FloatTensor,
|
703 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
704 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
705 |
+
temb: Optional[torch.FloatTensor] = None,
|
706 |
+
scale: float = 1.0,
|
707 |
+
) -> torch.Tensor:
|
708 |
+
residual = hidden_states
|
709 |
+
|
710 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
711 |
+
|
712 |
+
if attn.spatial_norm is not None:
|
713 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
714 |
+
|
715 |
+
input_ndim = hidden_states.ndim
|
716 |
+
|
717 |
+
if input_ndim == 4:
|
718 |
+
batch_size, channel, height, width = hidden_states.shape
|
719 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
720 |
+
|
721 |
+
batch_size, sequence_length, _ = (
|
722 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
723 |
+
)
|
724 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
725 |
+
|
726 |
+
if attn.group_norm is not None:
|
727 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
728 |
+
|
729 |
+
query = attn.to_q(hidden_states, *args)
|
730 |
+
|
731 |
+
if encoder_hidden_states is None:
|
732 |
+
encoder_hidden_states = hidden_states
|
733 |
+
elif attn.norm_cross:
|
734 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
735 |
+
|
736 |
+
key = attn.to_k(encoder_hidden_states, *args)
|
737 |
+
value = attn.to_v(encoder_hidden_states, *args)
|
738 |
+
|
739 |
+
query = attn.head_to_batch_dim(query)
|
740 |
+
key = attn.head_to_batch_dim(key)
|
741 |
+
value = attn.head_to_batch_dim(value)
|
742 |
+
|
743 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
744 |
+
hidden_states = torch.bmm(attention_probs, value)
|
745 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
746 |
+
|
747 |
+
# linear proj
|
748 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
749 |
+
# dropout
|
750 |
+
hidden_states = attn.to_out[1](hidden_states)
|
751 |
+
|
752 |
+
if input_ndim == 4:
|
753 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
754 |
+
|
755 |
+
if attn.residual_connection:
|
756 |
+
hidden_states = hidden_states + residual
|
757 |
+
|
758 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
759 |
+
|
760 |
+
return hidden_states
|
761 |
+
|
762 |
+
|
763 |
+
class CustomDiffusionAttnProcessor(nn.Module):
|
764 |
+
r"""
|
765 |
+
Processor for implementing attention for the Custom Diffusion method.
|
766 |
+
|
767 |
+
Args:
|
768 |
+
train_kv (`bool`, defaults to `True`):
|
769 |
+
Whether to newly train the key and value matrices corresponding to the text features.
|
770 |
+
train_q_out (`bool`, defaults to `True`):
|
771 |
+
Whether to newly train query matrices corresponding to the latent image features.
|
772 |
+
hidden_size (`int`, *optional*, defaults to `None`):
|
773 |
+
The hidden size of the attention layer.
|
774 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
775 |
+
The number of channels in the `encoder_hidden_states`.
|
776 |
+
out_bias (`bool`, defaults to `True`):
|
777 |
+
Whether to include the bias parameter in `train_q_out`.
|
778 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
779 |
+
The dropout probability to use.
|
780 |
+
"""
|
781 |
+
|
782 |
+
def __init__(
|
783 |
+
self,
|
784 |
+
train_kv: bool = True,
|
785 |
+
train_q_out: bool = True,
|
786 |
+
hidden_size: Optional[int] = None,
|
787 |
+
cross_attention_dim: Optional[int] = None,
|
788 |
+
out_bias: bool = True,
|
789 |
+
dropout: float = 0.0,
|
790 |
+
):
|
791 |
+
super().__init__()
|
792 |
+
self.train_kv = train_kv
|
793 |
+
self.train_q_out = train_q_out
|
794 |
+
|
795 |
+
self.hidden_size = hidden_size
|
796 |
+
self.cross_attention_dim = cross_attention_dim
|
797 |
+
|
798 |
+
# `_custom_diffusion` id for easy serialization and loading.
|
799 |
+
if self.train_kv:
|
800 |
+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
801 |
+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
802 |
+
if self.train_q_out:
|
803 |
+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
804 |
+
self.to_out_custom_diffusion = nn.ModuleList([])
|
805 |
+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
806 |
+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
807 |
+
|
808 |
+
def __call__(
|
809 |
+
self,
|
810 |
+
attn: Attention,
|
811 |
+
hidden_states: torch.FloatTensor,
|
812 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
813 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
814 |
+
) -> torch.Tensor:
|
815 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
816 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
817 |
+
if self.train_q_out:
|
818 |
+
query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
|
819 |
+
else:
|
820 |
+
query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
|
821 |
+
|
822 |
+
if encoder_hidden_states is None:
|
823 |
+
crossattn = False
|
824 |
+
encoder_hidden_states = hidden_states
|
825 |
+
else:
|
826 |
+
crossattn = True
|
827 |
+
if attn.norm_cross:
|
828 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
829 |
+
|
830 |
+
if self.train_kv:
|
831 |
+
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
|
832 |
+
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
|
833 |
+
key = key.to(attn.to_q.weight.dtype)
|
834 |
+
value = value.to(attn.to_q.weight.dtype)
|
835 |
+
else:
|
836 |
+
key = attn.to_k(encoder_hidden_states)
|
837 |
+
value = attn.to_v(encoder_hidden_states)
|
838 |
+
|
839 |
+
if crossattn:
|
840 |
+
detach = torch.ones_like(key)
|
841 |
+
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
842 |
+
key = detach * key + (1 - detach) * key.detach()
|
843 |
+
value = detach * value + (1 - detach) * value.detach()
|
844 |
+
|
845 |
+
query = attn.head_to_batch_dim(query)
|
846 |
+
key = attn.head_to_batch_dim(key)
|
847 |
+
value = attn.head_to_batch_dim(value)
|
848 |
+
|
849 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
850 |
+
hidden_states = torch.bmm(attention_probs, value)
|
851 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
852 |
+
|
853 |
+
if self.train_q_out:
|
854 |
+
# linear proj
|
855 |
+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
856 |
+
# dropout
|
857 |
+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
858 |
+
else:
|
859 |
+
# linear proj
|
860 |
+
hidden_states = attn.to_out[0](hidden_states)
|
861 |
+
# dropout
|
862 |
+
hidden_states = attn.to_out[1](hidden_states)
|
863 |
+
|
864 |
+
return hidden_states
|
865 |
+
|
866 |
+
|
867 |
+
class AttnAddedKVProcessor:
|
868 |
+
r"""
|
869 |
+
Processor for performing attention-related computations with extra learnable key and value matrices for the text
|
870 |
+
encoder.
|
871 |
+
"""
|
872 |
+
|
873 |
+
def __call__(
|
874 |
+
self,
|
875 |
+
attn: Attention,
|
876 |
+
hidden_states: torch.FloatTensor,
|
877 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
878 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
879 |
+
scale: float = 1.0,
|
880 |
+
) -> torch.Tensor:
|
881 |
+
residual = hidden_states
|
882 |
+
|
883 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
884 |
+
|
885 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
886 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
887 |
+
|
888 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
889 |
+
|
890 |
+
if encoder_hidden_states is None:
|
891 |
+
encoder_hidden_states = hidden_states
|
892 |
+
elif attn.norm_cross:
|
893 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
894 |
+
|
895 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
896 |
+
|
897 |
+
query = attn.to_q(hidden_states, *args)
|
898 |
+
query = attn.head_to_batch_dim(query)
|
899 |
+
|
900 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
|
901 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
|
902 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
903 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
904 |
+
|
905 |
+
if not attn.only_cross_attention:
|
906 |
+
key = attn.to_k(hidden_states, *args)
|
907 |
+
value = attn.to_v(hidden_states, *args)
|
908 |
+
key = attn.head_to_batch_dim(key)
|
909 |
+
value = attn.head_to_batch_dim(value)
|
910 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
911 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
912 |
+
else:
|
913 |
+
key = encoder_hidden_states_key_proj
|
914 |
+
value = encoder_hidden_states_value_proj
|
915 |
+
|
916 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
917 |
+
hidden_states = torch.bmm(attention_probs, value)
|
918 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
919 |
+
|
920 |
+
# linear proj
|
921 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
922 |
+
# dropout
|
923 |
+
hidden_states = attn.to_out[1](hidden_states)
|
924 |
+
|
925 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
926 |
+
hidden_states = hidden_states + residual
|
927 |
+
|
928 |
+
return hidden_states
|
929 |
+
|
930 |
+
|
931 |
+
class AttnAddedKVProcessor2_0:
|
932 |
+
r"""
|
933 |
+
Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
|
934 |
+
learnable key and value matrices for the text encoder.
|
935 |
+
"""
|
936 |
+
|
937 |
+
def __init__(self):
|
938 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
939 |
+
raise ImportError(
|
940 |
+
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
941 |
+
)
|
942 |
+
|
943 |
+
def __call__(
|
944 |
+
self,
|
945 |
+
attn: Attention,
|
946 |
+
hidden_states: torch.FloatTensor,
|
947 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
948 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
949 |
+
scale: float = 1.0,
|
950 |
+
) -> torch.Tensor:
|
951 |
+
residual = hidden_states
|
952 |
+
|
953 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
954 |
+
|
955 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
956 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
957 |
+
|
958 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
|
959 |
+
|
960 |
+
if encoder_hidden_states is None:
|
961 |
+
encoder_hidden_states = hidden_states
|
962 |
+
elif attn.norm_cross:
|
963 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
964 |
+
|
965 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
966 |
+
|
967 |
+
query = attn.to_q(hidden_states, *args)
|
968 |
+
query = attn.head_to_batch_dim(query, out_dim=4)
|
969 |
+
|
970 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
971 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
972 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
|
973 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
|
974 |
+
|
975 |
+
if not attn.only_cross_attention:
|
976 |
+
key = attn.to_k(hidden_states, *args)
|
977 |
+
value = attn.to_v(hidden_states, *args)
|
978 |
+
key = attn.head_to_batch_dim(key, out_dim=4)
|
979 |
+
value = attn.head_to_batch_dim(value, out_dim=4)
|
980 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
981 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
982 |
+
else:
|
983 |
+
key = encoder_hidden_states_key_proj
|
984 |
+
value = encoder_hidden_states_value_proj
|
985 |
+
|
986 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
987 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
988 |
+
hidden_states = F.scaled_dot_product_attention(
|
989 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
990 |
+
)
|
991 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
|
992 |
+
|
993 |
+
# linear proj
|
994 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
995 |
+
# dropout
|
996 |
+
hidden_states = attn.to_out[1](hidden_states)
|
997 |
+
|
998 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
999 |
+
hidden_states = hidden_states + residual
|
1000 |
+
|
1001 |
+
return hidden_states
|
1002 |
+
|
1003 |
+
|
1004 |
+
class XFormersAttnAddedKVProcessor:
|
1005 |
+
r"""
|
1006 |
+
Processor for implementing memory efficient attention using xFormers.
|
1007 |
+
|
1008 |
+
Args:
|
1009 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1010 |
+
The base
|
1011 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
1012 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
1013 |
+
operator.
|
1014 |
+
"""
|
1015 |
+
|
1016 |
+
def __init__(self, attention_op: Optional[Callable] = None):
|
1017 |
+
self.attention_op = attention_op
|
1018 |
+
|
1019 |
+
def __call__(
|
1020 |
+
self,
|
1021 |
+
attn: Attention,
|
1022 |
+
hidden_states: torch.FloatTensor,
|
1023 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1024 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1025 |
+
) -> torch.Tensor:
|
1026 |
+
residual = hidden_states
|
1027 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
1028 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
1029 |
+
|
1030 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1031 |
+
|
1032 |
+
if encoder_hidden_states is None:
|
1033 |
+
encoder_hidden_states = hidden_states
|
1034 |
+
elif attn.norm_cross:
|
1035 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1036 |
+
|
1037 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1038 |
+
|
1039 |
+
query = attn.to_q(hidden_states)
|
1040 |
+
query = attn.head_to_batch_dim(query)
|
1041 |
+
|
1042 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1043 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1044 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
1045 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
1046 |
+
|
1047 |
+
if not attn.only_cross_attention:
|
1048 |
+
key = attn.to_k(hidden_states)
|
1049 |
+
value = attn.to_v(hidden_states)
|
1050 |
+
key = attn.head_to_batch_dim(key)
|
1051 |
+
value = attn.head_to_batch_dim(value)
|
1052 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
1053 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
1054 |
+
else:
|
1055 |
+
key = encoder_hidden_states_key_proj
|
1056 |
+
value = encoder_hidden_states_value_proj
|
1057 |
+
|
1058 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1059 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1060 |
+
)
|
1061 |
+
hidden_states = hidden_states.to(query.dtype)
|
1062 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1063 |
+
|
1064 |
+
# linear proj
|
1065 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1066 |
+
# dropout
|
1067 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1068 |
+
|
1069 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
1070 |
+
hidden_states = hidden_states + residual
|
1071 |
+
|
1072 |
+
return hidden_states
|
1073 |
+
|
1074 |
+
|
1075 |
+
class XFormersAttnProcessor:
|
1076 |
+
r"""
|
1077 |
+
Processor for implementing memory efficient attention using xFormers.
|
1078 |
+
|
1079 |
+
Args:
|
1080 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1081 |
+
The base
|
1082 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
1083 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
1084 |
+
operator.
|
1085 |
+
"""
|
1086 |
+
|
1087 |
+
def __init__(self, attention_op: Optional[Callable] = None):
|
1088 |
+
self.attention_op = attention_op
|
1089 |
+
|
1090 |
+
def __call__(
|
1091 |
+
self,
|
1092 |
+
attn: Attention,
|
1093 |
+
hidden_states: torch.FloatTensor,
|
1094 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1095 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1096 |
+
temb: Optional[torch.FloatTensor] = None,
|
1097 |
+
scale: float = 1.0,
|
1098 |
+
) -> torch.FloatTensor:
|
1099 |
+
residual = hidden_states
|
1100 |
+
|
1101 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
1102 |
+
|
1103 |
+
if attn.spatial_norm is not None:
|
1104 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1105 |
+
|
1106 |
+
input_ndim = hidden_states.ndim
|
1107 |
+
|
1108 |
+
if input_ndim == 4:
|
1109 |
+
batch_size, channel, height, width = hidden_states.shape
|
1110 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1111 |
+
|
1112 |
+
batch_size, key_tokens, _ = (
|
1113 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1114 |
+
)
|
1115 |
+
|
1116 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
1117 |
+
if attention_mask is not None:
|
1118 |
+
# expand our mask's singleton query_tokens dimension:
|
1119 |
+
# [batch*heads, 1, key_tokens] ->
|
1120 |
+
# [batch*heads, query_tokens, key_tokens]
|
1121 |
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
1122 |
+
# [batch*heads, query_tokens, key_tokens]
|
1123 |
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
1124 |
+
_, query_tokens, _ = hidden_states.shape
|
1125 |
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
1126 |
+
|
1127 |
+
if attn.group_norm is not None:
|
1128 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1129 |
+
|
1130 |
+
query = attn.to_q(hidden_states, *args)
|
1131 |
+
|
1132 |
+
if encoder_hidden_states is None:
|
1133 |
+
encoder_hidden_states = hidden_states
|
1134 |
+
elif attn.norm_cross:
|
1135 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1136 |
+
|
1137 |
+
key = attn.to_k(encoder_hidden_states, *args)
|
1138 |
+
value = attn.to_v(encoder_hidden_states, *args)
|
1139 |
+
|
1140 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
1141 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
1142 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
1143 |
+
|
1144 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1145 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1146 |
+
)
|
1147 |
+
hidden_states = hidden_states.to(query.dtype)
|
1148 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1149 |
+
|
1150 |
+
# linear proj
|
1151 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
1152 |
+
# dropout
|
1153 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1154 |
+
|
1155 |
+
if input_ndim == 4:
|
1156 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1157 |
+
|
1158 |
+
if attn.residual_connection:
|
1159 |
+
hidden_states = hidden_states + residual
|
1160 |
+
|
1161 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1162 |
+
|
1163 |
+
return hidden_states
|
1164 |
+
|
1165 |
+
|
1166 |
+
class AttnProcessor2_0:
|
1167 |
+
r"""
|
1168 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
1169 |
+
"""
|
1170 |
+
|
1171 |
+
def __init__(self):
|
1172 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1173 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1174 |
+
|
1175 |
+
def __call__(
|
1176 |
+
self,
|
1177 |
+
attn: Attention,
|
1178 |
+
hidden_states: torch.FloatTensor,
|
1179 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1180 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1181 |
+
temb: Optional[torch.FloatTensor] = None,
|
1182 |
+
scale: float = 1.0,
|
1183 |
+
) -> torch.FloatTensor:
|
1184 |
+
residual = hidden_states
|
1185 |
+
|
1186 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
1187 |
+
|
1188 |
+
if attn.spatial_norm is not None:
|
1189 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1190 |
+
|
1191 |
+
input_ndim = hidden_states.ndim
|
1192 |
+
|
1193 |
+
if input_ndim == 4:
|
1194 |
+
batch_size, channel, height, width = hidden_states.shape
|
1195 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1196 |
+
|
1197 |
+
batch_size, sequence_length, _ = (
|
1198 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1199 |
+
)
|
1200 |
+
|
1201 |
+
if attention_mask is not None:
|
1202 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1203 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
1204 |
+
# (batch, heads, source_length, target_length)
|
1205 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1206 |
+
|
1207 |
+
if attn.group_norm is not None:
|
1208 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1209 |
+
|
1210 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
1211 |
+
query = attn.to_q(hidden_states, *args)
|
1212 |
+
|
1213 |
+
if encoder_hidden_states is None:
|
1214 |
+
encoder_hidden_states = hidden_states
|
1215 |
+
elif attn.norm_cross:
|
1216 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1217 |
+
|
1218 |
+
key = attn.to_k(encoder_hidden_states, *args)
|
1219 |
+
value = attn.to_v(encoder_hidden_states, *args)
|
1220 |
+
|
1221 |
+
inner_dim = key.shape[-1]
|
1222 |
+
head_dim = inner_dim // attn.heads
|
1223 |
+
|
1224 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1225 |
+
|
1226 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1227 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1228 |
+
|
1229 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1230 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1231 |
+
hidden_states = F.scaled_dot_product_attention(
|
1232 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1233 |
+
)
|
1234 |
+
|
1235 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1236 |
+
hidden_states = hidden_states.to(query.dtype)
|
1237 |
+
|
1238 |
+
# linear proj
|
1239 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
1240 |
+
# dropout
|
1241 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1242 |
+
|
1243 |
+
if input_ndim == 4:
|
1244 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1245 |
+
|
1246 |
+
if attn.residual_connection:
|
1247 |
+
hidden_states = hidden_states + residual
|
1248 |
+
|
1249 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1250 |
+
|
1251 |
+
return hidden_states
|
1252 |
+
|
1253 |
+
|
1254 |
+
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
1255 |
+
r"""
|
1256 |
+
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
|
1257 |
+
|
1258 |
+
Args:
|
1259 |
+
train_kv (`bool`, defaults to `True`):
|
1260 |
+
Whether to newly train the key and value matrices corresponding to the text features.
|
1261 |
+
train_q_out (`bool`, defaults to `True`):
|
1262 |
+
Whether to newly train query matrices corresponding to the latent image features.
|
1263 |
+
hidden_size (`int`, *optional*, defaults to `None`):
|
1264 |
+
The hidden size of the attention layer.
|
1265 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
1266 |
+
The number of channels in the `encoder_hidden_states`.
|
1267 |
+
out_bias (`bool`, defaults to `True`):
|
1268 |
+
Whether to include the bias parameter in `train_q_out`.
|
1269 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
1270 |
+
The dropout probability to use.
|
1271 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1272 |
+
The base
|
1273 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
|
1274 |
+
as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
|
1275 |
+
"""
|
1276 |
+
|
1277 |
+
def __init__(
|
1278 |
+
self,
|
1279 |
+
train_kv: bool = True,
|
1280 |
+
train_q_out: bool = False,
|
1281 |
+
hidden_size: Optional[int] = None,
|
1282 |
+
cross_attention_dim: Optional[int] = None,
|
1283 |
+
out_bias: bool = True,
|
1284 |
+
dropout: float = 0.0,
|
1285 |
+
attention_op: Optional[Callable] = None,
|
1286 |
+
):
|
1287 |
+
super().__init__()
|
1288 |
+
self.train_kv = train_kv
|
1289 |
+
self.train_q_out = train_q_out
|
1290 |
+
|
1291 |
+
self.hidden_size = hidden_size
|
1292 |
+
self.cross_attention_dim = cross_attention_dim
|
1293 |
+
self.attention_op = attention_op
|
1294 |
+
|
1295 |
+
# `_custom_diffusion` id for easy serialization and loading.
|
1296 |
+
if self.train_kv:
|
1297 |
+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
1298 |
+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
1299 |
+
if self.train_q_out:
|
1300 |
+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
1301 |
+
self.to_out_custom_diffusion = nn.ModuleList([])
|
1302 |
+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
1303 |
+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
1304 |
+
|
1305 |
+
def __call__(
|
1306 |
+
self,
|
1307 |
+
attn: Attention,
|
1308 |
+
hidden_states: torch.FloatTensor,
|
1309 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1310 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1311 |
+
) -> torch.FloatTensor:
|
1312 |
+
batch_size, sequence_length, _ = (
|
1313 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1314 |
+
)
|
1315 |
+
|
1316 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1317 |
+
|
1318 |
+
if self.train_q_out:
|
1319 |
+
query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
|
1320 |
+
else:
|
1321 |
+
query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
|
1322 |
+
|
1323 |
+
if encoder_hidden_states is None:
|
1324 |
+
crossattn = False
|
1325 |
+
encoder_hidden_states = hidden_states
|
1326 |
+
else:
|
1327 |
+
crossattn = True
|
1328 |
+
if attn.norm_cross:
|
1329 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1330 |
+
|
1331 |
+
if self.train_kv:
|
1332 |
+
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
|
1333 |
+
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
|
1334 |
+
key = key.to(attn.to_q.weight.dtype)
|
1335 |
+
value = value.to(attn.to_q.weight.dtype)
|
1336 |
+
else:
|
1337 |
+
key = attn.to_k(encoder_hidden_states)
|
1338 |
+
value = attn.to_v(encoder_hidden_states)
|
1339 |
+
|
1340 |
+
if crossattn:
|
1341 |
+
detach = torch.ones_like(key)
|
1342 |
+
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
1343 |
+
key = detach * key + (1 - detach) * key.detach()
|
1344 |
+
value = detach * value + (1 - detach) * value.detach()
|
1345 |
+
|
1346 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
1347 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
1348 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
1349 |
+
|
1350 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1351 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1352 |
+
)
|
1353 |
+
hidden_states = hidden_states.to(query.dtype)
|
1354 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1355 |
+
|
1356 |
+
if self.train_q_out:
|
1357 |
+
# linear proj
|
1358 |
+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
1359 |
+
# dropout
|
1360 |
+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
1361 |
+
else:
|
1362 |
+
# linear proj
|
1363 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1364 |
+
# dropout
|
1365 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1366 |
+
|
1367 |
+
return hidden_states
|
1368 |
+
|
1369 |
+
|
1370 |
+
class CustomDiffusionAttnProcessor2_0(nn.Module):
|
1371 |
+
r"""
|
1372 |
+
Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled
|
1373 |
+
dot-product attention.
|
1374 |
+
|
1375 |
+
Args:
|
1376 |
+
train_kv (`bool`, defaults to `True`):
|
1377 |
+
Whether to newly train the key and value matrices corresponding to the text features.
|
1378 |
+
train_q_out (`bool`, defaults to `True`):
|
1379 |
+
Whether to newly train query matrices corresponding to the latent image features.
|
1380 |
+
hidden_size (`int`, *optional*, defaults to `None`):
|
1381 |
+
The hidden size of the attention layer.
|
1382 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
1383 |
+
The number of channels in the `encoder_hidden_states`.
|
1384 |
+
out_bias (`bool`, defaults to `True`):
|
1385 |
+
Whether to include the bias parameter in `train_q_out`.
|
1386 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
1387 |
+
The dropout probability to use.
|
1388 |
+
"""
|
1389 |
+
|
1390 |
+
def __init__(
|
1391 |
+
self,
|
1392 |
+
train_kv: bool = True,
|
1393 |
+
train_q_out: bool = True,
|
1394 |
+
hidden_size: Optional[int] = None,
|
1395 |
+
cross_attention_dim: Optional[int] = None,
|
1396 |
+
out_bias: bool = True,
|
1397 |
+
dropout: float = 0.0,
|
1398 |
+
):
|
1399 |
+
super().__init__()
|
1400 |
+
self.train_kv = train_kv
|
1401 |
+
self.train_q_out = train_q_out
|
1402 |
+
|
1403 |
+
self.hidden_size = hidden_size
|
1404 |
+
self.cross_attention_dim = cross_attention_dim
|
1405 |
+
|
1406 |
+
# `_custom_diffusion` id for easy serialization and loading.
|
1407 |
+
if self.train_kv:
|
1408 |
+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
1409 |
+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
1410 |
+
if self.train_q_out:
|
1411 |
+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
1412 |
+
self.to_out_custom_diffusion = nn.ModuleList([])
|
1413 |
+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
1414 |
+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
1415 |
+
|
1416 |
+
def __call__(
|
1417 |
+
self,
|
1418 |
+
attn: Attention,
|
1419 |
+
hidden_states: torch.FloatTensor,
|
1420 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1421 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1422 |
+
) -> torch.FloatTensor:
|
1423 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
1424 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1425 |
+
if self.train_q_out:
|
1426 |
+
query = self.to_q_custom_diffusion(hidden_states)
|
1427 |
+
else:
|
1428 |
+
query = attn.to_q(hidden_states)
|
1429 |
+
|
1430 |
+
if encoder_hidden_states is None:
|
1431 |
+
crossattn = False
|
1432 |
+
encoder_hidden_states = hidden_states
|
1433 |
+
else:
|
1434 |
+
crossattn = True
|
1435 |
+
if attn.norm_cross:
|
1436 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1437 |
+
|
1438 |
+
if self.train_kv:
|
1439 |
+
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
|
1440 |
+
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
|
1441 |
+
key = key.to(attn.to_q.weight.dtype)
|
1442 |
+
value = value.to(attn.to_q.weight.dtype)
|
1443 |
+
|
1444 |
+
else:
|
1445 |
+
key = attn.to_k(encoder_hidden_states)
|
1446 |
+
value = attn.to_v(encoder_hidden_states)
|
1447 |
+
|
1448 |
+
if crossattn:
|
1449 |
+
detach = torch.ones_like(key)
|
1450 |
+
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
1451 |
+
key = detach * key + (1 - detach) * key.detach()
|
1452 |
+
value = detach * value + (1 - detach) * value.detach()
|
1453 |
+
|
1454 |
+
inner_dim = hidden_states.shape[-1]
|
1455 |
+
|
1456 |
+
head_dim = inner_dim // attn.heads
|
1457 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1458 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1459 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1460 |
+
|
1461 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1462 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1463 |
+
hidden_states = F.scaled_dot_product_attention(
|
1464 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1465 |
+
)
|
1466 |
+
|
1467 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1468 |
+
hidden_states = hidden_states.to(query.dtype)
|
1469 |
+
|
1470 |
+
if self.train_q_out:
|
1471 |
+
# linear proj
|
1472 |
+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
1473 |
+
# dropout
|
1474 |
+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
1475 |
+
else:
|
1476 |
+
# linear proj
|
1477 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1478 |
+
# dropout
|
1479 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1480 |
+
|
1481 |
+
return hidden_states
|
1482 |
+
|
1483 |
+
|
1484 |
+
class SlicedAttnProcessor:
|
1485 |
+
r"""
|
1486 |
+
Processor for implementing sliced attention.
|
1487 |
+
|
1488 |
+
Args:
|
1489 |
+
slice_size (`int`, *optional*):
|
1490 |
+
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
1491 |
+
`attention_head_dim` must be a multiple of the `slice_size`.
|
1492 |
+
"""
|
1493 |
+
|
1494 |
+
def __init__(self, slice_size: int):
|
1495 |
+
self.slice_size = slice_size
|
1496 |
+
|
1497 |
+
def __call__(
|
1498 |
+
self,
|
1499 |
+
attn: Attention,
|
1500 |
+
hidden_states: torch.FloatTensor,
|
1501 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1502 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1503 |
+
) -> torch.FloatTensor:
|
1504 |
+
residual = hidden_states
|
1505 |
+
|
1506 |
+
input_ndim = hidden_states.ndim
|
1507 |
+
|
1508 |
+
if input_ndim == 4:
|
1509 |
+
batch_size, channel, height, width = hidden_states.shape
|
1510 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1511 |
+
|
1512 |
+
batch_size, sequence_length, _ = (
|
1513 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1514 |
+
)
|
1515 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1516 |
+
|
1517 |
+
if attn.group_norm is not None:
|
1518 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1519 |
+
|
1520 |
+
query = attn.to_q(hidden_states)
|
1521 |
+
dim = query.shape[-1]
|
1522 |
+
query = attn.head_to_batch_dim(query)
|
1523 |
+
|
1524 |
+
if encoder_hidden_states is None:
|
1525 |
+
encoder_hidden_states = hidden_states
|
1526 |
+
elif attn.norm_cross:
|
1527 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1528 |
+
|
1529 |
+
key = attn.to_k(encoder_hidden_states)
|
1530 |
+
value = attn.to_v(encoder_hidden_states)
|
1531 |
+
key = attn.head_to_batch_dim(key)
|
1532 |
+
value = attn.head_to_batch_dim(value)
|
1533 |
+
|
1534 |
+
batch_size_attention, query_tokens, _ = query.shape
|
1535 |
+
hidden_states = torch.zeros(
|
1536 |
+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
1537 |
+
)
|
1538 |
+
|
1539 |
+
for i in range(batch_size_attention // self.slice_size):
|
1540 |
+
start_idx = i * self.slice_size
|
1541 |
+
end_idx = (i + 1) * self.slice_size
|
1542 |
+
|
1543 |
+
query_slice = query[start_idx:end_idx]
|
1544 |
+
key_slice = key[start_idx:end_idx]
|
1545 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
1546 |
+
|
1547 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
1548 |
+
|
1549 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
1550 |
+
|
1551 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
1552 |
+
|
1553 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1554 |
+
|
1555 |
+
# linear proj
|
1556 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1557 |
+
# dropout
|
1558 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1559 |
+
|
1560 |
+
if input_ndim == 4:
|
1561 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1562 |
+
|
1563 |
+
if attn.residual_connection:
|
1564 |
+
hidden_states = hidden_states + residual
|
1565 |
+
|
1566 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1567 |
+
|
1568 |
+
return hidden_states
|
1569 |
+
|
1570 |
+
|
1571 |
+
class SlicedAttnAddedKVProcessor:
|
1572 |
+
r"""
|
1573 |
+
Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
|
1574 |
+
|
1575 |
+
Args:
|
1576 |
+
slice_size (`int`, *optional*):
|
1577 |
+
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
1578 |
+
`attention_head_dim` must be a multiple of the `slice_size`.
|
1579 |
+
"""
|
1580 |
+
|
1581 |
+
def __init__(self, slice_size):
|
1582 |
+
self.slice_size = slice_size
|
1583 |
+
|
1584 |
+
def __call__(
|
1585 |
+
self,
|
1586 |
+
attn: "Attention",
|
1587 |
+
hidden_states: torch.FloatTensor,
|
1588 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1589 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1590 |
+
temb: Optional[torch.FloatTensor] = None,
|
1591 |
+
) -> torch.FloatTensor:
|
1592 |
+
residual = hidden_states
|
1593 |
+
|
1594 |
+
if attn.spatial_norm is not None:
|
1595 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1596 |
+
|
1597 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
1598 |
+
|
1599 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
1600 |
+
|
1601 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1602 |
+
|
1603 |
+
if encoder_hidden_states is None:
|
1604 |
+
encoder_hidden_states = hidden_states
|
1605 |
+
elif attn.norm_cross:
|
1606 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1607 |
+
|
1608 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1609 |
+
|
1610 |
+
query = attn.to_q(hidden_states)
|
1611 |
+
dim = query.shape[-1]
|
1612 |
+
query = attn.head_to_batch_dim(query)
|
1613 |
+
|
1614 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1615 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1616 |
+
|
1617 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
1618 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
1619 |
+
|
1620 |
+
if not attn.only_cross_attention:
|
1621 |
+
key = attn.to_k(hidden_states)
|
1622 |
+
value = attn.to_v(hidden_states)
|
1623 |
+
key = attn.head_to_batch_dim(key)
|
1624 |
+
value = attn.head_to_batch_dim(value)
|
1625 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
1626 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
1627 |
+
else:
|
1628 |
+
key = encoder_hidden_states_key_proj
|
1629 |
+
value = encoder_hidden_states_value_proj
|
1630 |
+
|
1631 |
+
batch_size_attention, query_tokens, _ = query.shape
|
1632 |
+
hidden_states = torch.zeros(
|
1633 |
+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
1634 |
+
)
|
1635 |
+
|
1636 |
+
for i in range(batch_size_attention // self.slice_size):
|
1637 |
+
start_idx = i * self.slice_size
|
1638 |
+
end_idx = (i + 1) * self.slice_size
|
1639 |
+
|
1640 |
+
query_slice = query[start_idx:end_idx]
|
1641 |
+
key_slice = key[start_idx:end_idx]
|
1642 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
1643 |
+
|
1644 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
1645 |
+
|
1646 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
1647 |
+
|
1648 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
1649 |
+
|
1650 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1651 |
+
|
1652 |
+
# linear proj
|
1653 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1654 |
+
# dropout
|
1655 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1656 |
+
|
1657 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
1658 |
+
hidden_states = hidden_states + residual
|
1659 |
+
|
1660 |
+
return hidden_states
|
1661 |
+
|
1662 |
+
|
1663 |
+
class SpatialNorm(nn.Module):
|
1664 |
+
"""
|
1665 |
+
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
|
1666 |
+
|
1667 |
+
Args:
|
1668 |
+
f_channels (`int`):
|
1669 |
+
The number of channels for input to group normalization layer, and output of the spatial norm layer.
|
1670 |
+
zq_channels (`int`):
|
1671 |
+
The number of channels for the quantized vector as described in the paper.
|
1672 |
+
"""
|
1673 |
+
|
1674 |
+
def __init__(
|
1675 |
+
self,
|
1676 |
+
f_channels: int,
|
1677 |
+
zq_channels: int,
|
1678 |
+
):
|
1679 |
+
super().__init__()
|
1680 |
+
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
1681 |
+
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
1682 |
+
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
1683 |
+
|
1684 |
+
def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
|
1685 |
+
f_size = f.shape[-2:]
|
1686 |
+
zq = F.interpolate(zq, size=f_size, mode="nearest")
|
1687 |
+
norm_f = self.norm_layer(f)
|
1688 |
+
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
1689 |
+
return new_f
|
1690 |
+
|
1691 |
+
|
1692 |
+
## Deprecated
|
1693 |
+
class LoRAAttnProcessor(nn.Module):
|
1694 |
+
r"""
|
1695 |
+
Processor for implementing the LoRA attention mechanism.
|
1696 |
+
|
1697 |
+
Args:
|
1698 |
+
hidden_size (`int`, *optional*):
|
1699 |
+
The hidden size of the attention layer.
|
1700 |
+
cross_attention_dim (`int`, *optional*):
|
1701 |
+
The number of channels in the `encoder_hidden_states`.
|
1702 |
+
rank (`int`, defaults to 4):
|
1703 |
+
The dimension of the LoRA update matrices.
|
1704 |
+
network_alpha (`int`, *optional*):
|
1705 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
1706 |
+
kwargs (`dict`):
|
1707 |
+
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
|
1708 |
+
"""
|
1709 |
+
|
1710 |
+
def __init__(
|
1711 |
+
self,
|
1712 |
+
hidden_size: int,
|
1713 |
+
cross_attention_dim: Optional[int] = None,
|
1714 |
+
rank: int = 4,
|
1715 |
+
network_alpha: Optional[int] = None,
|
1716 |
+
**kwargs,
|
1717 |
+
):
|
1718 |
+
super().__init__()
|
1719 |
+
|
1720 |
+
self.hidden_size = hidden_size
|
1721 |
+
self.cross_attention_dim = cross_attention_dim
|
1722 |
+
self.rank = rank
|
1723 |
+
|
1724 |
+
q_rank = kwargs.pop("q_rank", None)
|
1725 |
+
q_hidden_size = kwargs.pop("q_hidden_size", None)
|
1726 |
+
q_rank = q_rank if q_rank is not None else rank
|
1727 |
+
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
|
1728 |
+
|
1729 |
+
v_rank = kwargs.pop("v_rank", None)
|
1730 |
+
v_hidden_size = kwargs.pop("v_hidden_size", None)
|
1731 |
+
v_rank = v_rank if v_rank is not None else rank
|
1732 |
+
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
|
1733 |
+
|
1734 |
+
out_rank = kwargs.pop("out_rank", None)
|
1735 |
+
out_hidden_size = kwargs.pop("out_hidden_size", None)
|
1736 |
+
out_rank = out_rank if out_rank is not None else rank
|
1737 |
+
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
|
1738 |
+
|
1739 |
+
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
|
1740 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1741 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
1742 |
+
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
1743 |
+
|
1744 |
+
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
1745 |
+
self_cls_name = self.__class__.__name__
|
1746 |
+
deprecate(
|
1747 |
+
self_cls_name,
|
1748 |
+
"0.26.0",
|
1749 |
+
(
|
1750 |
+
f"Make sure use {self_cls_name[4:]} instead by setting"
|
1751 |
+
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
|
1752 |
+
" `LoraLoaderMixin.load_lora_weights`"
|
1753 |
+
),
|
1754 |
+
)
|
1755 |
+
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
|
1756 |
+
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
|
1757 |
+
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
|
1758 |
+
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
|
1759 |
+
|
1760 |
+
attn._modules.pop("processor")
|
1761 |
+
attn.processor = AttnProcessor()
|
1762 |
+
return attn.processor(attn, hidden_states, *args, **kwargs)
|
1763 |
+
|
1764 |
+
|
1765 |
+
class LoRAAttnProcessor2_0(nn.Module):
|
1766 |
+
r"""
|
1767 |
+
Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
|
1768 |
+
attention.
|
1769 |
+
|
1770 |
+
Args:
|
1771 |
+
hidden_size (`int`):
|
1772 |
+
The hidden size of the attention layer.
|
1773 |
+
cross_attention_dim (`int`, *optional*):
|
1774 |
+
The number of channels in the `encoder_hidden_states`.
|
1775 |
+
rank (`int`, defaults to 4):
|
1776 |
+
The dimension of the LoRA update matrices.
|
1777 |
+
network_alpha (`int`, *optional*):
|
1778 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
1779 |
+
kwargs (`dict`):
|
1780 |
+
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
|
1781 |
+
"""
|
1782 |
+
|
1783 |
+
def __init__(
|
1784 |
+
self,
|
1785 |
+
hidden_size: int,
|
1786 |
+
cross_attention_dim: Optional[int] = None,
|
1787 |
+
rank: int = 4,
|
1788 |
+
network_alpha: Optional[int] = None,
|
1789 |
+
**kwargs,
|
1790 |
+
):
|
1791 |
+
super().__init__()
|
1792 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1793 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1794 |
+
|
1795 |
+
self.hidden_size = hidden_size
|
1796 |
+
self.cross_attention_dim = cross_attention_dim
|
1797 |
+
self.rank = rank
|
1798 |
+
|
1799 |
+
q_rank = kwargs.pop("q_rank", None)
|
1800 |
+
q_hidden_size = kwargs.pop("q_hidden_size", None)
|
1801 |
+
q_rank = q_rank if q_rank is not None else rank
|
1802 |
+
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
|
1803 |
+
|
1804 |
+
v_rank = kwargs.pop("v_rank", None)
|
1805 |
+
v_hidden_size = kwargs.pop("v_hidden_size", None)
|
1806 |
+
v_rank = v_rank if v_rank is not None else rank
|
1807 |
+
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
|
1808 |
+
|
1809 |
+
out_rank = kwargs.pop("out_rank", None)
|
1810 |
+
out_hidden_size = kwargs.pop("out_hidden_size", None)
|
1811 |
+
out_rank = out_rank if out_rank is not None else rank
|
1812 |
+
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
|
1813 |
+
|
1814 |
+
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
|
1815 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1816 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
1817 |
+
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
1818 |
+
|
1819 |
+
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
1820 |
+
self_cls_name = self.__class__.__name__
|
1821 |
+
deprecate(
|
1822 |
+
self_cls_name,
|
1823 |
+
"0.26.0",
|
1824 |
+
(
|
1825 |
+
f"Make sure use {self_cls_name[4:]} instead by setting"
|
1826 |
+
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
|
1827 |
+
" `LoraLoaderMixin.load_lora_weights`"
|
1828 |
+
),
|
1829 |
+
)
|
1830 |
+
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
|
1831 |
+
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
|
1832 |
+
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
|
1833 |
+
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
|
1834 |
+
|
1835 |
+
attn._modules.pop("processor")
|
1836 |
+
attn.processor = AttnProcessor2_0()
|
1837 |
+
return attn.processor(attn, hidden_states, *args, **kwargs)
|
1838 |
+
|
1839 |
+
|
1840 |
+
class LoRAXFormersAttnProcessor(nn.Module):
|
1841 |
+
r"""
|
1842 |
+
Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
|
1843 |
+
|
1844 |
+
Args:
|
1845 |
+
hidden_size (`int`, *optional*):
|
1846 |
+
The hidden size of the attention layer.
|
1847 |
+
cross_attention_dim (`int`, *optional*):
|
1848 |
+
The number of channels in the `encoder_hidden_states`.
|
1849 |
+
rank (`int`, defaults to 4):
|
1850 |
+
The dimension of the LoRA update matrices.
|
1851 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1852 |
+
The base
|
1853 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
1854 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
1855 |
+
operator.
|
1856 |
+
network_alpha (`int`, *optional*):
|
1857 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
1858 |
+
kwargs (`dict`):
|
1859 |
+
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
|
1860 |
+
"""
|
1861 |
+
|
1862 |
+
def __init__(
|
1863 |
+
self,
|
1864 |
+
hidden_size: int,
|
1865 |
+
cross_attention_dim: int,
|
1866 |
+
rank: int = 4,
|
1867 |
+
attention_op: Optional[Callable] = None,
|
1868 |
+
network_alpha: Optional[int] = None,
|
1869 |
+
**kwargs,
|
1870 |
+
):
|
1871 |
+
super().__init__()
|
1872 |
+
|
1873 |
+
self.hidden_size = hidden_size
|
1874 |
+
self.cross_attention_dim = cross_attention_dim
|
1875 |
+
self.rank = rank
|
1876 |
+
self.attention_op = attention_op
|
1877 |
+
|
1878 |
+
q_rank = kwargs.pop("q_rank", None)
|
1879 |
+
q_hidden_size = kwargs.pop("q_hidden_size", None)
|
1880 |
+
q_rank = q_rank if q_rank is not None else rank
|
1881 |
+
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
|
1882 |
+
|
1883 |
+
v_rank = kwargs.pop("v_rank", None)
|
1884 |
+
v_hidden_size = kwargs.pop("v_hidden_size", None)
|
1885 |
+
v_rank = v_rank if v_rank is not None else rank
|
1886 |
+
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
|
1887 |
+
|
1888 |
+
out_rank = kwargs.pop("out_rank", None)
|
1889 |
+
out_hidden_size = kwargs.pop("out_hidden_size", None)
|
1890 |
+
out_rank = out_rank if out_rank is not None else rank
|
1891 |
+
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
|
1892 |
+
|
1893 |
+
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
|
1894 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1895 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
1896 |
+
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
1897 |
+
|
1898 |
+
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
1899 |
+
self_cls_name = self.__class__.__name__
|
1900 |
+
deprecate(
|
1901 |
+
self_cls_name,
|
1902 |
+
"0.26.0",
|
1903 |
+
(
|
1904 |
+
f"Make sure use {self_cls_name[4:]} instead by setting"
|
1905 |
+
"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
|
1906 |
+
" `LoraLoaderMixin.load_lora_weights`"
|
1907 |
+
),
|
1908 |
+
)
|
1909 |
+
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
|
1910 |
+
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
|
1911 |
+
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
|
1912 |
+
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
|
1913 |
+
|
1914 |
+
attn._modules.pop("processor")
|
1915 |
+
attn.processor = XFormersAttnProcessor()
|
1916 |
+
return attn.processor(attn, hidden_states, *args, **kwargs)
|
1917 |
+
|
1918 |
+
|
1919 |
+
class LoRAAttnAddedKVProcessor(nn.Module):
|
1920 |
+
r"""
|
1921 |
+
Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
|
1922 |
+
encoder.
|
1923 |
+
|
1924 |
+
Args:
|
1925 |
+
hidden_size (`int`, *optional*):
|
1926 |
+
The hidden size of the attention layer.
|
1927 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
1928 |
+
The number of channels in the `encoder_hidden_states`.
|
1929 |
+
rank (`int`, defaults to 4):
|
1930 |
+
The dimension of the LoRA update matrices.
|
1931 |
+
network_alpha (`int`, *optional*):
|
1932 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
1933 |
+
kwargs (`dict`):
|
1934 |
+
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
|
1935 |
+
"""
|
1936 |
+
|
1937 |
+
def __init__(
|
1938 |
+
self,
|
1939 |
+
hidden_size: int,
|
1940 |
+
cross_attention_dim: Optional[int] = None,
|
1941 |
+
rank: int = 4,
|
1942 |
+
network_alpha: Optional[int] = None,
|
1943 |
+
):
|
1944 |
+
super().__init__()
|
1945 |
+
|
1946 |
+
self.hidden_size = hidden_size
|
1947 |
+
self.cross_attention_dim = cross_attention_dim
|
1948 |
+
self.rank = rank
|
1949 |
+
|
1950 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1951 |
+
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1952 |
+
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1953 |
+
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1954 |
+
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1955 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1956 |
+
|
1957 |
+
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
1958 |
+
self_cls_name = self.__class__.__name__
|
1959 |
+
deprecate(
|
1960 |
+
self_cls_name,
|
1961 |
+
"0.26.0",
|
1962 |
+
(
|
1963 |
+
f"Make sure use {self_cls_name[4:]} instead by setting"
|
1964 |
+
"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
|
1965 |
+
" `LoraLoaderMixin.load_lora_weights`"
|
1966 |
+
),
|
1967 |
+
)
|
1968 |
+
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
|
1969 |
+
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
|
1970 |
+
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
|
1971 |
+
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
|
1972 |
+
|
1973 |
+
attn._modules.pop("processor")
|
1974 |
+
attn.processor = AttnAddedKVProcessor()
|
1975 |
+
return attn.processor(attn, hidden_states, *args, **kwargs)
|
1976 |
+
|
1977 |
+
|
1978 |
+
LORA_ATTENTION_PROCESSORS = (
|
1979 |
+
LoRAAttnProcessor,
|
1980 |
+
LoRAAttnProcessor2_0,
|
1981 |
+
LoRAXFormersAttnProcessor,
|
1982 |
+
LoRAAttnAddedKVProcessor,
|
1983 |
+
)
|
1984 |
+
|
1985 |
+
ADDED_KV_ATTENTION_PROCESSORS = (
|
1986 |
+
AttnAddedKVProcessor,
|
1987 |
+
SlicedAttnAddedKVProcessor,
|
1988 |
+
AttnAddedKVProcessor2_0,
|
1989 |
+
XFormersAttnAddedKVProcessor,
|
1990 |
+
LoRAAttnAddedKVProcessor,
|
1991 |
+
)
|
1992 |
+
|
1993 |
+
CROSS_ATTENTION_PROCESSORS = (
|
1994 |
+
AttnProcessor,
|
1995 |
+
AttnProcessor2_0,
|
1996 |
+
XFormersAttnProcessor,
|
1997 |
+
SlicedAttnProcessor,
|
1998 |
+
LoRAAttnProcessor,
|
1999 |
+
LoRAAttnProcessor2_0,
|
2000 |
+
LoRAXFormersAttnProcessor,
|
2001 |
+
)
|
2002 |
+
|
2003 |
+
AttentionProcessor = Union[
|
2004 |
+
AttnProcessor,
|
2005 |
+
AttnProcessor2_0,
|
2006 |
+
XFormersAttnProcessor,
|
2007 |
+
SlicedAttnProcessor,
|
2008 |
+
AttnAddedKVProcessor,
|
2009 |
+
SlicedAttnAddedKVProcessor,
|
2010 |
+
AttnAddedKVProcessor2_0,
|
2011 |
+
XFormersAttnAddedKVProcessor,
|
2012 |
+
CustomDiffusionAttnProcessor,
|
2013 |
+
CustomDiffusionXFormersAttnProcessor,
|
2014 |
+
CustomDiffusionAttnProcessor2_0,
|
2015 |
+
# deprecated
|
2016 |
+
LoRAAttnProcessor,
|
2017 |
+
LoRAAttnProcessor2_0,
|
2018 |
+
LoRAXFormersAttnProcessor,
|
2019 |
+
LoRAAttnAddedKVProcessor,
|
2020 |
+
]
|
diffusers/models/autoencoder_asym_kl.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Optional, Tuple, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
20 |
+
from ..utils.accelerate_utils import apply_forward_hook
|
21 |
+
from .autoencoder_kl import AutoencoderKLOutput
|
22 |
+
from .modeling_utils import ModelMixin
|
23 |
+
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
|
24 |
+
|
25 |
+
|
26 |
+
class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
27 |
+
r"""
|
28 |
+
Designing a Better Asymmetric VQGAN for StableDiffusion https://arxiv.org/abs/2306.04632 . A VAE model with KL loss
|
29 |
+
for encoding images into latents and decoding latent representations into images.
|
30 |
+
|
31 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
32 |
+
for all models (such as downloading or saving).
|
33 |
+
|
34 |
+
Parameters:
|
35 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
36 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
37 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
38 |
+
Tuple of downsample block types.
|
39 |
+
down_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
40 |
+
Tuple of down block output channels.
|
41 |
+
layers_per_down_block (`int`, *optional*, defaults to `1`):
|
42 |
+
Number layers for down block.
|
43 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
44 |
+
Tuple of upsample block types.
|
45 |
+
up_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
46 |
+
Tuple of up block output channels.
|
47 |
+
layers_per_up_block (`int`, *optional*, defaults to `1`):
|
48 |
+
Number layers for up block.
|
49 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
50 |
+
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
51 |
+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
52 |
+
norm_num_groups (`int`, *optional*, defaults to `32`):
|
53 |
+
Number of groups to use for the first normalization layer in ResNet blocks.
|
54 |
+
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
55 |
+
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
56 |
+
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
57 |
+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
58 |
+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
59 |
+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
60 |
+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
61 |
+
"""
|
62 |
+
|
63 |
+
@register_to_config
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
in_channels: int = 3,
|
67 |
+
out_channels: int = 3,
|
68 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
69 |
+
down_block_out_channels: Tuple[int] = (64,),
|
70 |
+
layers_per_down_block: int = 1,
|
71 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
72 |
+
up_block_out_channels: Tuple[int] = (64,),
|
73 |
+
layers_per_up_block: int = 1,
|
74 |
+
act_fn: str = "silu",
|
75 |
+
latent_channels: int = 4,
|
76 |
+
norm_num_groups: int = 32,
|
77 |
+
sample_size: int = 32,
|
78 |
+
scaling_factor: float = 0.18215,
|
79 |
+
) -> None:
|
80 |
+
super().__init__()
|
81 |
+
|
82 |
+
# pass init params to Encoder
|
83 |
+
self.encoder = Encoder(
|
84 |
+
in_channels=in_channels,
|
85 |
+
out_channels=latent_channels,
|
86 |
+
down_block_types=down_block_types,
|
87 |
+
block_out_channels=down_block_out_channels,
|
88 |
+
layers_per_block=layers_per_down_block,
|
89 |
+
act_fn=act_fn,
|
90 |
+
norm_num_groups=norm_num_groups,
|
91 |
+
double_z=True,
|
92 |
+
)
|
93 |
+
|
94 |
+
# pass init params to Decoder
|
95 |
+
self.decoder = MaskConditionDecoder(
|
96 |
+
in_channels=latent_channels,
|
97 |
+
out_channels=out_channels,
|
98 |
+
up_block_types=up_block_types,
|
99 |
+
block_out_channels=up_block_out_channels,
|
100 |
+
layers_per_block=layers_per_up_block,
|
101 |
+
act_fn=act_fn,
|
102 |
+
norm_num_groups=norm_num_groups,
|
103 |
+
)
|
104 |
+
|
105 |
+
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
106 |
+
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
|
107 |
+
|
108 |
+
self.use_slicing = False
|
109 |
+
self.use_tiling = False
|
110 |
+
|
111 |
+
@apply_forward_hook
|
112 |
+
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
113 |
+
h = self.encoder(x)
|
114 |
+
moments = self.quant_conv(h)
|
115 |
+
posterior = DiagonalGaussianDistribution(moments)
|
116 |
+
|
117 |
+
if not return_dict:
|
118 |
+
return (posterior,)
|
119 |
+
|
120 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
121 |
+
|
122 |
+
def _decode(
|
123 |
+
self,
|
124 |
+
z: torch.FloatTensor,
|
125 |
+
image: Optional[torch.FloatTensor] = None,
|
126 |
+
mask: Optional[torch.FloatTensor] = None,
|
127 |
+
return_dict: bool = True,
|
128 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
129 |
+
z = self.post_quant_conv(z)
|
130 |
+
dec = self.decoder(z, image, mask)
|
131 |
+
|
132 |
+
if not return_dict:
|
133 |
+
return (dec,)
|
134 |
+
|
135 |
+
return DecoderOutput(sample=dec)
|
136 |
+
|
137 |
+
@apply_forward_hook
|
138 |
+
def decode(
|
139 |
+
self,
|
140 |
+
z: torch.FloatTensor,
|
141 |
+
generator: Optional[torch.Generator] = None,
|
142 |
+
image: Optional[torch.FloatTensor] = None,
|
143 |
+
mask: Optional[torch.FloatTensor] = None,
|
144 |
+
return_dict: bool = True,
|
145 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
146 |
+
decoded = self._decode(z, image, mask).sample
|
147 |
+
|
148 |
+
if not return_dict:
|
149 |
+
return (decoded,)
|
150 |
+
|
151 |
+
return DecoderOutput(sample=decoded)
|
152 |
+
|
153 |
+
def forward(
|
154 |
+
self,
|
155 |
+
sample: torch.FloatTensor,
|
156 |
+
mask: Optional[torch.FloatTensor] = None,
|
157 |
+
sample_posterior: bool = False,
|
158 |
+
return_dict: bool = True,
|
159 |
+
generator: Optional[torch.Generator] = None,
|
160 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
161 |
+
r"""
|
162 |
+
Args:
|
163 |
+
sample (`torch.FloatTensor`): Input sample.
|
164 |
+
mask (`torch.FloatTensor`, *optional*, defaults to `None`): Optional inpainting mask.
|
165 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
166 |
+
Whether to sample from the posterior.
|
167 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
168 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
169 |
+
"""
|
170 |
+
x = sample
|
171 |
+
posterior = self.encode(x).latent_dist
|
172 |
+
if sample_posterior:
|
173 |
+
z = posterior.sample(generator=generator)
|
174 |
+
else:
|
175 |
+
z = posterior.mode()
|
176 |
+
dec = self.decode(z, sample, mask).sample
|
177 |
+
|
178 |
+
if not return_dict:
|
179 |
+
return (dec,)
|
180 |
+
|
181 |
+
return DecoderOutput(sample=dec)
|
diffusers/models/autoencoder_kl.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Dict, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
21 |
+
from ..loaders import FromOriginalVAEMixin
|
22 |
+
from ..utils import BaseOutput
|
23 |
+
from ..utils.accelerate_utils import apply_forward_hook
|
24 |
+
from .attention_processor import (
|
25 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
26 |
+
CROSS_ATTENTION_PROCESSORS,
|
27 |
+
AttentionProcessor,
|
28 |
+
AttnAddedKVProcessor,
|
29 |
+
AttnProcessor,
|
30 |
+
)
|
31 |
+
from .modeling_utils import ModelMixin
|
32 |
+
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class AutoencoderKLOutput(BaseOutput):
|
37 |
+
"""
|
38 |
+
Output of AutoencoderKL encoding method.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
latent_dist (`DiagonalGaussianDistribution`):
|
42 |
+
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
|
43 |
+
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
|
44 |
+
"""
|
45 |
+
|
46 |
+
latent_dist: "DiagonalGaussianDistribution"
|
47 |
+
|
48 |
+
|
49 |
+
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
50 |
+
r"""
|
51 |
+
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
52 |
+
|
53 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
54 |
+
for all models (such as downloading or saving).
|
55 |
+
|
56 |
+
Parameters:
|
57 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
58 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
59 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
60 |
+
Tuple of downsample block types.
|
61 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
62 |
+
Tuple of upsample block types.
|
63 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
64 |
+
Tuple of block output channels.
|
65 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
66 |
+
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
67 |
+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
68 |
+
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
69 |
+
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
70 |
+
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
71 |
+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
72 |
+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
73 |
+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
74 |
+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
75 |
+
force_upcast (`bool`, *optional*, default to `True`):
|
76 |
+
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
77 |
+
can be fine-tuned / trained to a lower range without loosing too much precision in which case
|
78 |
+
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
79 |
+
"""
|
80 |
+
|
81 |
+
_supports_gradient_checkpointing = True
|
82 |
+
|
83 |
+
@register_to_config
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
in_channels: int = 3,
|
87 |
+
out_channels: int = 3,
|
88 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
89 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
90 |
+
block_out_channels: Tuple[int] = (64,),
|
91 |
+
layers_per_block: int = 1,
|
92 |
+
act_fn: str = "silu",
|
93 |
+
latent_channels: int = 4,
|
94 |
+
norm_num_groups: int = 32,
|
95 |
+
sample_size: int = 32,
|
96 |
+
scaling_factor: float = 0.18215,
|
97 |
+
force_upcast: float = True,
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
|
101 |
+
# pass init params to Encoder
|
102 |
+
self.encoder = Encoder(
|
103 |
+
in_channels=in_channels,
|
104 |
+
out_channels=latent_channels,
|
105 |
+
down_block_types=down_block_types,
|
106 |
+
block_out_channels=block_out_channels,
|
107 |
+
layers_per_block=layers_per_block,
|
108 |
+
act_fn=act_fn,
|
109 |
+
norm_num_groups=norm_num_groups,
|
110 |
+
double_z=True,
|
111 |
+
)
|
112 |
+
|
113 |
+
# pass init params to Decoder
|
114 |
+
self.decoder = Decoder(
|
115 |
+
in_channels=latent_channels,
|
116 |
+
out_channels=out_channels,
|
117 |
+
up_block_types=up_block_types,
|
118 |
+
block_out_channels=block_out_channels,
|
119 |
+
layers_per_block=layers_per_block,
|
120 |
+
norm_num_groups=norm_num_groups,
|
121 |
+
act_fn=act_fn,
|
122 |
+
)
|
123 |
+
|
124 |
+
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
125 |
+
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
|
126 |
+
|
127 |
+
self.use_slicing = False
|
128 |
+
self.use_tiling = False
|
129 |
+
|
130 |
+
# only relevant if vae tiling is enabled
|
131 |
+
self.tile_sample_min_size = self.config.sample_size
|
132 |
+
sample_size = (
|
133 |
+
self.config.sample_size[0]
|
134 |
+
if isinstance(self.config.sample_size, (list, tuple))
|
135 |
+
else self.config.sample_size
|
136 |
+
)
|
137 |
+
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
138 |
+
self.tile_overlap_factor = 0.25
|
139 |
+
|
140 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
141 |
+
if isinstance(module, (Encoder, Decoder)):
|
142 |
+
module.gradient_checkpointing = value
|
143 |
+
|
144 |
+
def enable_tiling(self, use_tiling: bool = True):
|
145 |
+
r"""
|
146 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
147 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
148 |
+
processing larger images.
|
149 |
+
"""
|
150 |
+
self.use_tiling = use_tiling
|
151 |
+
|
152 |
+
def disable_tiling(self):
|
153 |
+
r"""
|
154 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
155 |
+
decoding in one step.
|
156 |
+
"""
|
157 |
+
self.enable_tiling(False)
|
158 |
+
|
159 |
+
def enable_slicing(self):
|
160 |
+
r"""
|
161 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
162 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
163 |
+
"""
|
164 |
+
self.use_slicing = True
|
165 |
+
|
166 |
+
def disable_slicing(self):
|
167 |
+
r"""
|
168 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
169 |
+
decoding in one step.
|
170 |
+
"""
|
171 |
+
self.use_slicing = False
|
172 |
+
|
173 |
+
@property
|
174 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
175 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
176 |
+
r"""
|
177 |
+
Returns:
|
178 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
179 |
+
indexed by its weight name.
|
180 |
+
"""
|
181 |
+
# set recursively
|
182 |
+
processors = {}
|
183 |
+
|
184 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
185 |
+
if hasattr(module, "get_processor"):
|
186 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
187 |
+
|
188 |
+
for sub_name, child in module.named_children():
|
189 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
190 |
+
|
191 |
+
return processors
|
192 |
+
|
193 |
+
for name, module in self.named_children():
|
194 |
+
fn_recursive_add_processors(name, module, processors)
|
195 |
+
|
196 |
+
return processors
|
197 |
+
|
198 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
199 |
+
def set_attn_processor(
|
200 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
201 |
+
):
|
202 |
+
r"""
|
203 |
+
Sets the attention processor to use to compute attention.
|
204 |
+
|
205 |
+
Parameters:
|
206 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
207 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
208 |
+
for **all** `Attention` layers.
|
209 |
+
|
210 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
211 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
212 |
+
|
213 |
+
"""
|
214 |
+
count = len(self.attn_processors.keys())
|
215 |
+
|
216 |
+
if isinstance(processor, dict) and len(processor) != count:
|
217 |
+
raise ValueError(
|
218 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
219 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
220 |
+
)
|
221 |
+
|
222 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
223 |
+
if hasattr(module, "set_processor"):
|
224 |
+
if not isinstance(processor, dict):
|
225 |
+
module.set_processor(processor, _remove_lora=_remove_lora)
|
226 |
+
else:
|
227 |
+
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
228 |
+
|
229 |
+
for sub_name, child in module.named_children():
|
230 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
231 |
+
|
232 |
+
for name, module in self.named_children():
|
233 |
+
fn_recursive_attn_processor(name, module, processor)
|
234 |
+
|
235 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
236 |
+
def set_default_attn_processor(self):
|
237 |
+
"""
|
238 |
+
Disables custom attention processors and sets the default attention implementation.
|
239 |
+
"""
|
240 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
241 |
+
processor = AttnAddedKVProcessor()
|
242 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
243 |
+
processor = AttnProcessor()
|
244 |
+
else:
|
245 |
+
raise ValueError(
|
246 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
247 |
+
)
|
248 |
+
|
249 |
+
self.set_attn_processor(processor, _remove_lora=True)
|
250 |
+
|
251 |
+
@apply_forward_hook
|
252 |
+
def encode(
|
253 |
+
self, x: torch.FloatTensor, return_dict: bool = True
|
254 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
255 |
+
"""
|
256 |
+
Encode a batch of images into latents.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
x (`torch.FloatTensor`): Input batch of images.
|
260 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
261 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
The latent representations of the encoded images. If `return_dict` is True, a
|
265 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
266 |
+
"""
|
267 |
+
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
268 |
+
return self.tiled_encode(x, return_dict=return_dict)
|
269 |
+
|
270 |
+
if self.use_slicing and x.shape[0] > 1:
|
271 |
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
272 |
+
h = torch.cat(encoded_slices)
|
273 |
+
else:
|
274 |
+
h = self.encoder(x)
|
275 |
+
|
276 |
+
moments = self.quant_conv(h)
|
277 |
+
posterior = DiagonalGaussianDistribution(moments)
|
278 |
+
|
279 |
+
if not return_dict:
|
280 |
+
return (posterior,)
|
281 |
+
|
282 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
283 |
+
|
284 |
+
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
285 |
+
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
286 |
+
return self.tiled_decode(z, return_dict=return_dict)
|
287 |
+
|
288 |
+
z = self.post_quant_conv(z)
|
289 |
+
dec = self.decoder(z)
|
290 |
+
|
291 |
+
if not return_dict:
|
292 |
+
return (dec,)
|
293 |
+
|
294 |
+
return DecoderOutput(sample=dec)
|
295 |
+
|
296 |
+
@apply_forward_hook
|
297 |
+
def decode(
|
298 |
+
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
299 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
300 |
+
"""
|
301 |
+
Decode a batch of images.
|
302 |
+
|
303 |
+
Args:
|
304 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
305 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
306 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
307 |
+
|
308 |
+
Returns:
|
309 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
310 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
311 |
+
returned.
|
312 |
+
|
313 |
+
"""
|
314 |
+
if self.use_slicing and z.shape[0] > 1:
|
315 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
316 |
+
decoded = torch.cat(decoded_slices)
|
317 |
+
else:
|
318 |
+
decoded = self._decode(z).sample
|
319 |
+
|
320 |
+
if not return_dict:
|
321 |
+
return (decoded,)
|
322 |
+
|
323 |
+
return DecoderOutput(sample=decoded)
|
324 |
+
|
325 |
+
def blend_v(self, a, b, blend_extent):
|
326 |
+
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
327 |
+
for y in range(blend_extent):
|
328 |
+
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
329 |
+
return b
|
330 |
+
|
331 |
+
def blend_h(self, a, b, blend_extent):
|
332 |
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
333 |
+
for x in range(blend_extent):
|
334 |
+
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
335 |
+
return b
|
336 |
+
|
337 |
+
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
338 |
+
r"""Encode a batch of images using a tiled encoder.
|
339 |
+
|
340 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
341 |
+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
342 |
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
343 |
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
344 |
+
output, but they should be much less noticeable.
|
345 |
+
|
346 |
+
Args:
|
347 |
+
x (`torch.FloatTensor`): Input batch of images.
|
348 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
349 |
+
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
350 |
+
|
351 |
+
Returns:
|
352 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
353 |
+
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
354 |
+
`tuple` is returned.
|
355 |
+
"""
|
356 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
357 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
358 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
359 |
+
|
360 |
+
# Split the image into 512x512 tiles and encode them separately.
|
361 |
+
rows = []
|
362 |
+
for i in range(0, x.shape[2], overlap_size):
|
363 |
+
row = []
|
364 |
+
for j in range(0, x.shape[3], overlap_size):
|
365 |
+
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
366 |
+
tile = self.encoder(tile)
|
367 |
+
tile = self.quant_conv(tile)
|
368 |
+
row.append(tile)
|
369 |
+
rows.append(row)
|
370 |
+
result_rows = []
|
371 |
+
for i, row in enumerate(rows):
|
372 |
+
result_row = []
|
373 |
+
for j, tile in enumerate(row):
|
374 |
+
# blend the above tile and the left tile
|
375 |
+
# to the current tile and add the current tile to the result row
|
376 |
+
if i > 0:
|
377 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
378 |
+
if j > 0:
|
379 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
380 |
+
result_row.append(tile[:, :, :row_limit, :row_limit])
|
381 |
+
result_rows.append(torch.cat(result_row, dim=3))
|
382 |
+
|
383 |
+
moments = torch.cat(result_rows, dim=2)
|
384 |
+
posterior = DiagonalGaussianDistribution(moments)
|
385 |
+
|
386 |
+
if not return_dict:
|
387 |
+
return (posterior,)
|
388 |
+
|
389 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
390 |
+
|
391 |
+
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
392 |
+
r"""
|
393 |
+
Decode a batch of images using a tiled decoder.
|
394 |
+
|
395 |
+
Args:
|
396 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
397 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
398 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
399 |
+
|
400 |
+
Returns:
|
401 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
402 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
403 |
+
returned.
|
404 |
+
"""
|
405 |
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
406 |
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
407 |
+
row_limit = self.tile_sample_min_size - blend_extent
|
408 |
+
|
409 |
+
# Split z into overlapping 64x64 tiles and decode them separately.
|
410 |
+
# The tiles have an overlap to avoid seams between tiles.
|
411 |
+
rows = []
|
412 |
+
for i in range(0, z.shape[2], overlap_size):
|
413 |
+
row = []
|
414 |
+
for j in range(0, z.shape[3], overlap_size):
|
415 |
+
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
416 |
+
tile = self.post_quant_conv(tile)
|
417 |
+
decoded = self.decoder(tile)
|
418 |
+
row.append(decoded)
|
419 |
+
rows.append(row)
|
420 |
+
result_rows = []
|
421 |
+
for i, row in enumerate(rows):
|
422 |
+
result_row = []
|
423 |
+
for j, tile in enumerate(row):
|
424 |
+
# blend the above tile and the left tile
|
425 |
+
# to the current tile and add the current tile to the result row
|
426 |
+
if i > 0:
|
427 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
428 |
+
if j > 0:
|
429 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
430 |
+
result_row.append(tile[:, :, :row_limit, :row_limit])
|
431 |
+
result_rows.append(torch.cat(result_row, dim=3))
|
432 |
+
|
433 |
+
dec = torch.cat(result_rows, dim=2)
|
434 |
+
if not return_dict:
|
435 |
+
return (dec,)
|
436 |
+
|
437 |
+
return DecoderOutput(sample=dec)
|
438 |
+
|
439 |
+
def forward(
|
440 |
+
self,
|
441 |
+
sample: torch.FloatTensor,
|
442 |
+
sample_posterior: bool = False,
|
443 |
+
return_dict: bool = True,
|
444 |
+
generator: Optional[torch.Generator] = None,
|
445 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
446 |
+
r"""
|
447 |
+
Args:
|
448 |
+
sample (`torch.FloatTensor`): Input sample.
|
449 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
450 |
+
Whether to sample from the posterior.
|
451 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
452 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
453 |
+
"""
|
454 |
+
x = sample
|
455 |
+
posterior = self.encode(x).latent_dist
|
456 |
+
if sample_posterior:
|
457 |
+
z = posterior.sample(generator=generator)
|
458 |
+
else:
|
459 |
+
z = posterior.mode()
|
460 |
+
dec = self.decode(z).sample
|
461 |
+
|
462 |
+
if not return_dict:
|
463 |
+
return (dec,)
|
464 |
+
|
465 |
+
return DecoderOutput(sample=dec)
|
diffusers/models/autoencoder_tiny.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Ollin Boer Bohan and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Optional, Tuple, Union
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from ..utils import BaseOutput
|
23 |
+
from ..utils.accelerate_utils import apply_forward_hook
|
24 |
+
from .modeling_utils import ModelMixin
|
25 |
+
from .vae import DecoderOutput, DecoderTiny, EncoderTiny
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class AutoencoderTinyOutput(BaseOutput):
|
30 |
+
"""
|
31 |
+
Output of AutoencoderTiny encoding method.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
latents (`torch.Tensor`): Encoded outputs of the `Encoder`.
|
35 |
+
|
36 |
+
"""
|
37 |
+
|
38 |
+
latents: torch.Tensor
|
39 |
+
|
40 |
+
|
41 |
+
class AutoencoderTiny(ModelMixin, ConfigMixin):
|
42 |
+
r"""
|
43 |
+
A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.
|
44 |
+
|
45 |
+
[`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`.
|
46 |
+
|
47 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
|
48 |
+
all models (such as downloading or saving).
|
49 |
+
|
50 |
+
Parameters:
|
51 |
+
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
|
52 |
+
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
|
53 |
+
encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
|
54 |
+
Tuple of integers representing the number of output channels for each encoder block. The length of the
|
55 |
+
tuple should be equal to the number of encoder blocks.
|
56 |
+
decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
|
57 |
+
Tuple of integers representing the number of output channels for each decoder block. The length of the
|
58 |
+
tuple should be equal to the number of decoder blocks.
|
59 |
+
act_fn (`str`, *optional*, defaults to `"relu"`):
|
60 |
+
Activation function to be used throughout the model.
|
61 |
+
latent_channels (`int`, *optional*, defaults to 4):
|
62 |
+
Number of channels in the latent representation. The latent space acts as a compressed representation of
|
63 |
+
the input image.
|
64 |
+
upsampling_scaling_factor (`int`, *optional*, defaults to 2):
|
65 |
+
Scaling factor for upsampling in the decoder. It determines the size of the output image during the
|
66 |
+
upsampling process.
|
67 |
+
num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
|
68 |
+
Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
|
69 |
+
length of the tuple should be equal to the number of stages in the encoder. Each stage has a different
|
70 |
+
number of encoder blocks.
|
71 |
+
num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
|
72 |
+
Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
|
73 |
+
length of the tuple should be equal to the number of stages in the decoder. Each stage has a different
|
74 |
+
number of decoder blocks.
|
75 |
+
latent_magnitude (`float`, *optional*, defaults to 3.0):
|
76 |
+
Magnitude of the latent representation. This parameter scales the latent representation values to control
|
77 |
+
the extent of information preservation.
|
78 |
+
latent_shift (float, *optional*, defaults to 0.5):
|
79 |
+
Shift applied to the latent representation. This parameter controls the center of the latent space.
|
80 |
+
scaling_factor (`float`, *optional*, defaults to 1.0):
|
81 |
+
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
82 |
+
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
83 |
+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
84 |
+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
85 |
+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
86 |
+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder,
|
87 |
+
however, no such scaling factor was used, hence the value of 1.0 as the default.
|
88 |
+
force_upcast (`bool`, *optional*, default to `False`):
|
89 |
+
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
90 |
+
can be fine-tuned / trained to a lower range without losing too much precision, in which case
|
91 |
+
`force_upcast` can be set to `False` (see this fp16-friendly
|
92 |
+
[AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
93 |
+
"""
|
94 |
+
_supports_gradient_checkpointing = True
|
95 |
+
|
96 |
+
@register_to_config
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
in_channels=3,
|
100 |
+
out_channels=3,
|
101 |
+
encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
|
102 |
+
decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
|
103 |
+
act_fn: str = "relu",
|
104 |
+
latent_channels: int = 4,
|
105 |
+
upsampling_scaling_factor: int = 2,
|
106 |
+
num_encoder_blocks: Tuple[int] = (1, 3, 3, 3),
|
107 |
+
num_decoder_blocks: Tuple[int] = (3, 3, 3, 1),
|
108 |
+
latent_magnitude: int = 3,
|
109 |
+
latent_shift: float = 0.5,
|
110 |
+
force_upcast: float = False,
|
111 |
+
scaling_factor: float = 1.0,
|
112 |
+
):
|
113 |
+
super().__init__()
|
114 |
+
|
115 |
+
if len(encoder_block_out_channels) != len(num_encoder_blocks):
|
116 |
+
raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.")
|
117 |
+
if len(decoder_block_out_channels) != len(num_decoder_blocks):
|
118 |
+
raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.")
|
119 |
+
|
120 |
+
self.encoder = EncoderTiny(
|
121 |
+
in_channels=in_channels,
|
122 |
+
out_channels=latent_channels,
|
123 |
+
num_blocks=num_encoder_blocks,
|
124 |
+
block_out_channels=encoder_block_out_channels,
|
125 |
+
act_fn=act_fn,
|
126 |
+
)
|
127 |
+
|
128 |
+
self.decoder = DecoderTiny(
|
129 |
+
in_channels=latent_channels,
|
130 |
+
out_channels=out_channels,
|
131 |
+
num_blocks=num_decoder_blocks,
|
132 |
+
block_out_channels=decoder_block_out_channels,
|
133 |
+
upsampling_scaling_factor=upsampling_scaling_factor,
|
134 |
+
act_fn=act_fn,
|
135 |
+
)
|
136 |
+
|
137 |
+
self.latent_magnitude = latent_magnitude
|
138 |
+
self.latent_shift = latent_shift
|
139 |
+
self.scaling_factor = scaling_factor
|
140 |
+
|
141 |
+
self.use_slicing = False
|
142 |
+
self.use_tiling = False
|
143 |
+
|
144 |
+
# only relevant if vae tiling is enabled
|
145 |
+
self.spatial_scale_factor = 2**out_channels
|
146 |
+
self.tile_overlap_factor = 0.125
|
147 |
+
self.tile_sample_min_size = 512
|
148 |
+
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
|
149 |
+
|
150 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
151 |
+
if isinstance(module, (EncoderTiny, DecoderTiny)):
|
152 |
+
module.gradient_checkpointing = value
|
153 |
+
|
154 |
+
def scale_latents(self, x):
|
155 |
+
"""raw latents -> [0, 1]"""
|
156 |
+
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
|
157 |
+
|
158 |
+
def unscale_latents(self, x):
|
159 |
+
"""[0, 1] -> raw latents"""
|
160 |
+
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
|
161 |
+
|
162 |
+
def enable_slicing(self):
|
163 |
+
r"""
|
164 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
165 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
166 |
+
"""
|
167 |
+
self.use_slicing = True
|
168 |
+
|
169 |
+
def disable_slicing(self):
|
170 |
+
r"""
|
171 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
172 |
+
decoding in one step.
|
173 |
+
"""
|
174 |
+
self.use_slicing = False
|
175 |
+
|
176 |
+
def enable_tiling(self, use_tiling: bool = True):
|
177 |
+
r"""
|
178 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
179 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
180 |
+
processing larger images.
|
181 |
+
"""
|
182 |
+
self.use_tiling = use_tiling
|
183 |
+
|
184 |
+
def disable_tiling(self):
|
185 |
+
r"""
|
186 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
187 |
+
decoding in one step.
|
188 |
+
"""
|
189 |
+
self.enable_tiling(False)
|
190 |
+
|
191 |
+
def _tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
192 |
+
r"""Encode a batch of images using a tiled encoder.
|
193 |
+
|
194 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
195 |
+
steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
|
196 |
+
tiles overlap and are blended together to form a smooth output.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
x (`torch.FloatTensor`): Input batch of images.
|
200 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
201 |
+
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
[`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`:
|
205 |
+
If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a
|
206 |
+
plain `tuple` is returned.
|
207 |
+
"""
|
208 |
+
# scale of encoder output relative to input
|
209 |
+
sf = self.spatial_scale_factor
|
210 |
+
tile_size = self.tile_sample_min_size
|
211 |
+
|
212 |
+
# number of pixels to blend and to traverse between tile
|
213 |
+
blend_size = int(tile_size * self.tile_overlap_factor)
|
214 |
+
traverse_size = tile_size - blend_size
|
215 |
+
|
216 |
+
# tiles index (up/left)
|
217 |
+
ti = range(0, x.shape[-2], traverse_size)
|
218 |
+
tj = range(0, x.shape[-1], traverse_size)
|
219 |
+
|
220 |
+
# mask for blending
|
221 |
+
blend_masks = torch.stack(
|
222 |
+
torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij")
|
223 |
+
)
|
224 |
+
blend_masks = blend_masks.clamp(0, 1).to(x.device)
|
225 |
+
|
226 |
+
# output array
|
227 |
+
out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device)
|
228 |
+
for i in ti:
|
229 |
+
for j in tj:
|
230 |
+
tile_in = x[..., i : i + tile_size, j : j + tile_size]
|
231 |
+
# tile result
|
232 |
+
tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf]
|
233 |
+
tile = self.encoder(tile_in)
|
234 |
+
h, w = tile.shape[-2], tile.shape[-1]
|
235 |
+
# blend tile result into output
|
236 |
+
blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
|
237 |
+
blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
|
238 |
+
blend_mask = blend_mask_i * blend_mask_j
|
239 |
+
tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w]
|
240 |
+
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
|
241 |
+
return out
|
242 |
+
|
243 |
+
def _tiled_decode(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
244 |
+
r"""Encode a batch of images using a tiled encoder.
|
245 |
+
|
246 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
247 |
+
steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
|
248 |
+
tiles overlap and are blended together to form a smooth output.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
x (`torch.FloatTensor`): Input batch of images.
|
252 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
253 |
+
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
257 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
258 |
+
returned.
|
259 |
+
"""
|
260 |
+
# scale of decoder output relative to input
|
261 |
+
sf = self.spatial_scale_factor
|
262 |
+
tile_size = self.tile_latent_min_size
|
263 |
+
|
264 |
+
# number of pixels to blend and to traverse between tiles
|
265 |
+
blend_size = int(tile_size * self.tile_overlap_factor)
|
266 |
+
traverse_size = tile_size - blend_size
|
267 |
+
|
268 |
+
# tiles index (up/left)
|
269 |
+
ti = range(0, x.shape[-2], traverse_size)
|
270 |
+
tj = range(0, x.shape[-1], traverse_size)
|
271 |
+
|
272 |
+
# mask for blending
|
273 |
+
blend_masks = torch.stack(
|
274 |
+
torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij")
|
275 |
+
)
|
276 |
+
blend_masks = blend_masks.clamp(0, 1).to(x.device)
|
277 |
+
|
278 |
+
# output array
|
279 |
+
out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device)
|
280 |
+
for i in ti:
|
281 |
+
for j in tj:
|
282 |
+
tile_in = x[..., i : i + tile_size, j : j + tile_size]
|
283 |
+
# tile result
|
284 |
+
tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf]
|
285 |
+
tile = self.decoder(tile_in)
|
286 |
+
h, w = tile.shape[-2], tile.shape[-1]
|
287 |
+
# blend tile result into output
|
288 |
+
blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
|
289 |
+
blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
|
290 |
+
blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w]
|
291 |
+
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
|
292 |
+
return out
|
293 |
+
|
294 |
+
@apply_forward_hook
|
295 |
+
def encode(
|
296 |
+
self, x: torch.FloatTensor, return_dict: bool = True
|
297 |
+
) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
|
298 |
+
if self.use_slicing and x.shape[0] > 1:
|
299 |
+
output = [self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x) for x_slice in x.split(1)]
|
300 |
+
output = torch.cat(output)
|
301 |
+
else:
|
302 |
+
output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)
|
303 |
+
|
304 |
+
if not return_dict:
|
305 |
+
return (output,)
|
306 |
+
|
307 |
+
return AutoencoderTinyOutput(latents=output)
|
308 |
+
|
309 |
+
@apply_forward_hook
|
310 |
+
def decode(
|
311 |
+
self, x: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
|
312 |
+
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
313 |
+
if self.use_slicing and x.shape[0] > 1:
|
314 |
+
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
|
315 |
+
output = torch.cat(output)
|
316 |
+
else:
|
317 |
+
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
|
318 |
+
|
319 |
+
if not return_dict:
|
320 |
+
return (output,)
|
321 |
+
|
322 |
+
return DecoderOutput(sample=output)
|
323 |
+
|
324 |
+
def forward(
|
325 |
+
self,
|
326 |
+
sample: torch.FloatTensor,
|
327 |
+
return_dict: bool = True,
|
328 |
+
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
329 |
+
r"""
|
330 |
+
Args:
|
331 |
+
sample (`torch.FloatTensor`): Input sample.
|
332 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
333 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
334 |
+
"""
|
335 |
+
enc = self.encode(sample).latents
|
336 |
+
|
337 |
+
# scale latents to be in [0, 1], then quantize latents to a byte tensor,
|
338 |
+
# as if we were storing the latents in an RGBA uint8 image.
|
339 |
+
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
|
340 |
+
|
341 |
+
# unquantize latents back into [0, 1], then unscale latents back to their original range,
|
342 |
+
# as if we were loading the latents from an RGBA uint8 image.
|
343 |
+
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
|
344 |
+
|
345 |
+
dec = self.decode(unscaled_enc)
|
346 |
+
|
347 |
+
if not return_dict:
|
348 |
+
return (dec,)
|
349 |
+
return DecoderOutput(sample=dec)
|
diffusers/models/consistency_decoder_vae.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Dict, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from ..schedulers import ConsistencyDecoderScheduler
|
23 |
+
from ..utils import BaseOutput
|
24 |
+
from ..utils.accelerate_utils import apply_forward_hook
|
25 |
+
from ..utils.torch_utils import randn_tensor
|
26 |
+
from .attention_processor import (
|
27 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
28 |
+
CROSS_ATTENTION_PROCESSORS,
|
29 |
+
AttentionProcessor,
|
30 |
+
AttnAddedKVProcessor,
|
31 |
+
AttnProcessor,
|
32 |
+
)
|
33 |
+
from .modeling_utils import ModelMixin
|
34 |
+
from .unet_2d import UNet2DModel
|
35 |
+
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
|
36 |
+
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class ConsistencyDecoderVAEOutput(BaseOutput):
|
40 |
+
"""
|
41 |
+
Output of encoding method.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
latent_dist (`DiagonalGaussianDistribution`):
|
45 |
+
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
|
46 |
+
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
|
47 |
+
"""
|
48 |
+
|
49 |
+
latent_dist: "DiagonalGaussianDistribution"
|
50 |
+
|
51 |
+
|
52 |
+
class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
53 |
+
r"""
|
54 |
+
The consistency decoder used with DALL-E 3.
|
55 |
+
|
56 |
+
Examples:
|
57 |
+
```py
|
58 |
+
>>> import torch
|
59 |
+
>>> from diffusers import DiffusionPipeline, ConsistencyDecoderVAE
|
60 |
+
|
61 |
+
>>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=pipe.torch_dtype)
|
62 |
+
>>> pipe = StableDiffusionPipeline.from_pretrained(
|
63 |
+
... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
|
64 |
+
... ).to("cuda")
|
65 |
+
|
66 |
+
>>> pipe("horse", generator=torch.manual_seed(0)).images
|
67 |
+
```
|
68 |
+
"""
|
69 |
+
|
70 |
+
@register_to_config
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
scaling_factor=0.18215,
|
74 |
+
latent_channels=4,
|
75 |
+
encoder_act_fn="silu",
|
76 |
+
encoder_block_out_channels=(128, 256, 512, 512),
|
77 |
+
encoder_double_z=True,
|
78 |
+
encoder_down_block_types=(
|
79 |
+
"DownEncoderBlock2D",
|
80 |
+
"DownEncoderBlock2D",
|
81 |
+
"DownEncoderBlock2D",
|
82 |
+
"DownEncoderBlock2D",
|
83 |
+
),
|
84 |
+
encoder_in_channels=3,
|
85 |
+
encoder_layers_per_block=2,
|
86 |
+
encoder_norm_num_groups=32,
|
87 |
+
encoder_out_channels=4,
|
88 |
+
decoder_add_attention=False,
|
89 |
+
decoder_block_out_channels=(320, 640, 1024, 1024),
|
90 |
+
decoder_down_block_types=(
|
91 |
+
"ResnetDownsampleBlock2D",
|
92 |
+
"ResnetDownsampleBlock2D",
|
93 |
+
"ResnetDownsampleBlock2D",
|
94 |
+
"ResnetDownsampleBlock2D",
|
95 |
+
),
|
96 |
+
decoder_downsample_padding=1,
|
97 |
+
decoder_in_channels=7,
|
98 |
+
decoder_layers_per_block=3,
|
99 |
+
decoder_norm_eps=1e-05,
|
100 |
+
decoder_norm_num_groups=32,
|
101 |
+
decoder_num_train_timesteps=1024,
|
102 |
+
decoder_out_channels=6,
|
103 |
+
decoder_resnet_time_scale_shift="scale_shift",
|
104 |
+
decoder_time_embedding_type="learned",
|
105 |
+
decoder_up_block_types=(
|
106 |
+
"ResnetUpsampleBlock2D",
|
107 |
+
"ResnetUpsampleBlock2D",
|
108 |
+
"ResnetUpsampleBlock2D",
|
109 |
+
"ResnetUpsampleBlock2D",
|
110 |
+
),
|
111 |
+
):
|
112 |
+
super().__init__()
|
113 |
+
self.encoder = Encoder(
|
114 |
+
act_fn=encoder_act_fn,
|
115 |
+
block_out_channels=encoder_block_out_channels,
|
116 |
+
double_z=encoder_double_z,
|
117 |
+
down_block_types=encoder_down_block_types,
|
118 |
+
in_channels=encoder_in_channels,
|
119 |
+
layers_per_block=encoder_layers_per_block,
|
120 |
+
norm_num_groups=encoder_norm_num_groups,
|
121 |
+
out_channels=encoder_out_channels,
|
122 |
+
)
|
123 |
+
|
124 |
+
self.decoder_unet = UNet2DModel(
|
125 |
+
add_attention=decoder_add_attention,
|
126 |
+
block_out_channels=decoder_block_out_channels,
|
127 |
+
down_block_types=decoder_down_block_types,
|
128 |
+
downsample_padding=decoder_downsample_padding,
|
129 |
+
in_channels=decoder_in_channels,
|
130 |
+
layers_per_block=decoder_layers_per_block,
|
131 |
+
norm_eps=decoder_norm_eps,
|
132 |
+
norm_num_groups=decoder_norm_num_groups,
|
133 |
+
num_train_timesteps=decoder_num_train_timesteps,
|
134 |
+
out_channels=decoder_out_channels,
|
135 |
+
resnet_time_scale_shift=decoder_resnet_time_scale_shift,
|
136 |
+
time_embedding_type=decoder_time_embedding_type,
|
137 |
+
up_block_types=decoder_up_block_types,
|
138 |
+
)
|
139 |
+
self.decoder_scheduler = ConsistencyDecoderScheduler()
|
140 |
+
self.register_to_config(block_out_channels=encoder_block_out_channels)
|
141 |
+
self.register_buffer(
|
142 |
+
"means",
|
143 |
+
torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],
|
144 |
+
persistent=False,
|
145 |
+
)
|
146 |
+
self.register_buffer(
|
147 |
+
"stds", torch.tensor([0.9654121, 1.0440036, 0.76147926, 0.77022034])[None, :, None, None], persistent=False
|
148 |
+
)
|
149 |
+
|
150 |
+
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
151 |
+
|
152 |
+
self.use_slicing = False
|
153 |
+
self.use_tiling = False
|
154 |
+
|
155 |
+
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_tiling
|
156 |
+
def enable_tiling(self, use_tiling: bool = True):
|
157 |
+
r"""
|
158 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
159 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
160 |
+
processing larger images.
|
161 |
+
"""
|
162 |
+
self.use_tiling = use_tiling
|
163 |
+
|
164 |
+
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_tiling
|
165 |
+
def disable_tiling(self):
|
166 |
+
r"""
|
167 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
168 |
+
decoding in one step.
|
169 |
+
"""
|
170 |
+
self.enable_tiling(False)
|
171 |
+
|
172 |
+
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_slicing
|
173 |
+
def enable_slicing(self):
|
174 |
+
r"""
|
175 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
176 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
177 |
+
"""
|
178 |
+
self.use_slicing = True
|
179 |
+
|
180 |
+
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_slicing
|
181 |
+
def disable_slicing(self):
|
182 |
+
r"""
|
183 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
184 |
+
decoding in one step.
|
185 |
+
"""
|
186 |
+
self.use_slicing = False
|
187 |
+
|
188 |
+
@property
|
189 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
190 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
191 |
+
r"""
|
192 |
+
Returns:
|
193 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
194 |
+
indexed by its weight name.
|
195 |
+
"""
|
196 |
+
# set recursively
|
197 |
+
processors = {}
|
198 |
+
|
199 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
200 |
+
if hasattr(module, "get_processor"):
|
201 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
202 |
+
|
203 |
+
for sub_name, child in module.named_children():
|
204 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
205 |
+
|
206 |
+
return processors
|
207 |
+
|
208 |
+
for name, module in self.named_children():
|
209 |
+
fn_recursive_add_processors(name, module, processors)
|
210 |
+
|
211 |
+
return processors
|
212 |
+
|
213 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
214 |
+
def set_attn_processor(
|
215 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
216 |
+
):
|
217 |
+
r"""
|
218 |
+
Sets the attention processor to use to compute attention.
|
219 |
+
|
220 |
+
Parameters:
|
221 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
222 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
223 |
+
for **all** `Attention` layers.
|
224 |
+
|
225 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
226 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
227 |
+
|
228 |
+
"""
|
229 |
+
count = len(self.attn_processors.keys())
|
230 |
+
|
231 |
+
if isinstance(processor, dict) and len(processor) != count:
|
232 |
+
raise ValueError(
|
233 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
234 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
235 |
+
)
|
236 |
+
|
237 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
238 |
+
if hasattr(module, "set_processor"):
|
239 |
+
if not isinstance(processor, dict):
|
240 |
+
module.set_processor(processor, _remove_lora=_remove_lora)
|
241 |
+
else:
|
242 |
+
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
243 |
+
|
244 |
+
for sub_name, child in module.named_children():
|
245 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
246 |
+
|
247 |
+
for name, module in self.named_children():
|
248 |
+
fn_recursive_attn_processor(name, module, processor)
|
249 |
+
|
250 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
251 |
+
def set_default_attn_processor(self):
|
252 |
+
"""
|
253 |
+
Disables custom attention processors and sets the default attention implementation.
|
254 |
+
"""
|
255 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
256 |
+
processor = AttnAddedKVProcessor()
|
257 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
258 |
+
processor = AttnProcessor()
|
259 |
+
else:
|
260 |
+
raise ValueError(
|
261 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
262 |
+
)
|
263 |
+
|
264 |
+
self.set_attn_processor(processor, _remove_lora=True)
|
265 |
+
|
266 |
+
@apply_forward_hook
|
267 |
+
def encode(
|
268 |
+
self, x: torch.FloatTensor, return_dict: bool = True
|
269 |
+
) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
|
270 |
+
"""
|
271 |
+
Encode a batch of images into latents.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
x (`torch.FloatTensor`): Input batch of images.
|
275 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
276 |
+
Whether to return a [`~models.consistecy_decoder_vae.ConsistencyDecoderOoutput`] instead of a plain
|
277 |
+
tuple.
|
278 |
+
|
279 |
+
Returns:
|
280 |
+
The latent representations of the encoded images. If `return_dict` is True, a
|
281 |
+
[`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a plain `tuple`
|
282 |
+
is returned.
|
283 |
+
"""
|
284 |
+
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
285 |
+
return self.tiled_encode(x, return_dict=return_dict)
|
286 |
+
|
287 |
+
if self.use_slicing and x.shape[0] > 1:
|
288 |
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
289 |
+
h = torch.cat(encoded_slices)
|
290 |
+
else:
|
291 |
+
h = self.encoder(x)
|
292 |
+
|
293 |
+
moments = self.quant_conv(h)
|
294 |
+
posterior = DiagonalGaussianDistribution(moments)
|
295 |
+
|
296 |
+
if not return_dict:
|
297 |
+
return (posterior,)
|
298 |
+
|
299 |
+
return ConsistencyDecoderVAEOutput(latent_dist=posterior)
|
300 |
+
|
301 |
+
@apply_forward_hook
|
302 |
+
def decode(
|
303 |
+
self,
|
304 |
+
z: torch.FloatTensor,
|
305 |
+
generator: Optional[torch.Generator] = None,
|
306 |
+
return_dict: bool = True,
|
307 |
+
num_inference_steps=2,
|
308 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
309 |
+
z = (z * self.config.scaling_factor - self.means) / self.stds
|
310 |
+
|
311 |
+
scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
|
312 |
+
z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)
|
313 |
+
|
314 |
+
batch_size, _, height, width = z.shape
|
315 |
+
|
316 |
+
self.decoder_scheduler.set_timesteps(num_inference_steps, device=self.device)
|
317 |
+
|
318 |
+
x_t = self.decoder_scheduler.init_noise_sigma * randn_tensor(
|
319 |
+
(batch_size, 3, height, width), generator=generator, dtype=z.dtype, device=z.device
|
320 |
+
)
|
321 |
+
|
322 |
+
for t in self.decoder_scheduler.timesteps:
|
323 |
+
model_input = torch.concat([self.decoder_scheduler.scale_model_input(x_t, t), z], dim=1)
|
324 |
+
model_output = self.decoder_unet(model_input, t).sample[:, :3, :, :]
|
325 |
+
prev_sample = self.decoder_scheduler.step(model_output, t, x_t, generator).prev_sample
|
326 |
+
x_t = prev_sample
|
327 |
+
|
328 |
+
x_0 = x_t
|
329 |
+
|
330 |
+
if not return_dict:
|
331 |
+
return (x_0,)
|
332 |
+
|
333 |
+
return DecoderOutput(sample=x_0)
|
334 |
+
|
335 |
+
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v
|
336 |
+
def blend_v(self, a, b, blend_extent):
|
337 |
+
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
338 |
+
for y in range(blend_extent):
|
339 |
+
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
340 |
+
return b
|
341 |
+
|
342 |
+
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h
|
343 |
+
def blend_h(self, a, b, blend_extent):
|
344 |
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
345 |
+
for x in range(blend_extent):
|
346 |
+
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
347 |
+
return b
|
348 |
+
|
349 |
+
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput:
|
350 |
+
r"""Encode a batch of images using a tiled encoder.
|
351 |
+
|
352 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
353 |
+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
354 |
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
355 |
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
356 |
+
output, but they should be much less noticeable.
|
357 |
+
|
358 |
+
Args:
|
359 |
+
x (`torch.FloatTensor`): Input batch of images.
|
360 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
361 |
+
Whether or not to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a
|
362 |
+
plain tuple.
|
363 |
+
|
364 |
+
Returns:
|
365 |
+
[`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] or `tuple`:
|
366 |
+
If return_dict is True, a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned,
|
367 |
+
otherwise a plain `tuple` is returned.
|
368 |
+
"""
|
369 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
370 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
371 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
372 |
+
|
373 |
+
# Split the image into 512x512 tiles and encode them separately.
|
374 |
+
rows = []
|
375 |
+
for i in range(0, x.shape[2], overlap_size):
|
376 |
+
row = []
|
377 |
+
for j in range(0, x.shape[3], overlap_size):
|
378 |
+
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
379 |
+
tile = self.encoder(tile)
|
380 |
+
tile = self.quant_conv(tile)
|
381 |
+
row.append(tile)
|
382 |
+
rows.append(row)
|
383 |
+
result_rows = []
|
384 |
+
for i, row in enumerate(rows):
|
385 |
+
result_row = []
|
386 |
+
for j, tile in enumerate(row):
|
387 |
+
# blend the above tile and the left tile
|
388 |
+
# to the current tile and add the current tile to the result row
|
389 |
+
if i > 0:
|
390 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
391 |
+
if j > 0:
|
392 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
393 |
+
result_row.append(tile[:, :, :row_limit, :row_limit])
|
394 |
+
result_rows.append(torch.cat(result_row, dim=3))
|
395 |
+
|
396 |
+
moments = torch.cat(result_rows, dim=2)
|
397 |
+
posterior = DiagonalGaussianDistribution(moments)
|
398 |
+
|
399 |
+
if not return_dict:
|
400 |
+
return (posterior,)
|
401 |
+
|
402 |
+
return ConsistencyDecoderVAEOutput(latent_dist=posterior)
|
403 |
+
|
404 |
+
def forward(
|
405 |
+
self,
|
406 |
+
sample: torch.FloatTensor,
|
407 |
+
sample_posterior: bool = False,
|
408 |
+
return_dict: bool = True,
|
409 |
+
generator: Optional[torch.Generator] = None,
|
410 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
411 |
+
r"""
|
412 |
+
Args:
|
413 |
+
sample (`torch.FloatTensor`): Input sample.
|
414 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
415 |
+
Whether to sample from the posterior.
|
416 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
417 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
418 |
+
"""
|
419 |
+
x = sample
|
420 |
+
posterior = self.encode(x).latent_dist
|
421 |
+
if sample_posterior:
|
422 |
+
z = posterior.sample(generator=generator)
|
423 |
+
else:
|
424 |
+
z = posterior.mode()
|
425 |
+
dec = self.decode(z, generator=generator).sample
|
426 |
+
|
427 |
+
if not return_dict:
|
428 |
+
return (dec,)
|
429 |
+
|
430 |
+
return DecoderOutput(sample=dec)
|
diffusers/models/controlnet.py
ADDED
@@ -0,0 +1,844 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
from torch.nn import functional as F
|
20 |
+
|
21 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from ..loaders import FromOriginalControlnetMixin
|
23 |
+
from ..utils import BaseOutput, logging
|
24 |
+
from .attention_processor import (
|
25 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
26 |
+
CROSS_ATTENTION_PROCESSORS,
|
27 |
+
AttentionProcessor,
|
28 |
+
AttnAddedKVProcessor,
|
29 |
+
AttnProcessor,
|
30 |
+
)
|
31 |
+
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
32 |
+
from .modeling_utils import ModelMixin
|
33 |
+
from .unet_2d_blocks import (
|
34 |
+
CrossAttnDownBlock2D,
|
35 |
+
DownBlock2D,
|
36 |
+
UNetMidBlock2DCrossAttn,
|
37 |
+
get_down_block,
|
38 |
+
)
|
39 |
+
from .unet_2d_condition import UNet2DConditionModel
|
40 |
+
|
41 |
+
|
42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class ControlNetOutput(BaseOutput):
|
47 |
+
"""
|
48 |
+
The output of [`ControlNetModel`].
|
49 |
+
|
50 |
+
Args:
|
51 |
+
down_block_res_samples (`tuple[torch.Tensor]`):
|
52 |
+
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
53 |
+
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
54 |
+
used to condition the original UNet's downsampling activations.
|
55 |
+
mid_down_block_re_sample (`torch.Tensor`):
|
56 |
+
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
57 |
+
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
58 |
+
Output can be used to condition the original UNet's middle block activation.
|
59 |
+
"""
|
60 |
+
|
61 |
+
down_block_res_samples: Tuple[torch.Tensor]
|
62 |
+
mid_block_res_sample: torch.Tensor
|
63 |
+
|
64 |
+
|
65 |
+
class ControlNetConditioningEmbedding(nn.Module):
|
66 |
+
"""
|
67 |
+
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
68 |
+
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
69 |
+
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
70 |
+
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
71 |
+
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
72 |
+
model) to encode image-space conditions ... into feature maps ..."
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
conditioning_embedding_channels: int,
|
78 |
+
conditioning_channels: int = 3,
|
79 |
+
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
80 |
+
):
|
81 |
+
super().__init__()
|
82 |
+
|
83 |
+
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
84 |
+
|
85 |
+
self.blocks = nn.ModuleList([])
|
86 |
+
|
87 |
+
for i in range(len(block_out_channels) - 1):
|
88 |
+
channel_in = block_out_channels[i]
|
89 |
+
channel_out = block_out_channels[i + 1]
|
90 |
+
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
91 |
+
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
92 |
+
|
93 |
+
self.conv_out = zero_module(
|
94 |
+
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
95 |
+
)
|
96 |
+
|
97 |
+
def forward(self, conditioning):
|
98 |
+
embedding = self.conv_in(conditioning)
|
99 |
+
embedding = F.silu(embedding)
|
100 |
+
|
101 |
+
for block in self.blocks:
|
102 |
+
embedding = block(embedding)
|
103 |
+
embedding = F.silu(embedding)
|
104 |
+
|
105 |
+
embedding = self.conv_out(embedding)
|
106 |
+
|
107 |
+
return embedding
|
108 |
+
|
109 |
+
|
110 |
+
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
111 |
+
"""
|
112 |
+
A ControlNet model.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
in_channels (`int`, defaults to 4):
|
116 |
+
The number of channels in the input sample.
|
117 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
118 |
+
Whether to flip the sin to cos in the time embedding.
|
119 |
+
freq_shift (`int`, defaults to 0):
|
120 |
+
The frequency shift to apply to the time embedding.
|
121 |
+
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
122 |
+
The tuple of downsample blocks to use.
|
123 |
+
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
124 |
+
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
125 |
+
The tuple of output channels for each block.
|
126 |
+
layers_per_block (`int`, defaults to 2):
|
127 |
+
The number of layers per block.
|
128 |
+
downsample_padding (`int`, defaults to 1):
|
129 |
+
The padding to use for the downsampling convolution.
|
130 |
+
mid_block_scale_factor (`float`, defaults to 1):
|
131 |
+
The scale factor to use for the mid block.
|
132 |
+
act_fn (`str`, defaults to "silu"):
|
133 |
+
The activation function to use.
|
134 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
135 |
+
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
136 |
+
in post-processing.
|
137 |
+
norm_eps (`float`, defaults to 1e-5):
|
138 |
+
The epsilon to use for the normalization.
|
139 |
+
cross_attention_dim (`int`, defaults to 1280):
|
140 |
+
The dimension of the cross attention features.
|
141 |
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
142 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
143 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
144 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
145 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
146 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
147 |
+
dimension to `cross_attention_dim`.
|
148 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
149 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
150 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
151 |
+
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
152 |
+
The dimension of the attention heads.
|
153 |
+
use_linear_projection (`bool`, defaults to `False`):
|
154 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
155 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
156 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
157 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
158 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
159 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
160 |
+
num_class_embeds (`int`, *optional*, defaults to 0):
|
161 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
162 |
+
class conditioning with `class_embed_type` equal to `None`.
|
163 |
+
upcast_attention (`bool`, defaults to `False`):
|
164 |
+
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
165 |
+
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
166 |
+
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
167 |
+
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
168 |
+
`class_embed_type="projection"`.
|
169 |
+
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
170 |
+
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
171 |
+
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
172 |
+
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
173 |
+
global_pool_conditions (`bool`, defaults to `False`):
|
174 |
+
"""
|
175 |
+
|
176 |
+
_supports_gradient_checkpointing = True
|
177 |
+
|
178 |
+
@register_to_config
|
179 |
+
def __init__(
|
180 |
+
self,
|
181 |
+
in_channels: int = 4,
|
182 |
+
conditioning_channels: int = 3,
|
183 |
+
flip_sin_to_cos: bool = True,
|
184 |
+
freq_shift: int = 0,
|
185 |
+
down_block_types: Tuple[str] = (
|
186 |
+
"CrossAttnDownBlock2D",
|
187 |
+
"CrossAttnDownBlock2D",
|
188 |
+
"CrossAttnDownBlock2D",
|
189 |
+
"DownBlock2D",
|
190 |
+
),
|
191 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
192 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
193 |
+
layers_per_block: int = 2,
|
194 |
+
downsample_padding: int = 1,
|
195 |
+
mid_block_scale_factor: float = 1,
|
196 |
+
act_fn: str = "silu",
|
197 |
+
norm_num_groups: Optional[int] = 32,
|
198 |
+
norm_eps: float = 1e-5,
|
199 |
+
cross_attention_dim: int = 1280,
|
200 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
201 |
+
encoder_hid_dim: Optional[int] = None,
|
202 |
+
encoder_hid_dim_type: Optional[str] = None,
|
203 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
204 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
205 |
+
use_linear_projection: bool = False,
|
206 |
+
class_embed_type: Optional[str] = None,
|
207 |
+
addition_embed_type: Optional[str] = None,
|
208 |
+
addition_time_embed_dim: Optional[int] = None,
|
209 |
+
num_class_embeds: Optional[int] = None,
|
210 |
+
upcast_attention: bool = False,
|
211 |
+
resnet_time_scale_shift: str = "default",
|
212 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
213 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
214 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
215 |
+
global_pool_conditions: bool = False,
|
216 |
+
addition_embed_type_num_heads=64,
|
217 |
+
):
|
218 |
+
super().__init__()
|
219 |
+
|
220 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
221 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
222 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
223 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
224 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
225 |
+
# which is why we correct for the naming here.
|
226 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
227 |
+
|
228 |
+
# Check inputs
|
229 |
+
if len(block_out_channels) != len(down_block_types):
|
230 |
+
raise ValueError(
|
231 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
232 |
+
)
|
233 |
+
|
234 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
235 |
+
raise ValueError(
|
236 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
237 |
+
)
|
238 |
+
|
239 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
240 |
+
raise ValueError(
|
241 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
242 |
+
)
|
243 |
+
|
244 |
+
if isinstance(transformer_layers_per_block, int):
|
245 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
246 |
+
|
247 |
+
# input
|
248 |
+
conv_in_kernel = 3
|
249 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
250 |
+
self.conv_in = nn.Conv2d(
|
251 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
252 |
+
)
|
253 |
+
|
254 |
+
# time
|
255 |
+
time_embed_dim = block_out_channels[0] * 4
|
256 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
257 |
+
timestep_input_dim = block_out_channels[0]
|
258 |
+
self.time_embedding = TimestepEmbedding(
|
259 |
+
timestep_input_dim,
|
260 |
+
time_embed_dim,
|
261 |
+
act_fn=act_fn,
|
262 |
+
)
|
263 |
+
|
264 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
265 |
+
encoder_hid_dim_type = "text_proj"
|
266 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
267 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
268 |
+
|
269 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
270 |
+
raise ValueError(
|
271 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
272 |
+
)
|
273 |
+
|
274 |
+
if encoder_hid_dim_type == "text_proj":
|
275 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
276 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
277 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
278 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
279 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
280 |
+
self.encoder_hid_proj = TextImageProjection(
|
281 |
+
text_embed_dim=encoder_hid_dim,
|
282 |
+
image_embed_dim=cross_attention_dim,
|
283 |
+
cross_attention_dim=cross_attention_dim,
|
284 |
+
)
|
285 |
+
|
286 |
+
elif encoder_hid_dim_type is not None:
|
287 |
+
raise ValueError(
|
288 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
289 |
+
)
|
290 |
+
else:
|
291 |
+
self.encoder_hid_proj = None
|
292 |
+
|
293 |
+
# class embedding
|
294 |
+
if class_embed_type is None and num_class_embeds is not None:
|
295 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
296 |
+
elif class_embed_type == "timestep":
|
297 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
298 |
+
elif class_embed_type == "identity":
|
299 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
300 |
+
elif class_embed_type == "projection":
|
301 |
+
if projection_class_embeddings_input_dim is None:
|
302 |
+
raise ValueError(
|
303 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
304 |
+
)
|
305 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
306 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
307 |
+
# 2. it projects from an arbitrary input dimension.
|
308 |
+
#
|
309 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
310 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
311 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
312 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
313 |
+
else:
|
314 |
+
self.class_embedding = None
|
315 |
+
|
316 |
+
if addition_embed_type == "text":
|
317 |
+
if encoder_hid_dim is not None:
|
318 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
319 |
+
else:
|
320 |
+
text_time_embedding_from_dim = cross_attention_dim
|
321 |
+
|
322 |
+
self.add_embedding = TextTimeEmbedding(
|
323 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
324 |
+
)
|
325 |
+
elif addition_embed_type == "text_image":
|
326 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
327 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
328 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
329 |
+
self.add_embedding = TextImageTimeEmbedding(
|
330 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
331 |
+
)
|
332 |
+
elif addition_embed_type == "text_time":
|
333 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
334 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
335 |
+
|
336 |
+
elif addition_embed_type is not None:
|
337 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
338 |
+
|
339 |
+
# control net conditioning embedding
|
340 |
+
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
341 |
+
conditioning_embedding_channels=block_out_channels[0],
|
342 |
+
block_out_channels=conditioning_embedding_out_channels,
|
343 |
+
conditioning_channels=conditioning_channels,
|
344 |
+
)
|
345 |
+
|
346 |
+
self.down_blocks = nn.ModuleList([])
|
347 |
+
self.controlnet_down_blocks = nn.ModuleList([])
|
348 |
+
|
349 |
+
if isinstance(only_cross_attention, bool):
|
350 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
351 |
+
|
352 |
+
if isinstance(attention_head_dim, int):
|
353 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
354 |
+
|
355 |
+
if isinstance(num_attention_heads, int):
|
356 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
357 |
+
|
358 |
+
# down
|
359 |
+
output_channel = block_out_channels[0]
|
360 |
+
|
361 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
362 |
+
controlnet_block = zero_module(controlnet_block)
|
363 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
364 |
+
|
365 |
+
for i, down_block_type in enumerate(down_block_types):
|
366 |
+
input_channel = output_channel
|
367 |
+
output_channel = block_out_channels[i]
|
368 |
+
is_final_block = i == len(block_out_channels) - 1
|
369 |
+
|
370 |
+
down_block = get_down_block(
|
371 |
+
down_block_type,
|
372 |
+
num_layers=layers_per_block,
|
373 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
374 |
+
in_channels=input_channel,
|
375 |
+
out_channels=output_channel,
|
376 |
+
temb_channels=time_embed_dim,
|
377 |
+
add_downsample=not is_final_block,
|
378 |
+
resnet_eps=norm_eps,
|
379 |
+
resnet_act_fn=act_fn,
|
380 |
+
resnet_groups=norm_num_groups,
|
381 |
+
cross_attention_dim=cross_attention_dim,
|
382 |
+
num_attention_heads=num_attention_heads[i],
|
383 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
384 |
+
downsample_padding=downsample_padding,
|
385 |
+
use_linear_projection=use_linear_projection,
|
386 |
+
only_cross_attention=only_cross_attention[i],
|
387 |
+
upcast_attention=upcast_attention,
|
388 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
389 |
+
)
|
390 |
+
self.down_blocks.append(down_block)
|
391 |
+
|
392 |
+
for _ in range(layers_per_block):
|
393 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
394 |
+
controlnet_block = zero_module(controlnet_block)
|
395 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
396 |
+
|
397 |
+
if not is_final_block:
|
398 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
399 |
+
controlnet_block = zero_module(controlnet_block)
|
400 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
401 |
+
|
402 |
+
# mid
|
403 |
+
mid_block_channel = block_out_channels[-1]
|
404 |
+
|
405 |
+
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
406 |
+
controlnet_block = zero_module(controlnet_block)
|
407 |
+
self.controlnet_mid_block = controlnet_block
|
408 |
+
|
409 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
410 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
411 |
+
in_channels=mid_block_channel,
|
412 |
+
temb_channels=time_embed_dim,
|
413 |
+
resnet_eps=norm_eps,
|
414 |
+
resnet_act_fn=act_fn,
|
415 |
+
output_scale_factor=mid_block_scale_factor,
|
416 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
417 |
+
cross_attention_dim=cross_attention_dim,
|
418 |
+
num_attention_heads=num_attention_heads[-1],
|
419 |
+
resnet_groups=norm_num_groups,
|
420 |
+
use_linear_projection=use_linear_projection,
|
421 |
+
upcast_attention=upcast_attention,
|
422 |
+
)
|
423 |
+
|
424 |
+
@classmethod
|
425 |
+
def from_unet(
|
426 |
+
cls,
|
427 |
+
unet: UNet2DConditionModel,
|
428 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
429 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
430 |
+
load_weights_from_unet: bool = True,
|
431 |
+
):
|
432 |
+
r"""
|
433 |
+
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
|
434 |
+
|
435 |
+
Parameters:
|
436 |
+
unet (`UNet2DConditionModel`):
|
437 |
+
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
438 |
+
where applicable.
|
439 |
+
"""
|
440 |
+
transformer_layers_per_block = (
|
441 |
+
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
442 |
+
)
|
443 |
+
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
444 |
+
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
445 |
+
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
446 |
+
addition_time_embed_dim = (
|
447 |
+
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
448 |
+
)
|
449 |
+
|
450 |
+
controlnet = cls(
|
451 |
+
encoder_hid_dim=encoder_hid_dim,
|
452 |
+
encoder_hid_dim_type=encoder_hid_dim_type,
|
453 |
+
addition_embed_type=addition_embed_type,
|
454 |
+
addition_time_embed_dim=addition_time_embed_dim,
|
455 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
456 |
+
in_channels=unet.config.in_channels,
|
457 |
+
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
458 |
+
freq_shift=unet.config.freq_shift,
|
459 |
+
down_block_types=unet.config.down_block_types,
|
460 |
+
only_cross_attention=unet.config.only_cross_attention,
|
461 |
+
block_out_channels=unet.config.block_out_channels,
|
462 |
+
layers_per_block=unet.config.layers_per_block,
|
463 |
+
downsample_padding=unet.config.downsample_padding,
|
464 |
+
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
465 |
+
act_fn=unet.config.act_fn,
|
466 |
+
norm_num_groups=unet.config.norm_num_groups,
|
467 |
+
norm_eps=unet.config.norm_eps,
|
468 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
469 |
+
attention_head_dim=unet.config.attention_head_dim,
|
470 |
+
num_attention_heads=unet.config.num_attention_heads,
|
471 |
+
use_linear_projection=unet.config.use_linear_projection,
|
472 |
+
class_embed_type=unet.config.class_embed_type,
|
473 |
+
num_class_embeds=unet.config.num_class_embeds,
|
474 |
+
upcast_attention=unet.config.upcast_attention,
|
475 |
+
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
476 |
+
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
477 |
+
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
478 |
+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
479 |
+
)
|
480 |
+
|
481 |
+
if load_weights_from_unet:
|
482 |
+
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
483 |
+
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
484 |
+
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
485 |
+
|
486 |
+
if controlnet.class_embedding:
|
487 |
+
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
488 |
+
|
489 |
+
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
|
490 |
+
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
|
491 |
+
|
492 |
+
return controlnet
|
493 |
+
|
494 |
+
@property
|
495 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
496 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
497 |
+
r"""
|
498 |
+
Returns:
|
499 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
500 |
+
indexed by its weight name.
|
501 |
+
"""
|
502 |
+
# set recursively
|
503 |
+
processors = {}
|
504 |
+
|
505 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
506 |
+
if hasattr(module, "get_processor"):
|
507 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
508 |
+
|
509 |
+
for sub_name, child in module.named_children():
|
510 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
511 |
+
|
512 |
+
return processors
|
513 |
+
|
514 |
+
for name, module in self.named_children():
|
515 |
+
fn_recursive_add_processors(name, module, processors)
|
516 |
+
|
517 |
+
return processors
|
518 |
+
|
519 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
520 |
+
def set_attn_processor(
|
521 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
522 |
+
):
|
523 |
+
r"""
|
524 |
+
Sets the attention processor to use to compute attention.
|
525 |
+
|
526 |
+
Parameters:
|
527 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
528 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
529 |
+
for **all** `Attention` layers.
|
530 |
+
|
531 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
532 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
533 |
+
|
534 |
+
"""
|
535 |
+
count = len(self.attn_processors.keys())
|
536 |
+
|
537 |
+
if isinstance(processor, dict) and len(processor) != count:
|
538 |
+
raise ValueError(
|
539 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
540 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
541 |
+
)
|
542 |
+
|
543 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
544 |
+
if hasattr(module, "set_processor"):
|
545 |
+
if not isinstance(processor, dict):
|
546 |
+
module.set_processor(processor, _remove_lora=_remove_lora)
|
547 |
+
else:
|
548 |
+
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
549 |
+
|
550 |
+
for sub_name, child in module.named_children():
|
551 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
552 |
+
|
553 |
+
for name, module in self.named_children():
|
554 |
+
fn_recursive_attn_processor(name, module, processor)
|
555 |
+
|
556 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
557 |
+
def set_default_attn_processor(self):
|
558 |
+
"""
|
559 |
+
Disables custom attention processors and sets the default attention implementation.
|
560 |
+
"""
|
561 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
562 |
+
processor = AttnAddedKVProcessor()
|
563 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
564 |
+
processor = AttnProcessor()
|
565 |
+
else:
|
566 |
+
raise ValueError(
|
567 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
568 |
+
)
|
569 |
+
|
570 |
+
self.set_attn_processor(processor, _remove_lora=True)
|
571 |
+
|
572 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
573 |
+
def set_attention_slice(self, slice_size):
|
574 |
+
r"""
|
575 |
+
Enable sliced attention computation.
|
576 |
+
|
577 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
578 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
579 |
+
|
580 |
+
Args:
|
581 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
582 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
583 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
584 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
585 |
+
must be a multiple of `slice_size`.
|
586 |
+
"""
|
587 |
+
sliceable_head_dims = []
|
588 |
+
|
589 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
590 |
+
if hasattr(module, "set_attention_slice"):
|
591 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
592 |
+
|
593 |
+
for child in module.children():
|
594 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
595 |
+
|
596 |
+
# retrieve number of attention layers
|
597 |
+
for module in self.children():
|
598 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
599 |
+
|
600 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
601 |
+
|
602 |
+
if slice_size == "auto":
|
603 |
+
# half the attention head size is usually a good trade-off between
|
604 |
+
# speed and memory
|
605 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
606 |
+
elif slice_size == "max":
|
607 |
+
# make smallest slice possible
|
608 |
+
slice_size = num_sliceable_layers * [1]
|
609 |
+
|
610 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
611 |
+
|
612 |
+
if len(slice_size) != len(sliceable_head_dims):
|
613 |
+
raise ValueError(
|
614 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
615 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
616 |
+
)
|
617 |
+
|
618 |
+
for i in range(len(slice_size)):
|
619 |
+
size = slice_size[i]
|
620 |
+
dim = sliceable_head_dims[i]
|
621 |
+
if size is not None and size > dim:
|
622 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
623 |
+
|
624 |
+
# Recursively walk through all the children.
|
625 |
+
# Any children which exposes the set_attention_slice method
|
626 |
+
# gets the message
|
627 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
628 |
+
if hasattr(module, "set_attention_slice"):
|
629 |
+
module.set_attention_slice(slice_size.pop())
|
630 |
+
|
631 |
+
for child in module.children():
|
632 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
633 |
+
|
634 |
+
reversed_slice_size = list(reversed(slice_size))
|
635 |
+
for module in self.children():
|
636 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
637 |
+
|
638 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
639 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
640 |
+
module.gradient_checkpointing = value
|
641 |
+
|
642 |
+
def forward(
|
643 |
+
self,
|
644 |
+
sample: torch.FloatTensor,
|
645 |
+
timestep: Union[torch.Tensor, float, int],
|
646 |
+
encoder_hidden_states: torch.Tensor,
|
647 |
+
controlnet_cond: torch.FloatTensor,
|
648 |
+
conditioning_scale: float = 1.0,
|
649 |
+
class_labels: Optional[torch.Tensor] = None,
|
650 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
651 |
+
attention_mask: Optional[torch.Tensor] = None,
|
652 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
653 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
654 |
+
guess_mode: bool = False,
|
655 |
+
return_dict: bool = True,
|
656 |
+
) -> Union[ControlNetOutput, Tuple]:
|
657 |
+
"""
|
658 |
+
The [`ControlNetModel`] forward method.
|
659 |
+
|
660 |
+
Args:
|
661 |
+
sample (`torch.FloatTensor`):
|
662 |
+
The noisy input tensor.
|
663 |
+
timestep (`Union[torch.Tensor, float, int]`):
|
664 |
+
The number of timesteps to denoise an input.
|
665 |
+
encoder_hidden_states (`torch.Tensor`):
|
666 |
+
The encoder hidden states.
|
667 |
+
controlnet_cond (`torch.FloatTensor`):
|
668 |
+
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
669 |
+
conditioning_scale (`float`, defaults to `1.0`):
|
670 |
+
The scale factor for ControlNet outputs.
|
671 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
672 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
673 |
+
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
674 |
+
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
675 |
+
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
676 |
+
embeddings.
|
677 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
678 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
679 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
680 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
681 |
+
added_cond_kwargs (`dict`):
|
682 |
+
Additional conditions for the Stable Diffusion XL UNet.
|
683 |
+
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
684 |
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
685 |
+
guess_mode (`bool`, defaults to `False`):
|
686 |
+
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
687 |
+
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
688 |
+
return_dict (`bool`, defaults to `True`):
|
689 |
+
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
690 |
+
|
691 |
+
Returns:
|
692 |
+
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
693 |
+
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
694 |
+
returned where the first element is the sample tensor.
|
695 |
+
"""
|
696 |
+
# check channel order
|
697 |
+
channel_order = self.config.controlnet_conditioning_channel_order
|
698 |
+
|
699 |
+
if channel_order == "rgb":
|
700 |
+
# in rgb order by default
|
701 |
+
...
|
702 |
+
elif channel_order == "bgr":
|
703 |
+
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
704 |
+
else:
|
705 |
+
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
706 |
+
|
707 |
+
# prepare attention_mask
|
708 |
+
if attention_mask is not None:
|
709 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
710 |
+
attention_mask = attention_mask.unsqueeze(1)
|
711 |
+
|
712 |
+
# 1. time
|
713 |
+
timesteps = timestep
|
714 |
+
if not torch.is_tensor(timesteps):
|
715 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
716 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
717 |
+
is_mps = sample.device.type == "mps"
|
718 |
+
if isinstance(timestep, float):
|
719 |
+
dtype = torch.float32 if is_mps else torch.float64
|
720 |
+
else:
|
721 |
+
dtype = torch.int32 if is_mps else torch.int64
|
722 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
723 |
+
elif len(timesteps.shape) == 0:
|
724 |
+
timesteps = timesteps[None].to(sample.device)
|
725 |
+
|
726 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
727 |
+
timesteps = timesteps.expand(sample.shape[0])
|
728 |
+
|
729 |
+
t_emb = self.time_proj(timesteps)
|
730 |
+
|
731 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
732 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
733 |
+
# there might be better ways to encapsulate this.
|
734 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
735 |
+
|
736 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
737 |
+
aug_emb = None
|
738 |
+
|
739 |
+
if self.class_embedding is not None:
|
740 |
+
if class_labels is None:
|
741 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
742 |
+
|
743 |
+
if self.config.class_embed_type == "timestep":
|
744 |
+
class_labels = self.time_proj(class_labels)
|
745 |
+
|
746 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
747 |
+
emb = emb + class_emb
|
748 |
+
|
749 |
+
if self.config.addition_embed_type is not None:
|
750 |
+
if self.config.addition_embed_type == "text":
|
751 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
752 |
+
|
753 |
+
elif self.config.addition_embed_type == "text_time":
|
754 |
+
if "text_embeds" not in added_cond_kwargs:
|
755 |
+
raise ValueError(
|
756 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
757 |
+
)
|
758 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
759 |
+
if "time_ids" not in added_cond_kwargs:
|
760 |
+
raise ValueError(
|
761 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
762 |
+
)
|
763 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
764 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
765 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
766 |
+
|
767 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
768 |
+
add_embeds = add_embeds.to(emb.dtype)
|
769 |
+
aug_emb = self.add_embedding(add_embeds)
|
770 |
+
|
771 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
772 |
+
|
773 |
+
# 2. pre-process
|
774 |
+
sample = self.conv_in(sample)
|
775 |
+
|
776 |
+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
777 |
+
sample = sample + controlnet_cond
|
778 |
+
|
779 |
+
# 3. down
|
780 |
+
down_block_res_samples = (sample,)
|
781 |
+
for downsample_block in self.down_blocks:
|
782 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
783 |
+
sample, res_samples = downsample_block(
|
784 |
+
hidden_states=sample,
|
785 |
+
temb=emb,
|
786 |
+
encoder_hidden_states=encoder_hidden_states,
|
787 |
+
attention_mask=attention_mask,
|
788 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
789 |
+
)
|
790 |
+
else:
|
791 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
792 |
+
|
793 |
+
down_block_res_samples += res_samples
|
794 |
+
|
795 |
+
# 4. mid
|
796 |
+
if self.mid_block is not None:
|
797 |
+
sample = self.mid_block(
|
798 |
+
sample,
|
799 |
+
emb,
|
800 |
+
encoder_hidden_states=encoder_hidden_states,
|
801 |
+
attention_mask=attention_mask,
|
802 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
803 |
+
)
|
804 |
+
|
805 |
+
# 5. Control net blocks
|
806 |
+
|
807 |
+
controlnet_down_block_res_samples = ()
|
808 |
+
|
809 |
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
810 |
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
811 |
+
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
812 |
+
|
813 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
814 |
+
|
815 |
+
mid_block_res_sample = self.controlnet_mid_block(sample)
|
816 |
+
|
817 |
+
# 6. scaling
|
818 |
+
if guess_mode and not self.config.global_pool_conditions:
|
819 |
+
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
820 |
+
scales = scales * conditioning_scale
|
821 |
+
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
822 |
+
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
823 |
+
else:
|
824 |
+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
825 |
+
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
826 |
+
|
827 |
+
if self.config.global_pool_conditions:
|
828 |
+
down_block_res_samples = [
|
829 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
830 |
+
]
|
831 |
+
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
832 |
+
|
833 |
+
if not return_dict:
|
834 |
+
return (down_block_res_samples, mid_block_res_sample)
|
835 |
+
|
836 |
+
return ControlNetOutput(
|
837 |
+
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
838 |
+
)
|
839 |
+
|
840 |
+
|
841 |
+
def zero_module(module):
|
842 |
+
for p in module.parameters():
|
843 |
+
nn.init.zeros_(p)
|
844 |
+
return module
|
diffusers/models/controlnet_flax.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Optional, Tuple, Union
|
15 |
+
|
16 |
+
import flax
|
17 |
+
import flax.linen as nn
|
18 |
+
import jax
|
19 |
+
import jax.numpy as jnp
|
20 |
+
from flax.core.frozen_dict import FrozenDict
|
21 |
+
|
22 |
+
from ..configuration_utils import ConfigMixin, flax_register_to_config
|
23 |
+
from ..utils import BaseOutput
|
24 |
+
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
25 |
+
from .modeling_flax_utils import FlaxModelMixin
|
26 |
+
from .unet_2d_blocks_flax import (
|
27 |
+
FlaxCrossAttnDownBlock2D,
|
28 |
+
FlaxDownBlock2D,
|
29 |
+
FlaxUNetMidBlock2DCrossAttn,
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
@flax.struct.dataclass
|
34 |
+
class FlaxControlNetOutput(BaseOutput):
|
35 |
+
"""
|
36 |
+
The output of [`FlaxControlNetModel`].
|
37 |
+
|
38 |
+
Args:
|
39 |
+
down_block_res_samples (`jnp.ndarray`):
|
40 |
+
mid_block_res_sample (`jnp.ndarray`):
|
41 |
+
"""
|
42 |
+
|
43 |
+
down_block_res_samples: jnp.ndarray
|
44 |
+
mid_block_res_sample: jnp.ndarray
|
45 |
+
|
46 |
+
|
47 |
+
class FlaxControlNetConditioningEmbedding(nn.Module):
|
48 |
+
conditioning_embedding_channels: int
|
49 |
+
block_out_channels: Tuple[int] = (16, 32, 96, 256)
|
50 |
+
dtype: jnp.dtype = jnp.float32
|
51 |
+
|
52 |
+
def setup(self):
|
53 |
+
self.conv_in = nn.Conv(
|
54 |
+
self.block_out_channels[0],
|
55 |
+
kernel_size=(3, 3),
|
56 |
+
padding=((1, 1), (1, 1)),
|
57 |
+
dtype=self.dtype,
|
58 |
+
)
|
59 |
+
|
60 |
+
blocks = []
|
61 |
+
for i in range(len(self.block_out_channels) - 1):
|
62 |
+
channel_in = self.block_out_channels[i]
|
63 |
+
channel_out = self.block_out_channels[i + 1]
|
64 |
+
conv1 = nn.Conv(
|
65 |
+
channel_in,
|
66 |
+
kernel_size=(3, 3),
|
67 |
+
padding=((1, 1), (1, 1)),
|
68 |
+
dtype=self.dtype,
|
69 |
+
)
|
70 |
+
blocks.append(conv1)
|
71 |
+
conv2 = nn.Conv(
|
72 |
+
channel_out,
|
73 |
+
kernel_size=(3, 3),
|
74 |
+
strides=(2, 2),
|
75 |
+
padding=((1, 1), (1, 1)),
|
76 |
+
dtype=self.dtype,
|
77 |
+
)
|
78 |
+
blocks.append(conv2)
|
79 |
+
self.blocks = blocks
|
80 |
+
|
81 |
+
self.conv_out = nn.Conv(
|
82 |
+
self.conditioning_embedding_channels,
|
83 |
+
kernel_size=(3, 3),
|
84 |
+
padding=((1, 1), (1, 1)),
|
85 |
+
kernel_init=nn.initializers.zeros_init(),
|
86 |
+
bias_init=nn.initializers.zeros_init(),
|
87 |
+
dtype=self.dtype,
|
88 |
+
)
|
89 |
+
|
90 |
+
def __call__(self, conditioning):
|
91 |
+
embedding = self.conv_in(conditioning)
|
92 |
+
embedding = nn.silu(embedding)
|
93 |
+
|
94 |
+
for block in self.blocks:
|
95 |
+
embedding = block(embedding)
|
96 |
+
embedding = nn.silu(embedding)
|
97 |
+
|
98 |
+
embedding = self.conv_out(embedding)
|
99 |
+
|
100 |
+
return embedding
|
101 |
+
|
102 |
+
|
103 |
+
@flax_register_to_config
|
104 |
+
class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
105 |
+
r"""
|
106 |
+
A ControlNet model.
|
107 |
+
|
108 |
+
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods
|
109 |
+
implemented for all models (such as downloading or saving).
|
110 |
+
|
111 |
+
This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
112 |
+
subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
|
113 |
+
general usage and behavior.
|
114 |
+
|
115 |
+
Inherent JAX features such as the following are supported:
|
116 |
+
|
117 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
118 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
119 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
120 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
121 |
+
|
122 |
+
Parameters:
|
123 |
+
sample_size (`int`, *optional*):
|
124 |
+
The size of the input sample.
|
125 |
+
in_channels (`int`, *optional*, defaults to 4):
|
126 |
+
The number of channels in the input sample.
|
127 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
|
128 |
+
The tuple of downsample blocks to use.
|
129 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
130 |
+
The tuple of output channels for each block.
|
131 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
132 |
+
The number of layers per block.
|
133 |
+
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
|
134 |
+
The dimension of the attention heads.
|
135 |
+
num_attention_heads (`int` or `Tuple[int]`, *optional*):
|
136 |
+
The number of attention heads.
|
137 |
+
cross_attention_dim (`int`, *optional*, defaults to 768):
|
138 |
+
The dimension of the cross attention features.
|
139 |
+
dropout (`float`, *optional*, defaults to 0):
|
140 |
+
Dropout probability for down, up and bottleneck blocks.
|
141 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
142 |
+
Whether to flip the sin to cos in the time embedding.
|
143 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
144 |
+
controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
|
145 |
+
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
146 |
+
conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
|
147 |
+
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
148 |
+
"""
|
149 |
+
sample_size: int = 32
|
150 |
+
in_channels: int = 4
|
151 |
+
down_block_types: Tuple[str] = (
|
152 |
+
"CrossAttnDownBlock2D",
|
153 |
+
"CrossAttnDownBlock2D",
|
154 |
+
"CrossAttnDownBlock2D",
|
155 |
+
"DownBlock2D",
|
156 |
+
)
|
157 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False
|
158 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
159 |
+
layers_per_block: int = 2
|
160 |
+
attention_head_dim: Union[int, Tuple[int]] = 8
|
161 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None
|
162 |
+
cross_attention_dim: int = 1280
|
163 |
+
dropout: float = 0.0
|
164 |
+
use_linear_projection: bool = False
|
165 |
+
dtype: jnp.dtype = jnp.float32
|
166 |
+
flip_sin_to_cos: bool = True
|
167 |
+
freq_shift: int = 0
|
168 |
+
controlnet_conditioning_channel_order: str = "rgb"
|
169 |
+
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
|
170 |
+
|
171 |
+
def init_weights(self, rng: jax.Array) -> FrozenDict:
|
172 |
+
# init input tensors
|
173 |
+
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
174 |
+
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
175 |
+
timesteps = jnp.ones((1,), dtype=jnp.int32)
|
176 |
+
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
|
177 |
+
controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8)
|
178 |
+
controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32)
|
179 |
+
|
180 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
181 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
182 |
+
|
183 |
+
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
|
184 |
+
|
185 |
+
def setup(self):
|
186 |
+
block_out_channels = self.block_out_channels
|
187 |
+
time_embed_dim = block_out_channels[0] * 4
|
188 |
+
|
189 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
190 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
191 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
192 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
193 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
194 |
+
# which is why we correct for the naming here.
|
195 |
+
num_attention_heads = self.num_attention_heads or self.attention_head_dim
|
196 |
+
|
197 |
+
# input
|
198 |
+
self.conv_in = nn.Conv(
|
199 |
+
block_out_channels[0],
|
200 |
+
kernel_size=(3, 3),
|
201 |
+
strides=(1, 1),
|
202 |
+
padding=((1, 1), (1, 1)),
|
203 |
+
dtype=self.dtype,
|
204 |
+
)
|
205 |
+
|
206 |
+
# time
|
207 |
+
self.time_proj = FlaxTimesteps(
|
208 |
+
block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
|
209 |
+
)
|
210 |
+
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
|
211 |
+
|
212 |
+
self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding(
|
213 |
+
conditioning_embedding_channels=block_out_channels[0],
|
214 |
+
block_out_channels=self.conditioning_embedding_out_channels,
|
215 |
+
)
|
216 |
+
|
217 |
+
only_cross_attention = self.only_cross_attention
|
218 |
+
if isinstance(only_cross_attention, bool):
|
219 |
+
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
|
220 |
+
|
221 |
+
if isinstance(num_attention_heads, int):
|
222 |
+
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
|
223 |
+
|
224 |
+
# down
|
225 |
+
down_blocks = []
|
226 |
+
controlnet_down_blocks = []
|
227 |
+
|
228 |
+
output_channel = block_out_channels[0]
|
229 |
+
|
230 |
+
controlnet_block = nn.Conv(
|
231 |
+
output_channel,
|
232 |
+
kernel_size=(1, 1),
|
233 |
+
padding="VALID",
|
234 |
+
kernel_init=nn.initializers.zeros_init(),
|
235 |
+
bias_init=nn.initializers.zeros_init(),
|
236 |
+
dtype=self.dtype,
|
237 |
+
)
|
238 |
+
controlnet_down_blocks.append(controlnet_block)
|
239 |
+
|
240 |
+
for i, down_block_type in enumerate(self.down_block_types):
|
241 |
+
input_channel = output_channel
|
242 |
+
output_channel = block_out_channels[i]
|
243 |
+
is_final_block = i == len(block_out_channels) - 1
|
244 |
+
|
245 |
+
if down_block_type == "CrossAttnDownBlock2D":
|
246 |
+
down_block = FlaxCrossAttnDownBlock2D(
|
247 |
+
in_channels=input_channel,
|
248 |
+
out_channels=output_channel,
|
249 |
+
dropout=self.dropout,
|
250 |
+
num_layers=self.layers_per_block,
|
251 |
+
num_attention_heads=num_attention_heads[i],
|
252 |
+
add_downsample=not is_final_block,
|
253 |
+
use_linear_projection=self.use_linear_projection,
|
254 |
+
only_cross_attention=only_cross_attention[i],
|
255 |
+
dtype=self.dtype,
|
256 |
+
)
|
257 |
+
else:
|
258 |
+
down_block = FlaxDownBlock2D(
|
259 |
+
in_channels=input_channel,
|
260 |
+
out_channels=output_channel,
|
261 |
+
dropout=self.dropout,
|
262 |
+
num_layers=self.layers_per_block,
|
263 |
+
add_downsample=not is_final_block,
|
264 |
+
dtype=self.dtype,
|
265 |
+
)
|
266 |
+
|
267 |
+
down_blocks.append(down_block)
|
268 |
+
|
269 |
+
for _ in range(self.layers_per_block):
|
270 |
+
controlnet_block = nn.Conv(
|
271 |
+
output_channel,
|
272 |
+
kernel_size=(1, 1),
|
273 |
+
padding="VALID",
|
274 |
+
kernel_init=nn.initializers.zeros_init(),
|
275 |
+
bias_init=nn.initializers.zeros_init(),
|
276 |
+
dtype=self.dtype,
|
277 |
+
)
|
278 |
+
controlnet_down_blocks.append(controlnet_block)
|
279 |
+
|
280 |
+
if not is_final_block:
|
281 |
+
controlnet_block = nn.Conv(
|
282 |
+
output_channel,
|
283 |
+
kernel_size=(1, 1),
|
284 |
+
padding="VALID",
|
285 |
+
kernel_init=nn.initializers.zeros_init(),
|
286 |
+
bias_init=nn.initializers.zeros_init(),
|
287 |
+
dtype=self.dtype,
|
288 |
+
)
|
289 |
+
controlnet_down_blocks.append(controlnet_block)
|
290 |
+
|
291 |
+
self.down_blocks = down_blocks
|
292 |
+
self.controlnet_down_blocks = controlnet_down_blocks
|
293 |
+
|
294 |
+
# mid
|
295 |
+
mid_block_channel = block_out_channels[-1]
|
296 |
+
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
|
297 |
+
in_channels=mid_block_channel,
|
298 |
+
dropout=self.dropout,
|
299 |
+
num_attention_heads=num_attention_heads[-1],
|
300 |
+
use_linear_projection=self.use_linear_projection,
|
301 |
+
dtype=self.dtype,
|
302 |
+
)
|
303 |
+
|
304 |
+
self.controlnet_mid_block = nn.Conv(
|
305 |
+
mid_block_channel,
|
306 |
+
kernel_size=(1, 1),
|
307 |
+
padding="VALID",
|
308 |
+
kernel_init=nn.initializers.zeros_init(),
|
309 |
+
bias_init=nn.initializers.zeros_init(),
|
310 |
+
dtype=self.dtype,
|
311 |
+
)
|
312 |
+
|
313 |
+
def __call__(
|
314 |
+
self,
|
315 |
+
sample,
|
316 |
+
timesteps,
|
317 |
+
encoder_hidden_states,
|
318 |
+
controlnet_cond,
|
319 |
+
conditioning_scale: float = 1.0,
|
320 |
+
return_dict: bool = True,
|
321 |
+
train: bool = False,
|
322 |
+
) -> Union[FlaxControlNetOutput, Tuple]:
|
323 |
+
r"""
|
324 |
+
Args:
|
325 |
+
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
|
326 |
+
timestep (`jnp.ndarray` or `float` or `int`): timesteps
|
327 |
+
encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
|
328 |
+
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
|
329 |
+
conditioning_scale: (`float`) the scale factor for controlnet outputs
|
330 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
331 |
+
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
332 |
+
plain tuple.
|
333 |
+
train (`bool`, *optional*, defaults to `False`):
|
334 |
+
Use deterministic functions and disable dropout when not training.
|
335 |
+
|
336 |
+
Returns:
|
337 |
+
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
338 |
+
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
|
339 |
+
When returning a tuple, the first element is the sample tensor.
|
340 |
+
"""
|
341 |
+
channel_order = self.controlnet_conditioning_channel_order
|
342 |
+
if channel_order == "bgr":
|
343 |
+
controlnet_cond = jnp.flip(controlnet_cond, axis=1)
|
344 |
+
|
345 |
+
# 1. time
|
346 |
+
if not isinstance(timesteps, jnp.ndarray):
|
347 |
+
timesteps = jnp.array([timesteps], dtype=jnp.int32)
|
348 |
+
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
|
349 |
+
timesteps = timesteps.astype(dtype=jnp.float32)
|
350 |
+
timesteps = jnp.expand_dims(timesteps, 0)
|
351 |
+
|
352 |
+
t_emb = self.time_proj(timesteps)
|
353 |
+
t_emb = self.time_embedding(t_emb)
|
354 |
+
|
355 |
+
# 2. pre-process
|
356 |
+
sample = jnp.transpose(sample, (0, 2, 3, 1))
|
357 |
+
sample = self.conv_in(sample)
|
358 |
+
|
359 |
+
controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1))
|
360 |
+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
361 |
+
sample += controlnet_cond
|
362 |
+
|
363 |
+
# 3. down
|
364 |
+
down_block_res_samples = (sample,)
|
365 |
+
for down_block in self.down_blocks:
|
366 |
+
if isinstance(down_block, FlaxCrossAttnDownBlock2D):
|
367 |
+
sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
368 |
+
else:
|
369 |
+
sample, res_samples = down_block(sample, t_emb, deterministic=not train)
|
370 |
+
down_block_res_samples += res_samples
|
371 |
+
|
372 |
+
# 4. mid
|
373 |
+
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
374 |
+
|
375 |
+
# 5. contronet blocks
|
376 |
+
controlnet_down_block_res_samples = ()
|
377 |
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
378 |
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
379 |
+
controlnet_down_block_res_samples += (down_block_res_sample,)
|
380 |
+
|
381 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
382 |
+
|
383 |
+
mid_block_res_sample = self.controlnet_mid_block(sample)
|
384 |
+
|
385 |
+
# 6. scaling
|
386 |
+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
387 |
+
mid_block_res_sample *= conditioning_scale
|
388 |
+
|
389 |
+
if not return_dict:
|
390 |
+
return (down_block_res_samples, mid_block_res_sample)
|
391 |
+
|
392 |
+
return FlaxControlNetOutput(
|
393 |
+
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
394 |
+
)
|
diffusers/models/dual_transformer_2d.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Optional
|
15 |
+
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
|
19 |
+
|
20 |
+
|
21 |
+
class DualTransformer2DModel(nn.Module):
|
22 |
+
"""
|
23 |
+
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
|
24 |
+
|
25 |
+
Parameters:
|
26 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
27 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
28 |
+
in_channels (`int`, *optional*):
|
29 |
+
Pass if the input is continuous. The number of channels in the input and output.
|
30 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
31 |
+
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
32 |
+
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
33 |
+
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
34 |
+
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
35 |
+
`ImagePositionalEmbeddings`.
|
36 |
+
num_vector_embeds (`int`, *optional*):
|
37 |
+
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
38 |
+
Includes the class for the masked latent pixel.
|
39 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
40 |
+
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
41 |
+
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
42 |
+
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
43 |
+
up to but not more than steps than `num_embeds_ada_norm`.
|
44 |
+
attention_bias (`bool`, *optional*):
|
45 |
+
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
num_attention_heads: int = 16,
|
51 |
+
attention_head_dim: int = 88,
|
52 |
+
in_channels: Optional[int] = None,
|
53 |
+
num_layers: int = 1,
|
54 |
+
dropout: float = 0.0,
|
55 |
+
norm_num_groups: int = 32,
|
56 |
+
cross_attention_dim: Optional[int] = None,
|
57 |
+
attention_bias: bool = False,
|
58 |
+
sample_size: Optional[int] = None,
|
59 |
+
num_vector_embeds: Optional[int] = None,
|
60 |
+
activation_fn: str = "geglu",
|
61 |
+
num_embeds_ada_norm: Optional[int] = None,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
self.transformers = nn.ModuleList(
|
65 |
+
[
|
66 |
+
Transformer2DModel(
|
67 |
+
num_attention_heads=num_attention_heads,
|
68 |
+
attention_head_dim=attention_head_dim,
|
69 |
+
in_channels=in_channels,
|
70 |
+
num_layers=num_layers,
|
71 |
+
dropout=dropout,
|
72 |
+
norm_num_groups=norm_num_groups,
|
73 |
+
cross_attention_dim=cross_attention_dim,
|
74 |
+
attention_bias=attention_bias,
|
75 |
+
sample_size=sample_size,
|
76 |
+
num_vector_embeds=num_vector_embeds,
|
77 |
+
activation_fn=activation_fn,
|
78 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
79 |
+
)
|
80 |
+
for _ in range(2)
|
81 |
+
]
|
82 |
+
)
|
83 |
+
|
84 |
+
# Variables that can be set by a pipeline:
|
85 |
+
|
86 |
+
# The ratio of transformer1 to transformer2's output states to be combined during inference
|
87 |
+
self.mix_ratio = 0.5
|
88 |
+
|
89 |
+
# The shape of `encoder_hidden_states` is expected to be
|
90 |
+
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
|
91 |
+
self.condition_lengths = [77, 257]
|
92 |
+
|
93 |
+
# Which transformer to use to encode which condition.
|
94 |
+
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
95 |
+
self.transformer_index_for_condition = [1, 0]
|
96 |
+
|
97 |
+
def forward(
|
98 |
+
self,
|
99 |
+
hidden_states,
|
100 |
+
encoder_hidden_states,
|
101 |
+
timestep=None,
|
102 |
+
attention_mask=None,
|
103 |
+
cross_attention_kwargs=None,
|
104 |
+
return_dict: bool = True,
|
105 |
+
):
|
106 |
+
"""
|
107 |
+
Args:
|
108 |
+
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
109 |
+
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
110 |
+
hidden_states.
|
111 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
112 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
113 |
+
self-attention.
|
114 |
+
timestep ( `torch.long`, *optional*):
|
115 |
+
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
116 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
117 |
+
Optional attention mask to be applied in Attention.
|
118 |
+
cross_attention_kwargs (`dict`, *optional*):
|
119 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
120 |
+
`self.processor` in
|
121 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
122 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
123 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
127 |
+
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
128 |
+
returning a tuple, the first element is the sample tensor.
|
129 |
+
"""
|
130 |
+
input_states = hidden_states
|
131 |
+
|
132 |
+
encoded_states = []
|
133 |
+
tokens_start = 0
|
134 |
+
# attention_mask is not used yet
|
135 |
+
for i in range(2):
|
136 |
+
# for each of the two transformers, pass the corresponding condition tokens
|
137 |
+
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
138 |
+
transformer_index = self.transformer_index_for_condition[i]
|
139 |
+
encoded_state = self.transformers[transformer_index](
|
140 |
+
input_states,
|
141 |
+
encoder_hidden_states=condition_state,
|
142 |
+
timestep=timestep,
|
143 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
144 |
+
return_dict=False,
|
145 |
+
)[0]
|
146 |
+
encoded_states.append(encoded_state - input_states)
|
147 |
+
tokens_start += self.condition_lengths[i]
|
148 |
+
|
149 |
+
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
|
150 |
+
output_states = output_states + input_states
|
151 |
+
|
152 |
+
if not return_dict:
|
153 |
+
return (output_states,)
|
154 |
+
|
155 |
+
return Transformer2DModelOutput(sample=output_states)
|
diffusers/models/embeddings.py
ADDED
@@ -0,0 +1,792 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
from typing import Optional
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from ..utils import USE_PEFT_BACKEND
|
22 |
+
from .activations import get_activation
|
23 |
+
from .lora import LoRACompatibleLinear
|
24 |
+
|
25 |
+
|
26 |
+
def get_timestep_embedding(
|
27 |
+
timesteps: torch.Tensor,
|
28 |
+
embedding_dim: int,
|
29 |
+
flip_sin_to_cos: bool = False,
|
30 |
+
downscale_freq_shift: float = 1,
|
31 |
+
scale: float = 1,
|
32 |
+
max_period: int = 10000,
|
33 |
+
):
|
34 |
+
"""
|
35 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
36 |
+
|
37 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
38 |
+
These may be fractional.
|
39 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
40 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
41 |
+
"""
|
42 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
43 |
+
|
44 |
+
half_dim = embedding_dim // 2
|
45 |
+
exponent = -math.log(max_period) * torch.arange(
|
46 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
47 |
+
)
|
48 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
49 |
+
|
50 |
+
emb = torch.exp(exponent)
|
51 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
52 |
+
|
53 |
+
# scale embeddings
|
54 |
+
emb = scale * emb
|
55 |
+
|
56 |
+
# concat sine and cosine embeddings
|
57 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
58 |
+
|
59 |
+
# flip sine and cosine embeddings
|
60 |
+
if flip_sin_to_cos:
|
61 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
62 |
+
|
63 |
+
# zero pad
|
64 |
+
if embedding_dim % 2 == 1:
|
65 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
66 |
+
return emb
|
67 |
+
|
68 |
+
|
69 |
+
def get_2d_sincos_pos_embed(
|
70 |
+
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
|
71 |
+
):
|
72 |
+
"""
|
73 |
+
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
74 |
+
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
75 |
+
"""
|
76 |
+
if isinstance(grid_size, int):
|
77 |
+
grid_size = (grid_size, grid_size)
|
78 |
+
|
79 |
+
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
|
80 |
+
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
|
81 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
82 |
+
grid = np.stack(grid, axis=0)
|
83 |
+
|
84 |
+
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
85 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
86 |
+
if cls_token and extra_tokens > 0:
|
87 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
88 |
+
return pos_embed
|
89 |
+
|
90 |
+
|
91 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
92 |
+
if embed_dim % 2 != 0:
|
93 |
+
raise ValueError("embed_dim must be divisible by 2")
|
94 |
+
|
95 |
+
# use half of dimensions to encode grid_h
|
96 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
97 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
98 |
+
|
99 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
100 |
+
return emb
|
101 |
+
|
102 |
+
|
103 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
104 |
+
"""
|
105 |
+
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
106 |
+
"""
|
107 |
+
if embed_dim % 2 != 0:
|
108 |
+
raise ValueError("embed_dim must be divisible by 2")
|
109 |
+
|
110 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
111 |
+
omega /= embed_dim / 2.0
|
112 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
113 |
+
|
114 |
+
pos = pos.reshape(-1) # (M,)
|
115 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
116 |
+
|
117 |
+
emb_sin = np.sin(out) # (M, D/2)
|
118 |
+
emb_cos = np.cos(out) # (M, D/2)
|
119 |
+
|
120 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
121 |
+
return emb
|
122 |
+
|
123 |
+
|
124 |
+
class PatchEmbed(nn.Module):
|
125 |
+
"""2D Image to Patch Embedding"""
|
126 |
+
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
height=224,
|
130 |
+
width=224,
|
131 |
+
patch_size=16,
|
132 |
+
in_channels=3,
|
133 |
+
embed_dim=768,
|
134 |
+
layer_norm=False,
|
135 |
+
flatten=True,
|
136 |
+
bias=True,
|
137 |
+
interpolation_scale=1,
|
138 |
+
):
|
139 |
+
super().__init__()
|
140 |
+
|
141 |
+
num_patches = (height // patch_size) * (width // patch_size)
|
142 |
+
self.flatten = flatten
|
143 |
+
self.layer_norm = layer_norm
|
144 |
+
|
145 |
+
self.proj = nn.Conv2d(
|
146 |
+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
147 |
+
)
|
148 |
+
if layer_norm:
|
149 |
+
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
|
150 |
+
else:
|
151 |
+
self.norm = None
|
152 |
+
|
153 |
+
self.patch_size = patch_size
|
154 |
+
# See:
|
155 |
+
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
|
156 |
+
self.height, self.width = height // patch_size, width // patch_size
|
157 |
+
self.base_size = height // patch_size
|
158 |
+
self.interpolation_scale = interpolation_scale
|
159 |
+
pos_embed = get_2d_sincos_pos_embed(
|
160 |
+
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
|
161 |
+
)
|
162 |
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
|
163 |
+
|
164 |
+
def forward(self, latent):
|
165 |
+
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
166 |
+
|
167 |
+
latent = self.proj(latent)
|
168 |
+
if self.flatten:
|
169 |
+
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
170 |
+
if self.layer_norm:
|
171 |
+
latent = self.norm(latent)
|
172 |
+
|
173 |
+
# Interpolate positional embeddings if needed.
|
174 |
+
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
|
175 |
+
if self.height != height or self.width != width:
|
176 |
+
pos_embed = get_2d_sincos_pos_embed(
|
177 |
+
embed_dim=self.pos_embed.shape[-1],
|
178 |
+
grid_size=(height, width),
|
179 |
+
base_size=self.base_size,
|
180 |
+
interpolation_scale=self.interpolation_scale,
|
181 |
+
)
|
182 |
+
pos_embed = torch.from_numpy(pos_embed)
|
183 |
+
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
|
184 |
+
else:
|
185 |
+
pos_embed = self.pos_embed
|
186 |
+
|
187 |
+
return (latent + pos_embed).to(latent.dtype)
|
188 |
+
|
189 |
+
|
190 |
+
class TimestepEmbedding(nn.Module):
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
in_channels: int,
|
194 |
+
time_embed_dim: int,
|
195 |
+
act_fn: str = "silu",
|
196 |
+
out_dim: int = None,
|
197 |
+
post_act_fn: Optional[str] = None,
|
198 |
+
cond_proj_dim=None,
|
199 |
+
):
|
200 |
+
super().__init__()
|
201 |
+
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
202 |
+
|
203 |
+
self.linear_1 = linear_cls(in_channels, time_embed_dim)
|
204 |
+
|
205 |
+
if cond_proj_dim is not None:
|
206 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
207 |
+
else:
|
208 |
+
self.cond_proj = None
|
209 |
+
|
210 |
+
self.act = get_activation(act_fn)
|
211 |
+
|
212 |
+
if out_dim is not None:
|
213 |
+
time_embed_dim_out = out_dim
|
214 |
+
else:
|
215 |
+
time_embed_dim_out = time_embed_dim
|
216 |
+
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out)
|
217 |
+
|
218 |
+
if post_act_fn is None:
|
219 |
+
self.post_act = None
|
220 |
+
else:
|
221 |
+
self.post_act = get_activation(post_act_fn)
|
222 |
+
|
223 |
+
def forward(self, sample, condition=None):
|
224 |
+
if condition is not None:
|
225 |
+
sample = sample + self.cond_proj(condition)
|
226 |
+
sample = self.linear_1(sample)
|
227 |
+
|
228 |
+
if self.act is not None:
|
229 |
+
sample = self.act(sample)
|
230 |
+
|
231 |
+
sample = self.linear_2(sample)
|
232 |
+
|
233 |
+
if self.post_act is not None:
|
234 |
+
sample = self.post_act(sample)
|
235 |
+
return sample
|
236 |
+
|
237 |
+
|
238 |
+
class Timesteps(nn.Module):
|
239 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
240 |
+
super().__init__()
|
241 |
+
self.num_channels = num_channels
|
242 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
243 |
+
self.downscale_freq_shift = downscale_freq_shift
|
244 |
+
|
245 |
+
def forward(self, timesteps):
|
246 |
+
t_emb = get_timestep_embedding(
|
247 |
+
timesteps,
|
248 |
+
self.num_channels,
|
249 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
250 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
251 |
+
)
|
252 |
+
return t_emb
|
253 |
+
|
254 |
+
|
255 |
+
class GaussianFourierProjection(nn.Module):
|
256 |
+
"""Gaussian Fourier embeddings for noise levels."""
|
257 |
+
|
258 |
+
def __init__(
|
259 |
+
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
|
260 |
+
):
|
261 |
+
super().__init__()
|
262 |
+
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
263 |
+
self.log = log
|
264 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
265 |
+
|
266 |
+
if set_W_to_weight:
|
267 |
+
# to delete later
|
268 |
+
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
269 |
+
|
270 |
+
self.weight = self.W
|
271 |
+
|
272 |
+
def forward(self, x):
|
273 |
+
if self.log:
|
274 |
+
x = torch.log(x)
|
275 |
+
|
276 |
+
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
|
277 |
+
|
278 |
+
if self.flip_sin_to_cos:
|
279 |
+
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
|
280 |
+
else:
|
281 |
+
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
282 |
+
return out
|
283 |
+
|
284 |
+
|
285 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
286 |
+
"""Apply positional information to a sequence of embeddings.
|
287 |
+
|
288 |
+
Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
|
289 |
+
them
|
290 |
+
|
291 |
+
Args:
|
292 |
+
embed_dim: (int): Dimension of the positional embedding.
|
293 |
+
max_seq_length: Maximum sequence length to apply positional embeddings
|
294 |
+
|
295 |
+
"""
|
296 |
+
|
297 |
+
def __init__(self, embed_dim: int, max_seq_length: int = 32):
|
298 |
+
super().__init__()
|
299 |
+
position = torch.arange(max_seq_length).unsqueeze(1)
|
300 |
+
div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
|
301 |
+
pe = torch.zeros(1, max_seq_length, embed_dim)
|
302 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
303 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
304 |
+
self.register_buffer("pe", pe)
|
305 |
+
|
306 |
+
def forward(self, x):
|
307 |
+
_, seq_length, _ = x.shape
|
308 |
+
x = x + self.pe[:, :seq_length]
|
309 |
+
return x
|
310 |
+
|
311 |
+
|
312 |
+
class ImagePositionalEmbeddings(nn.Module):
|
313 |
+
"""
|
314 |
+
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
|
315 |
+
height and width of the latent space.
|
316 |
+
|
317 |
+
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
|
318 |
+
|
319 |
+
For VQ-diffusion:
|
320 |
+
|
321 |
+
Output vector embeddings are used as input for the transformer.
|
322 |
+
|
323 |
+
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
|
324 |
+
|
325 |
+
Args:
|
326 |
+
num_embed (`int`):
|
327 |
+
Number of embeddings for the latent pixels embeddings.
|
328 |
+
height (`int`):
|
329 |
+
Height of the latent image i.e. the number of height embeddings.
|
330 |
+
width (`int`):
|
331 |
+
Width of the latent image i.e. the number of width embeddings.
|
332 |
+
embed_dim (`int`):
|
333 |
+
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
|
334 |
+
"""
|
335 |
+
|
336 |
+
def __init__(
|
337 |
+
self,
|
338 |
+
num_embed: int,
|
339 |
+
height: int,
|
340 |
+
width: int,
|
341 |
+
embed_dim: int,
|
342 |
+
):
|
343 |
+
super().__init__()
|
344 |
+
|
345 |
+
self.height = height
|
346 |
+
self.width = width
|
347 |
+
self.num_embed = num_embed
|
348 |
+
self.embed_dim = embed_dim
|
349 |
+
|
350 |
+
self.emb = nn.Embedding(self.num_embed, embed_dim)
|
351 |
+
self.height_emb = nn.Embedding(self.height, embed_dim)
|
352 |
+
self.width_emb = nn.Embedding(self.width, embed_dim)
|
353 |
+
|
354 |
+
def forward(self, index):
|
355 |
+
emb = self.emb(index)
|
356 |
+
|
357 |
+
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
|
358 |
+
|
359 |
+
# 1 x H x D -> 1 x H x 1 x D
|
360 |
+
height_emb = height_emb.unsqueeze(2)
|
361 |
+
|
362 |
+
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
|
363 |
+
|
364 |
+
# 1 x W x D -> 1 x 1 x W x D
|
365 |
+
width_emb = width_emb.unsqueeze(1)
|
366 |
+
|
367 |
+
pos_emb = height_emb + width_emb
|
368 |
+
|
369 |
+
# 1 x H x W x D -> 1 x L xD
|
370 |
+
pos_emb = pos_emb.view(1, self.height * self.width, -1)
|
371 |
+
|
372 |
+
emb = emb + pos_emb[:, : emb.shape[1], :]
|
373 |
+
|
374 |
+
return emb
|
375 |
+
|
376 |
+
|
377 |
+
class LabelEmbedding(nn.Module):
|
378 |
+
"""
|
379 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
380 |
+
|
381 |
+
Args:
|
382 |
+
num_classes (`int`): The number of classes.
|
383 |
+
hidden_size (`int`): The size of the vector embeddings.
|
384 |
+
dropout_prob (`float`): The probability of dropping a label.
|
385 |
+
"""
|
386 |
+
|
387 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
388 |
+
super().__init__()
|
389 |
+
use_cfg_embedding = dropout_prob > 0
|
390 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
391 |
+
self.num_classes = num_classes
|
392 |
+
self.dropout_prob = dropout_prob
|
393 |
+
|
394 |
+
def token_drop(self, labels, force_drop_ids=None):
|
395 |
+
"""
|
396 |
+
Drops labels to enable classifier-free guidance.
|
397 |
+
"""
|
398 |
+
if force_drop_ids is None:
|
399 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
400 |
+
else:
|
401 |
+
drop_ids = torch.tensor(force_drop_ids == 1)
|
402 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
403 |
+
return labels
|
404 |
+
|
405 |
+
def forward(self, labels: torch.LongTensor, force_drop_ids=None):
|
406 |
+
use_dropout = self.dropout_prob > 0
|
407 |
+
if (self.training and use_dropout) or (force_drop_ids is not None):
|
408 |
+
labels = self.token_drop(labels, force_drop_ids)
|
409 |
+
embeddings = self.embedding_table(labels)
|
410 |
+
return embeddings
|
411 |
+
|
412 |
+
|
413 |
+
class TextImageProjection(nn.Module):
|
414 |
+
def __init__(
|
415 |
+
self,
|
416 |
+
text_embed_dim: int = 1024,
|
417 |
+
image_embed_dim: int = 768,
|
418 |
+
cross_attention_dim: int = 768,
|
419 |
+
num_image_text_embeds: int = 10,
|
420 |
+
):
|
421 |
+
super().__init__()
|
422 |
+
|
423 |
+
self.num_image_text_embeds = num_image_text_embeds
|
424 |
+
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
425 |
+
self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
|
426 |
+
|
427 |
+
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
|
428 |
+
batch_size = text_embeds.shape[0]
|
429 |
+
|
430 |
+
# image
|
431 |
+
image_text_embeds = self.image_embeds(image_embeds)
|
432 |
+
image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
433 |
+
|
434 |
+
# text
|
435 |
+
text_embeds = self.text_proj(text_embeds)
|
436 |
+
|
437 |
+
return torch.cat([image_text_embeds, text_embeds], dim=1)
|
438 |
+
|
439 |
+
|
440 |
+
class ImageProjection(nn.Module):
|
441 |
+
def __init__(
|
442 |
+
self,
|
443 |
+
image_embed_dim: int = 768,
|
444 |
+
cross_attention_dim: int = 768,
|
445 |
+
num_image_text_embeds: int = 32,
|
446 |
+
):
|
447 |
+
super().__init__()
|
448 |
+
|
449 |
+
self.num_image_text_embeds = num_image_text_embeds
|
450 |
+
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
451 |
+
self.norm = nn.LayerNorm(cross_attention_dim)
|
452 |
+
|
453 |
+
def forward(self, image_embeds: torch.FloatTensor):
|
454 |
+
batch_size = image_embeds.shape[0]
|
455 |
+
|
456 |
+
# image
|
457 |
+
image_embeds = self.image_embeds(image_embeds)
|
458 |
+
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
459 |
+
image_embeds = self.norm(image_embeds)
|
460 |
+
return image_embeds
|
461 |
+
|
462 |
+
|
463 |
+
class CombinedTimestepLabelEmbeddings(nn.Module):
|
464 |
+
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
|
465 |
+
super().__init__()
|
466 |
+
|
467 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
468 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
469 |
+
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
|
470 |
+
|
471 |
+
def forward(self, timestep, class_labels, hidden_dtype=None):
|
472 |
+
timesteps_proj = self.time_proj(timestep)
|
473 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
474 |
+
|
475 |
+
class_labels = self.class_embedder(class_labels) # (N, D)
|
476 |
+
|
477 |
+
conditioning = timesteps_emb + class_labels # (N, D)
|
478 |
+
|
479 |
+
return conditioning
|
480 |
+
|
481 |
+
|
482 |
+
class TextTimeEmbedding(nn.Module):
|
483 |
+
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
484 |
+
super().__init__()
|
485 |
+
self.norm1 = nn.LayerNorm(encoder_dim)
|
486 |
+
self.pool = AttentionPooling(num_heads, encoder_dim)
|
487 |
+
self.proj = nn.Linear(encoder_dim, time_embed_dim)
|
488 |
+
self.norm2 = nn.LayerNorm(time_embed_dim)
|
489 |
+
|
490 |
+
def forward(self, hidden_states):
|
491 |
+
hidden_states = self.norm1(hidden_states)
|
492 |
+
hidden_states = self.pool(hidden_states)
|
493 |
+
hidden_states = self.proj(hidden_states)
|
494 |
+
hidden_states = self.norm2(hidden_states)
|
495 |
+
return hidden_states
|
496 |
+
|
497 |
+
|
498 |
+
class TextImageTimeEmbedding(nn.Module):
|
499 |
+
def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
500 |
+
super().__init__()
|
501 |
+
self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
|
502 |
+
self.text_norm = nn.LayerNorm(time_embed_dim)
|
503 |
+
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
504 |
+
|
505 |
+
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
|
506 |
+
# text
|
507 |
+
time_text_embeds = self.text_proj(text_embeds)
|
508 |
+
time_text_embeds = self.text_norm(time_text_embeds)
|
509 |
+
|
510 |
+
# image
|
511 |
+
time_image_embeds = self.image_proj(image_embeds)
|
512 |
+
|
513 |
+
return time_image_embeds + time_text_embeds
|
514 |
+
|
515 |
+
|
516 |
+
class ImageTimeEmbedding(nn.Module):
|
517 |
+
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
518 |
+
super().__init__()
|
519 |
+
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
520 |
+
self.image_norm = nn.LayerNorm(time_embed_dim)
|
521 |
+
|
522 |
+
def forward(self, image_embeds: torch.FloatTensor):
|
523 |
+
# image
|
524 |
+
time_image_embeds = self.image_proj(image_embeds)
|
525 |
+
time_image_embeds = self.image_norm(time_image_embeds)
|
526 |
+
return time_image_embeds
|
527 |
+
|
528 |
+
|
529 |
+
class ImageHintTimeEmbedding(nn.Module):
|
530 |
+
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
531 |
+
super().__init__()
|
532 |
+
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
533 |
+
self.image_norm = nn.LayerNorm(time_embed_dim)
|
534 |
+
self.input_hint_block = nn.Sequential(
|
535 |
+
nn.Conv2d(3, 16, 3, padding=1),
|
536 |
+
nn.SiLU(),
|
537 |
+
nn.Conv2d(16, 16, 3, padding=1),
|
538 |
+
nn.SiLU(),
|
539 |
+
nn.Conv2d(16, 32, 3, padding=1, stride=2),
|
540 |
+
nn.SiLU(),
|
541 |
+
nn.Conv2d(32, 32, 3, padding=1),
|
542 |
+
nn.SiLU(),
|
543 |
+
nn.Conv2d(32, 96, 3, padding=1, stride=2),
|
544 |
+
nn.SiLU(),
|
545 |
+
nn.Conv2d(96, 96, 3, padding=1),
|
546 |
+
nn.SiLU(),
|
547 |
+
nn.Conv2d(96, 256, 3, padding=1, stride=2),
|
548 |
+
nn.SiLU(),
|
549 |
+
nn.Conv2d(256, 4, 3, padding=1),
|
550 |
+
)
|
551 |
+
|
552 |
+
def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
|
553 |
+
# image
|
554 |
+
time_image_embeds = self.image_proj(image_embeds)
|
555 |
+
time_image_embeds = self.image_norm(time_image_embeds)
|
556 |
+
hint = self.input_hint_block(hint)
|
557 |
+
return time_image_embeds, hint
|
558 |
+
|
559 |
+
|
560 |
+
class AttentionPooling(nn.Module):
|
561 |
+
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
|
562 |
+
|
563 |
+
def __init__(self, num_heads, embed_dim, dtype=None):
|
564 |
+
super().__init__()
|
565 |
+
self.dtype = dtype
|
566 |
+
self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
|
567 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
568 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
569 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
570 |
+
self.num_heads = num_heads
|
571 |
+
self.dim_per_head = embed_dim // self.num_heads
|
572 |
+
|
573 |
+
def forward(self, x):
|
574 |
+
bs, length, width = x.size()
|
575 |
+
|
576 |
+
def shape(x):
|
577 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
578 |
+
x = x.view(bs, -1, self.num_heads, self.dim_per_head)
|
579 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
580 |
+
x = x.transpose(1, 2)
|
581 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
582 |
+
x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
|
583 |
+
# (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
|
584 |
+
x = x.transpose(1, 2)
|
585 |
+
return x
|
586 |
+
|
587 |
+
class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
|
588 |
+
x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
|
589 |
+
|
590 |
+
# (bs*n_heads, class_token_length, dim_per_head)
|
591 |
+
q = shape(self.q_proj(class_token))
|
592 |
+
# (bs*n_heads, length+class_token_length, dim_per_head)
|
593 |
+
k = shape(self.k_proj(x))
|
594 |
+
v = shape(self.v_proj(x))
|
595 |
+
|
596 |
+
# (bs*n_heads, class_token_length, length+class_token_length):
|
597 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
|
598 |
+
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
599 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
600 |
+
|
601 |
+
# (bs*n_heads, dim_per_head, class_token_length)
|
602 |
+
a = torch.einsum("bts,bcs->bct", weight, v)
|
603 |
+
|
604 |
+
# (bs, length+1, width)
|
605 |
+
a = a.reshape(bs, -1, 1).transpose(1, 2)
|
606 |
+
|
607 |
+
return a[:, 0, :] # cls_token
|
608 |
+
|
609 |
+
|
610 |
+
class FourierEmbedder(nn.Module):
|
611 |
+
def __init__(self, num_freqs=64, temperature=100):
|
612 |
+
super().__init__()
|
613 |
+
|
614 |
+
self.num_freqs = num_freqs
|
615 |
+
self.temperature = temperature
|
616 |
+
|
617 |
+
freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
|
618 |
+
freq_bands = freq_bands[None, None, None]
|
619 |
+
self.register_buffer("freq_bands", freq_bands, persistent=False)
|
620 |
+
|
621 |
+
def __call__(self, x):
|
622 |
+
x = self.freq_bands * x.unsqueeze(-1)
|
623 |
+
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
|
624 |
+
|
625 |
+
|
626 |
+
class PositionNet(nn.Module):
|
627 |
+
def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
|
628 |
+
super().__init__()
|
629 |
+
self.positive_len = positive_len
|
630 |
+
self.out_dim = out_dim
|
631 |
+
|
632 |
+
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
|
633 |
+
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
|
634 |
+
|
635 |
+
if isinstance(out_dim, tuple):
|
636 |
+
out_dim = out_dim[0]
|
637 |
+
|
638 |
+
if feature_type == "text-only":
|
639 |
+
self.linears = nn.Sequential(
|
640 |
+
nn.Linear(self.positive_len + self.position_dim, 512),
|
641 |
+
nn.SiLU(),
|
642 |
+
nn.Linear(512, 512),
|
643 |
+
nn.SiLU(),
|
644 |
+
nn.Linear(512, out_dim),
|
645 |
+
)
|
646 |
+
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
|
647 |
+
|
648 |
+
elif feature_type == "text-image":
|
649 |
+
self.linears_text = nn.Sequential(
|
650 |
+
nn.Linear(self.positive_len + self.position_dim, 512),
|
651 |
+
nn.SiLU(),
|
652 |
+
nn.Linear(512, 512),
|
653 |
+
nn.SiLU(),
|
654 |
+
nn.Linear(512, out_dim),
|
655 |
+
)
|
656 |
+
self.linears_image = nn.Sequential(
|
657 |
+
nn.Linear(self.positive_len + self.position_dim, 512),
|
658 |
+
nn.SiLU(),
|
659 |
+
nn.Linear(512, 512),
|
660 |
+
nn.SiLU(),
|
661 |
+
nn.Linear(512, out_dim),
|
662 |
+
)
|
663 |
+
self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
|
664 |
+
self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
|
665 |
+
|
666 |
+
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
|
667 |
+
|
668 |
+
def forward(
|
669 |
+
self,
|
670 |
+
boxes,
|
671 |
+
masks,
|
672 |
+
positive_embeddings=None,
|
673 |
+
phrases_masks=None,
|
674 |
+
image_masks=None,
|
675 |
+
phrases_embeddings=None,
|
676 |
+
image_embeddings=None,
|
677 |
+
):
|
678 |
+
masks = masks.unsqueeze(-1)
|
679 |
+
|
680 |
+
# embedding position (it may includes padding as placeholder)
|
681 |
+
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
|
682 |
+
|
683 |
+
# learnable null embedding
|
684 |
+
xyxy_null = self.null_position_feature.view(1, 1, -1)
|
685 |
+
|
686 |
+
# replace padding with learnable null embedding
|
687 |
+
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
|
688 |
+
|
689 |
+
# positionet with text only information
|
690 |
+
if positive_embeddings is not None:
|
691 |
+
# learnable null embedding
|
692 |
+
positive_null = self.null_positive_feature.view(1, 1, -1)
|
693 |
+
|
694 |
+
# replace padding with learnable null embedding
|
695 |
+
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
|
696 |
+
|
697 |
+
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
|
698 |
+
|
699 |
+
# positionet with text and image infomation
|
700 |
+
else:
|
701 |
+
phrases_masks = phrases_masks.unsqueeze(-1)
|
702 |
+
image_masks = image_masks.unsqueeze(-1)
|
703 |
+
|
704 |
+
# learnable null embedding
|
705 |
+
text_null = self.null_text_feature.view(1, 1, -1)
|
706 |
+
image_null = self.null_image_feature.view(1, 1, -1)
|
707 |
+
|
708 |
+
# replace padding with learnable null embedding
|
709 |
+
phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null
|
710 |
+
image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null
|
711 |
+
|
712 |
+
objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1))
|
713 |
+
objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1))
|
714 |
+
objs = torch.cat([objs_text, objs_image], dim=1)
|
715 |
+
|
716 |
+
return objs
|
717 |
+
|
718 |
+
|
719 |
+
class CombinedTimestepSizeEmbeddings(nn.Module):
|
720 |
+
"""
|
721 |
+
For PixArt-Alpha.
|
722 |
+
|
723 |
+
Reference:
|
724 |
+
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
725 |
+
"""
|
726 |
+
|
727 |
+
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
|
728 |
+
super().__init__()
|
729 |
+
|
730 |
+
self.outdim = size_emb_dim
|
731 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
732 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
733 |
+
|
734 |
+
self.use_additional_conditions = use_additional_conditions
|
735 |
+
if use_additional_conditions:
|
736 |
+
self.use_additional_conditions = True
|
737 |
+
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
738 |
+
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
739 |
+
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
740 |
+
|
741 |
+
def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
|
742 |
+
if size.ndim == 1:
|
743 |
+
size = size[:, None]
|
744 |
+
|
745 |
+
if size.shape[0] != batch_size:
|
746 |
+
size = size.repeat(batch_size // size.shape[0], 1)
|
747 |
+
if size.shape[0] != batch_size:
|
748 |
+
raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")
|
749 |
+
|
750 |
+
current_batch_size, dims = size.shape[0], size.shape[1]
|
751 |
+
size = size.reshape(-1)
|
752 |
+
size_freq = self.additional_condition_proj(size).to(size.dtype)
|
753 |
+
|
754 |
+
size_emb = embedder(size_freq)
|
755 |
+
size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
|
756 |
+
return size_emb
|
757 |
+
|
758 |
+
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
759 |
+
timesteps_proj = self.time_proj(timestep)
|
760 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
761 |
+
|
762 |
+
if self.use_additional_conditions:
|
763 |
+
resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder)
|
764 |
+
aspect_ratio = self.apply_condition(
|
765 |
+
aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder
|
766 |
+
)
|
767 |
+
conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1)
|
768 |
+
else:
|
769 |
+
conditioning = timesteps_emb
|
770 |
+
|
771 |
+
return conditioning
|
772 |
+
|
773 |
+
|
774 |
+
class CaptionProjection(nn.Module):
|
775 |
+
"""
|
776 |
+
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
777 |
+
|
778 |
+
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
779 |
+
"""
|
780 |
+
|
781 |
+
def __init__(self, in_features, hidden_size, num_tokens=120):
|
782 |
+
super().__init__()
|
783 |
+
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
|
784 |
+
self.act_1 = nn.GELU(approximate="tanh")
|
785 |
+
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
|
786 |
+
self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5))
|
787 |
+
|
788 |
+
def forward(self, caption, force_drop_ids=None):
|
789 |
+
hidden_states = self.linear_1(caption)
|
790 |
+
hidden_states = self.act_1(hidden_states)
|
791 |
+
hidden_states = self.linear_2(hidden_states)
|
792 |
+
return hidden_states
|
diffusers/models/embeddings_flax.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
|
16 |
+
import flax.linen as nn
|
17 |
+
import jax.numpy as jnp
|
18 |
+
|
19 |
+
|
20 |
+
def get_sinusoidal_embeddings(
|
21 |
+
timesteps: jnp.ndarray,
|
22 |
+
embedding_dim: int,
|
23 |
+
freq_shift: float = 1,
|
24 |
+
min_timescale: float = 1,
|
25 |
+
max_timescale: float = 1.0e4,
|
26 |
+
flip_sin_to_cos: bool = False,
|
27 |
+
scale: float = 1.0,
|
28 |
+
) -> jnp.ndarray:
|
29 |
+
"""Returns the positional encoding (same as Tensor2Tensor).
|
30 |
+
|
31 |
+
Args:
|
32 |
+
timesteps: a 1-D Tensor of N indices, one per batch element.
|
33 |
+
These may be fractional.
|
34 |
+
embedding_dim: The number of output channels.
|
35 |
+
min_timescale: The smallest time unit (should probably be 0.0).
|
36 |
+
max_timescale: The largest time unit.
|
37 |
+
Returns:
|
38 |
+
a Tensor of timing signals [N, num_channels]
|
39 |
+
"""
|
40 |
+
assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
|
41 |
+
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
|
42 |
+
num_timescales = float(embedding_dim // 2)
|
43 |
+
log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
|
44 |
+
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
|
45 |
+
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
|
46 |
+
|
47 |
+
# scale embeddings
|
48 |
+
scaled_time = scale * emb
|
49 |
+
|
50 |
+
if flip_sin_to_cos:
|
51 |
+
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
|
52 |
+
else:
|
53 |
+
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
|
54 |
+
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
|
55 |
+
return signal
|
56 |
+
|
57 |
+
|
58 |
+
class FlaxTimestepEmbedding(nn.Module):
|
59 |
+
r"""
|
60 |
+
Time step Embedding Module. Learns embeddings for input time steps.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
time_embed_dim (`int`, *optional*, defaults to `32`):
|
64 |
+
Time step embedding dimension
|
65 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
66 |
+
Parameters `dtype`
|
67 |
+
"""
|
68 |
+
time_embed_dim: int = 32
|
69 |
+
dtype: jnp.dtype = jnp.float32
|
70 |
+
|
71 |
+
@nn.compact
|
72 |
+
def __call__(self, temb):
|
73 |
+
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
|
74 |
+
temb = nn.silu(temb)
|
75 |
+
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
|
76 |
+
return temb
|
77 |
+
|
78 |
+
|
79 |
+
class FlaxTimesteps(nn.Module):
|
80 |
+
r"""
|
81 |
+
Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
|
82 |
+
|
83 |
+
Args:
|
84 |
+
dim (`int`, *optional*, defaults to `32`):
|
85 |
+
Time step embedding dimension
|
86 |
+
"""
|
87 |
+
dim: int = 32
|
88 |
+
flip_sin_to_cos: bool = False
|
89 |
+
freq_shift: float = 1
|
90 |
+
|
91 |
+
@nn.compact
|
92 |
+
def __call__(self, timesteps):
|
93 |
+
return get_sinusoidal_embeddings(
|
94 |
+
timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
|
95 |
+
)
|
diffusers/models/lora.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
|
22 |
+
from ..utils import logging
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
26 |
+
|
27 |
+
|
28 |
+
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
|
29 |
+
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
30 |
+
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
31 |
+
attn_module.q_proj.lora_scale = lora_scale
|
32 |
+
attn_module.k_proj.lora_scale = lora_scale
|
33 |
+
attn_module.v_proj.lora_scale = lora_scale
|
34 |
+
attn_module.out_proj.lora_scale = lora_scale
|
35 |
+
|
36 |
+
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
37 |
+
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
38 |
+
mlp_module.fc1.lora_scale = lora_scale
|
39 |
+
mlp_module.fc2.lora_scale = lora_scale
|
40 |
+
|
41 |
+
|
42 |
+
class LoRALinearLayer(nn.Module):
|
43 |
+
r"""
|
44 |
+
A linear layer that is used with LoRA.
|
45 |
+
|
46 |
+
Parameters:
|
47 |
+
in_features (`int`):
|
48 |
+
Number of input features.
|
49 |
+
out_features (`int`):
|
50 |
+
Number of output features.
|
51 |
+
rank (`int`, `optional`, defaults to 4):
|
52 |
+
The rank of the LoRA layer.
|
53 |
+
network_alpha (`float`, `optional`, defaults to `None`):
|
54 |
+
The value of the network alpha used for stable learning and preventing underflow. This value has the same
|
55 |
+
meaning as the `--network_alpha` option in the kohya-ss trainer script. See
|
56 |
+
https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
57 |
+
device (`torch.device`, `optional`, defaults to `None`):
|
58 |
+
The device to use for the layer's weights.
|
59 |
+
dtype (`torch.dtype`, `optional`, defaults to `None`):
|
60 |
+
The dtype to use for the layer's weights.
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
in_features: int,
|
66 |
+
out_features: int,
|
67 |
+
rank: int = 4,
|
68 |
+
network_alpha: Optional[float] = None,
|
69 |
+
device: Optional[Union[torch.device, str]] = None,
|
70 |
+
dtype: Optional[torch.dtype] = None,
|
71 |
+
):
|
72 |
+
super().__init__()
|
73 |
+
|
74 |
+
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
75 |
+
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
|
76 |
+
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
77 |
+
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
78 |
+
self.network_alpha = network_alpha
|
79 |
+
self.rank = rank
|
80 |
+
self.out_features = out_features
|
81 |
+
self.in_features = in_features
|
82 |
+
|
83 |
+
nn.init.normal_(self.down.weight, std=1 / rank)
|
84 |
+
nn.init.zeros_(self.up.weight)
|
85 |
+
|
86 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
87 |
+
orig_dtype = hidden_states.dtype
|
88 |
+
dtype = self.down.weight.dtype
|
89 |
+
|
90 |
+
down_hidden_states = self.down(hidden_states.to(dtype))
|
91 |
+
up_hidden_states = self.up(down_hidden_states)
|
92 |
+
|
93 |
+
if self.network_alpha is not None:
|
94 |
+
up_hidden_states *= self.network_alpha / self.rank
|
95 |
+
|
96 |
+
return up_hidden_states.to(orig_dtype)
|
97 |
+
|
98 |
+
|
99 |
+
class LoRAConv2dLayer(nn.Module):
|
100 |
+
r"""
|
101 |
+
A convolutional layer that is used with LoRA.
|
102 |
+
|
103 |
+
Parameters:
|
104 |
+
in_features (`int`):
|
105 |
+
Number of input features.
|
106 |
+
out_features (`int`):
|
107 |
+
Number of output features.
|
108 |
+
rank (`int`, `optional`, defaults to 4):
|
109 |
+
The rank of the LoRA layer.
|
110 |
+
kernel_size (`int` or `tuple` of two `int`, `optional`, defaults to 1):
|
111 |
+
The kernel size of the convolution.
|
112 |
+
stride (`int` or `tuple` of two `int`, `optional`, defaults to 1):
|
113 |
+
The stride of the convolution.
|
114 |
+
padding (`int` or `tuple` of two `int` or `str`, `optional`, defaults to 0):
|
115 |
+
The padding of the convolution.
|
116 |
+
network_alpha (`float`, `optional`, defaults to `None`):
|
117 |
+
The value of the network alpha used for stable learning and preventing underflow. This value has the same
|
118 |
+
meaning as the `--network_alpha` option in the kohya-ss trainer script. See
|
119 |
+
https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
in_features: int,
|
125 |
+
out_features: int,
|
126 |
+
rank: int = 4,
|
127 |
+
kernel_size: Union[int, Tuple[int, int]] = (1, 1),
|
128 |
+
stride: Union[int, Tuple[int, int]] = (1, 1),
|
129 |
+
padding: Union[int, Tuple[int, int], str] = 0,
|
130 |
+
network_alpha: Optional[float] = None,
|
131 |
+
):
|
132 |
+
super().__init__()
|
133 |
+
|
134 |
+
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
|
135 |
+
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
|
136 |
+
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
|
137 |
+
self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
|
138 |
+
|
139 |
+
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
140 |
+
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
141 |
+
self.network_alpha = network_alpha
|
142 |
+
self.rank = rank
|
143 |
+
|
144 |
+
nn.init.normal_(self.down.weight, std=1 / rank)
|
145 |
+
nn.init.zeros_(self.up.weight)
|
146 |
+
|
147 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
148 |
+
orig_dtype = hidden_states.dtype
|
149 |
+
dtype = self.down.weight.dtype
|
150 |
+
|
151 |
+
down_hidden_states = self.down(hidden_states.to(dtype))
|
152 |
+
up_hidden_states = self.up(down_hidden_states)
|
153 |
+
|
154 |
+
if self.network_alpha is not None:
|
155 |
+
up_hidden_states *= self.network_alpha / self.rank
|
156 |
+
|
157 |
+
return up_hidden_states.to(orig_dtype)
|
158 |
+
|
159 |
+
|
160 |
+
class LoRACompatibleConv(nn.Conv2d):
|
161 |
+
"""
|
162 |
+
A convolutional layer that can be used with LoRA.
|
163 |
+
"""
|
164 |
+
|
165 |
+
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
|
166 |
+
super().__init__(*args, **kwargs)
|
167 |
+
self.lora_layer = lora_layer
|
168 |
+
|
169 |
+
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
|
170 |
+
self.lora_layer = lora_layer
|
171 |
+
|
172 |
+
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
|
173 |
+
if self.lora_layer is None:
|
174 |
+
return
|
175 |
+
|
176 |
+
dtype, device = self.weight.data.dtype, self.weight.data.device
|
177 |
+
|
178 |
+
w_orig = self.weight.data.float()
|
179 |
+
w_up = self.lora_layer.up.weight.data.float()
|
180 |
+
w_down = self.lora_layer.down.weight.data.float()
|
181 |
+
|
182 |
+
if self.lora_layer.network_alpha is not None:
|
183 |
+
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
|
184 |
+
|
185 |
+
fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
|
186 |
+
fusion = fusion.reshape((w_orig.shape))
|
187 |
+
fused_weight = w_orig + (lora_scale * fusion)
|
188 |
+
|
189 |
+
if safe_fusing and torch.isnan(fused_weight).any().item():
|
190 |
+
raise ValueError(
|
191 |
+
"This LoRA weight seems to be broken. "
|
192 |
+
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
|
193 |
+
"LoRA weights will not be fused."
|
194 |
+
)
|
195 |
+
|
196 |
+
self.weight.data = fused_weight.to(device=device, dtype=dtype)
|
197 |
+
|
198 |
+
# we can drop the lora layer now
|
199 |
+
self.lora_layer = None
|
200 |
+
|
201 |
+
# offload the up and down matrices to CPU to not blow the memory
|
202 |
+
self.w_up = w_up.cpu()
|
203 |
+
self.w_down = w_down.cpu()
|
204 |
+
self._lora_scale = lora_scale
|
205 |
+
|
206 |
+
def _unfuse_lora(self):
|
207 |
+
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
|
208 |
+
return
|
209 |
+
|
210 |
+
fused_weight = self.weight.data
|
211 |
+
dtype, device = fused_weight.data.dtype, fused_weight.data.device
|
212 |
+
|
213 |
+
self.w_up = self.w_up.to(device=device).float()
|
214 |
+
self.w_down = self.w_down.to(device).float()
|
215 |
+
|
216 |
+
fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1))
|
217 |
+
fusion = fusion.reshape((fused_weight.shape))
|
218 |
+
unfused_weight = fused_weight.float() - (self._lora_scale * fusion)
|
219 |
+
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
|
220 |
+
|
221 |
+
self.w_up = None
|
222 |
+
self.w_down = None
|
223 |
+
|
224 |
+
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
225 |
+
if self.lora_layer is None:
|
226 |
+
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
|
227 |
+
# see: https://github.com/huggingface/diffusers/pull/4315
|
228 |
+
return F.conv2d(
|
229 |
+
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
230 |
+
)
|
231 |
+
else:
|
232 |
+
original_outputs = F.conv2d(
|
233 |
+
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
234 |
+
)
|
235 |
+
return original_outputs + (scale * self.lora_layer(hidden_states))
|
236 |
+
|
237 |
+
|
238 |
+
class LoRACompatibleLinear(nn.Linear):
|
239 |
+
"""
|
240 |
+
A Linear layer that can be used with LoRA.
|
241 |
+
"""
|
242 |
+
|
243 |
+
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
|
244 |
+
super().__init__(*args, **kwargs)
|
245 |
+
self.lora_layer = lora_layer
|
246 |
+
|
247 |
+
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
|
248 |
+
self.lora_layer = lora_layer
|
249 |
+
|
250 |
+
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
|
251 |
+
if self.lora_layer is None:
|
252 |
+
return
|
253 |
+
|
254 |
+
dtype, device = self.weight.data.dtype, self.weight.data.device
|
255 |
+
|
256 |
+
w_orig = self.weight.data.float()
|
257 |
+
w_up = self.lora_layer.up.weight.data.float()
|
258 |
+
w_down = self.lora_layer.down.weight.data.float()
|
259 |
+
|
260 |
+
if self.lora_layer.network_alpha is not None:
|
261 |
+
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
|
262 |
+
|
263 |
+
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
|
264 |
+
|
265 |
+
if safe_fusing and torch.isnan(fused_weight).any().item():
|
266 |
+
raise ValueError(
|
267 |
+
"This LoRA weight seems to be broken. "
|
268 |
+
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
|
269 |
+
"LoRA weights will not be fused."
|
270 |
+
)
|
271 |
+
|
272 |
+
self.weight.data = fused_weight.to(device=device, dtype=dtype)
|
273 |
+
|
274 |
+
# we can drop the lora layer now
|
275 |
+
self.lora_layer = None
|
276 |
+
|
277 |
+
# offload the up and down matrices to CPU to not blow the memory
|
278 |
+
self.w_up = w_up.cpu()
|
279 |
+
self.w_down = w_down.cpu()
|
280 |
+
self._lora_scale = lora_scale
|
281 |
+
|
282 |
+
def _unfuse_lora(self):
|
283 |
+
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
|
284 |
+
return
|
285 |
+
|
286 |
+
fused_weight = self.weight.data
|
287 |
+
dtype, device = fused_weight.dtype, fused_weight.device
|
288 |
+
|
289 |
+
w_up = self.w_up.to(device=device).float()
|
290 |
+
w_down = self.w_down.to(device).float()
|
291 |
+
|
292 |
+
unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
|
293 |
+
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
|
294 |
+
|
295 |
+
self.w_up = None
|
296 |
+
self.w_down = None
|
297 |
+
|
298 |
+
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
299 |
+
if self.lora_layer is None:
|
300 |
+
out = super().forward(hidden_states)
|
301 |
+
return out
|
302 |
+
else:
|
303 |
+
out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
|
304 |
+
return out
|
diffusers/models/modeling_flax_pytorch_utils.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch - Flax general utilities."""
|
16 |
+
import re
|
17 |
+
|
18 |
+
import jax.numpy as jnp
|
19 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
20 |
+
from jax.random import PRNGKey
|
21 |
+
|
22 |
+
from ..utils import logging
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
def rename_key(key):
|
29 |
+
regex = r"\w+[.]\d+"
|
30 |
+
pats = re.findall(regex, key)
|
31 |
+
for pat in pats:
|
32 |
+
key = key.replace(pat, "_".join(pat.split(".")))
|
33 |
+
return key
|
34 |
+
|
35 |
+
|
36 |
+
#####################
|
37 |
+
# PyTorch => Flax #
|
38 |
+
#####################
|
39 |
+
|
40 |
+
|
41 |
+
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
|
42 |
+
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
|
43 |
+
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
|
44 |
+
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
|
45 |
+
# conv norm or layer norm
|
46 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
47 |
+
|
48 |
+
# rename attention layers
|
49 |
+
if len(pt_tuple_key) > 1:
|
50 |
+
for rename_from, rename_to in (
|
51 |
+
("to_out_0", "proj_attn"),
|
52 |
+
("to_k", "key"),
|
53 |
+
("to_v", "value"),
|
54 |
+
("to_q", "query"),
|
55 |
+
):
|
56 |
+
if pt_tuple_key[-2] == rename_from:
|
57 |
+
weight_name = pt_tuple_key[-1]
|
58 |
+
weight_name = "kernel" if weight_name == "weight" else weight_name
|
59 |
+
renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name)
|
60 |
+
if renamed_pt_tuple_key in random_flax_state_dict:
|
61 |
+
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
|
62 |
+
return renamed_pt_tuple_key, pt_tensor.T
|
63 |
+
|
64 |
+
if (
|
65 |
+
any("norm" in str_ for str_ in pt_tuple_key)
|
66 |
+
and (pt_tuple_key[-1] == "bias")
|
67 |
+
and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
|
68 |
+
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
|
69 |
+
):
|
70 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
71 |
+
return renamed_pt_tuple_key, pt_tensor
|
72 |
+
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
|
73 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
74 |
+
return renamed_pt_tuple_key, pt_tensor
|
75 |
+
|
76 |
+
# embedding
|
77 |
+
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
|
78 |
+
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
|
79 |
+
return renamed_pt_tuple_key, pt_tensor
|
80 |
+
|
81 |
+
# conv layer
|
82 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
83 |
+
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
|
84 |
+
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
|
85 |
+
return renamed_pt_tuple_key, pt_tensor
|
86 |
+
|
87 |
+
# linear layer
|
88 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
89 |
+
if pt_tuple_key[-1] == "weight":
|
90 |
+
pt_tensor = pt_tensor.T
|
91 |
+
return renamed_pt_tuple_key, pt_tensor
|
92 |
+
|
93 |
+
# old PyTorch layer norm weight
|
94 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
95 |
+
if pt_tuple_key[-1] == "gamma":
|
96 |
+
return renamed_pt_tuple_key, pt_tensor
|
97 |
+
|
98 |
+
# old PyTorch layer norm bias
|
99 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
100 |
+
if pt_tuple_key[-1] == "beta":
|
101 |
+
return renamed_pt_tuple_key, pt_tensor
|
102 |
+
|
103 |
+
return pt_tuple_key, pt_tensor
|
104 |
+
|
105 |
+
|
106 |
+
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
|
107 |
+
# Step 1: Convert pytorch tensor to numpy
|
108 |
+
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
109 |
+
|
110 |
+
# Step 2: Since the model is stateless, get random Flax params
|
111 |
+
random_flax_params = flax_model.init_weights(PRNGKey(init_key))
|
112 |
+
|
113 |
+
random_flax_state_dict = flatten_dict(random_flax_params)
|
114 |
+
flax_state_dict = {}
|
115 |
+
|
116 |
+
# Need to change some parameters name to match Flax names
|
117 |
+
for pt_key, pt_tensor in pt_state_dict.items():
|
118 |
+
renamed_pt_key = rename_key(pt_key)
|
119 |
+
pt_tuple_key = tuple(renamed_pt_key.split("."))
|
120 |
+
|
121 |
+
# Correctly rename weight parameters
|
122 |
+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
|
123 |
+
|
124 |
+
if flax_key in random_flax_state_dict:
|
125 |
+
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
|
126 |
+
raise ValueError(
|
127 |
+
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
|
128 |
+
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
129 |
+
)
|
130 |
+
|
131 |
+
# also add unexpected weight so that warning is thrown
|
132 |
+
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
|
133 |
+
|
134 |
+
return unflatten_dict(flax_state_dict)
|
diffusers/models/modeling_flax_utils.py
ADDED
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
from pickle import UnpicklingError
|
18 |
+
from typing import Any, Dict, Union
|
19 |
+
|
20 |
+
import jax
|
21 |
+
import jax.numpy as jnp
|
22 |
+
import msgpack.exceptions
|
23 |
+
from flax.core.frozen_dict import FrozenDict, unfreeze
|
24 |
+
from flax.serialization import from_bytes, to_bytes
|
25 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
26 |
+
from huggingface_hub import create_repo, hf_hub_download
|
27 |
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
28 |
+
from requests import HTTPError
|
29 |
+
|
30 |
+
from .. import __version__, is_torch_available
|
31 |
+
from ..utils import (
|
32 |
+
CONFIG_NAME,
|
33 |
+
DIFFUSERS_CACHE,
|
34 |
+
FLAX_WEIGHTS_NAME,
|
35 |
+
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
36 |
+
WEIGHTS_NAME,
|
37 |
+
PushToHubMixin,
|
38 |
+
logging,
|
39 |
+
)
|
40 |
+
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
41 |
+
|
42 |
+
|
43 |
+
logger = logging.get_logger(__name__)
|
44 |
+
|
45 |
+
|
46 |
+
class FlaxModelMixin(PushToHubMixin):
|
47 |
+
r"""
|
48 |
+
Base class for all Flax models.
|
49 |
+
|
50 |
+
[`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
|
51 |
+
saving models.
|
52 |
+
|
53 |
+
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
|
54 |
+
"""
|
55 |
+
config_name = CONFIG_NAME
|
56 |
+
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
57 |
+
_flax_internal_args = ["name", "parent", "dtype"]
|
58 |
+
|
59 |
+
@classmethod
|
60 |
+
def _from_config(cls, config, **kwargs):
|
61 |
+
"""
|
62 |
+
All context managers that the model should be initialized under go here.
|
63 |
+
"""
|
64 |
+
return cls(config, **kwargs)
|
65 |
+
|
66 |
+
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
|
67 |
+
"""
|
68 |
+
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
|
69 |
+
"""
|
70 |
+
|
71 |
+
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
|
72 |
+
def conditional_cast(param):
|
73 |
+
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
|
74 |
+
param = param.astype(dtype)
|
75 |
+
return param
|
76 |
+
|
77 |
+
if mask is None:
|
78 |
+
return jax.tree_map(conditional_cast, params)
|
79 |
+
|
80 |
+
flat_params = flatten_dict(params)
|
81 |
+
flat_mask, _ = jax.tree_flatten(mask)
|
82 |
+
|
83 |
+
for masked, key in zip(flat_mask, flat_params.keys()):
|
84 |
+
if masked:
|
85 |
+
param = flat_params[key]
|
86 |
+
flat_params[key] = conditional_cast(param)
|
87 |
+
|
88 |
+
return unflatten_dict(flat_params)
|
89 |
+
|
90 |
+
def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
91 |
+
r"""
|
92 |
+
Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
|
93 |
+
the `params` in place.
|
94 |
+
|
95 |
+
This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
|
96 |
+
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
|
97 |
+
|
98 |
+
Arguments:
|
99 |
+
params (`Union[Dict, FrozenDict]`):
|
100 |
+
A `PyTree` of model parameters.
|
101 |
+
mask (`Union[Dict, FrozenDict]`):
|
102 |
+
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
|
103 |
+
for params you want to cast, and `False` for those you want to skip.
|
104 |
+
|
105 |
+
Examples:
|
106 |
+
|
107 |
+
```python
|
108 |
+
>>> from diffusers import FlaxUNet2DConditionModel
|
109 |
+
|
110 |
+
>>> # load model
|
111 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
112 |
+
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
|
113 |
+
>>> params = model.to_bf16(params)
|
114 |
+
>>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
|
115 |
+
>>> # then pass the mask as follows
|
116 |
+
>>> from flax import traverse_util
|
117 |
+
|
118 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
119 |
+
>>> flat_params = traverse_util.flatten_dict(params)
|
120 |
+
>>> mask = {
|
121 |
+
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
|
122 |
+
... for path in flat_params
|
123 |
+
... }
|
124 |
+
>>> mask = traverse_util.unflatten_dict(mask)
|
125 |
+
>>> params = model.to_bf16(params, mask)
|
126 |
+
```"""
|
127 |
+
return self._cast_floating_to(params, jnp.bfloat16, mask)
|
128 |
+
|
129 |
+
def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
130 |
+
r"""
|
131 |
+
Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
|
132 |
+
model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
|
133 |
+
|
134 |
+
Arguments:
|
135 |
+
params (`Union[Dict, FrozenDict]`):
|
136 |
+
A `PyTree` of model parameters.
|
137 |
+
mask (`Union[Dict, FrozenDict]`):
|
138 |
+
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
|
139 |
+
for params you want to cast, and `False` for those you want to skip.
|
140 |
+
|
141 |
+
Examples:
|
142 |
+
|
143 |
+
```python
|
144 |
+
>>> from diffusers import FlaxUNet2DConditionModel
|
145 |
+
|
146 |
+
>>> # Download model and configuration from huggingface.co
|
147 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
148 |
+
>>> # By default, the model params will be in fp32, to illustrate the use of this method,
|
149 |
+
>>> # we'll first cast to fp16 and back to fp32
|
150 |
+
>>> params = model.to_f16(params)
|
151 |
+
>>> # now cast back to fp32
|
152 |
+
>>> params = model.to_fp32(params)
|
153 |
+
```"""
|
154 |
+
return self._cast_floating_to(params, jnp.float32, mask)
|
155 |
+
|
156 |
+
def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
157 |
+
r"""
|
158 |
+
Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
|
159 |
+
`params` in place.
|
160 |
+
|
161 |
+
This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full
|
162 |
+
half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
|
163 |
+
|
164 |
+
Arguments:
|
165 |
+
params (`Union[Dict, FrozenDict]`):
|
166 |
+
A `PyTree` of model parameters.
|
167 |
+
mask (`Union[Dict, FrozenDict]`):
|
168 |
+
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
|
169 |
+
for params you want to cast, and `False` for those you want to skip.
|
170 |
+
|
171 |
+
Examples:
|
172 |
+
|
173 |
+
```python
|
174 |
+
>>> from diffusers import FlaxUNet2DConditionModel
|
175 |
+
|
176 |
+
>>> # load model
|
177 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
178 |
+
>>> # By default, the model params will be in fp32, to cast these to float16
|
179 |
+
>>> params = model.to_fp16(params)
|
180 |
+
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
|
181 |
+
>>> # then pass the mask as follows
|
182 |
+
>>> from flax import traverse_util
|
183 |
+
|
184 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
185 |
+
>>> flat_params = traverse_util.flatten_dict(params)
|
186 |
+
>>> mask = {
|
187 |
+
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
|
188 |
+
... for path in flat_params
|
189 |
+
... }
|
190 |
+
>>> mask = traverse_util.unflatten_dict(mask)
|
191 |
+
>>> params = model.to_fp16(params, mask)
|
192 |
+
```"""
|
193 |
+
return self._cast_floating_to(params, jnp.float16, mask)
|
194 |
+
|
195 |
+
def init_weights(self, rng: jax.Array) -> Dict:
|
196 |
+
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
|
197 |
+
|
198 |
+
@classmethod
|
199 |
+
def from_pretrained(
|
200 |
+
cls,
|
201 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
202 |
+
dtype: jnp.dtype = jnp.float32,
|
203 |
+
*model_args,
|
204 |
+
**kwargs,
|
205 |
+
):
|
206 |
+
r"""
|
207 |
+
Instantiate a pretrained Flax model from a pretrained model configuration.
|
208 |
+
|
209 |
+
Parameters:
|
210 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
211 |
+
Can be either:
|
212 |
+
|
213 |
+
- A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model
|
214 |
+
hosted on the Hub.
|
215 |
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
216 |
+
using [`~FlaxModelMixin.save_pretrained`].
|
217 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
218 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
219 |
+
`jax.numpy.bfloat16` (on TPUs).
|
220 |
+
|
221 |
+
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
222 |
+
specified, all the computation will be performed with the given `dtype`.
|
223 |
+
|
224 |
+
<Tip>
|
225 |
+
|
226 |
+
This only specifies the dtype of the *computation* and does not influence the dtype of model
|
227 |
+
parameters.
|
228 |
+
|
229 |
+
If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
|
230 |
+
[`~FlaxModelMixin.to_bf16`].
|
231 |
+
|
232 |
+
</Tip>
|
233 |
+
|
234 |
+
model_args (sequence of positional arguments, *optional*):
|
235 |
+
All remaining positional arguments are passed to the underlying model's `__init__` method.
|
236 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
237 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
238 |
+
is not used.
|
239 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
240 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
241 |
+
cached versions if they exist.
|
242 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
243 |
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
244 |
+
incompletely downloaded files are deleted.
|
245 |
+
proxies (`Dict[str, str]`, *optional*):
|
246 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
247 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
248 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
249 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
250 |
+
won't be downloaded from the Hub.
|
251 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
252 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
253 |
+
allowed by Git.
|
254 |
+
from_pt (`bool`, *optional*, defaults to `False`):
|
255 |
+
Load the model weights from a PyTorch checkpoint save file.
|
256 |
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
257 |
+
Can be used to update the configuration object (after it is loaded) and initiate the model (for
|
258 |
+
example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
259 |
+
automatically loaded:
|
260 |
+
|
261 |
+
- If a configuration is provided with `config`, `kwargs` are directly passed to the underlying
|
262 |
+
model's `__init__` method (we assume all relevant updates to the configuration have already been
|
263 |
+
done).
|
264 |
+
- If a configuration is not provided, `kwargs` are first passed to the configuration class
|
265 |
+
initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds
|
266 |
+
to a configuration attribute is used to override said attribute with the supplied `kwargs` value.
|
267 |
+
Remaining keys that do not correspond to any configuration attribute are passed to the underlying
|
268 |
+
model's `__init__` function.
|
269 |
+
|
270 |
+
Examples:
|
271 |
+
|
272 |
+
```python
|
273 |
+
>>> from diffusers import FlaxUNet2DConditionModel
|
274 |
+
|
275 |
+
>>> # Download model and configuration from huggingface.co and cache.
|
276 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
277 |
+
>>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
|
278 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
|
279 |
+
```
|
280 |
+
|
281 |
+
If you get the error message below, you need to finetune the weights for your downstream task:
|
282 |
+
|
283 |
+
```bash
|
284 |
+
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
285 |
+
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
286 |
+
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
287 |
+
```
|
288 |
+
"""
|
289 |
+
config = kwargs.pop("config", None)
|
290 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
291 |
+
force_download = kwargs.pop("force_download", False)
|
292 |
+
from_pt = kwargs.pop("from_pt", False)
|
293 |
+
resume_download = kwargs.pop("resume_download", False)
|
294 |
+
proxies = kwargs.pop("proxies", None)
|
295 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
296 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
297 |
+
revision = kwargs.pop("revision", None)
|
298 |
+
subfolder = kwargs.pop("subfolder", None)
|
299 |
+
|
300 |
+
user_agent = {
|
301 |
+
"diffusers": __version__,
|
302 |
+
"file_type": "model",
|
303 |
+
"framework": "flax",
|
304 |
+
}
|
305 |
+
|
306 |
+
# Load config if we don't provide one
|
307 |
+
if config is None:
|
308 |
+
config, unused_kwargs = cls.load_config(
|
309 |
+
pretrained_model_name_or_path,
|
310 |
+
cache_dir=cache_dir,
|
311 |
+
return_unused_kwargs=True,
|
312 |
+
force_download=force_download,
|
313 |
+
resume_download=resume_download,
|
314 |
+
proxies=proxies,
|
315 |
+
local_files_only=local_files_only,
|
316 |
+
use_auth_token=use_auth_token,
|
317 |
+
revision=revision,
|
318 |
+
subfolder=subfolder,
|
319 |
+
**kwargs,
|
320 |
+
)
|
321 |
+
|
322 |
+
model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs)
|
323 |
+
|
324 |
+
# Load model
|
325 |
+
pretrained_path_with_subfolder = (
|
326 |
+
pretrained_model_name_or_path
|
327 |
+
if subfolder is None
|
328 |
+
else os.path.join(pretrained_model_name_or_path, subfolder)
|
329 |
+
)
|
330 |
+
if os.path.isdir(pretrained_path_with_subfolder):
|
331 |
+
if from_pt:
|
332 |
+
if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
|
333 |
+
raise EnvironmentError(
|
334 |
+
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
|
335 |
+
)
|
336 |
+
model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
|
337 |
+
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
|
338 |
+
# Load from a Flax checkpoint
|
339 |
+
model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
|
340 |
+
# Check if pytorch weights exist instead
|
341 |
+
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
|
342 |
+
raise EnvironmentError(
|
343 |
+
f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
|
344 |
+
" using `from_pt=True`."
|
345 |
+
)
|
346 |
+
else:
|
347 |
+
raise EnvironmentError(
|
348 |
+
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
349 |
+
f"{pretrained_path_with_subfolder}."
|
350 |
+
)
|
351 |
+
else:
|
352 |
+
try:
|
353 |
+
model_file = hf_hub_download(
|
354 |
+
pretrained_model_name_or_path,
|
355 |
+
filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
|
356 |
+
cache_dir=cache_dir,
|
357 |
+
force_download=force_download,
|
358 |
+
proxies=proxies,
|
359 |
+
resume_download=resume_download,
|
360 |
+
local_files_only=local_files_only,
|
361 |
+
use_auth_token=use_auth_token,
|
362 |
+
user_agent=user_agent,
|
363 |
+
subfolder=subfolder,
|
364 |
+
revision=revision,
|
365 |
+
)
|
366 |
+
|
367 |
+
except RepositoryNotFoundError:
|
368 |
+
raise EnvironmentError(
|
369 |
+
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
370 |
+
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
371 |
+
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
372 |
+
"login`."
|
373 |
+
)
|
374 |
+
except RevisionNotFoundError:
|
375 |
+
raise EnvironmentError(
|
376 |
+
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
377 |
+
"this model name. Check the model page at "
|
378 |
+
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
379 |
+
)
|
380 |
+
except EntryNotFoundError:
|
381 |
+
raise EnvironmentError(
|
382 |
+
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
|
383 |
+
)
|
384 |
+
except HTTPError as err:
|
385 |
+
raise EnvironmentError(
|
386 |
+
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
|
387 |
+
f"{err}"
|
388 |
+
)
|
389 |
+
except ValueError:
|
390 |
+
raise EnvironmentError(
|
391 |
+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
392 |
+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
393 |
+
f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
|
394 |
+
" internet connection or see how to run the library in offline mode at"
|
395 |
+
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
396 |
+
)
|
397 |
+
except EnvironmentError:
|
398 |
+
raise EnvironmentError(
|
399 |
+
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
400 |
+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
401 |
+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
402 |
+
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
403 |
+
)
|
404 |
+
|
405 |
+
if from_pt:
|
406 |
+
if is_torch_available():
|
407 |
+
from .modeling_utils import load_state_dict
|
408 |
+
else:
|
409 |
+
raise EnvironmentError(
|
410 |
+
"Can't load the model in PyTorch format because PyTorch is not installed. "
|
411 |
+
"Please, install PyTorch or use native Flax weights."
|
412 |
+
)
|
413 |
+
|
414 |
+
# Step 1: Get the pytorch file
|
415 |
+
pytorch_model_file = load_state_dict(model_file)
|
416 |
+
|
417 |
+
# Step 2: Convert the weights
|
418 |
+
state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
|
419 |
+
else:
|
420 |
+
try:
|
421 |
+
with open(model_file, "rb") as state_f:
|
422 |
+
state = from_bytes(cls, state_f.read())
|
423 |
+
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
424 |
+
try:
|
425 |
+
with open(model_file) as f:
|
426 |
+
if f.read().startswith("version"):
|
427 |
+
raise OSError(
|
428 |
+
"You seem to have cloned a repository without having git-lfs installed. Please"
|
429 |
+
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
430 |
+
" folder you cloned."
|
431 |
+
)
|
432 |
+
else:
|
433 |
+
raise ValueError from e
|
434 |
+
except (UnicodeDecodeError, ValueError):
|
435 |
+
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
|
436 |
+
# make sure all arrays are stored as jnp.ndarray
|
437 |
+
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
438 |
+
# https://github.com/google/flax/issues/1261
|
439 |
+
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
|
440 |
+
|
441 |
+
# flatten dicts
|
442 |
+
state = flatten_dict(state)
|
443 |
+
|
444 |
+
params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
|
445 |
+
required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
|
446 |
+
|
447 |
+
shape_state = flatten_dict(unfreeze(params_shape_tree))
|
448 |
+
|
449 |
+
missing_keys = required_params - set(state.keys())
|
450 |
+
unexpected_keys = set(state.keys()) - required_params
|
451 |
+
|
452 |
+
if missing_keys:
|
453 |
+
logger.warning(
|
454 |
+
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
|
455 |
+
"Make sure to call model.init_weights to initialize the missing weights."
|
456 |
+
)
|
457 |
+
cls._missing_keys = missing_keys
|
458 |
+
|
459 |
+
for key in state.keys():
|
460 |
+
if key in shape_state and state[key].shape != shape_state[key].shape:
|
461 |
+
raise ValueError(
|
462 |
+
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
|
463 |
+
f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
|
464 |
+
)
|
465 |
+
|
466 |
+
# remove unexpected keys to not be saved again
|
467 |
+
for unexpected_key in unexpected_keys:
|
468 |
+
del state[unexpected_key]
|
469 |
+
|
470 |
+
if len(unexpected_keys) > 0:
|
471 |
+
logger.warning(
|
472 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
473 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
474 |
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
|
475 |
+
" with another architecture."
|
476 |
+
)
|
477 |
+
else:
|
478 |
+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
479 |
+
|
480 |
+
if len(missing_keys) > 0:
|
481 |
+
logger.warning(
|
482 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
483 |
+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
484 |
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
485 |
+
)
|
486 |
+
else:
|
487 |
+
logger.info(
|
488 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
489 |
+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
|
490 |
+
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
|
491 |
+
" training."
|
492 |
+
)
|
493 |
+
|
494 |
+
return model, unflatten_dict(state)
|
495 |
+
|
496 |
+
def save_pretrained(
|
497 |
+
self,
|
498 |
+
save_directory: Union[str, os.PathLike],
|
499 |
+
params: Union[Dict, FrozenDict],
|
500 |
+
is_main_process: bool = True,
|
501 |
+
push_to_hub: bool = False,
|
502 |
+
**kwargs,
|
503 |
+
):
|
504 |
+
"""
|
505 |
+
Save a model and its configuration file to a directory so that it can be reloaded using the
|
506 |
+
[`~FlaxModelMixin.from_pretrained`] class method.
|
507 |
+
|
508 |
+
Arguments:
|
509 |
+
save_directory (`str` or `os.PathLike`):
|
510 |
+
Directory to save a model and its configuration file to. Will be created if it doesn't exist.
|
511 |
+
params (`Union[Dict, FrozenDict]`):
|
512 |
+
A `PyTree` of model parameters.
|
513 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
514 |
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
515 |
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
516 |
+
process to avoid race conditions.
|
517 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
518 |
+
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
519 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
520 |
+
namespace).
|
521 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
522 |
+
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
523 |
+
"""
|
524 |
+
if os.path.isfile(save_directory):
|
525 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
526 |
+
return
|
527 |
+
|
528 |
+
os.makedirs(save_directory, exist_ok=True)
|
529 |
+
|
530 |
+
if push_to_hub:
|
531 |
+
commit_message = kwargs.pop("commit_message", None)
|
532 |
+
private = kwargs.pop("private", False)
|
533 |
+
create_pr = kwargs.pop("create_pr", False)
|
534 |
+
token = kwargs.pop("token", None)
|
535 |
+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
536 |
+
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
537 |
+
|
538 |
+
model_to_save = self
|
539 |
+
|
540 |
+
# Attach architecture to the config
|
541 |
+
# Save the config
|
542 |
+
if is_main_process:
|
543 |
+
model_to_save.save_config(save_directory)
|
544 |
+
|
545 |
+
# save model
|
546 |
+
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
|
547 |
+
with open(output_model_file, "wb") as f:
|
548 |
+
model_bytes = to_bytes(params)
|
549 |
+
f.write(model_bytes)
|
550 |
+
|
551 |
+
logger.info(f"Model weights saved in {output_model_file}")
|
552 |
+
|
553 |
+
if push_to_hub:
|
554 |
+
self._upload_folder(
|
555 |
+
save_directory,
|
556 |
+
repo_id,
|
557 |
+
token=token,
|
558 |
+
commit_message=commit_message,
|
559 |
+
create_pr=create_pr,
|
560 |
+
)
|
diffusers/models/modeling_pytorch_flax_utils.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch - Flax general utilities."""
|
16 |
+
|
17 |
+
from pickle import UnpicklingError
|
18 |
+
|
19 |
+
import jax
|
20 |
+
import jax.numpy as jnp
|
21 |
+
import numpy as np
|
22 |
+
from flax.serialization import from_bytes
|
23 |
+
from flax.traverse_util import flatten_dict
|
24 |
+
|
25 |
+
from ..utils import logging
|
26 |
+
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
#####################
|
32 |
+
# Flax => PyTorch #
|
33 |
+
#####################
|
34 |
+
|
35 |
+
|
36 |
+
# from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
|
37 |
+
def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
|
38 |
+
try:
|
39 |
+
with open(model_file, "rb") as flax_state_f:
|
40 |
+
flax_state = from_bytes(None, flax_state_f.read())
|
41 |
+
except UnpicklingError as e:
|
42 |
+
try:
|
43 |
+
with open(model_file) as f:
|
44 |
+
if f.read().startswith("version"):
|
45 |
+
raise OSError(
|
46 |
+
"You seem to have cloned a repository without having git-lfs installed. Please"
|
47 |
+
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
48 |
+
" folder you cloned."
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
raise ValueError from e
|
52 |
+
except (UnicodeDecodeError, ValueError):
|
53 |
+
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
|
54 |
+
|
55 |
+
return load_flax_weights_in_pytorch_model(pt_model, flax_state)
|
56 |
+
|
57 |
+
|
58 |
+
def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
59 |
+
"""Load flax checkpoints in a PyTorch model"""
|
60 |
+
|
61 |
+
try:
|
62 |
+
import torch # noqa: F401
|
63 |
+
except ImportError:
|
64 |
+
logger.error(
|
65 |
+
"Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
|
66 |
+
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
67 |
+
" instructions."
|
68 |
+
)
|
69 |
+
raise
|
70 |
+
|
71 |
+
# check if we have bf16 weights
|
72 |
+
is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
|
73 |
+
if any(is_type_bf16):
|
74 |
+
# convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16
|
75 |
+
|
76 |
+
# and bf16 is not fully supported in PT yet.
|
77 |
+
logger.warning(
|
78 |
+
"Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
|
79 |
+
"before loading those in PyTorch model."
|
80 |
+
)
|
81 |
+
flax_state = jax.tree_util.tree_map(
|
82 |
+
lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
|
83 |
+
)
|
84 |
+
|
85 |
+
pt_model.base_model_prefix = ""
|
86 |
+
|
87 |
+
flax_state_dict = flatten_dict(flax_state, sep=".")
|
88 |
+
pt_model_dict = pt_model.state_dict()
|
89 |
+
|
90 |
+
# keep track of unexpected & missing keys
|
91 |
+
unexpected_keys = []
|
92 |
+
missing_keys = set(pt_model_dict.keys())
|
93 |
+
|
94 |
+
for flax_key_tuple, flax_tensor in flax_state_dict.items():
|
95 |
+
flax_key_tuple_array = flax_key_tuple.split(".")
|
96 |
+
|
97 |
+
if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
|
98 |
+
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
|
99 |
+
flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
|
100 |
+
elif flax_key_tuple_array[-1] == "kernel":
|
101 |
+
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
|
102 |
+
flax_tensor = flax_tensor.T
|
103 |
+
elif flax_key_tuple_array[-1] == "scale":
|
104 |
+
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
|
105 |
+
|
106 |
+
if "time_embedding" not in flax_key_tuple_array:
|
107 |
+
for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
|
108 |
+
flax_key_tuple_array[i] = (
|
109 |
+
flax_key_tuple_string.replace("_0", ".0")
|
110 |
+
.replace("_1", ".1")
|
111 |
+
.replace("_2", ".2")
|
112 |
+
.replace("_3", ".3")
|
113 |
+
.replace("_4", ".4")
|
114 |
+
.replace("_5", ".5")
|
115 |
+
.replace("_6", ".6")
|
116 |
+
.replace("_7", ".7")
|
117 |
+
.replace("_8", ".8")
|
118 |
+
.replace("_9", ".9")
|
119 |
+
)
|
120 |
+
|
121 |
+
flax_key = ".".join(flax_key_tuple_array)
|
122 |
+
|
123 |
+
if flax_key in pt_model_dict:
|
124 |
+
if flax_tensor.shape != pt_model_dict[flax_key].shape:
|
125 |
+
raise ValueError(
|
126 |
+
f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
|
127 |
+
f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
128 |
+
)
|
129 |
+
else:
|
130 |
+
# add weight to pytorch dict
|
131 |
+
flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
|
132 |
+
pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
|
133 |
+
# remove from missing keys
|
134 |
+
missing_keys.remove(flax_key)
|
135 |
+
else:
|
136 |
+
# weight is not expected by PyTorch model
|
137 |
+
unexpected_keys.append(flax_key)
|
138 |
+
|
139 |
+
pt_model.load_state_dict(pt_model_dict)
|
140 |
+
|
141 |
+
# re-transform missing_keys to list
|
142 |
+
missing_keys = list(missing_keys)
|
143 |
+
|
144 |
+
if len(unexpected_keys) > 0:
|
145 |
+
logger.warning(
|
146 |
+
"Some weights of the Flax model were not used when initializing the PyTorch model"
|
147 |
+
f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
|
148 |
+
f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
|
149 |
+
" (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
|
150 |
+
f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
|
151 |
+
" to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
|
152 |
+
" FlaxBertForSequenceClassification model)."
|
153 |
+
)
|
154 |
+
if len(missing_keys) > 0:
|
155 |
+
logger.warning(
|
156 |
+
f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
|
157 |
+
f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
|
158 |
+
" use it for predictions and inference."
|
159 |
+
)
|
160 |
+
|
161 |
+
return pt_model
|
diffusers/models/modeling_utils.py
ADDED
@@ -0,0 +1,1158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import inspect
|
18 |
+
import itertools
|
19 |
+
import os
|
20 |
+
import re
|
21 |
+
from functools import partial
|
22 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
23 |
+
|
24 |
+
import safetensors
|
25 |
+
import torch
|
26 |
+
from huggingface_hub import create_repo
|
27 |
+
from torch import Tensor, device, nn
|
28 |
+
|
29 |
+
from .. import __version__
|
30 |
+
from ..utils import (
|
31 |
+
CONFIG_NAME,
|
32 |
+
DIFFUSERS_CACHE,
|
33 |
+
FLAX_WEIGHTS_NAME,
|
34 |
+
HF_HUB_OFFLINE,
|
35 |
+
MIN_PEFT_VERSION,
|
36 |
+
SAFETENSORS_WEIGHTS_NAME,
|
37 |
+
WEIGHTS_NAME,
|
38 |
+
_add_variant,
|
39 |
+
_get_model_file,
|
40 |
+
check_peft_version,
|
41 |
+
deprecate,
|
42 |
+
is_accelerate_available,
|
43 |
+
is_torch_version,
|
44 |
+
logging,
|
45 |
+
)
|
46 |
+
from ..utils.hub_utils import PushToHubMixin
|
47 |
+
|
48 |
+
|
49 |
+
logger = logging.get_logger(__name__)
|
50 |
+
|
51 |
+
|
52 |
+
if is_torch_version(">=", "1.9.0"):
|
53 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
54 |
+
else:
|
55 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
56 |
+
|
57 |
+
|
58 |
+
if is_accelerate_available():
|
59 |
+
import accelerate
|
60 |
+
from accelerate.utils import set_module_tensor_to_device
|
61 |
+
from accelerate.utils.versions import is_torch_version
|
62 |
+
|
63 |
+
|
64 |
+
def get_parameter_device(parameter: torch.nn.Module):
|
65 |
+
try:
|
66 |
+
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
67 |
+
return next(parameters_and_buffers).device
|
68 |
+
except StopIteration:
|
69 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
70 |
+
|
71 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
72 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
73 |
+
return tuples
|
74 |
+
|
75 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
76 |
+
first_tuple = next(gen)
|
77 |
+
return first_tuple[1].device
|
78 |
+
|
79 |
+
|
80 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
81 |
+
try:
|
82 |
+
params = tuple(parameter.parameters())
|
83 |
+
if len(params) > 0:
|
84 |
+
return params[0].dtype
|
85 |
+
|
86 |
+
buffers = tuple(parameter.buffers())
|
87 |
+
if len(buffers) > 0:
|
88 |
+
return buffers[0].dtype
|
89 |
+
|
90 |
+
except StopIteration:
|
91 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
92 |
+
|
93 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
94 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
95 |
+
return tuples
|
96 |
+
|
97 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
98 |
+
first_tuple = next(gen)
|
99 |
+
return first_tuple[1].dtype
|
100 |
+
|
101 |
+
|
102 |
+
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
103 |
+
"""
|
104 |
+
Reads a checkpoint file, returning properly formatted errors if they arise.
|
105 |
+
"""
|
106 |
+
try:
|
107 |
+
if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
|
108 |
+
return torch.load(checkpoint_file, map_location="cpu")
|
109 |
+
else:
|
110 |
+
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
111 |
+
except Exception as e:
|
112 |
+
try:
|
113 |
+
with open(checkpoint_file) as f:
|
114 |
+
if f.read().startswith("version"):
|
115 |
+
raise OSError(
|
116 |
+
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
117 |
+
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
118 |
+
"you cloned."
|
119 |
+
)
|
120 |
+
else:
|
121 |
+
raise ValueError(
|
122 |
+
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
123 |
+
"model. Make sure you have saved the model properly."
|
124 |
+
) from e
|
125 |
+
except (UnicodeDecodeError, ValueError):
|
126 |
+
raise OSError(
|
127 |
+
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
|
128 |
+
f"at '{checkpoint_file}'. "
|
129 |
+
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_name_or_path=None):
|
134 |
+
device = device or torch.device("cpu")
|
135 |
+
dtype = dtype or torch.float32
|
136 |
+
|
137 |
+
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
138 |
+
|
139 |
+
unexpected_keys = []
|
140 |
+
empty_state_dict = model.state_dict()
|
141 |
+
for param_name, param in state_dict.items():
|
142 |
+
if param_name not in empty_state_dict:
|
143 |
+
unexpected_keys.append(param_name)
|
144 |
+
continue
|
145 |
+
|
146 |
+
if empty_state_dict[param_name].shape != param.shape:
|
147 |
+
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
148 |
+
raise ValueError(
|
149 |
+
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
150 |
+
)
|
151 |
+
|
152 |
+
if accepts_dtype:
|
153 |
+
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
|
154 |
+
else:
|
155 |
+
set_module_tensor_to_device(model, param_name, device, value=param)
|
156 |
+
return unexpected_keys
|
157 |
+
|
158 |
+
|
159 |
+
def _load_state_dict_into_model(model_to_load, state_dict):
|
160 |
+
# Convert old format to new format if needed from a PyTorch state_dict
|
161 |
+
# copy state_dict so _load_from_state_dict can modify it
|
162 |
+
state_dict = state_dict.copy()
|
163 |
+
error_msgs = []
|
164 |
+
|
165 |
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
166 |
+
# so we need to apply the function recursively.
|
167 |
+
def load(module: torch.nn.Module, prefix=""):
|
168 |
+
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
169 |
+
module._load_from_state_dict(*args)
|
170 |
+
|
171 |
+
for name, child in module._modules.items():
|
172 |
+
if child is not None:
|
173 |
+
load(child, prefix + name + ".")
|
174 |
+
|
175 |
+
load(model_to_load)
|
176 |
+
|
177 |
+
return error_msgs
|
178 |
+
|
179 |
+
|
180 |
+
class ModelMixin(torch.nn.Module, PushToHubMixin):
|
181 |
+
r"""
|
182 |
+
Base class for all models.
|
183 |
+
|
184 |
+
[`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
|
185 |
+
saving models.
|
186 |
+
|
187 |
+
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
|
188 |
+
"""
|
189 |
+
config_name = CONFIG_NAME
|
190 |
+
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
191 |
+
_supports_gradient_checkpointing = False
|
192 |
+
_keys_to_ignore_on_load_unexpected = None
|
193 |
+
_hf_peft_config_loaded = False
|
194 |
+
|
195 |
+
def __init__(self):
|
196 |
+
super().__init__()
|
197 |
+
|
198 |
+
def __getattr__(self, name: str) -> Any:
|
199 |
+
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
200 |
+
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
|
201 |
+
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
|
202 |
+
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
203 |
+
"""
|
204 |
+
|
205 |
+
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
206 |
+
is_attribute = name in self.__dict__
|
207 |
+
|
208 |
+
if is_in_config and not is_attribute:
|
209 |
+
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
|
210 |
+
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
|
211 |
+
return self._internal_dict[name]
|
212 |
+
|
213 |
+
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
214 |
+
return super().__getattr__(name)
|
215 |
+
|
216 |
+
@property
|
217 |
+
def is_gradient_checkpointing(self) -> bool:
|
218 |
+
"""
|
219 |
+
Whether gradient checkpointing is activated for this model or not.
|
220 |
+
"""
|
221 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
222 |
+
|
223 |
+
def enable_gradient_checkpointing(self):
|
224 |
+
"""
|
225 |
+
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
226 |
+
*checkpoint activations* in other frameworks).
|
227 |
+
"""
|
228 |
+
if not self._supports_gradient_checkpointing:
|
229 |
+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
230 |
+
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
231 |
+
|
232 |
+
def disable_gradient_checkpointing(self):
|
233 |
+
"""
|
234 |
+
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
235 |
+
*checkpoint activations* in other frameworks).
|
236 |
+
"""
|
237 |
+
if self._supports_gradient_checkpointing:
|
238 |
+
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
239 |
+
|
240 |
+
def set_use_memory_efficient_attention_xformers(
|
241 |
+
self, valid: bool, attention_op: Optional[Callable] = None
|
242 |
+
) -> None:
|
243 |
+
# Recursively walk through all the children.
|
244 |
+
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
245 |
+
# gets the message
|
246 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
247 |
+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
248 |
+
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
249 |
+
|
250 |
+
for child in module.children():
|
251 |
+
fn_recursive_set_mem_eff(child)
|
252 |
+
|
253 |
+
for module in self.children():
|
254 |
+
if isinstance(module, torch.nn.Module):
|
255 |
+
fn_recursive_set_mem_eff(module)
|
256 |
+
|
257 |
+
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
258 |
+
r"""
|
259 |
+
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
260 |
+
|
261 |
+
When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
|
262 |
+
inference. Speed up during training is not guaranteed.
|
263 |
+
|
264 |
+
<Tip warning={true}>
|
265 |
+
|
266 |
+
⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
|
267 |
+
precedent.
|
268 |
+
|
269 |
+
</Tip>
|
270 |
+
|
271 |
+
Parameters:
|
272 |
+
attention_op (`Callable`, *optional*):
|
273 |
+
Override the default `None` operator for use as `op` argument to the
|
274 |
+
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
|
275 |
+
function of xFormers.
|
276 |
+
|
277 |
+
Examples:
|
278 |
+
|
279 |
+
```py
|
280 |
+
>>> import torch
|
281 |
+
>>> from diffusers import UNet2DConditionModel
|
282 |
+
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
283 |
+
|
284 |
+
>>> model = UNet2DConditionModel.from_pretrained(
|
285 |
+
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
|
286 |
+
... )
|
287 |
+
>>> model = model.to("cuda")
|
288 |
+
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
289 |
+
```
|
290 |
+
"""
|
291 |
+
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
292 |
+
|
293 |
+
def disable_xformers_memory_efficient_attention(self):
|
294 |
+
r"""
|
295 |
+
Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
296 |
+
"""
|
297 |
+
self.set_use_memory_efficient_attention_xformers(False)
|
298 |
+
|
299 |
+
def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
|
300 |
+
r"""
|
301 |
+
Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned
|
302 |
+
to the adapter to follow the convention of the PEFT library.
|
303 |
+
|
304 |
+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT
|
305 |
+
[documentation](https://huggingface.co/docs/peft).
|
306 |
+
|
307 |
+
Args:
|
308 |
+
adapter_config (`[~peft.PeftConfig]`):
|
309 |
+
The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt
|
310 |
+
methods.
|
311 |
+
adapter_name (`str`, *optional*, defaults to `"default"`):
|
312 |
+
The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
|
313 |
+
"""
|
314 |
+
check_peft_version(min_version=MIN_PEFT_VERSION)
|
315 |
+
|
316 |
+
from peft import PeftConfig, inject_adapter_in_model
|
317 |
+
|
318 |
+
if not self._hf_peft_config_loaded:
|
319 |
+
self._hf_peft_config_loaded = True
|
320 |
+
elif adapter_name in self.peft_config:
|
321 |
+
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
|
322 |
+
|
323 |
+
if not isinstance(adapter_config, PeftConfig):
|
324 |
+
raise ValueError(
|
325 |
+
f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
|
326 |
+
)
|
327 |
+
|
328 |
+
# Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is
|
329 |
+
# handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here.
|
330 |
+
adapter_config.base_model_name_or_path = None
|
331 |
+
inject_adapter_in_model(adapter_config, self, adapter_name)
|
332 |
+
self.set_adapter(adapter_name)
|
333 |
+
|
334 |
+
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
|
335 |
+
"""
|
336 |
+
Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
|
337 |
+
|
338 |
+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
339 |
+
official documentation: https://huggingface.co/docs/peft
|
340 |
+
|
341 |
+
Args:
|
342 |
+
adapter_name (Union[str, List[str]])):
|
343 |
+
The list of adapters to set or the adapter name in case of single adapter.
|
344 |
+
"""
|
345 |
+
check_peft_version(min_version=MIN_PEFT_VERSION)
|
346 |
+
|
347 |
+
if not self._hf_peft_config_loaded:
|
348 |
+
raise ValueError("No adapter loaded. Please load an adapter first.")
|
349 |
+
|
350 |
+
if isinstance(adapter_name, str):
|
351 |
+
adapter_name = [adapter_name]
|
352 |
+
|
353 |
+
missing = set(adapter_name) - set(self.peft_config)
|
354 |
+
if len(missing) > 0:
|
355 |
+
raise ValueError(
|
356 |
+
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
|
357 |
+
f" current loaded adapters are: {list(self.peft_config.keys())}"
|
358 |
+
)
|
359 |
+
|
360 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
361 |
+
|
362 |
+
_adapters_has_been_set = False
|
363 |
+
|
364 |
+
for _, module in self.named_modules():
|
365 |
+
if isinstance(module, BaseTunerLayer):
|
366 |
+
if hasattr(module, "set_adapter"):
|
367 |
+
module.set_adapter(adapter_name)
|
368 |
+
# Previous versions of PEFT does not support multi-adapter inference
|
369 |
+
elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
|
370 |
+
raise ValueError(
|
371 |
+
"You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
|
372 |
+
" `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
|
373 |
+
)
|
374 |
+
else:
|
375 |
+
module.active_adapter = adapter_name
|
376 |
+
_adapters_has_been_set = True
|
377 |
+
|
378 |
+
if not _adapters_has_been_set:
|
379 |
+
raise ValueError(
|
380 |
+
"Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
|
381 |
+
)
|
382 |
+
|
383 |
+
def disable_adapters(self) -> None:
|
384 |
+
r"""
|
385 |
+
Disable all adapters attached to the model and fallback to inference with the base model only.
|
386 |
+
|
387 |
+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
388 |
+
official documentation: https://huggingface.co/docs/peft
|
389 |
+
"""
|
390 |
+
check_peft_version(min_version=MIN_PEFT_VERSION)
|
391 |
+
|
392 |
+
if not self._hf_peft_config_loaded:
|
393 |
+
raise ValueError("No adapter loaded. Please load an adapter first.")
|
394 |
+
|
395 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
396 |
+
|
397 |
+
for _, module in self.named_modules():
|
398 |
+
if isinstance(module, BaseTunerLayer):
|
399 |
+
if hasattr(module, "enable_adapters"):
|
400 |
+
module.enable_adapters(enabled=False)
|
401 |
+
else:
|
402 |
+
# support for older PEFT versions
|
403 |
+
module.disable_adapters = True
|
404 |
+
|
405 |
+
def enable_adapters(self) -> None:
|
406 |
+
"""
|
407 |
+
Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the
|
408 |
+
list of adapters to enable.
|
409 |
+
|
410 |
+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
411 |
+
official documentation: https://huggingface.co/docs/peft
|
412 |
+
"""
|
413 |
+
check_peft_version(min_version=MIN_PEFT_VERSION)
|
414 |
+
|
415 |
+
if not self._hf_peft_config_loaded:
|
416 |
+
raise ValueError("No adapter loaded. Please load an adapter first.")
|
417 |
+
|
418 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
419 |
+
|
420 |
+
for _, module in self.named_modules():
|
421 |
+
if isinstance(module, BaseTunerLayer):
|
422 |
+
if hasattr(module, "enable_adapters"):
|
423 |
+
module.enable_adapters(enabled=True)
|
424 |
+
else:
|
425 |
+
# support for older PEFT versions
|
426 |
+
module.disable_adapters = False
|
427 |
+
|
428 |
+
def active_adapters(self) -> List[str]:
|
429 |
+
"""
|
430 |
+
Gets the current list of active adapters of the model.
|
431 |
+
|
432 |
+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
433 |
+
official documentation: https://huggingface.co/docs/peft
|
434 |
+
"""
|
435 |
+
check_peft_version(min_version=MIN_PEFT_VERSION)
|
436 |
+
|
437 |
+
if not self._hf_peft_config_loaded:
|
438 |
+
raise ValueError("No adapter loaded. Please load an adapter first.")
|
439 |
+
|
440 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
441 |
+
|
442 |
+
for _, module in self.named_modules():
|
443 |
+
if isinstance(module, BaseTunerLayer):
|
444 |
+
return module.active_adapter
|
445 |
+
|
446 |
+
def save_pretrained(
|
447 |
+
self,
|
448 |
+
save_directory: Union[str, os.PathLike],
|
449 |
+
is_main_process: bool = True,
|
450 |
+
save_function: Callable = None,
|
451 |
+
safe_serialization: bool = True,
|
452 |
+
variant: Optional[str] = None,
|
453 |
+
push_to_hub: bool = False,
|
454 |
+
**kwargs,
|
455 |
+
):
|
456 |
+
"""
|
457 |
+
Save a model and its configuration file to a directory so that it can be reloaded using the
|
458 |
+
[`~models.ModelMixin.from_pretrained`] class method.
|
459 |
+
|
460 |
+
Arguments:
|
461 |
+
save_directory (`str` or `os.PathLike`):
|
462 |
+
Directory to save a model and its configuration file to. Will be created if it doesn't exist.
|
463 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
464 |
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
465 |
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
466 |
+
process to avoid race conditions.
|
467 |
+
save_function (`Callable`):
|
468 |
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
469 |
+
replace `torch.save` with another method. Can be configured with the environment variable
|
470 |
+
`DIFFUSERS_SAVE_MODE`.
|
471 |
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
472 |
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
473 |
+
variant (`str`, *optional*):
|
474 |
+
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
475 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
476 |
+
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
477 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
478 |
+
namespace).
|
479 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
480 |
+
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
481 |
+
"""
|
482 |
+
if os.path.isfile(save_directory):
|
483 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
484 |
+
return
|
485 |
+
|
486 |
+
os.makedirs(save_directory, exist_ok=True)
|
487 |
+
|
488 |
+
if push_to_hub:
|
489 |
+
commit_message = kwargs.pop("commit_message", None)
|
490 |
+
private = kwargs.pop("private", False)
|
491 |
+
create_pr = kwargs.pop("create_pr", False)
|
492 |
+
token = kwargs.pop("token", None)
|
493 |
+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
494 |
+
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
495 |
+
|
496 |
+
# Only save the model itself if we are using distributed training
|
497 |
+
model_to_save = self
|
498 |
+
|
499 |
+
# Attach architecture to the config
|
500 |
+
# Save the config
|
501 |
+
if is_main_process:
|
502 |
+
model_to_save.save_config(save_directory)
|
503 |
+
|
504 |
+
# Save the model
|
505 |
+
state_dict = model_to_save.state_dict()
|
506 |
+
|
507 |
+
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
508 |
+
weights_name = _add_variant(weights_name, variant)
|
509 |
+
|
510 |
+
# Save the model
|
511 |
+
if safe_serialization:
|
512 |
+
safetensors.torch.save_file(
|
513 |
+
state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
|
514 |
+
)
|
515 |
+
else:
|
516 |
+
torch.save(state_dict, os.path.join(save_directory, weights_name))
|
517 |
+
|
518 |
+
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
519 |
+
|
520 |
+
if push_to_hub:
|
521 |
+
self._upload_folder(
|
522 |
+
save_directory,
|
523 |
+
repo_id,
|
524 |
+
token=token,
|
525 |
+
commit_message=commit_message,
|
526 |
+
create_pr=create_pr,
|
527 |
+
)
|
528 |
+
|
529 |
+
@classmethod
|
530 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
531 |
+
r"""
|
532 |
+
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
533 |
+
|
534 |
+
The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
|
535 |
+
train the model, set it back in training mode with `model.train()`.
|
536 |
+
|
537 |
+
Parameters:
|
538 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
539 |
+
Can be either:
|
540 |
+
|
541 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
542 |
+
the Hub.
|
543 |
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
544 |
+
with [`~ModelMixin.save_pretrained`].
|
545 |
+
|
546 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
547 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
548 |
+
is not used.
|
549 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
550 |
+
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
551 |
+
dtype is automatically derived from the model's weights.
|
552 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
553 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
554 |
+
cached versions if they exist.
|
555 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
556 |
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
557 |
+
incompletely downloaded files are deleted.
|
558 |
+
proxies (`Dict[str, str]`, *optional*):
|
559 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
560 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
561 |
+
output_loading_info (`bool`, *optional*, defaults to `False`):
|
562 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
563 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
564 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
565 |
+
won't be downloaded from the Hub.
|
566 |
+
use_auth_token (`str` or *bool*, *optional*):
|
567 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
568 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
569 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
570 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
571 |
+
allowed by Git.
|
572 |
+
from_flax (`bool`, *optional*, defaults to `False`):
|
573 |
+
Load the model weights from a Flax checkpoint save file.
|
574 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
575 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
576 |
+
mirror (`str`, *optional*):
|
577 |
+
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
578 |
+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
579 |
+
information.
|
580 |
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
581 |
+
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
582 |
+
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
583 |
+
same device.
|
584 |
+
|
585 |
+
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
586 |
+
more information about each option see [designing a device
|
587 |
+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
588 |
+
max_memory (`Dict`, *optional*):
|
589 |
+
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
590 |
+
each GPU and the available CPU RAM if unset.
|
591 |
+
offload_folder (`str` or `os.PathLike`, *optional*):
|
592 |
+
The path to offload weights if `device_map` contains the value `"disk"`.
|
593 |
+
offload_state_dict (`bool`, *optional*):
|
594 |
+
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
|
595 |
+
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
|
596 |
+
when there is some disk offload.
|
597 |
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
598 |
+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
599 |
+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
600 |
+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
601 |
+
argument to `True` will raise an error.
|
602 |
+
variant (`str`, *optional*):
|
603 |
+
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
|
604 |
+
loading `from_flax`.
|
605 |
+
use_safetensors (`bool`, *optional*, defaults to `None`):
|
606 |
+
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
607 |
+
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
608 |
+
weights. If set to `False`, `safetensors` weights are not loaded.
|
609 |
+
|
610 |
+
<Tip>
|
611 |
+
|
612 |
+
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
|
613 |
+
`huggingface-cli login`. You can also activate the special
|
614 |
+
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
|
615 |
+
firewalled environment.
|
616 |
+
|
617 |
+
</Tip>
|
618 |
+
|
619 |
+
Example:
|
620 |
+
|
621 |
+
```py
|
622 |
+
from diffusers import UNet2DConditionModel
|
623 |
+
|
624 |
+
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
625 |
+
```
|
626 |
+
|
627 |
+
If you get the error message below, you need to finetune the weights for your downstream task:
|
628 |
+
|
629 |
+
```bash
|
630 |
+
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
631 |
+
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
632 |
+
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
633 |
+
```
|
634 |
+
"""
|
635 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
636 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
637 |
+
force_download = kwargs.pop("force_download", False)
|
638 |
+
from_flax = kwargs.pop("from_flax", False)
|
639 |
+
resume_download = kwargs.pop("resume_download", False)
|
640 |
+
proxies = kwargs.pop("proxies", None)
|
641 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
642 |
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
643 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
644 |
+
revision = kwargs.pop("revision", None)
|
645 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
646 |
+
subfolder = kwargs.pop("subfolder", None)
|
647 |
+
device_map = kwargs.pop("device_map", None)
|
648 |
+
max_memory = kwargs.pop("max_memory", None)
|
649 |
+
offload_folder = kwargs.pop("offload_folder", None)
|
650 |
+
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
651 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
652 |
+
variant = kwargs.pop("variant", None)
|
653 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
654 |
+
|
655 |
+
allow_pickle = False
|
656 |
+
if use_safetensors is None:
|
657 |
+
use_safetensors = True
|
658 |
+
allow_pickle = True
|
659 |
+
|
660 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
661 |
+
low_cpu_mem_usage = False
|
662 |
+
logger.warning(
|
663 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
664 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
665 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
666 |
+
" install accelerate\n```\n."
|
667 |
+
)
|
668 |
+
|
669 |
+
if device_map is not None and not is_accelerate_available():
|
670 |
+
raise NotImplementedError(
|
671 |
+
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
672 |
+
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
673 |
+
)
|
674 |
+
|
675 |
+
# Check if we can handle device_map and dispatching the weights
|
676 |
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
677 |
+
raise NotImplementedError(
|
678 |
+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
679 |
+
" `device_map=None`."
|
680 |
+
)
|
681 |
+
|
682 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
683 |
+
raise NotImplementedError(
|
684 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
685 |
+
" `low_cpu_mem_usage=False`."
|
686 |
+
)
|
687 |
+
|
688 |
+
if low_cpu_mem_usage is False and device_map is not None:
|
689 |
+
raise ValueError(
|
690 |
+
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
691 |
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
692 |
+
)
|
693 |
+
|
694 |
+
# Load config if we don't provide a configuration
|
695 |
+
config_path = pretrained_model_name_or_path
|
696 |
+
|
697 |
+
user_agent = {
|
698 |
+
"diffusers": __version__,
|
699 |
+
"file_type": "model",
|
700 |
+
"framework": "pytorch",
|
701 |
+
}
|
702 |
+
|
703 |
+
# load config
|
704 |
+
config, unused_kwargs, commit_hash = cls.load_config(
|
705 |
+
config_path,
|
706 |
+
cache_dir=cache_dir,
|
707 |
+
return_unused_kwargs=True,
|
708 |
+
return_commit_hash=True,
|
709 |
+
force_download=force_download,
|
710 |
+
resume_download=resume_download,
|
711 |
+
proxies=proxies,
|
712 |
+
local_files_only=local_files_only,
|
713 |
+
use_auth_token=use_auth_token,
|
714 |
+
revision=revision,
|
715 |
+
subfolder=subfolder,
|
716 |
+
device_map=device_map,
|
717 |
+
max_memory=max_memory,
|
718 |
+
offload_folder=offload_folder,
|
719 |
+
offload_state_dict=offload_state_dict,
|
720 |
+
user_agent=user_agent,
|
721 |
+
**kwargs,
|
722 |
+
)
|
723 |
+
|
724 |
+
# load model
|
725 |
+
model_file = None
|
726 |
+
if from_flax:
|
727 |
+
model_file = _get_model_file(
|
728 |
+
pretrained_model_name_or_path,
|
729 |
+
weights_name=FLAX_WEIGHTS_NAME,
|
730 |
+
cache_dir=cache_dir,
|
731 |
+
force_download=force_download,
|
732 |
+
resume_download=resume_download,
|
733 |
+
proxies=proxies,
|
734 |
+
local_files_only=local_files_only,
|
735 |
+
use_auth_token=use_auth_token,
|
736 |
+
revision=revision,
|
737 |
+
subfolder=subfolder,
|
738 |
+
user_agent=user_agent,
|
739 |
+
commit_hash=commit_hash,
|
740 |
+
)
|
741 |
+
model = cls.from_config(config, **unused_kwargs)
|
742 |
+
|
743 |
+
# Convert the weights
|
744 |
+
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
745 |
+
|
746 |
+
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
747 |
+
else:
|
748 |
+
if use_safetensors:
|
749 |
+
try:
|
750 |
+
model_file = _get_model_file(
|
751 |
+
pretrained_model_name_or_path,
|
752 |
+
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
753 |
+
cache_dir=cache_dir,
|
754 |
+
force_download=force_download,
|
755 |
+
resume_download=resume_download,
|
756 |
+
proxies=proxies,
|
757 |
+
local_files_only=local_files_only,
|
758 |
+
use_auth_token=use_auth_token,
|
759 |
+
revision=revision,
|
760 |
+
subfolder=subfolder,
|
761 |
+
user_agent=user_agent,
|
762 |
+
commit_hash=commit_hash,
|
763 |
+
)
|
764 |
+
except IOError as e:
|
765 |
+
if not allow_pickle:
|
766 |
+
raise e
|
767 |
+
pass
|
768 |
+
if model_file is None:
|
769 |
+
model_file = _get_model_file(
|
770 |
+
pretrained_model_name_or_path,
|
771 |
+
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
772 |
+
cache_dir=cache_dir,
|
773 |
+
force_download=force_download,
|
774 |
+
resume_download=resume_download,
|
775 |
+
proxies=proxies,
|
776 |
+
local_files_only=local_files_only,
|
777 |
+
use_auth_token=use_auth_token,
|
778 |
+
revision=revision,
|
779 |
+
subfolder=subfolder,
|
780 |
+
user_agent=user_agent,
|
781 |
+
commit_hash=commit_hash,
|
782 |
+
)
|
783 |
+
|
784 |
+
if low_cpu_mem_usage:
|
785 |
+
# Instantiate model with empty weights
|
786 |
+
with accelerate.init_empty_weights():
|
787 |
+
model = cls.from_config(config, **unused_kwargs)
|
788 |
+
|
789 |
+
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
790 |
+
if device_map is None:
|
791 |
+
param_device = "cpu"
|
792 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
793 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
794 |
+
# move the params from meta device to cpu
|
795 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
796 |
+
if len(missing_keys) > 0:
|
797 |
+
raise ValueError(
|
798 |
+
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
799 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
800 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
801 |
+
" those weights or else make sure your checkpoint file is correct."
|
802 |
+
)
|
803 |
+
|
804 |
+
unexpected_keys = load_model_dict_into_meta(
|
805 |
+
model,
|
806 |
+
state_dict,
|
807 |
+
device=param_device,
|
808 |
+
dtype=torch_dtype,
|
809 |
+
model_name_or_path=pretrained_model_name_or_path,
|
810 |
+
)
|
811 |
+
|
812 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
813 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
814 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
815 |
+
|
816 |
+
if len(unexpected_keys) > 0:
|
817 |
+
logger.warn(
|
818 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
819 |
+
)
|
820 |
+
|
821 |
+
else: # else let accelerate handle loading and dispatching.
|
822 |
+
# Load weights and dispatch according to the device_map
|
823 |
+
# by default the device_map is None and the weights are loaded on the CPU
|
824 |
+
try:
|
825 |
+
accelerate.load_checkpoint_and_dispatch(
|
826 |
+
model,
|
827 |
+
model_file,
|
828 |
+
device_map,
|
829 |
+
max_memory=max_memory,
|
830 |
+
offload_folder=offload_folder,
|
831 |
+
offload_state_dict=offload_state_dict,
|
832 |
+
dtype=torch_dtype,
|
833 |
+
)
|
834 |
+
except AttributeError as e:
|
835 |
+
# When using accelerate loading, we do not have the ability to load the state
|
836 |
+
# dict and rename the weight names manually. Additionally, accelerate skips
|
837 |
+
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
838 |
+
# (which look like they should be private variables?), so we can't use the standard hooks
|
839 |
+
# to rename parameters on load. We need to mimic the original weight names so the correct
|
840 |
+
# attributes are available. After we have loaded the weights, we convert the deprecated
|
841 |
+
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
842 |
+
# the weights so we don't have to do this again.
|
843 |
+
|
844 |
+
if "'Attention' object has no attribute" in str(e):
|
845 |
+
logger.warn(
|
846 |
+
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
847 |
+
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
848 |
+
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
849 |
+
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
850 |
+
" please also re-upload it or open a PR on the original repository."
|
851 |
+
)
|
852 |
+
model._temp_convert_self_to_deprecated_attention_blocks()
|
853 |
+
accelerate.load_checkpoint_and_dispatch(
|
854 |
+
model,
|
855 |
+
model_file,
|
856 |
+
device_map,
|
857 |
+
max_memory=max_memory,
|
858 |
+
offload_folder=offload_folder,
|
859 |
+
offload_state_dict=offload_state_dict,
|
860 |
+
dtype=torch_dtype,
|
861 |
+
)
|
862 |
+
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
863 |
+
else:
|
864 |
+
raise e
|
865 |
+
|
866 |
+
loading_info = {
|
867 |
+
"missing_keys": [],
|
868 |
+
"unexpected_keys": [],
|
869 |
+
"mismatched_keys": [],
|
870 |
+
"error_msgs": [],
|
871 |
+
}
|
872 |
+
else:
|
873 |
+
model = cls.from_config(config, **unused_kwargs)
|
874 |
+
|
875 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
876 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
877 |
+
|
878 |
+
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
879 |
+
model,
|
880 |
+
state_dict,
|
881 |
+
model_file,
|
882 |
+
pretrained_model_name_or_path,
|
883 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
884 |
+
)
|
885 |
+
|
886 |
+
loading_info = {
|
887 |
+
"missing_keys": missing_keys,
|
888 |
+
"unexpected_keys": unexpected_keys,
|
889 |
+
"mismatched_keys": mismatched_keys,
|
890 |
+
"error_msgs": error_msgs,
|
891 |
+
}
|
892 |
+
|
893 |
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
894 |
+
raise ValueError(
|
895 |
+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
896 |
+
)
|
897 |
+
elif torch_dtype is not None:
|
898 |
+
model = model.to(torch_dtype)
|
899 |
+
|
900 |
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
901 |
+
|
902 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
903 |
+
model.eval()
|
904 |
+
if output_loading_info:
|
905 |
+
return model, loading_info
|
906 |
+
|
907 |
+
return model
|
908 |
+
|
909 |
+
@classmethod
|
910 |
+
def _load_pretrained_model(
|
911 |
+
cls,
|
912 |
+
model,
|
913 |
+
state_dict,
|
914 |
+
resolved_archive_file,
|
915 |
+
pretrained_model_name_or_path,
|
916 |
+
ignore_mismatched_sizes=False,
|
917 |
+
):
|
918 |
+
# Retrieve missing & unexpected_keys
|
919 |
+
model_state_dict = model.state_dict()
|
920 |
+
loaded_keys = list(state_dict.keys())
|
921 |
+
|
922 |
+
expected_keys = list(model_state_dict.keys())
|
923 |
+
|
924 |
+
original_loaded_keys = loaded_keys
|
925 |
+
|
926 |
+
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
927 |
+
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
928 |
+
|
929 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
930 |
+
model_to_load = model
|
931 |
+
|
932 |
+
def _find_mismatched_keys(
|
933 |
+
state_dict,
|
934 |
+
model_state_dict,
|
935 |
+
loaded_keys,
|
936 |
+
ignore_mismatched_sizes,
|
937 |
+
):
|
938 |
+
mismatched_keys = []
|
939 |
+
if ignore_mismatched_sizes:
|
940 |
+
for checkpoint_key in loaded_keys:
|
941 |
+
model_key = checkpoint_key
|
942 |
+
|
943 |
+
if (
|
944 |
+
model_key in model_state_dict
|
945 |
+
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
946 |
+
):
|
947 |
+
mismatched_keys.append(
|
948 |
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
949 |
+
)
|
950 |
+
del state_dict[checkpoint_key]
|
951 |
+
return mismatched_keys
|
952 |
+
|
953 |
+
if state_dict is not None:
|
954 |
+
# Whole checkpoint
|
955 |
+
mismatched_keys = _find_mismatched_keys(
|
956 |
+
state_dict,
|
957 |
+
model_state_dict,
|
958 |
+
original_loaded_keys,
|
959 |
+
ignore_mismatched_sizes,
|
960 |
+
)
|
961 |
+
error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
|
962 |
+
|
963 |
+
if len(error_msgs) > 0:
|
964 |
+
error_msg = "\n\t".join(error_msgs)
|
965 |
+
if "size mismatch" in error_msg:
|
966 |
+
error_msg += (
|
967 |
+
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
968 |
+
)
|
969 |
+
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
970 |
+
|
971 |
+
if len(unexpected_keys) > 0:
|
972 |
+
logger.warning(
|
973 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
974 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
975 |
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
976 |
+
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
977 |
+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
978 |
+
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
979 |
+
" identical (initializing a BertForSequenceClassification model from a"
|
980 |
+
" BertForSequenceClassification model)."
|
981 |
+
)
|
982 |
+
else:
|
983 |
+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
984 |
+
if len(missing_keys) > 0:
|
985 |
+
logger.warning(
|
986 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
987 |
+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
988 |
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
989 |
+
)
|
990 |
+
elif len(mismatched_keys) == 0:
|
991 |
+
logger.info(
|
992 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
993 |
+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
|
994 |
+
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
|
995 |
+
" without further training."
|
996 |
+
)
|
997 |
+
if len(mismatched_keys) > 0:
|
998 |
+
mismatched_warning = "\n".join(
|
999 |
+
[
|
1000 |
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
1001 |
+
for key, shape1, shape2 in mismatched_keys
|
1002 |
+
]
|
1003 |
+
)
|
1004 |
+
logger.warning(
|
1005 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
1006 |
+
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
1007 |
+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
|
1008 |
+
" able to use it for predictions and inference."
|
1009 |
+
)
|
1010 |
+
|
1011 |
+
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
1012 |
+
|
1013 |
+
@property
|
1014 |
+
def device(self) -> device:
|
1015 |
+
"""
|
1016 |
+
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
1017 |
+
device).
|
1018 |
+
"""
|
1019 |
+
return get_parameter_device(self)
|
1020 |
+
|
1021 |
+
@property
|
1022 |
+
def dtype(self) -> torch.dtype:
|
1023 |
+
"""
|
1024 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
1025 |
+
"""
|
1026 |
+
return get_parameter_dtype(self)
|
1027 |
+
|
1028 |
+
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
1029 |
+
"""
|
1030 |
+
Get number of (trainable or non-embedding) parameters in the module.
|
1031 |
+
|
1032 |
+
Args:
|
1033 |
+
only_trainable (`bool`, *optional*, defaults to `False`):
|
1034 |
+
Whether or not to return only the number of trainable parameters.
|
1035 |
+
exclude_embeddings (`bool`, *optional*, defaults to `False`):
|
1036 |
+
Whether or not to return only the number of non-embedding parameters.
|
1037 |
+
|
1038 |
+
Returns:
|
1039 |
+
`int`: The number of parameters.
|
1040 |
+
|
1041 |
+
Example:
|
1042 |
+
|
1043 |
+
```py
|
1044 |
+
from diffusers import UNet2DConditionModel
|
1045 |
+
|
1046 |
+
model_id = "runwayml/stable-diffusion-v1-5"
|
1047 |
+
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
|
1048 |
+
unet.num_parameters(only_trainable=True)
|
1049 |
+
859520964
|
1050 |
+
```
|
1051 |
+
"""
|
1052 |
+
|
1053 |
+
if exclude_embeddings:
|
1054 |
+
embedding_param_names = [
|
1055 |
+
f"{name}.weight"
|
1056 |
+
for name, module_type in self.named_modules()
|
1057 |
+
if isinstance(module_type, torch.nn.Embedding)
|
1058 |
+
]
|
1059 |
+
non_embedding_parameters = [
|
1060 |
+
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
1061 |
+
]
|
1062 |
+
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
1063 |
+
else:
|
1064 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
1065 |
+
|
1066 |
+
def _convert_deprecated_attention_blocks(self, state_dict):
|
1067 |
+
deprecated_attention_block_paths = []
|
1068 |
+
|
1069 |
+
def recursive_find_attn_block(name, module):
|
1070 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
1071 |
+
deprecated_attention_block_paths.append(name)
|
1072 |
+
|
1073 |
+
for sub_name, sub_module in module.named_children():
|
1074 |
+
sub_name = sub_name if name == "" else f"{name}.{sub_name}"
|
1075 |
+
recursive_find_attn_block(sub_name, sub_module)
|
1076 |
+
|
1077 |
+
recursive_find_attn_block("", self)
|
1078 |
+
|
1079 |
+
# NOTE: we have to check if the deprecated parameters are in the state dict
|
1080 |
+
# because it is possible we are loading from a state dict that was already
|
1081 |
+
# converted
|
1082 |
+
|
1083 |
+
for path in deprecated_attention_block_paths:
|
1084 |
+
# group_norm path stays the same
|
1085 |
+
|
1086 |
+
# query -> to_q
|
1087 |
+
if f"{path}.query.weight" in state_dict:
|
1088 |
+
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
|
1089 |
+
if f"{path}.query.bias" in state_dict:
|
1090 |
+
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
|
1091 |
+
|
1092 |
+
# key -> to_k
|
1093 |
+
if f"{path}.key.weight" in state_dict:
|
1094 |
+
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
|
1095 |
+
if f"{path}.key.bias" in state_dict:
|
1096 |
+
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
|
1097 |
+
|
1098 |
+
# value -> to_v
|
1099 |
+
if f"{path}.value.weight" in state_dict:
|
1100 |
+
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
|
1101 |
+
if f"{path}.value.bias" in state_dict:
|
1102 |
+
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
|
1103 |
+
|
1104 |
+
# proj_attn -> to_out.0
|
1105 |
+
if f"{path}.proj_attn.weight" in state_dict:
|
1106 |
+
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
|
1107 |
+
if f"{path}.proj_attn.bias" in state_dict:
|
1108 |
+
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
1109 |
+
|
1110 |
+
def _temp_convert_self_to_deprecated_attention_blocks(self):
|
1111 |
+
deprecated_attention_block_modules = []
|
1112 |
+
|
1113 |
+
def recursive_find_attn_block(module):
|
1114 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
1115 |
+
deprecated_attention_block_modules.append(module)
|
1116 |
+
|
1117 |
+
for sub_module in module.children():
|
1118 |
+
recursive_find_attn_block(sub_module)
|
1119 |
+
|
1120 |
+
recursive_find_attn_block(self)
|
1121 |
+
|
1122 |
+
for module in deprecated_attention_block_modules:
|
1123 |
+
module.query = module.to_q
|
1124 |
+
module.key = module.to_k
|
1125 |
+
module.value = module.to_v
|
1126 |
+
module.proj_attn = module.to_out[0]
|
1127 |
+
|
1128 |
+
# We don't _have_ to delete the old attributes, but it's helpful to ensure
|
1129 |
+
# that _all_ the weights are loaded into the new attributes and we're not
|
1130 |
+
# making an incorrect assumption that this model should be converted when
|
1131 |
+
# it really shouldn't be.
|
1132 |
+
del module.to_q
|
1133 |
+
del module.to_k
|
1134 |
+
del module.to_v
|
1135 |
+
del module.to_out
|
1136 |
+
|
1137 |
+
def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
|
1138 |
+
deprecated_attention_block_modules = []
|
1139 |
+
|
1140 |
+
def recursive_find_attn_block(module):
|
1141 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
1142 |
+
deprecated_attention_block_modules.append(module)
|
1143 |
+
|
1144 |
+
for sub_module in module.children():
|
1145 |
+
recursive_find_attn_block(sub_module)
|
1146 |
+
|
1147 |
+
recursive_find_attn_block(self)
|
1148 |
+
|
1149 |
+
for module in deprecated_attention_block_modules:
|
1150 |
+
module.to_q = module.query
|
1151 |
+
module.to_k = module.key
|
1152 |
+
module.to_v = module.value
|
1153 |
+
module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
|
1154 |
+
|
1155 |
+
del module.query
|
1156 |
+
del module.key
|
1157 |
+
del module.value
|
1158 |
+
del module.proj_attn
|
diffusers/models/normalization.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 HuggingFace Inc.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Dict, Optional, Tuple
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
|
22 |
+
from .activations import get_activation
|
23 |
+
from .embeddings import CombinedTimestepLabelEmbeddings, CombinedTimestepSizeEmbeddings
|
24 |
+
|
25 |
+
|
26 |
+
class AdaLayerNorm(nn.Module):
|
27 |
+
r"""
|
28 |
+
Norm layer modified to incorporate timestep embeddings.
|
29 |
+
|
30 |
+
Parameters:
|
31 |
+
embedding_dim (`int`): The size of each embedding vector.
|
32 |
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, embedding_dim: int, num_embeddings: int):
|
36 |
+
super().__init__()
|
37 |
+
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
38 |
+
self.silu = nn.SiLU()
|
39 |
+
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
40 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
41 |
+
|
42 |
+
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
|
43 |
+
emb = self.linear(self.silu(self.emb(timestep)))
|
44 |
+
scale, shift = torch.chunk(emb, 2)
|
45 |
+
x = self.norm(x) * (1 + scale) + shift
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
class AdaLayerNormZero(nn.Module):
|
50 |
+
r"""
|
51 |
+
Norm layer adaptive layer norm zero (adaLN-Zero).
|
52 |
+
|
53 |
+
Parameters:
|
54 |
+
embedding_dim (`int`): The size of each embedding vector.
|
55 |
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, embedding_dim: int, num_embeddings: int):
|
59 |
+
super().__init__()
|
60 |
+
|
61 |
+
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
62 |
+
|
63 |
+
self.silu = nn.SiLU()
|
64 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
65 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
66 |
+
|
67 |
+
def forward(
|
68 |
+
self,
|
69 |
+
x: torch.Tensor,
|
70 |
+
timestep: torch.Tensor,
|
71 |
+
class_labels: torch.LongTensor,
|
72 |
+
hidden_dtype: Optional[torch.dtype] = None,
|
73 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
74 |
+
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
|
75 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
76 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
77 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
78 |
+
|
79 |
+
|
80 |
+
class AdaLayerNormSingle(nn.Module):
|
81 |
+
r"""
|
82 |
+
Norm layer adaptive layer norm single (adaLN-single).
|
83 |
+
|
84 |
+
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
85 |
+
|
86 |
+
Parameters:
|
87 |
+
embedding_dim (`int`): The size of each embedding vector.
|
88 |
+
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
89 |
+
"""
|
90 |
+
|
91 |
+
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
|
92 |
+
super().__init__()
|
93 |
+
|
94 |
+
self.emb = CombinedTimestepSizeEmbeddings(
|
95 |
+
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
|
96 |
+
)
|
97 |
+
|
98 |
+
self.silu = nn.SiLU()
|
99 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
100 |
+
|
101 |
+
def forward(
|
102 |
+
self,
|
103 |
+
timestep: torch.Tensor,
|
104 |
+
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
105 |
+
batch_size: int = None,
|
106 |
+
hidden_dtype: Optional[torch.dtype] = None,
|
107 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
108 |
+
# No modulation happening here.
|
109 |
+
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
110 |
+
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
111 |
+
|
112 |
+
|
113 |
+
class AdaGroupNorm(nn.Module):
|
114 |
+
r"""
|
115 |
+
GroupNorm layer modified to incorporate timestep embeddings.
|
116 |
+
|
117 |
+
Parameters:
|
118 |
+
embedding_dim (`int`): The size of each embedding vector.
|
119 |
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
120 |
+
num_groups (`int`): The number of groups to separate the channels into.
|
121 |
+
act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
|
122 |
+
eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(
|
126 |
+
self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
|
127 |
+
):
|
128 |
+
super().__init__()
|
129 |
+
self.num_groups = num_groups
|
130 |
+
self.eps = eps
|
131 |
+
|
132 |
+
if act_fn is None:
|
133 |
+
self.act = None
|
134 |
+
else:
|
135 |
+
self.act = get_activation(act_fn)
|
136 |
+
|
137 |
+
self.linear = nn.Linear(embedding_dim, out_dim * 2)
|
138 |
+
|
139 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
140 |
+
if self.act:
|
141 |
+
emb = self.act(emb)
|
142 |
+
emb = self.linear(emb)
|
143 |
+
emb = emb[:, :, None, None]
|
144 |
+
scale, shift = emb.chunk(2, dim=1)
|
145 |
+
|
146 |
+
x = F.group_norm(x, self.num_groups, eps=self.eps)
|
147 |
+
x = x * (1 + scale) + shift
|
148 |
+
return x
|
diffusers/models/prior_transformer.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict, Optional, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
9 |
+
from ..loaders import UNet2DConditionLoadersMixin
|
10 |
+
from ..utils import BaseOutput
|
11 |
+
from .attention import BasicTransformerBlock
|
12 |
+
from .attention_processor import (
|
13 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
14 |
+
CROSS_ATTENTION_PROCESSORS,
|
15 |
+
AttentionProcessor,
|
16 |
+
AttnAddedKVProcessor,
|
17 |
+
AttnProcessor,
|
18 |
+
)
|
19 |
+
from .embeddings import TimestepEmbedding, Timesteps
|
20 |
+
from .modeling_utils import ModelMixin
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class PriorTransformerOutput(BaseOutput):
|
25 |
+
"""
|
26 |
+
The output of [`PriorTransformer`].
|
27 |
+
|
28 |
+
Args:
|
29 |
+
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
30 |
+
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
|
31 |
+
"""
|
32 |
+
|
33 |
+
predicted_image_embedding: torch.FloatTensor
|
34 |
+
|
35 |
+
|
36 |
+
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
37 |
+
"""
|
38 |
+
A Prior Transformer model.
|
39 |
+
|
40 |
+
Parameters:
|
41 |
+
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
|
42 |
+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
43 |
+
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
|
44 |
+
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
|
45 |
+
num_embeddings (`int`, *optional*, defaults to 77):
|
46 |
+
The number of embeddings of the model input `hidden_states`
|
47 |
+
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
48 |
+
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
|
49 |
+
additional_embeddings`.
|
50 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
51 |
+
time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
|
52 |
+
The activation function to use to create timestep embeddings.
|
53 |
+
norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
|
54 |
+
passing to Transformer blocks. Set it to `None` if normalization is not needed.
|
55 |
+
embedding_proj_norm_type (`str`, *optional*, defaults to None):
|
56 |
+
The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
|
57 |
+
needed.
|
58 |
+
encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
|
59 |
+
The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
|
60 |
+
`encoder_hidden_states` is `None`.
|
61 |
+
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
|
62 |
+
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
|
63 |
+
product between the text embedding and image embedding as proposed in the unclip paper
|
64 |
+
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
|
65 |
+
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
|
66 |
+
If None, will be set to `num_attention_heads * attention_head_dim`
|
67 |
+
embedding_proj_dim (`int`, *optional*, default to None):
|
68 |
+
The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
|
69 |
+
clip_embed_dim (`int`, *optional*, default to None):
|
70 |
+
The dimension of the output. If None, will be set to `embedding_dim`.
|
71 |
+
"""
|
72 |
+
|
73 |
+
@register_to_config
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
num_attention_heads: int = 32,
|
77 |
+
attention_head_dim: int = 64,
|
78 |
+
num_layers: int = 20,
|
79 |
+
embedding_dim: int = 768,
|
80 |
+
num_embeddings=77,
|
81 |
+
additional_embeddings=4,
|
82 |
+
dropout: float = 0.0,
|
83 |
+
time_embed_act_fn: str = "silu",
|
84 |
+
norm_in_type: Optional[str] = None, # layer
|
85 |
+
embedding_proj_norm_type: Optional[str] = None, # layer
|
86 |
+
encoder_hid_proj_type: Optional[str] = "linear", # linear
|
87 |
+
added_emb_type: Optional[str] = "prd", # prd
|
88 |
+
time_embed_dim: Optional[int] = None,
|
89 |
+
embedding_proj_dim: Optional[int] = None,
|
90 |
+
clip_embed_dim: Optional[int] = None,
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
self.num_attention_heads = num_attention_heads
|
94 |
+
self.attention_head_dim = attention_head_dim
|
95 |
+
inner_dim = num_attention_heads * attention_head_dim
|
96 |
+
self.additional_embeddings = additional_embeddings
|
97 |
+
|
98 |
+
time_embed_dim = time_embed_dim or inner_dim
|
99 |
+
embedding_proj_dim = embedding_proj_dim or embedding_dim
|
100 |
+
clip_embed_dim = clip_embed_dim or embedding_dim
|
101 |
+
|
102 |
+
self.time_proj = Timesteps(inner_dim, True, 0)
|
103 |
+
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
|
104 |
+
|
105 |
+
self.proj_in = nn.Linear(embedding_dim, inner_dim)
|
106 |
+
|
107 |
+
if embedding_proj_norm_type is None:
|
108 |
+
self.embedding_proj_norm = None
|
109 |
+
elif embedding_proj_norm_type == "layer":
|
110 |
+
self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
|
111 |
+
else:
|
112 |
+
raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
|
113 |
+
|
114 |
+
self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
|
115 |
+
|
116 |
+
if encoder_hid_proj_type is None:
|
117 |
+
self.encoder_hidden_states_proj = None
|
118 |
+
elif encoder_hid_proj_type == "linear":
|
119 |
+
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
120 |
+
else:
|
121 |
+
raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
|
122 |
+
|
123 |
+
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
|
124 |
+
|
125 |
+
if added_emb_type == "prd":
|
126 |
+
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
127 |
+
elif added_emb_type is None:
|
128 |
+
self.prd_embedding = None
|
129 |
+
else:
|
130 |
+
raise ValueError(
|
131 |
+
f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
|
132 |
+
)
|
133 |
+
|
134 |
+
self.transformer_blocks = nn.ModuleList(
|
135 |
+
[
|
136 |
+
BasicTransformerBlock(
|
137 |
+
inner_dim,
|
138 |
+
num_attention_heads,
|
139 |
+
attention_head_dim,
|
140 |
+
dropout=dropout,
|
141 |
+
activation_fn="gelu",
|
142 |
+
attention_bias=True,
|
143 |
+
)
|
144 |
+
for d in range(num_layers)
|
145 |
+
]
|
146 |
+
)
|
147 |
+
|
148 |
+
if norm_in_type == "layer":
|
149 |
+
self.norm_in = nn.LayerNorm(inner_dim)
|
150 |
+
elif norm_in_type is None:
|
151 |
+
self.norm_in = None
|
152 |
+
else:
|
153 |
+
raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
|
154 |
+
|
155 |
+
self.norm_out = nn.LayerNorm(inner_dim)
|
156 |
+
|
157 |
+
self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
|
158 |
+
|
159 |
+
causal_attention_mask = torch.full(
|
160 |
+
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
|
161 |
+
)
|
162 |
+
causal_attention_mask.triu_(1)
|
163 |
+
causal_attention_mask = causal_attention_mask[None, ...]
|
164 |
+
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
|
165 |
+
|
166 |
+
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
167 |
+
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
168 |
+
|
169 |
+
@property
|
170 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
171 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
172 |
+
r"""
|
173 |
+
Returns:
|
174 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
175 |
+
indexed by its weight name.
|
176 |
+
"""
|
177 |
+
# set recursively
|
178 |
+
processors = {}
|
179 |
+
|
180 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
181 |
+
if hasattr(module, "get_processor"):
|
182 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
183 |
+
|
184 |
+
for sub_name, child in module.named_children():
|
185 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
186 |
+
|
187 |
+
return processors
|
188 |
+
|
189 |
+
for name, module in self.named_children():
|
190 |
+
fn_recursive_add_processors(name, module, processors)
|
191 |
+
|
192 |
+
return processors
|
193 |
+
|
194 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
195 |
+
def set_attn_processor(
|
196 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
197 |
+
):
|
198 |
+
r"""
|
199 |
+
Sets the attention processor to use to compute attention.
|
200 |
+
|
201 |
+
Parameters:
|
202 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
203 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
204 |
+
for **all** `Attention` layers.
|
205 |
+
|
206 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
207 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
208 |
+
|
209 |
+
"""
|
210 |
+
count = len(self.attn_processors.keys())
|
211 |
+
|
212 |
+
if isinstance(processor, dict) and len(processor) != count:
|
213 |
+
raise ValueError(
|
214 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
215 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
216 |
+
)
|
217 |
+
|
218 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
219 |
+
if hasattr(module, "set_processor"):
|
220 |
+
if not isinstance(processor, dict):
|
221 |
+
module.set_processor(processor, _remove_lora=_remove_lora)
|
222 |
+
else:
|
223 |
+
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
224 |
+
|
225 |
+
for sub_name, child in module.named_children():
|
226 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
227 |
+
|
228 |
+
for name, module in self.named_children():
|
229 |
+
fn_recursive_attn_processor(name, module, processor)
|
230 |
+
|
231 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
232 |
+
def set_default_attn_processor(self):
|
233 |
+
"""
|
234 |
+
Disables custom attention processors and sets the default attention implementation.
|
235 |
+
"""
|
236 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
237 |
+
processor = AttnAddedKVProcessor()
|
238 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
239 |
+
processor = AttnProcessor()
|
240 |
+
else:
|
241 |
+
raise ValueError(
|
242 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
243 |
+
)
|
244 |
+
|
245 |
+
self.set_attn_processor(processor, _remove_lora=True)
|
246 |
+
|
247 |
+
def forward(
|
248 |
+
self,
|
249 |
+
hidden_states,
|
250 |
+
timestep: Union[torch.Tensor, float, int],
|
251 |
+
proj_embedding: torch.FloatTensor,
|
252 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
253 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
254 |
+
return_dict: bool = True,
|
255 |
+
):
|
256 |
+
"""
|
257 |
+
The [`PriorTransformer`] forward method.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
261 |
+
The currently predicted image embeddings.
|
262 |
+
timestep (`torch.LongTensor`):
|
263 |
+
Current denoising step.
|
264 |
+
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
265 |
+
Projected embedding vector the denoising process is conditioned on.
|
266 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
|
267 |
+
Hidden states of the text embeddings the denoising process is conditioned on.
|
268 |
+
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
|
269 |
+
Text mask for the text embeddings.
|
270 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
271 |
+
Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
|
272 |
+
tuple.
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
|
276 |
+
If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
|
277 |
+
tuple is returned where the first element is the sample tensor.
|
278 |
+
"""
|
279 |
+
batch_size = hidden_states.shape[0]
|
280 |
+
|
281 |
+
timesteps = timestep
|
282 |
+
if not torch.is_tensor(timesteps):
|
283 |
+
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
|
284 |
+
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
285 |
+
timesteps = timesteps[None].to(hidden_states.device)
|
286 |
+
|
287 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
288 |
+
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
|
289 |
+
|
290 |
+
timesteps_projected = self.time_proj(timesteps)
|
291 |
+
|
292 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
293 |
+
# but time_embedding might be fp16, so we need to cast here.
|
294 |
+
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
295 |
+
time_embeddings = self.time_embedding(timesteps_projected)
|
296 |
+
|
297 |
+
if self.embedding_proj_norm is not None:
|
298 |
+
proj_embedding = self.embedding_proj_norm(proj_embedding)
|
299 |
+
|
300 |
+
proj_embeddings = self.embedding_proj(proj_embedding)
|
301 |
+
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
|
302 |
+
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
303 |
+
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
|
304 |
+
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
|
305 |
+
|
306 |
+
hidden_states = self.proj_in(hidden_states)
|
307 |
+
|
308 |
+
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
|
309 |
+
|
310 |
+
additional_embeds = []
|
311 |
+
additional_embeddings_len = 0
|
312 |
+
|
313 |
+
if encoder_hidden_states is not None:
|
314 |
+
additional_embeds.append(encoder_hidden_states)
|
315 |
+
additional_embeddings_len += encoder_hidden_states.shape[1]
|
316 |
+
|
317 |
+
if len(proj_embeddings.shape) == 2:
|
318 |
+
proj_embeddings = proj_embeddings[:, None, :]
|
319 |
+
|
320 |
+
if len(hidden_states.shape) == 2:
|
321 |
+
hidden_states = hidden_states[:, None, :]
|
322 |
+
|
323 |
+
additional_embeds = additional_embeds + [
|
324 |
+
proj_embeddings,
|
325 |
+
time_embeddings[:, None, :],
|
326 |
+
hidden_states,
|
327 |
+
]
|
328 |
+
|
329 |
+
if self.prd_embedding is not None:
|
330 |
+
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
331 |
+
additional_embeds.append(prd_embedding)
|
332 |
+
|
333 |
+
hidden_states = torch.cat(
|
334 |
+
additional_embeds,
|
335 |
+
dim=1,
|
336 |
+
)
|
337 |
+
|
338 |
+
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
|
339 |
+
additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
|
340 |
+
if positional_embeddings.shape[1] < hidden_states.shape[1]:
|
341 |
+
positional_embeddings = F.pad(
|
342 |
+
positional_embeddings,
|
343 |
+
(
|
344 |
+
0,
|
345 |
+
0,
|
346 |
+
additional_embeddings_len,
|
347 |
+
self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
|
348 |
+
),
|
349 |
+
value=0.0,
|
350 |
+
)
|
351 |
+
|
352 |
+
hidden_states = hidden_states + positional_embeddings
|
353 |
+
|
354 |
+
if attention_mask is not None:
|
355 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
356 |
+
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
|
357 |
+
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
358 |
+
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
359 |
+
|
360 |
+
if self.norm_in is not None:
|
361 |
+
hidden_states = self.norm_in(hidden_states)
|
362 |
+
|
363 |
+
for block in self.transformer_blocks:
|
364 |
+
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
365 |
+
|
366 |
+
hidden_states = self.norm_out(hidden_states)
|
367 |
+
|
368 |
+
if self.prd_embedding is not None:
|
369 |
+
hidden_states = hidden_states[:, -1]
|
370 |
+
else:
|
371 |
+
hidden_states = hidden_states[:, additional_embeddings_len:]
|
372 |
+
|
373 |
+
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
|
374 |
+
|
375 |
+
if not return_dict:
|
376 |
+
return (predicted_image_embedding,)
|
377 |
+
|
378 |
+
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
|
379 |
+
|
380 |
+
def post_process_latents(self, prior_latents):
|
381 |
+
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
|
382 |
+
return prior_latents
|