hjs8 akhaliq HF staff commited on
Commit
63e1dac
0 Parent(s):

Duplicate from THUDM/CogVideo

Browse files

Co-authored-by: Ahsen Khaliq <[email protected]>

Files changed (17) hide show
  1. .gitattributes +27 -0
  2. .gitignore +133 -0
  3. .gitmodules +3 -0
  4. .pre-commit-config.yaml +46 -0
  5. .style.yapf +5 -0
  6. CogVideo +1 -0
  7. LICENSE +21 -0
  8. LICENSE.CogVideo +201 -0
  9. README.md +14 -0
  10. app.py +126 -0
  11. icetk_models/.gitkeep +0 -0
  12. model.py +1243 -0
  13. patch +51 -0
  14. pretrained/.gitkeep +0 -0
  15. requirements.txt +7 -0
  16. samples.txt +2 -0
  17. style.css +7 -0
.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.onnx filter=lfs diff=lfs merge=lfs -text
13
+ *.ot filter=lfs diff=lfs merge=lfs -text
14
+ *.parquet filter=lfs diff=lfs merge=lfs -text
15
+ *.pb filter=lfs diff=lfs merge=lfs -text
16
+ *.pt filter=lfs diff=lfs merge=lfs -text
17
+ *.pth filter=lfs diff=lfs merge=lfs -text
18
+ *.rar filter=lfs diff=lfs merge=lfs -text
19
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
21
+ *.tflite filter=lfs diff=lfs merge=lfs -text
22
+ *.tgz filter=lfs diff=lfs merge=lfs -text
23
+ *.wasm filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio_queue.db*
2
+ pretrained/*
3
+ icetk_models/*
4
+ !*/.gitkeep
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ pip-wheel-metadata/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
99
+ __pypackages__/
100
+
101
+ # Celery stuff
102
+ celerybeat-schedule
103
+ celerybeat.pid
104
+
105
+ # SageMath parsed files
106
+ *.sage.py
107
+
108
+ # Environments
109
+ .env
110
+ .venv
111
+ env/
112
+ venv/
113
+ ENV/
114
+ env.bak/
115
+ venv.bak/
116
+
117
+ # Spyder project settings
118
+ .spyderproject
119
+ .spyproject
120
+
121
+ # Rope project settings
122
+ .ropeproject
123
+
124
+ # mkdocs documentation
125
+ /site
126
+
127
+ # mypy
128
+ .mypy_cache/
129
+ .dmypy.json
130
+ dmypy.json
131
+
132
+ # Pyre type checker
133
+ .pyre/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "CogVideo"]
2
+ path = CogVideo
3
+ url = https://github.com/THUDM/CogVideo
.pre-commit-config.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: ^patch
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.10.1
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.812
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ - repo: https://github.com/google/yapf
33
+ rev: v0.32.0
34
+ hooks:
35
+ - id: yapf
36
+ args: ['--parallel', '--in-place']
37
+ - repo: https://github.com/kynan/nbstripout
38
+ rev: 0.5.0
39
+ hooks:
40
+ - id: nbstripout
41
+ args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
42
+ - repo: https://github.com/nbQA-dev/nbQA
43
+ rev: 1.3.1
44
+ hooks:
45
+ - id: nbqa-isort
46
+ - id: nbqa-yapf
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
CogVideo ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit ff423aa169978fb2f636f761e348631fa3178b03
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LICENSE.CogVideo ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CogVideo
3
+ emoji: 🌍
4
+ colorFrom: indigo
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 3.1.6
8
+ python_version: 3.9.13
9
+ app_file: app.py
10
+ pinned: false
11
+ duplicated_from: THUDM/CogVideo
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import gradio as gr
6
+
7
+ # from model import AppModel
8
+
9
+ MAINTENANCE_NOTICE='Sorry, due to computing resources issues, this space is under maintenance, and will be restored as soon as possible. '
10
+
11
+ DESCRIPTION = '''# <a href="https://github.com/THUDM/CogVideo">CogVideo</a>
12
+
13
+ Currently, this Space only supports the first stage of the CogVideo pipeline due to hardware limitations.
14
+
15
+ The model accepts only Chinese as input.
16
+ By checking the "Translate to Chinese" checkbox, the results of English to Chinese translation with [this Space](https://huggingface.co/spaces/chinhon/translation_eng2ch) will be used as input.
17
+ Since the translation model may mistranslate, you may want to use the translation results from other translation services.
18
+ '''
19
+ NOTES = 'This app is adapted from <a href="https://github.com/hysts/CogVideo_demo">https://github.com/hysts/CogVideo_demo</a>. It would be recommended to use the repo if you want to run the app yourself.'
20
+ FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=THUDM.CogVideo" />'
21
+
22
+ import json
23
+ import requests
24
+ import numpy as np
25
+ import imageio.v2 as iio
26
+
27
+ def post(
28
+ text,
29
+ translate,
30
+ seed,
31
+ only_first_stage,
32
+ image_prompt
33
+ ):
34
+ url = 'https://ccb8is4fqtofrtdsfjebg.ml-platform-cn-beijing.volces.com/devinstance/di-20221130120908-bhpxq/proxy/6201'
35
+ headers = {
36
+ "Content-Type": "application/json; charset=UTF-8",
37
+ "User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/67.0.3396.87 Safari/537.36",
38
+ }
39
+
40
+ data = json.dumps({'text': text,
41
+ 'translate': translate,
42
+ 'seed': seed,
43
+ 'only_first_stage': only_first_stage,
44
+ 'image_prompt': image_prompt
45
+ })
46
+ r = requests.post(url, data, headers=headers)
47
+
48
+ translated_text = r.json()['data']['translated_text']
49
+ result_video = r.json()['data']['result_video']
50
+ frames = r.json()['data']['frames']
51
+ for i in range(4):
52
+ writer = iio.get_writer(result_video[i], fps=4)
53
+ for frame in frames[i]:
54
+ writer.append_data(np.array(frame))
55
+ writer.close()
56
+ print('finish')
57
+ return result_video[0], result_video[1], result_video[2], result_video[3]
58
+
59
+ def main():
60
+ only_first_stage = True
61
+ # model = AppModel(only_first_stage)
62
+
63
+ with gr.Blocks(css='style.css') as demo:
64
+ # gr.Markdown(MAINTENANCE_NOTICE)
65
+
66
+ gr.Markdown(DESCRIPTION)
67
+
68
+ with gr.Row():
69
+ with gr.Column():
70
+ with gr.Group():
71
+ text = gr.Textbox(label='Input Text')
72
+ translate = gr.Checkbox(label='Translate to Chinese',
73
+ value=False)
74
+ seed = gr.Slider(0,
75
+ 100000,
76
+ step=1,
77
+ value=1234,
78
+ label='Seed')
79
+ only_first_stage = gr.Checkbox(
80
+ label='Only First Stage',
81
+ value=only_first_stage,
82
+ visible=not only_first_stage)
83
+ image_prompt = gr.Image(type="filepath",
84
+ label="Image Prompt",
85
+ value=None)
86
+ run_button = gr.Button('Run')
87
+
88
+ with gr.Column():
89
+ with gr.Group():
90
+ #translated_text = gr.Textbox(label='Translated Text')
91
+ with gr.Tabs():
92
+ with gr.TabItem('Output (Video)'):
93
+ result_video1 = gr.Video(show_label=False)
94
+ result_video2 = gr.Video(show_label=False)
95
+ result_video3 = gr.Video(show_label=False)
96
+ result_video4 = gr.Video(show_label=False)
97
+
98
+
99
+
100
+ # examples = gr.Examples(
101
+ # examples=[['骑滑板的皮卡丘', False, 1234, True,None],
102
+ # ['a cat playing chess', True, 1253, True,None]],
103
+ # fn=model.run_with_translation,
104
+ # inputs=[text, translate, seed, only_first_stage,image_prompt],
105
+ # outputs=[translated_text, result_video],
106
+ # cache_examples=True)
107
+
108
+ gr.Markdown(NOTES)
109
+ gr.Markdown(FOOTER)
110
+ print(gr.__version__)
111
+ run_button.click(fn=post,
112
+ inputs=[
113
+ text,
114
+ translate,
115
+ seed,
116
+ only_first_stage,
117
+ image_prompt
118
+ ],
119
+ outputs=[result_video1, result_video2, result_video3, result_video4])
120
+ print(gr.__version__)
121
+
122
+ demo.launch()
123
+
124
+
125
+ if __name__ == '__main__':
126
+ main()
icetk_models/.gitkeep ADDED
File without changes
model.py ADDED
@@ -0,0 +1,1243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is adapted from https://github.com/THUDM/CogVideo/blob/ff423aa169978fb2f636f761e348631fa3178b03/cogvideo_pipeline.py
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import logging
7
+ import os
8
+ import pathlib
9
+ import shutil
10
+ import subprocess
11
+ import sys
12
+ import tempfile
13
+ import time
14
+ import zipfile
15
+ from typing import Any
16
+
17
+ if os.getenv('SYSTEM') == 'spaces':
18
+ subprocess.run('pip install icetk==0.0.4'.split())
19
+ subprocess.run('pip install SwissArmyTransformer==0.2.9'.split())
20
+ subprocess.run(
21
+ 'pip install git+https://github.com/Sleepychord/Image-Local-Attention@43fee31'
22
+ .split())
23
+ #subprocess.run('git clone https://github.com/NVIDIA/apex'.split())
24
+ #subprocess.run('git checkout 1403c21'.split(), cwd='apex')
25
+ #with open('patch.apex') as f:
26
+ # subprocess.run('patch -p1'.split(), cwd='apex', stdin=f)
27
+ #subprocess.run(
28
+ # 'pip install -v --disable-pip-version-check --no-cache-dir --global-option --cpp_ext --global-option --cuda_ext ./'
29
+ # .split(),
30
+ # cwd='apex')
31
+ #subprocess.run('rm -rf apex'.split())
32
+ with open('patch') as f:
33
+ subprocess.run('patch -p1'.split(), cwd='CogVideo', stdin=f)
34
+
35
+ from huggingface_hub import hf_hub_download
36
+
37
+ def download_and_extract_icetk_models() -> None:
38
+ icetk_model_dir = pathlib.Path('/home/user/.icetk_models')
39
+ icetk_model_dir.mkdir()
40
+ path = hf_hub_download('THUDM/icetk',
41
+ 'models.zip',
42
+ use_auth_token=os.getenv('HF_TOKEN'))
43
+ with zipfile.ZipFile(path) as f:
44
+ f.extractall(path=icetk_model_dir.as_posix())
45
+
46
+ def download_and_extract_cogvideo_models(name: str) -> None:
47
+ path = hf_hub_download('THUDM/CogVideo',
48
+ name,
49
+ use_auth_token=os.getenv('HF_TOKEN'))
50
+ with zipfile.ZipFile(path) as f:
51
+ f.extractall('pretrained')
52
+ os.remove(path)
53
+
54
+ def download_and_extract_cogview2_models(name: str) -> None:
55
+ path = hf_hub_download('THUDM/CogView2', name)
56
+ with zipfile.ZipFile(path) as f:
57
+ f.extractall()
58
+ shutil.move('/home/user/app/sharefs/cogview-new/cogview2-dsr',
59
+ 'pretrained')
60
+ shutil.rmtree('/home/user/app/sharefs/')
61
+ os.remove(path)
62
+
63
+ download_and_extract_icetk_models()
64
+ download_and_extract_cogvideo_models('cogvideo-stage1.zip')
65
+ #download_and_extract_cogvideo_models('cogvideo-stage2.zip')
66
+ #download_and_extract_cogview2_models('cogview2-dsr.zip')
67
+
68
+ os.environ['SAT_HOME'] = '/home/user/app/pretrained'
69
+
70
+ import gradio as gr
71
+ import imageio.v2 as iio
72
+ import numpy as np
73
+ import torch
74
+ from icetk import IceTokenizer
75
+ from SwissArmyTransformer import get_args
76
+ from SwissArmyTransformer.arguments import set_random_seed
77
+ from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
78
+ from SwissArmyTransformer.resources import auto_create
79
+
80
+ app_dir = pathlib.Path(__file__).parent
81
+ submodule_dir = app_dir / 'CogVideo'
82
+ sys.path.insert(0, submodule_dir.as_posix())
83
+
84
+ from coglm_strategy import CoglmStrategy
85
+ from models.cogvideo_cache_model import CogVideoCacheModel
86
+ from sr_pipeline import DirectSuperResolution
87
+
88
+ formatter = logging.Formatter(
89
+ '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
90
+ datefmt='%Y-%m-%d %H:%M:%S')
91
+ stream_handler = logging.StreamHandler(stream=sys.stdout)
92
+ stream_handler.setLevel(logging.INFO)
93
+ stream_handler.setFormatter(formatter)
94
+ logger = logging.getLogger(__name__)
95
+ logger.setLevel(logging.INFO)
96
+ logger.propagate = False
97
+ logger.addHandler(stream_handler)
98
+
99
+ ICETK_MODEL_DIR = app_dir / 'icetk_models'
100
+
101
+
102
+ def get_masks_and_position_ids_stage1(data, textlen, framelen):
103
+ # Extract batch size and sequence length.
104
+ tokens = data
105
+ seq_length = len(data[0])
106
+ # Attention mask (lower triangular).
107
+ attention_mask = torch.ones((1, textlen + framelen, textlen + framelen),
108
+ device=data.device)
109
+ attention_mask[:, :textlen, textlen:] = 0
110
+ attention_mask[:, textlen:, textlen:].tril_()
111
+ attention_mask.unsqueeze_(1)
112
+ # Unaligned version
113
+ position_ids = torch.zeros(seq_length,
114
+ dtype=torch.long,
115
+ device=data.device)
116
+ torch.arange(textlen,
117
+ out=position_ids[:textlen],
118
+ dtype=torch.long,
119
+ device=data.device)
120
+ torch.arange(512,
121
+ 512 + seq_length - textlen,
122
+ out=position_ids[textlen:],
123
+ dtype=torch.long,
124
+ device=data.device)
125
+ position_ids = position_ids.unsqueeze(0)
126
+
127
+ return tokens, attention_mask, position_ids
128
+
129
+
130
+ def get_masks_and_position_ids_stage2(data, textlen, framelen):
131
+ # Extract batch size and sequence length.
132
+ tokens = data
133
+ seq_length = len(data[0])
134
+
135
+ # Attention mask (lower triangular).
136
+ attention_mask = torch.ones((1, textlen + framelen, textlen + framelen),
137
+ device=data.device)
138
+ attention_mask[:, :textlen, textlen:] = 0
139
+ attention_mask[:, textlen:, textlen:].tril_()
140
+ attention_mask.unsqueeze_(1)
141
+
142
+ # Unaligned version
143
+ position_ids = torch.zeros(seq_length,
144
+ dtype=torch.long,
145
+ device=data.device)
146
+ torch.arange(textlen,
147
+ out=position_ids[:textlen],
148
+ dtype=torch.long,
149
+ device=data.device)
150
+ frame_num = (seq_length - textlen) // framelen
151
+ assert frame_num == 5
152
+ torch.arange(512,
153
+ 512 + framelen,
154
+ out=position_ids[textlen:textlen + framelen],
155
+ dtype=torch.long,
156
+ device=data.device)
157
+ torch.arange(512 + framelen * 2,
158
+ 512 + framelen * 3,
159
+ out=position_ids[textlen + framelen:textlen + framelen * 2],
160
+ dtype=torch.long,
161
+ device=data.device)
162
+ torch.arange(512 + framelen * (frame_num - 1),
163
+ 512 + framelen * frame_num,
164
+ out=position_ids[textlen + framelen * 2:textlen +
165
+ framelen * 3],
166
+ dtype=torch.long,
167
+ device=data.device)
168
+ torch.arange(512 + framelen * 1,
169
+ 512 + framelen * 2,
170
+ out=position_ids[textlen + framelen * 3:textlen +
171
+ framelen * 4],
172
+ dtype=torch.long,
173
+ device=data.device)
174
+ torch.arange(512 + framelen * 3,
175
+ 512 + framelen * 4,
176
+ out=position_ids[textlen + framelen * 4:textlen +
177
+ framelen * 5],
178
+ dtype=torch.long,
179
+ device=data.device)
180
+
181
+ position_ids = position_ids.unsqueeze(0)
182
+
183
+ return tokens, attention_mask, position_ids
184
+
185
+
186
+ def my_update_mems(hiddens, mems_buffers, mems_indexs,
187
+ limited_spatial_channel_mem, text_len, frame_len):
188
+ if hiddens is None:
189
+ return None, mems_indexs
190
+ mem_num = len(hiddens)
191
+ ret_mem = []
192
+ with torch.no_grad():
193
+ for id in range(mem_num):
194
+ if hiddens[id][0] is None:
195
+ ret_mem.append(None)
196
+ else:
197
+ if id == 0 and limited_spatial_channel_mem and mems_indexs[
198
+ id] + hiddens[0][0].shape[1] >= text_len + frame_len:
199
+ if mems_indexs[id] == 0:
200
+ for layer, hidden in enumerate(hiddens[id]):
201
+ mems_buffers[id][
202
+ layer, :, :text_len] = hidden.expand(
203
+ mems_buffers[id].shape[1], -1,
204
+ -1)[:, :text_len]
205
+ new_mem_len_part2 = (mems_indexs[id] +
206
+ hiddens[0][0].shape[1] -
207
+ text_len) % frame_len
208
+ if new_mem_len_part2 > 0:
209
+ for layer, hidden in enumerate(hiddens[id]):
210
+ mems_buffers[id][
211
+ layer, :, text_len:text_len +
212
+ new_mem_len_part2] = hidden.expand(
213
+ mems_buffers[id].shape[1], -1,
214
+ -1)[:, -new_mem_len_part2:]
215
+ mems_indexs[id] = text_len + new_mem_len_part2
216
+ else:
217
+ for layer, hidden in enumerate(hiddens[id]):
218
+ mems_buffers[id][layer, :,
219
+ mems_indexs[id]:mems_indexs[id] +
220
+ hidden.shape[1]] = hidden.expand(
221
+ mems_buffers[id].shape[1], -1, -1)
222
+ mems_indexs[id] += hidden.shape[1]
223
+ ret_mem.append(mems_buffers[id][:, :, :mems_indexs[id]])
224
+ return ret_mem, mems_indexs
225
+
226
+
227
+ def calc_next_tokens_frame_begin_id(text_len, frame_len, total_len):
228
+ # The fisrt token's position id of the frame that the next token belongs to;
229
+ if total_len < text_len:
230
+ return None
231
+ return (total_len - text_len) // frame_len * frame_len + text_len
232
+
233
+
234
+ def my_filling_sequence(
235
+ model,
236
+ tokenizer,
237
+ args,
238
+ seq,
239
+ batch_size,
240
+ get_masks_and_position_ids,
241
+ text_len,
242
+ frame_len,
243
+ strategy=BaseStrategy(),
244
+ strategy2=BaseStrategy(),
245
+ mems=None,
246
+ log_text_attention_weights=0, # default to 0: no artificial change
247
+ mode_stage1=True,
248
+ enforce_no_swin=False,
249
+ guider_seq=None,
250
+ guider_text_len=0,
251
+ guidance_alpha=1,
252
+ limited_spatial_channel_mem=False, # 空间通道的存储限制在本帧内
253
+ **kw_args):
254
+ '''
255
+ seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
256
+ mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
257
+ cache, should be first mems.shape[1] parts of context_tokens.
258
+ mems are the first-level citizens here, but we don't assume what is memorized.
259
+ input mems are used when multi-phase generation.
260
+ '''
261
+ if guider_seq is not None:
262
+ logger.debug('Using Guidance In Inference')
263
+ if limited_spatial_channel_mem:
264
+ logger.debug("Limit spatial-channel's mem to current frame")
265
+ assert len(seq.shape) == 2
266
+
267
+ # building the initial tokens, attention_mask, and position_ids
268
+ actual_context_length = 0
269
+
270
+ while seq[-1][
271
+ actual_context_length] >= 0: # the last seq has least given tokens
272
+ actual_context_length += 1 # [0, context_length-1] are given
273
+ assert actual_context_length > 0
274
+ current_frame_num = (actual_context_length - text_len) // frame_len
275
+ assert current_frame_num >= 0
276
+ context_length = text_len + current_frame_num * frame_len
277
+
278
+ tokens, attention_mask, position_ids = get_masks_and_position_ids(
279
+ seq, text_len, frame_len)
280
+ tokens = tokens[..., :context_length]
281
+ input_tokens = tokens.clone()
282
+
283
+ if guider_seq is not None:
284
+ guider_index_delta = text_len - guider_text_len
285
+ guider_tokens, guider_attention_mask, guider_position_ids = get_masks_and_position_ids(
286
+ guider_seq, guider_text_len, frame_len)
287
+ guider_tokens = guider_tokens[..., :context_length -
288
+ guider_index_delta]
289
+ guider_input_tokens = guider_tokens.clone()
290
+
291
+ for fid in range(current_frame_num):
292
+ input_tokens[:, text_len + 400 * fid] = tokenizer['<start_of_image>']
293
+ if guider_seq is not None:
294
+ guider_input_tokens[:, guider_text_len +
295
+ 400 * fid] = tokenizer['<start_of_image>']
296
+
297
+ attention_mask = attention_mask.type_as(next(
298
+ model.parameters())) # if fp16
299
+ # initialize generation
300
+ counter = context_length - 1 # Last fixed index is ``counter''
301
+ index = 0 # Next forward starting index, also the length of cache.
302
+ mems_buffers_on_GPU = False
303
+ mems_indexs = [0, 0]
304
+ mems_len = [(400 + 74) if limited_spatial_channel_mem else 5 * 400 + 74,
305
+ 5 * 400 + 74]
306
+ mems_buffers = [
307
+ torch.zeros(args.num_layers,
308
+ batch_size,
309
+ mem_len,
310
+ args.hidden_size * 2,
311
+ dtype=next(model.parameters()).dtype)
312
+ for mem_len in mems_len
313
+ ]
314
+
315
+ if guider_seq is not None:
316
+ guider_attention_mask = guider_attention_mask.type_as(
317
+ next(model.parameters())) # if fp16
318
+ guider_mems_buffers = [
319
+ torch.zeros(args.num_layers,
320
+ batch_size,
321
+ mem_len,
322
+ args.hidden_size * 2,
323
+ dtype=next(model.parameters()).dtype)
324
+ for mem_len in mems_len
325
+ ]
326
+ guider_mems_indexs = [0, 0]
327
+ guider_mems = None
328
+
329
+ torch.cuda.empty_cache()
330
+ # step-by-step generation
331
+ while counter < len(seq[0]) - 1:
332
+ # we have generated counter+1 tokens
333
+ # Now, we want to generate seq[counter + 1],
334
+ # token[:, index: counter+1] needs forwarding.
335
+ if index == 0:
336
+ group_size = 2 if (input_tokens.shape[0] == batch_size
337
+ and not mode_stage1) else batch_size
338
+
339
+ logits_all = None
340
+ for batch_idx in range(0, input_tokens.shape[0], group_size):
341
+ logits, *output_per_layers = model(
342
+ input_tokens[batch_idx:batch_idx + group_size, index:],
343
+ position_ids[..., index:counter + 1],
344
+ attention_mask, # TODO memlen
345
+ mems=mems,
346
+ text_len=text_len,
347
+ frame_len=frame_len,
348
+ counter=counter,
349
+ log_text_attention_weights=log_text_attention_weights,
350
+ enforce_no_swin=enforce_no_swin,
351
+ **kw_args)
352
+ logits_all = torch.cat(
353
+ (logits_all,
354
+ logits), dim=0) if logits_all is not None else logits
355
+ mem_kv01 = [[o['mem_kv'][0] for o in output_per_layers],
356
+ [o['mem_kv'][1] for o in output_per_layers]]
357
+ next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(
358
+ text_len, frame_len, mem_kv01[0][0].shape[1])
359
+ for id, mem_kv in enumerate(mem_kv01):
360
+ for layer, mem_kv_perlayer in enumerate(mem_kv):
361
+ if limited_spatial_channel_mem and id == 0:
362
+ mems_buffers[id][
363
+ layer, batch_idx:batch_idx + group_size, :
364
+ text_len] = mem_kv_perlayer.expand(
365
+ min(group_size,
366
+ input_tokens.shape[0] - batch_idx), -1,
367
+ -1)[:, :text_len]
368
+ mems_buffers[id][layer, batch_idx:batch_idx+group_size, text_len:text_len+mem_kv_perlayer.shape[1]-next_tokens_frame_begin_id] =\
369
+ mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, next_tokens_frame_begin_id:]
370
+ else:
371
+ mems_buffers[id][
372
+ layer, batch_idx:batch_idx +
373
+ group_size, :mem_kv_perlayer.
374
+ shape[1]] = mem_kv_perlayer.expand(
375
+ min(group_size,
376
+ input_tokens.shape[0] - batch_idx), -1,
377
+ -1)
378
+ mems_indexs[0], mems_indexs[1] = mem_kv01[0][0].shape[
379
+ 1], mem_kv01[1][0].shape[1]
380
+ if limited_spatial_channel_mem:
381
+ mems_indexs[0] -= (next_tokens_frame_begin_id - text_len)
382
+
383
+ mems = [
384
+ mems_buffers[id][:, :, :mems_indexs[id]] for id in range(2)
385
+ ]
386
+ logits = logits_all
387
+
388
+ # Guider
389
+ if guider_seq is not None:
390
+ guider_logits_all = None
391
+ for batch_idx in range(0, guider_input_tokens.shape[0],
392
+ group_size):
393
+ guider_logits, *guider_output_per_layers = model(
394
+ guider_input_tokens[batch_idx:batch_idx + group_size,
395
+ max(index -
396
+ guider_index_delta, 0):],
397
+ guider_position_ids[
398
+ ...,
399
+ max(index - guider_index_delta, 0):counter + 1 -
400
+ guider_index_delta],
401
+ guider_attention_mask,
402
+ mems=guider_mems,
403
+ text_len=guider_text_len,
404
+ frame_len=frame_len,
405
+ counter=counter - guider_index_delta,
406
+ log_text_attention_weights=log_text_attention_weights,
407
+ enforce_no_swin=enforce_no_swin,
408
+ **kw_args)
409
+ guider_logits_all = torch.cat(
410
+ (guider_logits_all, guider_logits), dim=0
411
+ ) if guider_logits_all is not None else guider_logits
412
+ guider_mem_kv01 = [[
413
+ o['mem_kv'][0] for o in guider_output_per_layers
414
+ ], [o['mem_kv'][1] for o in guider_output_per_layers]]
415
+ for id, guider_mem_kv in enumerate(guider_mem_kv01):
416
+ for layer, guider_mem_kv_perlayer in enumerate(
417
+ guider_mem_kv):
418
+ if limited_spatial_channel_mem and id == 0:
419
+ guider_mems_buffers[id][
420
+ layer, batch_idx:batch_idx + group_size, :
421
+ guider_text_len] = guider_mem_kv_perlayer.expand(
422
+ min(group_size,
423
+ input_tokens.shape[0] - batch_idx),
424
+ -1, -1)[:, :guider_text_len]
425
+ guider_next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(
426
+ guider_text_len, frame_len,
427
+ guider_mem_kv_perlayer.shape[1])
428
+ guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, guider_text_len:guider_text_len+guider_mem_kv_perlayer.shape[1]-guider_next_tokens_frame_begin_id] =\
429
+ guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, guider_next_tokens_frame_begin_id:]
430
+ else:
431
+ guider_mems_buffers[id][
432
+ layer, batch_idx:batch_idx +
433
+ group_size, :guider_mem_kv_perlayer.
434
+ shape[1]] = guider_mem_kv_perlayer.expand(
435
+ min(group_size,
436
+ input_tokens.shape[0] - batch_idx),
437
+ -1, -1)
438
+ guider_mems_indexs[0], guider_mems_indexs[
439
+ 1] = guider_mem_kv01[0][0].shape[1], guider_mem_kv01[
440
+ 1][0].shape[1]
441
+ if limited_spatial_channel_mem:
442
+ guider_mems_indexs[0] -= (
443
+ guider_next_tokens_frame_begin_id -
444
+ guider_text_len)
445
+ guider_mems = [
446
+ guider_mems_buffers[id][:, :, :guider_mems_indexs[id]]
447
+ for id in range(2)
448
+ ]
449
+ guider_logits = guider_logits_all
450
+ else:
451
+ if not mems_buffers_on_GPU:
452
+ if not mode_stage1:
453
+ torch.cuda.empty_cache()
454
+ for idx, mem in enumerate(mems):
455
+ mems[idx] = mem.to(next(model.parameters()).device)
456
+ if guider_seq is not None:
457
+ for idx, mem in enumerate(guider_mems):
458
+ guider_mems[idx] = mem.to(
459
+ next(model.parameters()).device)
460
+ else:
461
+ torch.cuda.empty_cache()
462
+ for idx, mem_buffer in enumerate(mems_buffers):
463
+ mems_buffers[idx] = mem_buffer.to(
464
+ next(model.parameters()).device)
465
+ mems = [
466
+ mems_buffers[id][:, :, :mems_indexs[id]]
467
+ for id in range(2)
468
+ ]
469
+ if guider_seq is not None:
470
+ for idx, guider_mem_buffer in enumerate(
471
+ guider_mems_buffers):
472
+ guider_mems_buffers[idx] = guider_mem_buffer.to(
473
+ next(model.parameters()).device)
474
+ guider_mems = [
475
+ guider_mems_buffers[id]
476
+ [:, :, :guider_mems_indexs[id]] for id in range(2)
477
+ ]
478
+ mems_buffers_on_GPU = True
479
+
480
+ logits, *output_per_layers = model(
481
+ input_tokens[:, index:],
482
+ position_ids[..., index:counter + 1],
483
+ attention_mask, # TODO memlen
484
+ mems=mems,
485
+ text_len=text_len,
486
+ frame_len=frame_len,
487
+ counter=counter,
488
+ log_text_attention_weights=log_text_attention_weights,
489
+ enforce_no_swin=enforce_no_swin,
490
+ limited_spatial_channel_mem=limited_spatial_channel_mem,
491
+ **kw_args)
492
+ mem_kv0, mem_kv1 = [o['mem_kv'][0] for o in output_per_layers
493
+ ], [o['mem_kv'][1] for o in output_per_layers]
494
+
495
+ if guider_seq is not None:
496
+ guider_logits, *guider_output_per_layers = model(
497
+ guider_input_tokens[:,
498
+ max(index - guider_index_delta, 0):],
499
+ guider_position_ids[...,
500
+ max(index -
501
+ guider_index_delta, 0):counter +
502
+ 1 - guider_index_delta],
503
+ guider_attention_mask,
504
+ mems=guider_mems,
505
+ text_len=guider_text_len,
506
+ frame_len=frame_len,
507
+ counter=counter - guider_index_delta,
508
+ log_text_attention_weights=0,
509
+ enforce_no_swin=enforce_no_swin,
510
+ limited_spatial_channel_mem=limited_spatial_channel_mem,
511
+ **kw_args)
512
+ guider_mem_kv0, guider_mem_kv1 = [
513
+ o['mem_kv'][0] for o in guider_output_per_layers
514
+ ], [o['mem_kv'][1] for o in guider_output_per_layers]
515
+
516
+ if not mems_buffers_on_GPU:
517
+ torch.cuda.empty_cache()
518
+ for idx, mem_buffer in enumerate(mems_buffers):
519
+ mems_buffers[idx] = mem_buffer.to(
520
+ next(model.parameters()).device)
521
+ if guider_seq is not None:
522
+ for idx, guider_mem_buffer in enumerate(
523
+ guider_mems_buffers):
524
+ guider_mems_buffers[idx] = guider_mem_buffer.to(
525
+ next(model.parameters()).device)
526
+ mems_buffers_on_GPU = True
527
+
528
+ mems, mems_indexs = my_update_mems([mem_kv0, mem_kv1],
529
+ mems_buffers, mems_indexs,
530
+ limited_spatial_channel_mem,
531
+ text_len, frame_len)
532
+ if guider_seq is not None:
533
+ guider_mems, guider_mems_indexs = my_update_mems(
534
+ [guider_mem_kv0, guider_mem_kv1], guider_mems_buffers,
535
+ guider_mems_indexs, limited_spatial_channel_mem,
536
+ guider_text_len, frame_len)
537
+
538
+ counter += 1
539
+ index = counter
540
+
541
+ logits = logits[:, -1].expand(batch_size,
542
+ -1) # [batch size, vocab size]
543
+ tokens = tokens.expand(batch_size, -1)
544
+ if guider_seq is not None:
545
+ guider_logits = guider_logits[:, -1].expand(batch_size, -1)
546
+ guider_tokens = guider_tokens.expand(batch_size, -1)
547
+
548
+ if seq[-1][counter].item() < 0:
549
+ # sampling
550
+ guided_logits = guider_logits + (
551
+ logits - guider_logits
552
+ ) * guidance_alpha if guider_seq is not None else logits
553
+ if mode_stage1 and counter < text_len + 400:
554
+ tokens, mems = strategy.forward(guided_logits, tokens, mems)
555
+ else:
556
+ tokens, mems = strategy2.forward(guided_logits, tokens, mems)
557
+ if guider_seq is not None:
558
+ guider_tokens = torch.cat((guider_tokens, tokens[:, -1:]),
559
+ dim=1)
560
+
561
+ if seq[0][counter].item() >= 0:
562
+ for si in range(seq.shape[0]):
563
+ if seq[si][counter].item() >= 0:
564
+ tokens[si, -1] = seq[si, counter]
565
+ if guider_seq is not None:
566
+ guider_tokens[si,
567
+ -1] = guider_seq[si, counter -
568
+ guider_index_delta]
569
+
570
+ else:
571
+ tokens = torch.cat(
572
+ (tokens, seq[:, counter:counter + 1].clone().expand(
573
+ tokens.shape[0], 1).to(device=tokens.device,
574
+ dtype=tokens.dtype)),
575
+ dim=1)
576
+ if guider_seq is not None:
577
+ guider_tokens = torch.cat(
578
+ (guider_tokens,
579
+ guider_seq[:, counter - guider_index_delta:counter + 1 -
580
+ guider_index_delta].clone().expand(
581
+ guider_tokens.shape[0], 1).to(
582
+ device=guider_tokens.device,
583
+ dtype=guider_tokens.dtype)),
584
+ dim=1)
585
+
586
+ input_tokens = tokens.clone()
587
+ if guider_seq is not None:
588
+ guider_input_tokens = guider_tokens.clone()
589
+ if (index - text_len - 1) // 400 < (input_tokens.shape[-1] - text_len -
590
+ 1) // 400:
591
+ boi_idx = ((index - text_len - 1) // 400 + 1) * 400 + text_len
592
+ while boi_idx < input_tokens.shape[-1]:
593
+ input_tokens[:, boi_idx] = tokenizer['<start_of_image>']
594
+ if guider_seq is not None:
595
+ guider_input_tokens[:, boi_idx -
596
+ guider_index_delta] = tokenizer[
597
+ '<start_of_image>']
598
+ boi_idx += 400
599
+
600
+ if strategy.is_done:
601
+ break
602
+ return strategy.finalize(tokens, mems)
603
+
604
+
605
+ class InferenceModel_Sequential(CogVideoCacheModel):
606
+ def __init__(self, args, transformer=None, parallel_output=True):
607
+ super().__init__(args,
608
+ transformer=transformer,
609
+ parallel_output=parallel_output,
610
+ window_size=-1,
611
+ cogvideo_stage=1)
612
+
613
+ # TODO: check it
614
+
615
+ def final_forward(self, logits, **kwargs):
616
+ logits_parallel = logits
617
+ logits_parallel = torch.nn.functional.linear(
618
+ logits_parallel.float(),
619
+ self.transformer.word_embeddings.weight[:20000].float())
620
+ return logits_parallel
621
+
622
+
623
+ class InferenceModel_Interpolate(CogVideoCacheModel):
624
+ def __init__(self, args, transformer=None, parallel_output=True):
625
+ super().__init__(args,
626
+ transformer=transformer,
627
+ parallel_output=parallel_output,
628
+ window_size=10,
629
+ cogvideo_stage=2)
630
+
631
+ # TODO: check it
632
+
633
+ def final_forward(self, logits, **kwargs):
634
+ logits_parallel = logits
635
+ logits_parallel = torch.nn.functional.linear(
636
+ logits_parallel.float(),
637
+ self.transformer.word_embeddings.weight[:20000].float())
638
+ return logits_parallel
639
+
640
+
641
+ def get_default_args() -> argparse.Namespace:
642
+ known = argparse.Namespace(generate_frame_num=5,
643
+ coglm_temperature2=0.89,
644
+ use_guidance_stage1=True,
645
+ use_guidance_stage2=False,
646
+ guidance_alpha=3.0,
647
+ stage_1=True,
648
+ stage_2=False,
649
+ both_stages=False,
650
+ parallel_size=1,
651
+ stage1_max_inference_batch_size=-1,
652
+ multi_gpu=False,
653
+ layout='64, 464, 2064',
654
+ window_size=10,
655
+ additional_seqlen=2000,
656
+ cogvideo_stage=1)
657
+
658
+ args_list = [
659
+ '--tokenizer-type',
660
+ 'fake',
661
+ '--mode',
662
+ 'inference',
663
+ '--distributed-backend',
664
+ 'nccl',
665
+ '--fp16',
666
+ '--model-parallel-size',
667
+ '1',
668
+ '--temperature',
669
+ '1.05',
670
+ '--top_k',
671
+ '12',
672
+ '--sandwich-ln',
673
+ '--seed',
674
+ '1234',
675
+ '--num-workers',
676
+ '0',
677
+ '--batch-size',
678
+ '1',
679
+ '--max-inference-batch-size',
680
+ '8',
681
+ ]
682
+ args = get_args(args_list)
683
+ args = argparse.Namespace(**vars(args), **vars(known))
684
+ args.layout = [int(x) for x in args.layout.split(',')]
685
+ args.do_train = False
686
+ return args
687
+
688
+
689
+ class Model:
690
+ def __init__(self, only_first_stage: bool = False):
691
+ self.args = get_default_args()
692
+ if only_first_stage:
693
+ self.args.stage_1 = True
694
+ self.args.both_stages = False
695
+ else:
696
+ self.args.stage_1 = False
697
+ self.args.both_stages = True
698
+
699
+ self.tokenizer = self.load_tokenizer()
700
+
701
+ self.model_stage1, self.args = self.load_model_stage1()
702
+ self.model_stage2, self.args = self.load_model_stage2()
703
+
704
+ self.strategy_cogview2, self.strategy_cogvideo = self.load_strategies()
705
+ self.dsr = self.load_dsr()
706
+
707
+ self.device = torch.device(self.args.device)
708
+
709
+ def load_tokenizer(self) -> IceTokenizer:
710
+ logger.info('--- load_tokenizer ---')
711
+ start = time.perf_counter()
712
+
713
+ tokenizer = IceTokenizer(ICETK_MODEL_DIR.as_posix())
714
+ tokenizer.add_special_tokens(
715
+ ['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
716
+
717
+ elapsed = time.perf_counter() - start
718
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
719
+ return tokenizer
720
+
721
+ def load_model_stage1(
722
+ self) -> tuple[CogVideoCacheModel, argparse.Namespace]:
723
+ logger.info('--- load_model_stage1 ---')
724
+ start = time.perf_counter()
725
+
726
+ args = self.args
727
+ model_stage1, args = InferenceModel_Sequential.from_pretrained(
728
+ args, 'cogvideo-stage1')
729
+ model_stage1.eval()
730
+ if args.both_stages:
731
+ model_stage1 = model_stage1.cpu()
732
+
733
+ elapsed = time.perf_counter() - start
734
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
735
+ return model_stage1, args
736
+
737
+ def load_model_stage2(
738
+ self) -> tuple[CogVideoCacheModel | None, argparse.Namespace]:
739
+ logger.info('--- load_model_stage2 ---')
740
+ start = time.perf_counter()
741
+
742
+ args = self.args
743
+ if args.both_stages:
744
+ model_stage2, args = InferenceModel_Interpolate.from_pretrained(
745
+ args, 'cogvideo-stage2')
746
+ model_stage2.eval()
747
+ if args.both_stages:
748
+ model_stage2 = model_stage2.cpu()
749
+ else:
750
+ model_stage2 = None
751
+
752
+ elapsed = time.perf_counter() - start
753
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
754
+ return model_stage2, args
755
+
756
+ def load_strategies(self) -> tuple[CoglmStrategy, CoglmStrategy]:
757
+ logger.info('--- load_strategies ---')
758
+ start = time.perf_counter()
759
+
760
+ invalid_slices = [slice(self.tokenizer.num_image_tokens, None)]
761
+ strategy_cogview2 = CoglmStrategy(invalid_slices,
762
+ temperature=1.0,
763
+ top_k=16)
764
+ strategy_cogvideo = CoglmStrategy(
765
+ invalid_slices,
766
+ temperature=self.args.temperature,
767
+ top_k=self.args.top_k,
768
+ temperature2=self.args.coglm_temperature2)
769
+
770
+ elapsed = time.perf_counter() - start
771
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
772
+ return strategy_cogview2, strategy_cogvideo
773
+
774
+ def load_dsr(self) -> DirectSuperResolution | None:
775
+ logger.info('--- load_dsr ---')
776
+ start = time.perf_counter()
777
+
778
+ if self.args.both_stages:
779
+ path = auto_create('cogview2-dsr', path=None)
780
+ dsr = DirectSuperResolution(self.args,
781
+ path,
782
+ max_bz=12,
783
+ onCUDA=False)
784
+ else:
785
+ dsr = None
786
+
787
+ elapsed = time.perf_counter() - start
788
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
789
+ return dsr
790
+
791
+ @torch.inference_mode()
792
+ def process_stage1(self,
793
+ model,
794
+ seq_text,
795
+ duration,
796
+ video_raw_text=None,
797
+ video_guidance_text='视频',
798
+ image_text_suffix='',
799
+ batch_size=1,
800
+ image_prompt=None):
801
+ process_start_time = time.perf_counter()
802
+
803
+ generate_frame_num = self.args.generate_frame_num
804
+ tokenizer = self.tokenizer
805
+ use_guide = self.args.use_guidance_stage1
806
+
807
+ if next(model.parameters()).device != self.device:
808
+ move_start_time = time.perf_counter()
809
+ logger.debug('moving stage 1 model to cuda')
810
+
811
+ model = model.to(self.device)
812
+
813
+ elapsed = time.perf_counter() - move_start_time
814
+ logger.debug(f'moving in model1 takes time: {elapsed:.2f}')
815
+
816
+ if video_raw_text is None:
817
+ video_raw_text = seq_text
818
+ mbz = self.args.stage1_max_inference_batch_size if self.args.stage1_max_inference_batch_size > 0 else self.args.max_inference_batch_size
819
+ assert batch_size < mbz or batch_size % mbz == 0
820
+ frame_len = 400
821
+
822
+ # generate the first frame:
823
+ enc_text = tokenizer.encode(seq_text + image_text_suffix)
824
+ seq_1st = enc_text + [tokenizer['<start_of_image>']] + [-1] * 400
825
+ logger.info(
826
+ f'[Generating First Frame with CogView2] Raw text: {tokenizer.decode(enc_text):s}'
827
+ )
828
+ text_len_1st = len(seq_1st) - frame_len * 1 - 1
829
+
830
+ seq_1st = torch.tensor(seq_1st, dtype=torch.long,
831
+ device=self.device).unsqueeze(0)
832
+ if image_prompt is None:
833
+ output_list_1st = []
834
+ for tim in range(max(batch_size // mbz, 1)):
835
+ start_time = time.perf_counter()
836
+ output_list_1st.append(
837
+ my_filling_sequence(
838
+ model,
839
+ tokenizer,
840
+ self.args,
841
+ seq_1st.clone(),
842
+ batch_size=min(batch_size, mbz),
843
+ get_masks_and_position_ids=
844
+ get_masks_and_position_ids_stage1,
845
+ text_len=text_len_1st,
846
+ frame_len=frame_len,
847
+ strategy=self.strategy_cogview2,
848
+ strategy2=self.strategy_cogvideo,
849
+ log_text_attention_weights=1.4,
850
+ enforce_no_swin=True,
851
+ mode_stage1=True,
852
+ )[0])
853
+ elapsed = time.perf_counter() - start_time
854
+ logger.info(f'[First Frame] Elapsed: {elapsed:.2f}')
855
+ output_tokens_1st = torch.cat(output_list_1st, dim=0)
856
+ given_tokens = output_tokens_1st[:, text_len_1st + 1:text_len_1st +
857
+ 401].unsqueeze(
858
+ 1
859
+ ) # given_tokens.shape: [bs, frame_num, 400]
860
+ else:
861
+ given_tokens = tokenizer.encode(image_path=image_prompt, image_size=160).repeat(batch_size, 1).unsqueeze(1)
862
+
863
+ # generate subsequent frames:
864
+ total_frames = generate_frame_num
865
+ enc_duration = tokenizer.encode(f'{float(duration)}秒')
866
+ if use_guide:
867
+ video_raw_text = video_raw_text + ' 视频'
868
+ enc_text_video = tokenizer.encode(video_raw_text)
869
+ seq = enc_duration + [tokenizer['<n>']] + enc_text_video + [
870
+ tokenizer['<start_of_image>']
871
+ ] + [-1] * 400 * generate_frame_num
872
+ guider_seq = enc_duration + [tokenizer['<n>']] + tokenizer.encode(
873
+ video_guidance_text) + [tokenizer['<start_of_image>']
874
+ ] + [-1] * 400 * generate_frame_num
875
+ logger.info(
876
+ f'[Stage1: Generating Subsequent Frames, Frame Rate {4/duration:.1f}] raw text: {tokenizer.decode(enc_text_video):s}'
877
+ )
878
+
879
+ text_len = len(seq) - frame_len * generate_frame_num - 1
880
+ guider_text_len = len(guider_seq) - frame_len * generate_frame_num - 1
881
+ seq = torch.tensor(seq, dtype=torch.long,
882
+ device=self.device).unsqueeze(0).repeat(
883
+ batch_size, 1)
884
+ guider_seq = torch.tensor(guider_seq,
885
+ dtype=torch.long,
886
+ device=self.device).unsqueeze(0).repeat(
887
+ batch_size, 1)
888
+
889
+ for given_frame_id in range(given_tokens.shape[1]):
890
+ seq[:, text_len + 1 + given_frame_id * 400:text_len + 1 +
891
+ (given_frame_id + 1) * 400] = given_tokens[:, given_frame_id]
892
+ guider_seq[:, guider_text_len + 1 +
893
+ given_frame_id * 400:guider_text_len + 1 +
894
+ (given_frame_id + 1) *
895
+ 400] = given_tokens[:, given_frame_id]
896
+ output_list = []
897
+
898
+ if use_guide:
899
+ video_log_text_attention_weights = 0
900
+ else:
901
+ guider_seq = None
902
+ video_log_text_attention_weights = 1.4
903
+
904
+ for tim in range(max(batch_size // mbz, 1)):
905
+ input_seq = seq[:min(batch_size, mbz)].clone(
906
+ ) if tim == 0 else seq[mbz * tim:mbz * (tim + 1)].clone()
907
+ guider_seq2 = (guider_seq[:min(batch_size, mbz)].clone()
908
+ if tim == 0 else guider_seq[mbz * tim:mbz *
909
+ (tim + 1)].clone()
910
+ ) if guider_seq is not None else None
911
+ output_list.append(
912
+ my_filling_sequence(
913
+ model,
914
+ tokenizer,
915
+ self.args,
916
+ input_seq,
917
+ batch_size=min(batch_size, mbz),
918
+ get_masks_and_position_ids=
919
+ get_masks_and_position_ids_stage1,
920
+ text_len=text_len,
921
+ frame_len=frame_len,
922
+ strategy=self.strategy_cogview2,
923
+ strategy2=self.strategy_cogvideo,
924
+ log_text_attention_weights=video_log_text_attention_weights,
925
+ guider_seq=guider_seq2,
926
+ guider_text_len=guider_text_len,
927
+ guidance_alpha=self.args.guidance_alpha,
928
+ limited_spatial_channel_mem=True,
929
+ mode_stage1=True,
930
+ )[0])
931
+
932
+ output_tokens = torch.cat(output_list, dim=0)[:, 1 + text_len:]
933
+
934
+ if self.args.both_stages:
935
+ move_start_time = time.perf_counter()
936
+ logger.debug('moving stage 1 model to cpu')
937
+ model = model.cpu()
938
+ torch.cuda.empty_cache()
939
+ elapsed = time.perf_counter() - move_start_time
940
+ logger.debug(f'moving in model1 takes time: {elapsed:.2f}')
941
+
942
+ # decoding
943
+ res = []
944
+ for seq in output_tokens:
945
+ decoded_imgs = [
946
+ self.postprocess(
947
+ torch.nn.functional.interpolate(tokenizer.decode(
948
+ image_ids=seq.tolist()[i * 400:(i + 1) * 400]),
949
+ size=(480, 480))[0])
950
+ for i in range(total_frames)
951
+ ]
952
+ res.append(decoded_imgs) # only the last image (target)
953
+
954
+ assert len(res) == batch_size
955
+ tokens = output_tokens[:, :+total_frames * 400].reshape(
956
+ -1, total_frames, 400).cpu()
957
+
958
+ elapsed = time.perf_counter() - process_start_time
959
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
960
+ return tokens, res[0]
961
+
962
+ @torch.inference_mode()
963
+ def process_stage2(self,
964
+ model,
965
+ seq_text,
966
+ duration,
967
+ parent_given_tokens,
968
+ video_raw_text=None,
969
+ video_guidance_text='视频',
970
+ gpu_rank=0,
971
+ gpu_parallel_size=1):
972
+ process_start_time = time.perf_counter()
973
+
974
+ generate_frame_num = self.args.generate_frame_num
975
+ tokenizer = self.tokenizer
976
+ use_guidance = self.args.use_guidance_stage2
977
+
978
+ stage2_start_time = time.perf_counter()
979
+
980
+ if next(model.parameters()).device != self.device:
981
+ move_start_time = time.perf_counter()
982
+ logger.debug('moving stage-2 model to cuda')
983
+
984
+ model = model.to(self.device)
985
+
986
+ elapsed = time.perf_counter() - move_start_time
987
+ logger.debug(f'moving in stage-2 model takes time: {elapsed:.2f}')
988
+
989
+ try:
990
+ sample_num_allgpu = parent_given_tokens.shape[0]
991
+ sample_num = sample_num_allgpu // gpu_parallel_size
992
+ assert sample_num * gpu_parallel_size == sample_num_allgpu
993
+ parent_given_tokens = parent_given_tokens[gpu_rank *
994
+ sample_num:(gpu_rank +
995
+ 1) *
996
+ sample_num]
997
+ except:
998
+ logger.critical('No frame_tokens found in interpolation, skip')
999
+ return False, []
1000
+
1001
+ # CogVideo Stage2 Generation
1002
+ while duration >= 0.5: # TODO: You can change the boundary to change the frame rate
1003
+ parent_given_tokens_num = parent_given_tokens.shape[1]
1004
+ generate_batchsize_persample = (parent_given_tokens_num - 1) // 2
1005
+ generate_batchsize_total = generate_batchsize_persample * sample_num
1006
+ total_frames = generate_frame_num
1007
+ frame_len = 400
1008
+ enc_text = tokenizer.encode(seq_text)
1009
+ enc_duration = tokenizer.encode(str(float(duration)) + '秒')
1010
+ seq = enc_duration + [tokenizer['<n>']] + enc_text + [
1011
+ tokenizer['<start_of_image>']
1012
+ ] + [-1] * 400 * generate_frame_num
1013
+ text_len = len(seq) - frame_len * generate_frame_num - 1
1014
+
1015
+ logger.info(
1016
+ f'[Stage2: Generating Frames, Frame Rate {int(4/duration):d}] raw text: {tokenizer.decode(enc_text):s}'
1017
+ )
1018
+
1019
+ # generation
1020
+ seq = torch.tensor(seq, dtype=torch.long,
1021
+ device=self.device).unsqueeze(0).repeat(
1022
+ generate_batchsize_total, 1)
1023
+ for sample_i in range(sample_num):
1024
+ for i in range(generate_batchsize_persample):
1025
+ seq[sample_i * generate_batchsize_persample +
1026
+ i][text_len + 1:text_len + 1 +
1027
+ 400] = parent_given_tokens[sample_i][2 * i]
1028
+ seq[sample_i * generate_batchsize_persample +
1029
+ i][text_len + 1 + 400:text_len + 1 +
1030
+ 800] = parent_given_tokens[sample_i][2 * i + 1]
1031
+ seq[sample_i * generate_batchsize_persample +
1032
+ i][text_len + 1 + 800:text_len + 1 +
1033
+ 1200] = parent_given_tokens[sample_i][2 * i + 2]
1034
+
1035
+ if use_guidance:
1036
+ guider_seq = enc_duration + [
1037
+ tokenizer['<n>']
1038
+ ] + tokenizer.encode(video_guidance_text) + [
1039
+ tokenizer['<start_of_image>']
1040
+ ] + [-1] * 400 * generate_frame_num
1041
+ guider_text_len = len(
1042
+ guider_seq) - frame_len * generate_frame_num - 1
1043
+ guider_seq = torch.tensor(
1044
+ guider_seq, dtype=torch.long,
1045
+ device=self.device).unsqueeze(0).repeat(
1046
+ generate_batchsize_total, 1)
1047
+ for sample_i in range(sample_num):
1048
+ for i in range(generate_batchsize_persample):
1049
+ guider_seq[sample_i * generate_batchsize_persample +
1050
+ i][text_len + 1:text_len + 1 +
1051
+ 400] = parent_given_tokens[sample_i][2 *
1052
+ i]
1053
+ guider_seq[sample_i * generate_batchsize_persample +
1054
+ i][text_len + 1 + 400:text_len + 1 +
1055
+ 800] = parent_given_tokens[sample_i][2 *
1056
+ i +
1057
+ 1]
1058
+ guider_seq[sample_i * generate_batchsize_persample +
1059
+ i][text_len + 1 + 800:text_len + 1 +
1060
+ 1200] = parent_given_tokens[sample_i][2 *
1061
+ i +
1062
+ 2]
1063
+ video_log_text_attention_weights = 0
1064
+ else:
1065
+ guider_seq = None
1066
+ guider_text_len = 0
1067
+ video_log_text_attention_weights = 1.4
1068
+
1069
+ mbz = self.args.max_inference_batch_size
1070
+
1071
+ assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0
1072
+ output_list = []
1073
+ start_time = time.perf_counter()
1074
+ for tim in range(max(generate_batchsize_total // mbz, 1)):
1075
+ input_seq = seq[:min(generate_batchsize_total, mbz)].clone(
1076
+ ) if tim == 0 else seq[mbz * tim:mbz * (tim + 1)].clone()
1077
+ guider_seq2 = (
1078
+ guider_seq[:min(generate_batchsize_total, mbz)].clone()
1079
+ if tim == 0 else guider_seq[mbz * tim:mbz *
1080
+ (tim + 1)].clone()
1081
+ ) if guider_seq is not None else None
1082
+ output_list.append(
1083
+ my_filling_sequence(
1084
+ model,
1085
+ tokenizer,
1086
+ self.args,
1087
+ input_seq,
1088
+ batch_size=min(generate_batchsize_total, mbz),
1089
+ get_masks_and_position_ids=
1090
+ get_masks_and_position_ids_stage2,
1091
+ text_len=text_len,
1092
+ frame_len=frame_len,
1093
+ strategy=self.strategy_cogview2,
1094
+ strategy2=self.strategy_cogvideo,
1095
+ log_text_attention_weights=
1096
+ video_log_text_attention_weights,
1097
+ mode_stage1=False,
1098
+ guider_seq=guider_seq2,
1099
+ guider_text_len=guider_text_len,
1100
+ guidance_alpha=self.args.guidance_alpha,
1101
+ limited_spatial_channel_mem=True,
1102
+ )[0])
1103
+ elapsed = time.perf_counter() - start_time
1104
+ logger.info(f'Duration {duration:.2f}, Elapsed: {elapsed:.2f}\n')
1105
+
1106
+ output_tokens = torch.cat(output_list, dim=0)
1107
+ output_tokens = output_tokens[:, text_len + 1:text_len + 1 +
1108
+ (total_frames) * 400].reshape(
1109
+ sample_num, -1,
1110
+ 400 * total_frames)
1111
+ output_tokens_merge = torch.cat(
1112
+ (output_tokens[:, :, :1 * 400], output_tokens[:, :,
1113
+ 400 * 3:4 * 400],
1114
+ output_tokens[:, :, 400 * 1:2 * 400],
1115
+ output_tokens[:, :, 400 * 4:(total_frames) * 400]),
1116
+ dim=2).reshape(sample_num, -1, 400)
1117
+
1118
+ output_tokens_merge = torch.cat(
1119
+ (output_tokens_merge, output_tokens[:, -1:, 400 * 2:3 * 400]),
1120
+ dim=1)
1121
+ duration /= 2
1122
+ parent_given_tokens = output_tokens_merge
1123
+
1124
+ if self.args.both_stages:
1125
+ move_start_time = time.perf_counter()
1126
+ logger.debug('moving stage 2 model to cpu')
1127
+ model = model.cpu()
1128
+ torch.cuda.empty_cache()
1129
+ elapsed = time.perf_counter() - move_start_time
1130
+ logger.debug(f'moving out model2 takes time: {elapsed:.2f}')
1131
+
1132
+ elapsed = time.perf_counter() - stage2_start_time
1133
+ logger.info(f'CogVideo Stage2 completed. Elapsed: {elapsed:.2f}\n')
1134
+
1135
+ # direct super-resolution by CogView2
1136
+ logger.info('[Direct super-resolution]')
1137
+ dsr_start_time = time.perf_counter()
1138
+
1139
+ enc_text = tokenizer.encode(seq_text)
1140
+ frame_num_per_sample = parent_given_tokens.shape[1]
1141
+ parent_given_tokens_2d = parent_given_tokens.reshape(-1, 400)
1142
+ text_seq = torch.tensor(enc_text, dtype=torch.long,
1143
+ device=self.device).unsqueeze(0).repeat(
1144
+ parent_given_tokens_2d.shape[0], 1)
1145
+ sred_tokens = self.dsr(text_seq, parent_given_tokens_2d)
1146
+
1147
+ decoded_sr_videos = []
1148
+ for sample_i in range(sample_num):
1149
+ decoded_sr_imgs = []
1150
+ for frame_i in range(frame_num_per_sample):
1151
+ decoded_sr_img = tokenizer.decode(
1152
+ image_ids=sred_tokens[frame_i + sample_i *
1153
+ frame_num_per_sample][-3600:])
1154
+ decoded_sr_imgs.append(
1155
+ self.postprocess(
1156
+ torch.nn.functional.interpolate(decoded_sr_img,
1157
+ size=(480, 480))[0]))
1158
+ decoded_sr_videos.append(decoded_sr_imgs)
1159
+
1160
+ elapsed = time.perf_counter() - dsr_start_time
1161
+ logger.info(
1162
+ f'Direct super-resolution completed. Elapsed: {elapsed:.2f}')
1163
+
1164
+ elapsed = time.perf_counter() - process_start_time
1165
+ logger.info(f'--- done ({elapsed=:.3f}) ---')
1166
+ return True, decoded_sr_videos[0]
1167
+
1168
+ @staticmethod
1169
+ def postprocess(tensor: torch.Tensor) -> np.ndarray:
1170
+ return tensor.cpu().mul(255).add_(0.5).clamp_(0, 255).permute(
1171
+ 1, 2, 0).to(torch.uint8).numpy()
1172
+
1173
+ def run(self, text: str, seed: int,
1174
+ only_first_stage: bool,image_prompt: None) -> list[np.ndarray]:
1175
+ logger.info('==================== run ====================')
1176
+ start = time.perf_counter()
1177
+
1178
+ set_random_seed(seed)
1179
+ self.args.seed = seed
1180
+
1181
+ if only_first_stage:
1182
+ self.args.stage_1 = True
1183
+ self.args.both_stages = False
1184
+ else:
1185
+ self.args.stage_1 = False
1186
+ self.args.both_stages = True
1187
+
1188
+ parent_given_tokens, res = self.process_stage1(
1189
+ self.model_stage1,
1190
+ text,
1191
+ duration=4.0,
1192
+ video_raw_text=text,
1193
+ video_guidance_text='视频',
1194
+ image_text_suffix=' 高清摄影',
1195
+ batch_size=self.args.batch_size,
1196
+ image_prompt=image_prompt)
1197
+ if not only_first_stage:
1198
+ _, res = self.process_stage2(
1199
+ self.model_stage2,
1200
+ text,
1201
+ duration=2.0,
1202
+ parent_given_tokens=parent_given_tokens,
1203
+ video_raw_text=text + ' 视频',
1204
+ video_guidance_text='视频',
1205
+ gpu_rank=0,
1206
+ gpu_parallel_size=1) # TODO: 修改
1207
+
1208
+ elapsed = time.perf_counter() - start
1209
+ logger.info(f'Elapsed: {elapsed:.3f}')
1210
+ logger.info('==================== done ====================')
1211
+ return res
1212
+
1213
+
1214
+ class AppModel(Model):
1215
+ def __init__(self, only_first_stage: bool):
1216
+ super().__init__(only_first_stage)
1217
+ self.translator = gr.Interface.load(
1218
+ 'spaces/chinhon/translation_eng2ch')
1219
+
1220
+ def to_video(self, frames: list[np.ndarray]) -> str:
1221
+ out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
1222
+ if self.args.stage_1:
1223
+ fps = 4
1224
+ else:
1225
+ fps = 8
1226
+ writer = iio.get_writer(out_file.name, fps=fps)
1227
+ for frame in frames:
1228
+ writer.append_data(frame)
1229
+ writer.close()
1230
+ return out_file.name
1231
+
1232
+ def run_with_translation(
1233
+ self, text: str, translate: bool, seed: int,
1234
+ only_first_stage: bool,image_prompt: None) -> tuple[str | None, str | None]:
1235
+
1236
+ logger.info(f'{text=}, {translate=}, {seed=}, {only_first_stage=},{image_prompt=}')
1237
+ if translate:
1238
+ text = translated_text = self.translator(text)
1239
+ else:
1240
+ translated_text = None
1241
+ frames = self.run(text, seed, only_first_stage,image_prompt)
1242
+ video_path = self.to_video(frames)
1243
+ return translated_text, video_path
patch ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/coglm_strategy.py b/coglm_strategy.py
2
+ index d485715..a9eab3b 100644
3
+ --- a/coglm_strategy.py
4
+ +++ b/coglm_strategy.py
5
+ @@ -8,6 +8,7 @@
6
+
7
+ # here put the import lib
8
+ import os
9
+ +import pathlib
10
+ import sys
11
+ import math
12
+ import random
13
+ @@ -58,7 +59,8 @@ class CoglmStrategy:
14
+ self._is_done = False
15
+ self.outlier_count_down = torch.zeros(16)
16
+ self.vis_list = [[]for i in range(16)]
17
+ - self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
18
+ + cluster_label_path = pathlib.Path(__file__).parent / 'cluster_label2.npy'
19
+ + self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long)
20
+ self.start_pos = -1
21
+ self.white_cluster = []
22
+ # self.fout = open('tmp.txt', 'w')
23
+ @@ -98,4 +100,4 @@ class CoglmStrategy:
24
+
25
+ def finalize(self, tokens, mems):
26
+ self._is_done = False
27
+ - return tokens, mems
28
+
29
+ + return tokens, mems
30
+ diff --git a/sr_pipeline/dsr_sampling.py b/sr_pipeline/dsr_sampling.py
31
+ index 5b8dded..07e97fd 100644
32
+ --- a/sr_pipeline/dsr_sampling.py
33
+ +++ b/sr_pipeline/dsr_sampling.py
34
+ @@ -8,6 +8,7 @@
35
+
36
+ # here put the import lib
37
+ import os
38
+ +import pathlib
39
+ import sys
40
+ import math
41
+ import random
42
+ @@ -28,7 +29,8 @@ class IterativeEntfilterStrategy:
43
+ self.invalid_slices = invalid_slices
44
+ self.temperature = temperature
45
+ self.topk = topk
46
+ - self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
47
+ + cluster_label_path = pathlib.Path(__file__).parents[1] / 'cluster_label2.npy'
48
+ + self.cluster_labels = torch.tensor(np.load(cluster_label_path), device='cuda', dtype=torch.long)
49
+
50
+
51
+ def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
pretrained/.gitkeep ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ imageio==2.19.5
3
+ imageio-ffmpeg==0.4.7
4
+ numpy==1.22.4
5
+ opencv-python-headless==4.6.0.66
6
+ torch==1.12.0+cu113
7
+ torchvision==0.13.0+cu113
samples.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ 骑滑板的皮卡丘
2
+ a cat playing chess
style.css ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+ img#visitor-badge {
5
+ display: block;
6
+ margin: auto;
7
+ }