kadirnar commited on
Commit
6da9572
1 Parent(s): 610f381

⭐ Add Paints-Undo Library

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download')
4
+ result_dir = os.path.join('./', 'results')
5
+ os.makedirs(result_dir, exist_ok=True)
6
+
7
+
8
+ import functools
9
+ import os
10
+ import random
11
+ import gradio as gr
12
+ import numpy as np
13
+ import torch
14
+ import wd14tagger
15
+ import memory_management
16
+ import uuid
17
+ import spaces
18
+ from PIL import Image
19
+ from diffusers_helper.code_cond import unet_add_coded_conds
20
+ from diffusers_helper.cat_cond import unet_add_concat_conds
21
+ from diffusers_helper.k_diffusion import KDiffusionSampler
22
+ from diffusers import AutoencoderKL, UNet2DConditionModel
23
+ from diffusers.models.attention_processor import AttnProcessor2_0
24
+ from transformers import CLIPTextModel, CLIPTokenizer
25
+ from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline
26
+ from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4
27
+
28
+
29
+ class ModifiedUNet(UNet2DConditionModel):
30
+ @classmethod
31
+ def from_config(cls, *args, **kwargs):
32
+ m = super().from_config(*args, **kwargs)
33
+ unet_add_concat_conds(unet=m, new_channels=4)
34
+ unet_add_coded_conds(unet=m, added_number_count=1)
35
+ return m
36
+
37
+
38
+ model_name = 'lllyasviel/paints_undo_single_frame'
39
+ tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
40
+ text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder").to(torch.float16)
41
+ vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16) # bfloat16 vae
42
+ unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16)
43
+
44
+ unet.set_attn_processor(AttnProcessor2_0())
45
+ vae.set_attn_processor(AttnProcessor2_0())
46
+
47
+ video_pipe = LatentVideoDiffusionPipeline.from_pretrained(
48
+ 'lllyasviel/paints_undo_multi_frame',
49
+ fp16=True
50
+ )
51
+
52
+ memory_management.unload_all_models([
53
+ video_pipe.unet, video_pipe.vae, video_pipe.text_encoder, video_pipe.image_projection, video_pipe.image_encoder,
54
+ unet, vae, text_encoder
55
+ ])
56
+
57
+ k_sampler = KDiffusionSampler(
58
+ unet=unet,
59
+ timesteps=1000,
60
+ linear_start=0.00085,
61
+ linear_end=0.020,
62
+ linear=True
63
+ )
64
+
65
+
66
+ def find_best_bucket(h, w, options):
67
+ min_metric = float('inf')
68
+ best_bucket = None
69
+ for (bucket_h, bucket_w) in options:
70
+ metric = abs(h * bucket_w - w * bucket_h)
71
+ if metric <= min_metric:
72
+ min_metric = metric
73
+ best_bucket = (bucket_h, bucket_w)
74
+ return best_bucket
75
+
76
+
77
+ @torch.inference_mode()
78
+ def encode_cropped_prompt_77tokens(txt: str):
79
+ memory_management.load_models_to_gpu(text_encoder)
80
+ cond_ids = tokenizer(txt,
81
+ padding="max_length",
82
+ max_length=tokenizer.model_max_length,
83
+ truncation=True,
84
+ return_tensors="pt").input_ids.to(device=text_encoder.device)
85
+ text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state
86
+ return text_cond
87
+
88
+
89
+ @torch.inference_mode()
90
+ def pytorch2numpy(imgs):
91
+ results = []
92
+ for x in imgs:
93
+ y = x.movedim(0, -1)
94
+ y = y * 127.5 + 127.5
95
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
96
+ results.append(y)
97
+ return results
98
+
99
+
100
+ @torch.inference_mode()
101
+ def numpy2pytorch(imgs):
102
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
103
+ h = h.movedim(-1, 1)
104
+ return h
105
+
106
+
107
+ def resize_without_crop(image, target_width, target_height):
108
+ pil_image = Image.fromarray(image)
109
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
110
+ return np.array(resized_image)
111
+
112
+
113
+ @torch.inference_mode()
114
+ def interrogator_process(x):
115
+ return wd14tagger.default_interrogator(x)
116
+
117
+
118
+ @torch.inference_mode()
119
+ def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
120
+ progress=gr.Progress()):
121
+ rng = torch.Generator(device=memory_management.gpu).manual_seed(int(seed))
122
+
123
+ memory_management.load_models_to_gpu(vae)
124
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
125
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
126
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
127
+
128
+ memory_management.load_models_to_gpu(text_encoder)
129
+ conds = encode_cropped_prompt_77tokens(prompt)
130
+ unconds = encode_cropped_prompt_77tokens(n_prompt)
131
+
132
+ memory_management.load_models_to_gpu(unet)
133
+ fs = torch.tensor(input_undo_steps).to(device=unet.device, dtype=torch.long)
134
+ initial_latents = torch.zeros_like(concat_conds)
135
+ concat_conds = concat_conds.to(device=unet.device, dtype=unet.dtype)
136
+ latents = k_sampler(
137
+ initial_latent=initial_latents,
138
+ strength=1.0,
139
+ num_inference_steps=steps,
140
+ guidance_scale=cfg,
141
+ batch_size=len(input_undo_steps),
142
+ generator=rng,
143
+ prompt_embeds=conds,
144
+ negative_prompt_embeds=unconds,
145
+ cross_attention_kwargs={'concat_conds': concat_conds, 'coded_conds': fs},
146
+ same_noise_in_batch=True,
147
+ progress_tqdm=functools.partial(progress.tqdm, desc='Generating Key Frames')
148
+ ).to(vae.dtype) / vae.config.scaling_factor
149
+
150
+ memory_management.load_models_to_gpu(vae)
151
+ pixels = vae.decode(latents).sample
152
+ pixels = pytorch2numpy(pixels)
153
+ pixels = [fg] + pixels + [np.zeros_like(fg) + 255]
154
+
155
+ return pixels
156
+
157
+
158
+ @torch.inference_mode()
159
+ def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None):
160
+ random.seed(seed)
161
+ np.random.seed(seed)
162
+ torch.manual_seed(seed)
163
+ torch.cuda.manual_seed_all(seed)
164
+
165
+ frames = 16
166
+
167
+ target_height, target_width = find_best_bucket(
168
+ image_1.shape[0], image_1.shape[1],
169
+ options=[(320, 512), (384, 448), (448, 384), (512, 320)]
170
+ )
171
+
172
+ image_1 = resize_and_center_crop(image_1, target_width=target_width, target_height=target_height)
173
+ image_2 = resize_and_center_crop(image_2, target_width=target_width, target_height=target_height)
174
+ input_frames = numpy2pytorch([image_1, image_2])
175
+ input_frames = input_frames.unsqueeze(0).movedim(1, 2)
176
+
177
+ memory_management.load_models_to_gpu(video_pipe.text_encoder)
178
+ positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
179
+ negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")
180
+
181
+ memory_management.load_models_to_gpu([video_pipe.image_projection, video_pipe.image_encoder])
182
+ input_frames = input_frames.to(device=video_pipe.image_encoder.device, dtype=video_pipe.image_encoder.dtype)
183
+ positive_image_cond = video_pipe.encode_clip_vision(input_frames)
184
+ positive_image_cond = video_pipe.image_projection(positive_image_cond)
185
+ negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
186
+ negative_image_cond = video_pipe.image_projection(negative_image_cond)
187
+
188
+ memory_management.load_models_to_gpu([video_pipe.vae])
189
+ input_frames = input_frames.to(device=video_pipe.vae.device, dtype=video_pipe.vae.dtype)
190
+ input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
191
+ first_frame = input_frame_latents[:, :, 0]
192
+ last_frame = input_frame_latents[:, :, 1]
193
+ concat_cond = torch.stack([first_frame] + [torch.zeros_like(first_frame)] * (frames - 2) + [last_frame], dim=2)
194
+
195
+ memory_management.load_models_to_gpu([video_pipe.unet])
196
+ latents = video_pipe(
197
+ batch_size=1,
198
+ steps=int(steps),
199
+ guidance_scale=cfg_scale,
200
+ positive_text_cond=positive_text_cond,
201
+ negative_text_cond=negative_text_cond,
202
+ positive_image_cond=positive_image_cond,
203
+ negative_image_cond=negative_image_cond,
204
+ concat_cond=concat_cond,
205
+ fs=fs,
206
+ progress_tqdm=progress_tqdm
207
+ )
208
+
209
+ memory_management.load_models_to_gpu([video_pipe.vae])
210
+ video = video_pipe.decode_latents(latents, vae_hidden_states)
211
+ return video, image_1, image_2
212
+
213
+ @spaces.GPU
214
+ @torch.inference_mode()
215
+ def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()):
216
+ result_frames = []
217
+ cropped_images = []
218
+
219
+ for i, (im1, im2) in enumerate(zip(keyframes[:-1], keyframes[1:])):
220
+ im1 = np.array(Image.open(im1[0]))
221
+ im2 = np.array(Image.open(im2[0]))
222
+ frames, im1, im2 = process_video_inner(
223
+ im1, im2, prompt, seed=seed + i, steps=steps, cfg_scale=cfg, fs=3,
224
+ progress_tqdm=functools.partial(progress.tqdm, desc=f'Generating Videos ({i + 1}/{len(keyframes) - 1})')
225
+ )
226
+ result_frames.append(frames[:, :, :-1, :, :])
227
+ cropped_images.append([im1, im2])
228
+
229
+ video = torch.cat(result_frames, dim=2)
230
+ video = torch.flip(video, dims=[2])
231
+
232
+ uuid_name = str(uuid.uuid4())
233
+ output_filename = os.path.join(result_dir, uuid_name + '.mp4')
234
+ Image.fromarray(cropped_images[0][0]).save(os.path.join(result_dir, uuid_name + '.png'))
235
+ video = save_bcthw_as_mp4(video, output_filename, fps=fps)
236
+ video = [x.cpu().numpy() for x in video]
237
+ return output_filename, video
238
+
239
+
240
+ block = gr.Blocks().queue()
241
+ with block:
242
+ gr.Markdown('# Paints-Undo')
243
+
244
+ with gr.Accordion(label='Step 1: Upload Image and Generate Prompt', open=True):
245
+ with gr.Row():
246
+ with gr.Column():
247
+ input_fg = gr.Image(sources=['upload'], type="numpy", label="Image", height=512)
248
+ with gr.Column():
249
+ prompt_gen_button = gr.Button(value="Generate Prompt", interactive=False)
250
+ prompt = gr.Textbox(label="Output Prompt", interactive=True)
251
+
252
+ with gr.Accordion(label='Step 2: Generate Key Frames', open=True):
253
+ with gr.Row():
254
+ with gr.Column():
255
+ input_undo_steps = gr.Dropdown(label="Operation Steps", value=[400, 600, 800, 900, 950, 999],
256
+ choices=list(range(1000)), multiselect=True)
257
+ seed = gr.Slider(label='Stage 1 Seed', minimum=0, maximum=50000, step=1, value=12345)
258
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
259
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
260
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
261
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=3.0, step=0.01)
262
+ n_prompt = gr.Textbox(label="Negative Prompt",
263
+ value='lowres, bad anatomy, bad hands, cropped, worst quality')
264
+
265
+ with gr.Column():
266
+ key_gen_button = gr.Button(value="Generate Key Frames", interactive=False)
267
+ result_gallery = gr.Gallery(height=512, object_fit='contain', label='Outputs', columns=4)
268
+
269
+ with gr.Accordion(label='Step 3: Generate All Videos', open=True):
270
+ with gr.Row():
271
+ with gr.Column():
272
+ i2v_input_text = gr.Text(label='Prompts', value='1girl, masterpiece, best quality')
273
+ i2v_seed = gr.Slider(label='Stage 2 Seed', minimum=0, maximum=50000, step=1, value=123)
274
+ i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5,
275
+ elem_id="i2v_cfg_scale")
276
+ i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps",
277
+ label="Sampling steps", value=50)
278
+ i2v_fps = gr.Slider(minimum=1, maximum=30, step=1, elem_id="i2v_motion", label="FPS", value=4)
279
+ with gr.Column():
280
+ i2v_end_btn = gr.Button("Generate Video", interactive=False)
281
+ i2v_output_video = gr.Video(label="Generated Video", elem_id="output_vid", autoplay=True,
282
+ show_share_button=True, height=512)
283
+ with gr.Row():
284
+ i2v_output_images = gr.Gallery(height=512, label="Output Frames", object_fit="contain", columns=8)
285
+
286
+ input_fg.change(lambda: ["", gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=False)],
287
+ outputs=[prompt, prompt_gen_button, key_gen_button, i2v_end_btn])
288
+
289
+ prompt_gen_button.click(
290
+ fn=interrogator_process,
291
+ inputs=[input_fg],
292
+ outputs=[prompt]
293
+ ).then(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False)],
294
+ outputs=[prompt_gen_button, key_gen_button, i2v_end_btn])
295
+
296
+ key_gen_button.click(
297
+ fn=process,
298
+ inputs=[input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg],
299
+ outputs=[result_gallery]
300
+ ).then(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)],
301
+ outputs=[prompt_gen_button, key_gen_button, i2v_end_btn])
302
+
303
+ i2v_end_btn.click(
304
+ inputs=[result_gallery, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_fps, i2v_seed],
305
+ outputs=[i2v_output_video, i2v_output_images],
306
+ fn=process_video
307
+ )
308
+
309
+ dbs = [
310
+ ['./imgs/1.jpg', 12345, 123],
311
+ ['./imgs/2.jpg', 37000, 12345],
312
+ ['./imgs/3.jpg', 3000, 3000],
313
+ ]
314
+
315
+ gr.Examples(
316
+ examples=dbs,
317
+ inputs=[input_fg, seed, i2v_seed],
318
+ examples_per_page=1024
319
+ )
320
+
321
+ block.queue().launch(server_name='0.0.0.0')
diffusers_helper/cat_cond.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def unet_add_concat_conds(unet, new_channels=4):
5
+ with torch.no_grad():
6
+ new_conv_in = torch.nn.Conv2d(4 + new_channels, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
7
+ new_conv_in.weight.zero_()
8
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
9
+ new_conv_in.bias = unet.conv_in.bias
10
+ unet.conv_in = new_conv_in
11
+
12
+ unet_original_forward = unet.forward
13
+
14
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
15
+ cross_attention_kwargs = {k: v for k, v in kwargs['cross_attention_kwargs'].items()}
16
+ c_concat = cross_attention_kwargs.pop('concat_conds')
17
+ kwargs['cross_attention_kwargs'] = cross_attention_kwargs
18
+
19
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0).to(sample)
20
+ new_sample = torch.cat([sample, c_concat], dim=1)
21
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
22
+
23
+ unet.forward = hooked_unet_forward
24
+ return
diffusers_helper/code_cond.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
4
+
5
+
6
+ def unet_add_coded_conds(unet, added_number_count=1):
7
+ unet.add_time_proj = Timesteps(256, True, 0)
8
+ unet.add_embedding = TimestepEmbedding(256 * added_number_count, 1280)
9
+
10
+ def get_aug_embed(emb, encoder_hidden_states, added_cond_kwargs):
11
+ coded_conds = added_cond_kwargs.get("coded_conds")
12
+ batch_size = coded_conds.shape[0]
13
+ time_embeds = unet.add_time_proj(coded_conds.flatten())
14
+ time_embeds = time_embeds.reshape((batch_size, -1))
15
+ time_embeds = time_embeds.to(emb)
16
+ aug_emb = unet.add_embedding(time_embeds)
17
+ return aug_emb
18
+
19
+ unet.get_aug_embed = get_aug_embed
20
+
21
+ unet_original_forward = unet.forward
22
+
23
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
24
+ cross_attention_kwargs = {k: v for k, v in kwargs['cross_attention_kwargs'].items()}
25
+ coded_conds = cross_attention_kwargs.pop('coded_conds')
26
+ kwargs['cross_attention_kwargs'] = cross_attention_kwargs
27
+
28
+ coded_conds = torch.cat([coded_conds] * (sample.shape[0] // coded_conds.shape[0]), dim=0).to(sample.device)
29
+ kwargs['added_cond_kwargs'] = dict(coded_conds=coded_conds)
30
+ return unet_original_forward(sample, timestep, encoder_hidden_states, **kwargs)
31
+
32
+ unet.forward = hooked_unet_forward
33
+
34
+ return
diffusers_helper/k_diffusion.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from tqdm import tqdm
5
+
6
+
7
+ @torch.no_grad()
8
+ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, progress_tqdm=None):
9
+ """DPM-Solver++(2M)."""
10
+ extra_args = {} if extra_args is None else extra_args
11
+ s_in = x.new_ones([x.shape[0]])
12
+ sigma_fn = lambda t: t.neg().exp()
13
+ t_fn = lambda sigma: sigma.log().neg()
14
+ old_denoised = None
15
+
16
+ bar = tqdm if progress_tqdm is None else progress_tqdm
17
+
18
+ for i in bar(range(len(sigmas) - 1)):
19
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
20
+ if callback is not None:
21
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
22
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
23
+ h = t_next - t
24
+ if old_denoised is None or sigmas[i + 1] == 0:
25
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
26
+ else:
27
+ h_last = t - t_fn(sigmas[i - 1])
28
+ r = h_last / h
29
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
30
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
31
+ old_denoised = denoised
32
+ return x
33
+
34
+
35
+ class KModel:
36
+ def __init__(self, unet, timesteps=1000, linear_start=0.00085, linear_end=0.012, linear=False):
37
+ if linear:
38
+ betas = torch.linspace(linear_start, linear_end, timesteps, dtype=torch.float64)
39
+ else:
40
+ betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, timesteps, dtype=torch.float64) ** 2
41
+
42
+ alphas = 1. - betas
43
+ alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
44
+
45
+ self.sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
46
+ self.log_sigmas = self.sigmas.log()
47
+ self.sigma_data = 1.0
48
+ self.unet = unet
49
+ return
50
+
51
+ @property
52
+ def sigma_min(self):
53
+ return self.sigmas[0]
54
+
55
+ @property
56
+ def sigma_max(self):
57
+ return self.sigmas[-1]
58
+
59
+ def timestep(self, sigma):
60
+ log_sigma = sigma.log()
61
+ dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
62
+ return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
63
+
64
+ def get_sigmas_karras(self, n, rho=7.):
65
+ ramp = torch.linspace(0, 1, n)
66
+ min_inv_rho = self.sigma_min ** (1 / rho)
67
+ max_inv_rho = self.sigma_max ** (1 / rho)
68
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
69
+ return torch.cat([sigmas, sigmas.new_zeros([1])])
70
+
71
+ def __call__(self, x, sigma, **extra_args):
72
+ x_ddim_space = x / (sigma[:, None, None, None] ** 2 + self.sigma_data ** 2) ** 0.5
73
+ x_ddim_space = x_ddim_space.to(dtype=self.unet.dtype)
74
+ t = self.timestep(sigma)
75
+ cfg_scale = extra_args['cfg_scale']
76
+ eps_positive = self.unet(x_ddim_space, t, return_dict=False, **extra_args['positive'])[0]
77
+ eps_negative = self.unet(x_ddim_space, t, return_dict=False, **extra_args['negative'])[0]
78
+ noise_pred = eps_negative + cfg_scale * (eps_positive - eps_negative)
79
+ return x - noise_pred * sigma[:, None, None, None]
80
+
81
+
82
+ class KDiffusionSampler:
83
+ def __init__(self, unet, **kwargs):
84
+ self.unet = unet
85
+ self.k_model = KModel(unet=unet, **kwargs)
86
+
87
+ @torch.inference_mode()
88
+ def __call__(
89
+ self,
90
+ initial_latent = None,
91
+ strength = 1.0,
92
+ num_inference_steps = 25,
93
+ guidance_scale = 5.0,
94
+ batch_size = 1,
95
+ generator = None,
96
+ prompt_embeds = None,
97
+ negative_prompt_embeds = None,
98
+ cross_attention_kwargs = None,
99
+ same_noise_in_batch = False,
100
+ progress_tqdm = None,
101
+ ):
102
+
103
+ device = self.unet.device
104
+
105
+ # Sigmas
106
+
107
+ sigmas = self.k_model.get_sigmas_karras(int(num_inference_steps/strength))
108
+ sigmas = sigmas[-(num_inference_steps + 1):].to(device)
109
+
110
+ # Initial latents
111
+
112
+ if same_noise_in_batch:
113
+ noise = torch.randn(initial_latent.shape, generator=generator, device=device, dtype=self.unet.dtype).repeat(batch_size, 1, 1, 1)
114
+ initial_latent = initial_latent.repeat(batch_size, 1, 1, 1).to(device=device, dtype=self.unet.dtype)
115
+ else:
116
+ initial_latent = initial_latent.repeat(batch_size, 1, 1, 1).to(device=device, dtype=self.unet.dtype)
117
+ noise = torch.randn(initial_latent.shape, generator=generator, device=device, dtype=self.unet.dtype)
118
+
119
+ latents = initial_latent + noise * sigmas[0].to(initial_latent)
120
+
121
+ # Batch
122
+
123
+ latents = latents.to(device)
124
+ prompt_embeds = prompt_embeds.repeat(batch_size, 1, 1).to(device)
125
+ negative_prompt_embeds = negative_prompt_embeds.repeat(batch_size, 1, 1).to(device)
126
+
127
+ # Feeds
128
+
129
+ sampler_kwargs = dict(
130
+ cfg_scale=guidance_scale,
131
+ positive=dict(
132
+ encoder_hidden_states=prompt_embeds,
133
+ cross_attention_kwargs=cross_attention_kwargs
134
+ ),
135
+ negative=dict(
136
+ encoder_hidden_states=negative_prompt_embeds,
137
+ cross_attention_kwargs=cross_attention_kwargs,
138
+ )
139
+ )
140
+
141
+ # Sample
142
+
143
+ results = sample_dpmpp_2m(self.k_model, latents, sigmas, extra_args=sampler_kwargs, progress_tqdm=progress_tqdm)
144
+
145
+ return results
diffusers_helper/utils.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import glob
5
+ import torch
6
+ import einops
7
+ import torchvision
8
+
9
+ import safetensors.torch as sf
10
+
11
+
12
+ def write_to_json(data, file_path):
13
+ temp_file_path = file_path + ".tmp"
14
+ with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
15
+ json.dump(data, temp_file, indent=4)
16
+ os.replace(temp_file_path, file_path)
17
+ return
18
+
19
+
20
+ def read_from_json(file_path):
21
+ with open(file_path, 'rt', encoding='utf-8') as file:
22
+ data = json.load(file)
23
+ return data
24
+
25
+
26
+ def get_active_parameters(m):
27
+ return {k:v for k, v in m.named_parameters() if v.requires_grad}
28
+
29
+
30
+ def cast_training_params(m, dtype=torch.float32):
31
+ for param in m.parameters():
32
+ if param.requires_grad:
33
+ param.data = param.to(dtype)
34
+ return
35
+
36
+
37
+ def set_attr_recursive(obj, attr, value):
38
+ attrs = attr.split(".")
39
+ for name in attrs[:-1]:
40
+ obj = getattr(obj, name)
41
+ setattr(obj, attrs[-1], value)
42
+ return
43
+
44
+
45
+ @torch.no_grad()
46
+ def batch_mixture(a, b, probability_a=0.5, mask_a=None):
47
+ assert a.shape == b.shape, "Tensors must have the same shape"
48
+ batch_size = a.size(0)
49
+
50
+ if mask_a is None:
51
+ mask_a = torch.rand(batch_size) < probability_a
52
+
53
+ mask_a = mask_a.to(a.device)
54
+ mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
55
+ result = torch.where(mask_a, a, b)
56
+ return result
57
+
58
+
59
+ @torch.no_grad()
60
+ def zero_module(module):
61
+ for p in module.parameters():
62
+ p.detach().zero_()
63
+ return module
64
+
65
+
66
+ def load_last_state(model, folder='accelerator_output'):
67
+ file_pattern = os.path.join(folder, '**', 'model.safetensors')
68
+ files = glob.glob(file_pattern, recursive=True)
69
+
70
+ if not files:
71
+ print("No model.safetensors files found in the specified folder.")
72
+ return
73
+
74
+ newest_file = max(files, key=os.path.getmtime)
75
+ state_dict = sf.load_file(newest_file)
76
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
77
+
78
+ if missing_keys:
79
+ print("Missing keys:", missing_keys)
80
+ if unexpected_keys:
81
+ print("Unexpected keys:", unexpected_keys)
82
+
83
+ print("Loaded model state from:", newest_file)
84
+ return
85
+
86
+
87
+ def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
88
+ tags = tags_str.split(', ')
89
+ tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
90
+ prompt = ', '.join(tags)
91
+ return prompt
92
+
93
+
94
+ def save_bcthw_as_mp4(x, output_filename, fps=10):
95
+ b, c, t, h, w = x.shape
96
+
97
+ per_row = b
98
+ for p in [6, 5, 4, 3, 2]:
99
+ if b % p == 0:
100
+ per_row = p
101
+ break
102
+
103
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
104
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
105
+ x = x.detach().cpu().to(torch.uint8)
106
+ x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
107
+ torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '0'})
108
+ return x
109
+
110
+
111
+ def save_bcthw_as_png(x, output_filename):
112
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
113
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
114
+ x = x.detach().cpu().to(torch.uint8)
115
+ x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
116
+ torchvision.io.write_png(x, output_filename)
117
+ return output_filename
118
+
119
+
120
+ def add_tensors_with_padding(tensor1, tensor2):
121
+ if tensor1.shape == tensor2.shape:
122
+ return tensor1 + tensor2
123
+
124
+ shape1 = tensor1.shape
125
+ shape2 = tensor2.shape
126
+
127
+ new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
128
+
129
+ padded_tensor1 = torch.zeros(new_shape)
130
+ padded_tensor2 = torch.zeros(new_shape)
131
+
132
+ padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
133
+ padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
134
+
135
+ result = padded_tensor1 + padded_tensor2
136
+ return result
diffusers_vdm/attention.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import xformers.ops
3
+ import torch.nn.functional as F
4
+
5
+ from torch import nn
6
+ from einops import rearrange, repeat
7
+ from functools import partial
8
+ from diffusers_vdm.basics import zero_module, checkpoint, default, make_temporal_window
9
+
10
+
11
+ def sdp(q, k, v, heads):
12
+ b, _, C = q.shape
13
+ dim_head = C // heads
14
+
15
+ q, k, v = map(
16
+ lambda t: t.unsqueeze(3)
17
+ .reshape(b, t.shape[1], heads, dim_head)
18
+ .permute(0, 2, 1, 3)
19
+ .reshape(b * heads, t.shape[1], dim_head)
20
+ .contiguous(),
21
+ (q, k, v),
22
+ )
23
+
24
+ out = xformers.ops.memory_efficient_attention(q, k, v)
25
+
26
+ out = (
27
+ out.unsqueeze(0)
28
+ .reshape(b, heads, out.shape[1], dim_head)
29
+ .permute(0, 2, 1, 3)
30
+ .reshape(b, out.shape[1], heads * dim_head)
31
+ )
32
+
33
+ return out
34
+
35
+
36
+ class RelativePosition(nn.Module):
37
+ """ https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """
38
+
39
+ def __init__(self, num_units, max_relative_position):
40
+ super().__init__()
41
+ self.num_units = num_units
42
+ self.max_relative_position = max_relative_position
43
+ self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
44
+ nn.init.xavier_uniform_(self.embeddings_table)
45
+
46
+ def forward(self, length_q, length_k):
47
+ device = self.embeddings_table.device
48
+ range_vec_q = torch.arange(length_q, device=device)
49
+ range_vec_k = torch.arange(length_k, device=device)
50
+ distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
51
+ distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
52
+ final_mat = distance_mat_clipped + self.max_relative_position
53
+ final_mat = final_mat.long()
54
+ embeddings = self.embeddings_table[final_mat]
55
+ return embeddings
56
+
57
+
58
+ class CrossAttention(nn.Module):
59
+
60
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.,
61
+ relative_position=False, temporal_length=None, video_length=None, image_cross_attention=False,
62
+ image_cross_attention_scale=1.0, image_cross_attention_scale_learnable=False,
63
+ text_context_len=77, temporal_window_for_spatial_self_attention=False):
64
+ super().__init__()
65
+ inner_dim = dim_head * heads
66
+ context_dim = default(context_dim, query_dim)
67
+
68
+ self.scale = dim_head**-0.5
69
+ self.heads = heads
70
+ self.dim_head = dim_head
71
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
72
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
73
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
74
+
75
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
76
+
77
+ self.is_temporal_attention = temporal_length is not None
78
+
79
+ self.relative_position = relative_position
80
+ if self.relative_position:
81
+ assert self.is_temporal_attention
82
+ self.relative_position_k = RelativePosition(num_units=dim_head, max_relative_position=temporal_length)
83
+ self.relative_position_v = RelativePosition(num_units=dim_head, max_relative_position=temporal_length)
84
+
85
+ self.video_length = video_length
86
+ self.temporal_window_for_spatial_self_attention = temporal_window_for_spatial_self_attention
87
+ self.temporal_window_type = 'prv'
88
+
89
+ self.image_cross_attention = image_cross_attention
90
+ self.image_cross_attention_scale = image_cross_attention_scale
91
+ self.text_context_len = text_context_len
92
+ self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable
93
+ if self.image_cross_attention:
94
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
95
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
96
+ if image_cross_attention_scale_learnable:
97
+ self.register_parameter('alpha', nn.Parameter(torch.tensor(0.)) )
98
+
99
+ def forward(self, x, context=None, mask=None):
100
+ if self.is_temporal_attention:
101
+ return self.temporal_forward(x, context=context, mask=mask)
102
+ else:
103
+ return self.spatial_forward(x, context=context, mask=mask)
104
+
105
+ def temporal_forward(self, x, context=None, mask=None):
106
+ assert mask is None, 'Attention mask not implemented!'
107
+ assert context is None, 'Temporal attention only supports self attention!'
108
+
109
+ q = self.to_q(x)
110
+ k = self.to_k(x)
111
+ v = self.to_v(x)
112
+
113
+ out = sdp(q, k, v, self.heads)
114
+
115
+ return self.to_out(out)
116
+
117
+ def spatial_forward(self, x, context=None, mask=None):
118
+ assert mask is None, 'Attention mask not implemented!'
119
+
120
+ spatial_self_attn = (context is None)
121
+ k_ip, v_ip, out_ip = None, None, None
122
+
123
+ q = self.to_q(x)
124
+ context = default(context, x)
125
+
126
+ if spatial_self_attn:
127
+ k = self.to_k(context)
128
+ v = self.to_v(context)
129
+
130
+ if self.temporal_window_for_spatial_self_attention:
131
+ k = make_temporal_window(k, t=self.video_length, method=self.temporal_window_type)
132
+ v = make_temporal_window(v, t=self.video_length, method=self.temporal_window_type)
133
+ elif self.image_cross_attention:
134
+ context, context_image = context
135
+ k = self.to_k(context)
136
+ v = self.to_v(context)
137
+ k_ip = self.to_k_ip(context_image)
138
+ v_ip = self.to_v_ip(context_image)
139
+ else:
140
+ raise NotImplementedError('Traditional prompt-only attention without IP-Adapter is illegal now.')
141
+
142
+ out = sdp(q, k, v, self.heads)
143
+
144
+ if k_ip is not None:
145
+ out_ip = sdp(q, k_ip, v_ip, self.heads)
146
+
147
+ if self.image_cross_attention_scale_learnable:
148
+ out = out + self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha) + 1)
149
+ else:
150
+ out = out + self.image_cross_attention_scale * out_ip
151
+
152
+ return self.to_out(out)
153
+
154
+
155
+ class BasicTransformerBlock(nn.Module):
156
+
157
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
158
+ disable_self_attn=False, attention_cls=None, video_length=None, image_cross_attention=False, image_cross_attention_scale=1.0, image_cross_attention_scale_learnable=False, text_context_len=77):
159
+ super().__init__()
160
+ attn_cls = CrossAttention if attention_cls is None else attention_cls
161
+ self.disable_self_attn = disable_self_attn
162
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
163
+ context_dim=context_dim if self.disable_self_attn else None, video_length=video_length)
164
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
165
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout, video_length=video_length, image_cross_attention=image_cross_attention, image_cross_attention_scale=image_cross_attention_scale, image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,text_context_len=text_context_len)
166
+ self.image_cross_attention = image_cross_attention
167
+
168
+ self.norm1 = nn.LayerNorm(dim)
169
+ self.norm2 = nn.LayerNorm(dim)
170
+ self.norm3 = nn.LayerNorm(dim)
171
+ self.checkpoint = checkpoint
172
+
173
+
174
+ def forward(self, x, context=None, mask=None, **kwargs):
175
+ ## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
176
+ input_tuple = (x,) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments
177
+ if context is not None:
178
+ input_tuple = (x, context)
179
+ if mask is not None:
180
+ forward_mask = partial(self._forward, mask=mask)
181
+ return checkpoint(forward_mask, (x,), self.parameters(), self.checkpoint)
182
+ return checkpoint(self._forward, input_tuple, self.parameters(), self.checkpoint)
183
+
184
+
185
+ def _forward(self, x, context=None, mask=None):
186
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x
187
+ x = self.attn2(self.norm2(x), context=context, mask=mask) + x
188
+ x = self.ff(self.norm3(x)) + x
189
+ return x
190
+
191
+
192
+ class SpatialTransformer(nn.Module):
193
+ """
194
+ Transformer block for image-like data in spatial axis.
195
+ First, project the input (aka embedding)
196
+ and reshape to b, t, d.
197
+ Then apply standard transformer action.
198
+ Finally, reshape to image
199
+ NEW: use_linear for more efficiency instead of the 1x1 convs
200
+ """
201
+
202
+ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None,
203
+ use_checkpoint=True, disable_self_attn=False, use_linear=False, video_length=None,
204
+ image_cross_attention=False, image_cross_attention_scale_learnable=False):
205
+ super().__init__()
206
+ self.in_channels = in_channels
207
+ inner_dim = n_heads * d_head
208
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
209
+ if not use_linear:
210
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
211
+ else:
212
+ self.proj_in = nn.Linear(in_channels, inner_dim)
213
+
214
+ attention_cls = None
215
+ self.transformer_blocks = nn.ModuleList([
216
+ BasicTransformerBlock(
217
+ inner_dim,
218
+ n_heads,
219
+ d_head,
220
+ dropout=dropout,
221
+ context_dim=context_dim,
222
+ disable_self_attn=disable_self_attn,
223
+ checkpoint=use_checkpoint,
224
+ attention_cls=attention_cls,
225
+ video_length=video_length,
226
+ image_cross_attention=image_cross_attention,
227
+ image_cross_attention_scale_learnable=image_cross_attention_scale_learnable,
228
+ ) for d in range(depth)
229
+ ])
230
+ if not use_linear:
231
+ self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
232
+ else:
233
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
234
+ self.use_linear = use_linear
235
+
236
+
237
+ def forward(self, x, context=None, **kwargs):
238
+ b, c, h, w = x.shape
239
+ x_in = x
240
+ x = self.norm(x)
241
+ if not self.use_linear:
242
+ x = self.proj_in(x)
243
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
244
+ if self.use_linear:
245
+ x = self.proj_in(x)
246
+ for i, block in enumerate(self.transformer_blocks):
247
+ x = block(x, context=context, **kwargs)
248
+ if self.use_linear:
249
+ x = self.proj_out(x)
250
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
251
+ if not self.use_linear:
252
+ x = self.proj_out(x)
253
+ return x + x_in
254
+
255
+
256
+ class TemporalTransformer(nn.Module):
257
+ """
258
+ Transformer block for image-like data in temporal axis.
259
+ First, reshape to b, t, d.
260
+ Then apply standard transformer action.
261
+ Finally, reshape to image
262
+ """
263
+ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None,
264
+ use_checkpoint=True, use_linear=False, only_self_att=True, causal_attention=False, causal_block_size=1,
265
+ relative_position=False, temporal_length=None):
266
+ super().__init__()
267
+ self.only_self_att = only_self_att
268
+ self.relative_position = relative_position
269
+ self.causal_attention = causal_attention
270
+ self.causal_block_size = causal_block_size
271
+
272
+ self.in_channels = in_channels
273
+ inner_dim = n_heads * d_head
274
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
275
+ self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
276
+ if not use_linear:
277
+ self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
278
+ else:
279
+ self.proj_in = nn.Linear(in_channels, inner_dim)
280
+
281
+ if relative_position:
282
+ assert(temporal_length is not None)
283
+ attention_cls = partial(CrossAttention, relative_position=True, temporal_length=temporal_length)
284
+ else:
285
+ attention_cls = partial(CrossAttention, temporal_length=temporal_length)
286
+ if self.causal_attention:
287
+ assert(temporal_length is not None)
288
+ self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length]))
289
+
290
+ if self.only_self_att:
291
+ context_dim = None
292
+ self.transformer_blocks = nn.ModuleList([
293
+ BasicTransformerBlock(
294
+ inner_dim,
295
+ n_heads,
296
+ d_head,
297
+ dropout=dropout,
298
+ context_dim=context_dim,
299
+ attention_cls=attention_cls,
300
+ checkpoint=use_checkpoint) for d in range(depth)
301
+ ])
302
+ if not use_linear:
303
+ self.proj_out = zero_module(nn.Conv1d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
304
+ else:
305
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
306
+ self.use_linear = use_linear
307
+
308
+ def forward(self, x, context=None):
309
+ b, c, t, h, w = x.shape
310
+ x_in = x
311
+ x = self.norm(x)
312
+ x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous()
313
+ if not self.use_linear:
314
+ x = self.proj_in(x)
315
+ x = rearrange(x, 'bhw c t -> bhw t c').contiguous()
316
+ if self.use_linear:
317
+ x = self.proj_in(x)
318
+
319
+ temp_mask = None
320
+ if self.causal_attention:
321
+ # slice the from mask map
322
+ temp_mask = self.mask[:,:t,:t].to(x.device)
323
+
324
+ if temp_mask is not None:
325
+ mask = temp_mask.to(x.device)
326
+ mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b*h*w)
327
+ else:
328
+ mask = None
329
+
330
+ if self.only_self_att:
331
+ ## note: if no context is given, cross-attention defaults to self-attention
332
+ for i, block in enumerate(self.transformer_blocks):
333
+ x = block(x, mask=mask)
334
+ x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
335
+ else:
336
+ x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
337
+ context = rearrange(context, '(b t) l con -> b t l con', t=t).contiguous()
338
+ for i, block in enumerate(self.transformer_blocks):
339
+ # calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
340
+ for j in range(b):
341
+ context_j = repeat(
342
+ context[j],
343
+ 't l con -> (t r) l con', r=(h * w) // t, t=t).contiguous()
344
+ ## note: causal mask will not applied in cross-attention case
345
+ x[j] = block(x[j], context=context_j)
346
+
347
+ if self.use_linear:
348
+ x = self.proj_out(x)
349
+ x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous()
350
+ if not self.use_linear:
351
+ x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous()
352
+ x = self.proj_out(x)
353
+ x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h, w=w).contiguous()
354
+
355
+ return x + x_in
356
+
357
+
358
+ class GEGLU(nn.Module):
359
+ def __init__(self, dim_in, dim_out):
360
+ super().__init__()
361
+ self.proj = nn.Linear(dim_in, dim_out * 2)
362
+
363
+ def forward(self, x):
364
+ x, gate = self.proj(x).chunk(2, dim=-1)
365
+ return x * F.gelu(gate)
366
+
367
+
368
+ class FeedForward(nn.Module):
369
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
370
+ super().__init__()
371
+ inner_dim = int(dim * mult)
372
+ dim_out = default(dim_out, dim)
373
+ project_in = nn.Sequential(
374
+ nn.Linear(dim, inner_dim),
375
+ nn.GELU()
376
+ ) if not glu else GEGLU(dim, inner_dim)
377
+
378
+ self.net = nn.Sequential(
379
+ project_in,
380
+ nn.Dropout(dropout),
381
+ nn.Linear(inner_dim, dim_out)
382
+ )
383
+
384
+ def forward(self, x):
385
+ return self.net(x)
diffusers_vdm/basics.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import einops
14
+
15
+ from inspect import isfunction
16
+
17
+
18
+ def zero_module(module):
19
+ """
20
+ Zero out the parameters of a module and return it.
21
+ """
22
+ for p in module.parameters():
23
+ p.detach().zero_()
24
+ return module
25
+
26
+ def scale_module(module, scale):
27
+ """
28
+ Scale the parameters of a module and return it.
29
+ """
30
+ for p in module.parameters():
31
+ p.detach().mul_(scale)
32
+ return module
33
+
34
+
35
+ def conv_nd(dims, *args, **kwargs):
36
+ """
37
+ Create a 1D, 2D, or 3D convolution module.
38
+ """
39
+ if dims == 1:
40
+ return nn.Conv1d(*args, **kwargs)
41
+ elif dims == 2:
42
+ return nn.Conv2d(*args, **kwargs)
43
+ elif dims == 3:
44
+ return nn.Conv3d(*args, **kwargs)
45
+ raise ValueError(f"unsupported dimensions: {dims}")
46
+
47
+
48
+ def linear(*args, **kwargs):
49
+ """
50
+ Create a linear module.
51
+ """
52
+ return nn.Linear(*args, **kwargs)
53
+
54
+
55
+ def avg_pool_nd(dims, *args, **kwargs):
56
+ """
57
+ Create a 1D, 2D, or 3D average pooling module.
58
+ """
59
+ if dims == 1:
60
+ return nn.AvgPool1d(*args, **kwargs)
61
+ elif dims == 2:
62
+ return nn.AvgPool2d(*args, **kwargs)
63
+ elif dims == 3:
64
+ return nn.AvgPool3d(*args, **kwargs)
65
+ raise ValueError(f"unsupported dimensions: {dims}")
66
+
67
+
68
+ def nonlinearity(type='silu'):
69
+ if type == 'silu':
70
+ return nn.SiLU()
71
+ elif type == 'leaky_relu':
72
+ return nn.LeakyReLU()
73
+
74
+
75
+ def normalization(channels, num_groups=32):
76
+ """
77
+ Make a standard normalization layer.
78
+ :param channels: number of input channels.
79
+ :return: an nn.Module for normalization.
80
+ """
81
+ return nn.GroupNorm(num_groups, channels)
82
+
83
+
84
+ def default(val, d):
85
+ if exists(val):
86
+ return val
87
+ return d() if isfunction(d) else d
88
+
89
+
90
+ def exists(val):
91
+ return val is not None
92
+
93
+
94
+ def extract_into_tensor(a, t, x_shape):
95
+ b, *_ = t.shape
96
+ out = a.gather(-1, t)
97
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
98
+
99
+
100
+ def make_temporal_window(x, t, method):
101
+ assert method in ['roll', 'prv', 'first']
102
+
103
+ if method == 'roll':
104
+ m = einops.rearrange(x, '(b t) d c -> b t d c', t=t)
105
+ l = torch.roll(m, shifts=1, dims=1)
106
+ r = torch.roll(m, shifts=-1, dims=1)
107
+
108
+ recon = torch.cat([l, m, r], dim=2)
109
+ del l, m, r
110
+
111
+ recon = einops.rearrange(recon, 'b t d c -> (b t) d c')
112
+ return recon
113
+
114
+ if method == 'prv':
115
+ x = einops.rearrange(x, '(b t) d c -> b t d c', t=t)
116
+ prv = torch.cat([x[:, :1], x[:, :-1]], dim=1)
117
+
118
+ recon = torch.cat([x, prv], dim=2)
119
+ del x, prv
120
+
121
+ recon = einops.rearrange(recon, 'b t d c -> (b t) d c')
122
+ return recon
123
+
124
+ if method == 'first':
125
+ x = einops.rearrange(x, '(b t) d c -> b t d c', t=t)
126
+ prv = x[:, [0], :, :].repeat(1, t, 1, 1)
127
+
128
+ recon = torch.cat([x, prv], dim=2)
129
+ del x, prv
130
+
131
+ recon = einops.rearrange(recon, 'b t d c -> (b t) d c')
132
+ return recon
133
+
134
+
135
+ def checkpoint(func, inputs, params, flag):
136
+ """
137
+ Evaluate a function without caching intermediate activations, allowing for
138
+ reduced memory at the expense of extra compute in the backward pass.
139
+ :param func: the function to evaluate.
140
+ :param inputs: the argument sequence to pass to `func`.
141
+ :param params: a sequence of parameters `func` depends on but does not
142
+ explicitly take as arguments.
143
+ :param flag: if False, disable gradient checkpointing.
144
+ """
145
+ if flag:
146
+ return torch.utils.checkpoint.checkpoint(func, *inputs, use_reentrant=False)
147
+ else:
148
+ return func(*inputs)
diffusers_vdm/dynamic_tsnr_sampler.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # everything that can improve v-prediction model
2
+ # dynamic scaling + tsnr + beta modifier + dynamic cfg rescale + ...
3
+ # written by lvmin at stanford 2024
4
+
5
+ import torch
6
+ import numpy as np
7
+
8
+ from tqdm import tqdm
9
+ from functools import partial
10
+ from diffusers_vdm.basics import extract_into_tensor
11
+
12
+
13
+ to_torch = partial(torch.tensor, dtype=torch.float32)
14
+
15
+
16
+ def rescale_zero_terminal_snr(betas):
17
+ # Convert betas to alphas_bar_sqrt
18
+ alphas = 1.0 - betas
19
+ alphas_cumprod = np.cumprod(alphas, axis=0)
20
+ alphas_bar_sqrt = np.sqrt(alphas_cumprod)
21
+
22
+ # Store old values.
23
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
24
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()
25
+
26
+ # Shift so the last timestep is zero.
27
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
28
+
29
+ # Scale so the first timestep is back to the old value.
30
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
31
+
32
+ # Convert alphas_bar_sqrt to betas
33
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
34
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
35
+ alphas = np.concatenate([alphas_bar[0:1], alphas])
36
+ betas = 1 - alphas
37
+
38
+ return betas
39
+
40
+
41
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
42
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
43
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
44
+
45
+ # rescale the results from guidance (fixes overexposure)
46
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
47
+
48
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
49
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
50
+
51
+ return noise_cfg
52
+
53
+
54
+ class SamplerDynamicTSNR(torch.nn.Module):
55
+ @torch.no_grad()
56
+ def __init__(self, unet, terminal_scale=0.7):
57
+ super().__init__()
58
+ self.unet = unet
59
+
60
+ self.is_v = True
61
+ self.n_timestep = 1000
62
+ self.guidance_rescale = 0.7
63
+
64
+ linear_start = 0.00085
65
+ linear_end = 0.012
66
+
67
+ betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, self.n_timestep, dtype=np.float64) ** 2
68
+ betas = rescale_zero_terminal_snr(betas)
69
+ alphas = 1. - betas
70
+
71
+ alphas_cumprod = np.cumprod(alphas, axis=0)
72
+
73
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod).to(unet.device))
74
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)).to(unet.device))
75
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)).to(unet.device))
76
+
77
+ # Dynamic TSNR
78
+ turning_step = 400
79
+ scale_arr = np.concatenate([
80
+ np.linspace(1.0, terminal_scale, turning_step),
81
+ np.full(self.n_timestep - turning_step, terminal_scale)
82
+ ])
83
+ self.register_buffer('scale_arr', to_torch(scale_arr).to(unet.device))
84
+
85
+ def predict_eps_from_z_and_v(self, x_t, t, v):
86
+ return self.sqrt_alphas_cumprod[t] * v + self.sqrt_one_minus_alphas_cumprod[t] * x_t
87
+
88
+ def predict_start_from_z_and_v(self, x_t, t, v):
89
+ return self.sqrt_alphas_cumprod[t] * x_t - self.sqrt_one_minus_alphas_cumprod[t] * v
90
+
91
+ def q_sample(self, x0, t, noise):
92
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape) * x0 +
93
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
94
+
95
+ def get_v(self, x0, t, noise):
96
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape) * noise -
97
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * x0)
98
+
99
+ def dynamic_x0_rescale(self, x0, t):
100
+ return x0 * extract_into_tensor(self.scale_arr, t, x0.shape)
101
+
102
+ @torch.no_grad()
103
+ def get_ground_truth(self, x0, noise, t):
104
+ x0 = self.dynamic_x0_rescale(x0, t)
105
+ xt = self.q_sample(x0, t, noise)
106
+ target = self.get_v(x0, t, noise) if self.is_v else noise
107
+ return xt, target
108
+
109
+ def get_uniform_trailing_steps(self, steps):
110
+ c = self.n_timestep / steps
111
+ ddim_timesteps = np.flip(np.round(np.arange(self.n_timestep, 0, -c))).astype(np.int64)
112
+ steps_out = ddim_timesteps - 1
113
+ return torch.tensor(steps_out, device=self.unet.device, dtype=torch.long)
114
+
115
+ @torch.no_grad()
116
+ def forward(self, latent_shape, steps, extra_args, progress_tqdm=None):
117
+ bar = tqdm if progress_tqdm is None else progress_tqdm
118
+
119
+ eta = 1.0
120
+
121
+ timesteps = self.get_uniform_trailing_steps(steps)
122
+ timesteps_prev = torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))
123
+
124
+ x = torch.randn(latent_shape, device=self.unet.device, dtype=self.unet.dtype)
125
+
126
+ alphas = self.alphas_cumprod[timesteps]
127
+ alphas_prev = self.alphas_cumprod[timesteps_prev]
128
+ scale_arr = self.scale_arr[timesteps]
129
+ scale_arr_prev = self.scale_arr[timesteps_prev]
130
+
131
+ sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
132
+ sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
133
+
134
+ s_in = x.new_ones((x.shape[0]))
135
+ s_x = x.new_ones((x.shape[0], ) + (1, ) * (x.ndim - 1))
136
+ for i in bar(range(len(timesteps))):
137
+ index = len(timesteps) - 1 - i
138
+ t = timesteps[index].item()
139
+
140
+ model_output = self.model_apply(x, t * s_in, **extra_args)
141
+
142
+ if self.is_v:
143
+ e_t = self.predict_eps_from_z_and_v(x, t, model_output)
144
+ else:
145
+ e_t = model_output
146
+
147
+ a_prev = alphas_prev[index].item() * s_x
148
+ sigma_t = sigmas[index].item() * s_x
149
+
150
+ if self.is_v:
151
+ pred_x0 = self.predict_start_from_z_and_v(x, t, model_output)
152
+ else:
153
+ a_t = alphas[index].item() * s_x
154
+ sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
155
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
156
+
157
+ # dynamic rescale
158
+ scale_t = scale_arr[index].item() * s_x
159
+ prev_scale_t = scale_arr_prev[index].item() * s_x
160
+ rescale = (prev_scale_t / scale_t)
161
+ pred_x0 = pred_x0 * rescale
162
+
163
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
164
+ noise = sigma_t * torch.randn_like(x)
165
+ x = a_prev.sqrt() * pred_x0 + dir_xt + noise
166
+
167
+ return x
168
+
169
+ @torch.no_grad()
170
+ def model_apply(self, x, t, **extra_args):
171
+ x = x.to(device=self.unet.device, dtype=self.unet.dtype)
172
+ cfg_scale = extra_args['cfg_scale']
173
+ p = self.unet(x, t, **extra_args['positive'])
174
+ n = self.unet(x, t, **extra_args['negative'])
175
+ o = n + cfg_scale * (p - n)
176
+ o_better = rescale_noise_cfg(o, p, guidance_rescale=self.guidance_rescale)
177
+ return o_better
diffusers_vdm/improved_clip_vision.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A CLIP Vision supporting arbitrary aspect ratios, by lllyasviel
2
+ # The input range is changed to [-1, 1] rather than [0, 1] !!!! (same as VAE's range)
3
+
4
+ import torch
5
+ import types
6
+ import einops
7
+
8
+ from abc import ABCMeta
9
+ from transformers import CLIPVisionModelWithProjection
10
+
11
+
12
+ def preprocess(image):
13
+ mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=image.device, dtype=image.dtype)[None, :, None, None]
14
+ std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=image.device, dtype=image.dtype)[None, :, None, None]
15
+
16
+ scale = 16 / min(image.shape[2], image.shape[3])
17
+ image = torch.nn.functional.interpolate(
18
+ image,
19
+ size=(14 * round(scale * image.shape[2]), 14 * round(scale * image.shape[3])),
20
+ mode="bicubic",
21
+ antialias=True
22
+ )
23
+
24
+ return (image - mean) / std
25
+
26
+
27
+ def arbitrary_positional_encoding(p, H, W):
28
+ weight = p.weight
29
+ cls = weight[:1]
30
+ pos = weight[1:]
31
+ pos = einops.rearrange(pos, '(H W) C -> 1 C H W', H=16, W=16)
32
+ pos = torch.nn.functional.interpolate(pos, size=(H, W), mode="nearest")
33
+ pos = einops.rearrange(pos, '1 C H W -> (H W) C')
34
+ weight = torch.cat([cls, pos])[None]
35
+ return weight
36
+
37
+
38
+ def improved_clipvision_embedding_forward(self, pixel_values):
39
+ pixel_values = pixel_values * 0.5 + 0.5
40
+ pixel_values = preprocess(pixel_values)
41
+ batch_size = pixel_values.shape[0]
42
+ target_dtype = self.patch_embedding.weight.dtype
43
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
44
+ B, C, H, W = patch_embeds.shape
45
+ patch_embeds = einops.rearrange(patch_embeds, 'B C H W -> B (H W) C')
46
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
47
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
48
+ embeddings = embeddings + arbitrary_positional_encoding(self.position_embedding, H, W)
49
+ return embeddings
50
+
51
+
52
+ class ImprovedCLIPVisionModelWithProjection(CLIPVisionModelWithProjection, metaclass=ABCMeta):
53
+ def __init__(self, config):
54
+ super().__init__(config)
55
+ self.vision_model.embeddings.forward = types.MethodType(
56
+ improved_clipvision_embedding_forward,
57
+ self.vision_model.embeddings
58
+ )
diffusers_vdm/pipeline.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import einops
4
+
5
+ from diffusers import DiffusionPipeline
6
+ from transformers import CLIPTextModel, CLIPTokenizer
7
+ from huggingface_hub import snapshot_download
8
+ from diffusers_vdm.vae import VideoAutoencoderKL
9
+ from diffusers_vdm.projection import Resampler
10
+ from diffusers_vdm.unet import UNet3DModel
11
+ from diffusers_vdm.improved_clip_vision import ImprovedCLIPVisionModelWithProjection
12
+ from diffusers_vdm.dynamic_tsnr_sampler import SamplerDynamicTSNR
13
+
14
+
15
+ class LatentVideoDiffusionPipeline(DiffusionPipeline):
16
+ def __init__(self, tokenizer, text_encoder, image_encoder, vae, image_projection, unet, fp16=True, eval=True):
17
+ super().__init__()
18
+
19
+ self.loading_components = dict(
20
+ vae=vae,
21
+ text_encoder=text_encoder,
22
+ tokenizer=tokenizer,
23
+ unet=unet,
24
+ image_encoder=image_encoder,
25
+ image_projection=image_projection
26
+ )
27
+
28
+ for k, v in self.loading_components.items():
29
+ setattr(self, k, v)
30
+
31
+ if fp16:
32
+ self.vae.half()
33
+ self.text_encoder.half()
34
+ self.unet.half()
35
+ self.image_encoder.half()
36
+ self.image_projection.half()
37
+
38
+ self.vae.requires_grad_(False)
39
+ self.text_encoder.requires_grad_(False)
40
+ self.image_encoder.requires_grad_(False)
41
+
42
+ self.vae.eval()
43
+ self.text_encoder.eval()
44
+ self.image_encoder.eval()
45
+
46
+ if eval:
47
+ self.unet.eval()
48
+ self.image_projection.eval()
49
+ else:
50
+ self.unet.train()
51
+ self.image_projection.train()
52
+
53
+ def to(self, *args, **kwargs):
54
+ for k, v in self.loading_components.items():
55
+ if hasattr(v, 'to'):
56
+ v.to(*args, **kwargs)
57
+ return self
58
+
59
+ def save_pretrained(self, save_directory, **kwargs):
60
+ for k, v in self.loading_components.items():
61
+ folder = os.path.join(save_directory, k)
62
+ os.makedirs(folder, exist_ok=True)
63
+ v.save_pretrained(folder)
64
+ return
65
+
66
+ @classmethod
67
+ def from_pretrained(cls, repo_id, fp16=True, eval=True, token=None):
68
+ local_folder = snapshot_download(repo_id=repo_id, token=token)
69
+ return cls(
70
+ tokenizer=CLIPTokenizer.from_pretrained(os.path.join(local_folder, "tokenizer")),
71
+ text_encoder=CLIPTextModel.from_pretrained(os.path.join(local_folder, "text_encoder")),
72
+ image_encoder=ImprovedCLIPVisionModelWithProjection.from_pretrained(os.path.join(local_folder, "image_encoder")),
73
+ vae=VideoAutoencoderKL.from_pretrained(os.path.join(local_folder, "vae")),
74
+ image_projection=Resampler.from_pretrained(os.path.join(local_folder, "image_projection")),
75
+ unet=UNet3DModel.from_pretrained(os.path.join(local_folder, "unet")),
76
+ fp16=fp16,
77
+ eval=eval
78
+ )
79
+
80
+ @torch.inference_mode()
81
+ def encode_cropped_prompt_77tokens(self, prompt: str):
82
+ cond_ids = self.tokenizer(prompt,
83
+ padding="max_length",
84
+ max_length=self.tokenizer.model_max_length,
85
+ truncation=True,
86
+ return_tensors="pt").input_ids.to(self.text_encoder.device)
87
+ cond = self.text_encoder(cond_ids, attention_mask=None).last_hidden_state
88
+ return cond
89
+
90
+ @torch.inference_mode()
91
+ def encode_clip_vision(self, frames):
92
+ b, c, t, h, w = frames.shape
93
+ frames = einops.rearrange(frames, 'b c t h w -> (b t) c h w')
94
+ clipvision_embed = self.image_encoder(frames).last_hidden_state
95
+ clipvision_embed = einops.rearrange(clipvision_embed, '(b t) d c -> b t d c', t=t)
96
+ return clipvision_embed
97
+
98
+ @torch.inference_mode()
99
+ def encode_latents(self, videos, return_hidden_states=True):
100
+ b, c, t, h, w = videos.shape
101
+ x = einops.rearrange(videos, 'b c t h w -> (b t) c h w')
102
+ encoder_posterior, hidden_states = self.vae.encode(x, return_hidden_states=return_hidden_states)
103
+ z = encoder_posterior.mode() * self.vae.scale_factor
104
+ z = einops.rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
105
+
106
+ if not return_hidden_states:
107
+ return z
108
+
109
+ hidden_states = [einops.rearrange(h, '(b t) c h w -> b c t h w', b=b) for h in hidden_states]
110
+ hidden_states = [h[:, :, [0, -1], :, :] for h in hidden_states] # only need first and last
111
+
112
+ return z, hidden_states
113
+
114
+ @torch.inference_mode()
115
+ def decode_latents(self, latents, hidden_states):
116
+ B, C, T, H, W = latents.shape
117
+ latents = einops.rearrange(latents, 'b c t h w -> (b t) c h w')
118
+ latents = latents.to(device=self.vae.device, dtype=self.vae.dtype) / self.vae.scale_factor
119
+ pixels = self.vae.decode(latents, ref_context=hidden_states, timesteps=T)
120
+ pixels = einops.rearrange(pixels, '(b t) c h w -> b c t h w', b=B, t=T)
121
+ return pixels
122
+
123
+ @torch.inference_mode()
124
+ def __call__(
125
+ self,
126
+ batch_size: int = 1,
127
+ steps: int = 50,
128
+ guidance_scale: float = 5.0,
129
+ positive_text_cond = None,
130
+ negative_text_cond = None,
131
+ positive_image_cond = None,
132
+ negative_image_cond = None,
133
+ concat_cond = None,
134
+ fs = 3,
135
+ progress_tqdm = None,
136
+ ):
137
+ unet_is_training = self.unet.training
138
+
139
+ if unet_is_training:
140
+ self.unet.eval()
141
+
142
+ device = self.unet.device
143
+ dtype = self.unet.dtype
144
+ dynamic_tsnr_model = SamplerDynamicTSNR(self.unet)
145
+
146
+ # Batch
147
+
148
+ concat_cond = concat_cond.repeat(batch_size, 1, 1, 1, 1).to(device=device, dtype=dtype) # b, c, t, h, w
149
+ positive_text_cond = positive_text_cond.repeat(batch_size, 1, 1).to(concat_cond) # b, f, c
150
+ negative_text_cond = negative_text_cond.repeat(batch_size, 1, 1).to(concat_cond) # b, f, c
151
+ positive_image_cond = positive_image_cond.repeat(batch_size, 1, 1, 1).to(concat_cond) # b, t, l, c
152
+ negative_image_cond = negative_image_cond.repeat(batch_size, 1, 1, 1).to(concat_cond)
153
+
154
+ if isinstance(fs, torch.Tensor):
155
+ fs = fs.repeat(batch_size, ).to(dtype=torch.long, device=device) # b
156
+ else:
157
+ fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=device) # b
158
+
159
+ # Initial latents
160
+
161
+ latent_shape = concat_cond.shape
162
+
163
+ # Feeds
164
+
165
+ sampler_kwargs = dict(
166
+ cfg_scale=guidance_scale,
167
+ positive=dict(
168
+ context_text=positive_text_cond,
169
+ context_img=positive_image_cond,
170
+ fs=fs,
171
+ concat_cond=concat_cond
172
+ ),
173
+ negative=dict(
174
+ context_text=negative_text_cond,
175
+ context_img=negative_image_cond,
176
+ fs=fs,
177
+ concat_cond=concat_cond
178
+ )
179
+ )
180
+
181
+ # Sample
182
+
183
+ results = dynamic_tsnr_model(latent_shape, steps, extra_args=sampler_kwargs, progress_tqdm=progress_tqdm)
184
+
185
+ if unet_is_training:
186
+ self.unet.train()
187
+
188
+ return results
diffusers_vdm/projection.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
+ # and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
4
+
5
+
6
+ import math
7
+ import torch
8
+ import einops
9
+ import torch.nn as nn
10
+
11
+ from huggingface_hub import PyTorchModelHubMixin
12
+
13
+
14
+ class ImageProjModel(nn.Module):
15
+ """Projection Model"""
16
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
17
+ super().__init__()
18
+ self.cross_attention_dim = cross_attention_dim
19
+ self.clip_extra_context_tokens = clip_extra_context_tokens
20
+ self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
21
+ self.norm = nn.LayerNorm(cross_attention_dim)
22
+
23
+ def forward(self, image_embeds):
24
+ #embeds = image_embeds
25
+ embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
26
+ clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
27
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
28
+ return clip_extra_context_tokens
29
+
30
+
31
+ # FFN
32
+ def FeedForward(dim, mult=4):
33
+ inner_dim = int(dim * mult)
34
+ return nn.Sequential(
35
+ nn.LayerNorm(dim),
36
+ nn.Linear(dim, inner_dim, bias=False),
37
+ nn.GELU(),
38
+ nn.Linear(inner_dim, dim, bias=False),
39
+ )
40
+
41
+
42
+ def reshape_tensor(x, heads):
43
+ bs, length, width = x.shape
44
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
45
+ x = x.view(bs, length, heads, -1)
46
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
47
+ x = x.transpose(1, 2)
48
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
49
+ x = x.reshape(bs, heads, length, -1)
50
+ return x
51
+
52
+
53
+ class PerceiverAttention(nn.Module):
54
+ def __init__(self, *, dim, dim_head=64, heads=8):
55
+ super().__init__()
56
+ self.scale = dim_head**-0.5
57
+ self.dim_head = dim_head
58
+ self.heads = heads
59
+ inner_dim = dim_head * heads
60
+
61
+ self.norm1 = nn.LayerNorm(dim)
62
+ self.norm2 = nn.LayerNorm(dim)
63
+
64
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
65
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
66
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
67
+
68
+
69
+ def forward(self, x, latents):
70
+ """
71
+ Args:
72
+ x (torch.Tensor): image features
73
+ shape (b, n1, D)
74
+ latent (torch.Tensor): latent features
75
+ shape (b, n2, D)
76
+ """
77
+ x = self.norm1(x)
78
+ latents = self.norm2(latents)
79
+
80
+ b, l, _ = latents.shape
81
+
82
+ q = self.to_q(latents)
83
+ kv_input = torch.cat((x, latents), dim=-2)
84
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
85
+
86
+ q = reshape_tensor(q, self.heads)
87
+ k = reshape_tensor(k, self.heads)
88
+ v = reshape_tensor(v, self.heads)
89
+
90
+ # attention
91
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
92
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
93
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
94
+ out = weight @ v
95
+
96
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
97
+
98
+ return self.to_out(out)
99
+
100
+
101
+ class Resampler(nn.Module, PyTorchModelHubMixin):
102
+ def __init__(
103
+ self,
104
+ dim=1024,
105
+ depth=8,
106
+ dim_head=64,
107
+ heads=16,
108
+ num_queries=8,
109
+ embedding_dim=768,
110
+ output_dim=1024,
111
+ ff_mult=4,
112
+ video_length=16,
113
+ input_frames_length=2,
114
+ ):
115
+ super().__init__()
116
+ self.num_queries = num_queries
117
+ self.video_length = video_length
118
+
119
+ self.latents = nn.Parameter(torch.randn(1, num_queries * video_length, dim) / dim**0.5)
120
+ self.input_pos = nn.Parameter(torch.zeros(1, input_frames_length, 1, embedding_dim))
121
+
122
+ self.proj_in = nn.Linear(embedding_dim, dim)
123
+ self.proj_out = nn.Linear(dim, output_dim)
124
+ self.norm_out = nn.LayerNorm(output_dim)
125
+
126
+ self.layers = nn.ModuleList([])
127
+ for _ in range(depth):
128
+ self.layers.append(
129
+ nn.ModuleList(
130
+ [
131
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
132
+ FeedForward(dim=dim, mult=ff_mult),
133
+ ]
134
+ )
135
+ )
136
+
137
+ def forward(self, x):
138
+ latents = self.latents.repeat(x.size(0), 1, 1)
139
+
140
+ x = x + self.input_pos
141
+ x = einops.rearrange(x, 'b ti d c -> b (ti d) c')
142
+ x = self.proj_in(x)
143
+
144
+ for attn, ff in self.layers:
145
+ latents = attn(x, latents) + latents
146
+ latents = ff(latents) + latents
147
+
148
+ latents = self.proj_out(latents)
149
+ latents = self.norm_out(latents)
150
+
151
+ latents = einops.rearrange(latents, 'b (to l) c -> b to l c', to=self.video_length)
152
+ return latents
153
+
154
+ @property
155
+ def device(self):
156
+ return next(self.parameters()).device
157
+
158
+ @property
159
+ def dtype(self):
160
+ return next(self.parameters()).dtype
diffusers_vdm/unet.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/AILab-CVC/VideoCrafter
2
+ # https://github.com/Doubiiu/DynamiCrafter
3
+ # https://github.com/ToonCrafter/ToonCrafter
4
+ # Then edited by lllyasviel
5
+
6
+ from functools import partial
7
+ from abc import abstractmethod
8
+ import torch
9
+ import math
10
+ import torch.nn as nn
11
+ from einops import rearrange, repeat
12
+ import torch.nn.functional as F
13
+ from diffusers_vdm.basics import checkpoint
14
+ from diffusers_vdm.basics import (
15
+ zero_module,
16
+ conv_nd,
17
+ linear,
18
+ avg_pool_nd,
19
+ normalization
20
+ )
21
+ from diffusers_vdm.attention import SpatialTransformer, TemporalTransformer
22
+ from huggingface_hub import PyTorchModelHubMixin
23
+
24
+
25
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
26
+ """
27
+ Create sinusoidal timestep embeddings.
28
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
29
+ These may be fractional.
30
+ :param dim: the dimension of the output.
31
+ :param max_period: controls the minimum frequency of the embeddings.
32
+ :return: an [N x dim] Tensor of positional embeddings.
33
+ """
34
+ if not repeat_only:
35
+ half = dim // 2
36
+ freqs = torch.exp(
37
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
38
+ ).to(device=timesteps.device)
39
+ args = timesteps[:, None].float() * freqs[None]
40
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
41
+ if dim % 2:
42
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
43
+ else:
44
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
45
+ return embedding
46
+
47
+
48
+ class TimestepBlock(nn.Module):
49
+ """
50
+ Any module where forward() takes timestep embeddings as a second argument.
51
+ """
52
+
53
+ @abstractmethod
54
+ def forward(self, x, emb):
55
+ """
56
+ Apply the module to `x` given `emb` timestep embeddings.
57
+ """
58
+
59
+
60
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
61
+ """
62
+ A sequential module that passes timestep embeddings to the children that
63
+ support it as an extra input.
64
+ """
65
+
66
+ def forward(self, x, emb, context=None, batch_size=None):
67
+ for layer in self:
68
+ if isinstance(layer, TimestepBlock):
69
+ x = layer(x, emb, batch_size=batch_size)
70
+ elif isinstance(layer, SpatialTransformer):
71
+ x = layer(x, context)
72
+ elif isinstance(layer, TemporalTransformer):
73
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size)
74
+ x = layer(x, context)
75
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
76
+ else:
77
+ x = layer(x)
78
+ return x
79
+
80
+
81
+ class Downsample(nn.Module):
82
+ """
83
+ A downsampling layer with an optional convolution.
84
+ :param channels: channels in the inputs and outputs.
85
+ :param use_conv: a bool determining if a convolution is applied.
86
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
87
+ downsampling occurs in the inner-two dimensions.
88
+ """
89
+
90
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
91
+ super().__init__()
92
+ self.channels = channels
93
+ self.out_channels = out_channels or channels
94
+ self.use_conv = use_conv
95
+ self.dims = dims
96
+ stride = 2 if dims != 3 else (1, 2, 2)
97
+ if use_conv:
98
+ self.op = conv_nd(
99
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
100
+ )
101
+ else:
102
+ assert self.channels == self.out_channels
103
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
104
+
105
+ def forward(self, x):
106
+ assert x.shape[1] == self.channels
107
+ return self.op(x)
108
+
109
+
110
+ class Upsample(nn.Module):
111
+ """
112
+ An upsampling layer with an optional convolution.
113
+ :param channels: channels in the inputs and outputs.
114
+ :param use_conv: a bool determining if a convolution is applied.
115
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
116
+ upsampling occurs in the inner-two dimensions.
117
+ """
118
+
119
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
120
+ super().__init__()
121
+ self.channels = channels
122
+ self.out_channels = out_channels or channels
123
+ self.use_conv = use_conv
124
+ self.dims = dims
125
+ if use_conv:
126
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
127
+
128
+ def forward(self, x):
129
+ assert x.shape[1] == self.channels
130
+ if self.dims == 3:
131
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest')
132
+ else:
133
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
134
+ if self.use_conv:
135
+ x = self.conv(x)
136
+ return x
137
+
138
+
139
+ class ResBlock(TimestepBlock):
140
+ """
141
+ A residual block that can optionally change the number of channels.
142
+ :param channels: the number of input channels.
143
+ :param emb_channels: the number of timestep embedding channels.
144
+ :param dropout: the rate of dropout.
145
+ :param out_channels: if specified, the number of out channels.
146
+ :param use_conv: if True and out_channels is specified, use a spatial
147
+ convolution instead of a smaller 1x1 convolution to change the
148
+ channels in the skip connection.
149
+ :param dims: determines if the signal is 1D, 2D, or 3D.
150
+ :param up: if True, use this block for upsampling.
151
+ :param down: if True, use this block for downsampling.
152
+ :param use_temporal_conv: if True, use the temporal convolution.
153
+ :param use_image_dataset: if True, the temporal parameters will not be optimized.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ channels,
159
+ emb_channels,
160
+ dropout,
161
+ out_channels=None,
162
+ use_scale_shift_norm=False,
163
+ dims=2,
164
+ use_checkpoint=False,
165
+ use_conv=False,
166
+ up=False,
167
+ down=False,
168
+ use_temporal_conv=False,
169
+ tempspatial_aware=False
170
+ ):
171
+ super().__init__()
172
+ self.channels = channels
173
+ self.emb_channels = emb_channels
174
+ self.dropout = dropout
175
+ self.out_channels = out_channels or channels
176
+ self.use_conv = use_conv
177
+ self.use_checkpoint = use_checkpoint
178
+ self.use_scale_shift_norm = use_scale_shift_norm
179
+ self.use_temporal_conv = use_temporal_conv
180
+
181
+ self.in_layers = nn.Sequential(
182
+ normalization(channels),
183
+ nn.SiLU(),
184
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
185
+ )
186
+
187
+ self.updown = up or down
188
+
189
+ if up:
190
+ self.h_upd = Upsample(channels, False, dims)
191
+ self.x_upd = Upsample(channels, False, dims)
192
+ elif down:
193
+ self.h_upd = Downsample(channels, False, dims)
194
+ self.x_upd = Downsample(channels, False, dims)
195
+ else:
196
+ self.h_upd = self.x_upd = nn.Identity()
197
+
198
+ self.emb_layers = nn.Sequential(
199
+ nn.SiLU(),
200
+ nn.Linear(
201
+ emb_channels,
202
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
203
+ ),
204
+ )
205
+ self.out_layers = nn.Sequential(
206
+ normalization(self.out_channels),
207
+ nn.SiLU(),
208
+ nn.Dropout(p=dropout),
209
+ zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
210
+ )
211
+
212
+ if self.out_channels == channels:
213
+ self.skip_connection = nn.Identity()
214
+ elif use_conv:
215
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
216
+ else:
217
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
218
+
219
+ if self.use_temporal_conv:
220
+ self.temopral_conv = TemporalConvBlock(
221
+ self.out_channels,
222
+ self.out_channels,
223
+ dropout=0.1,
224
+ spatial_aware=tempspatial_aware
225
+ )
226
+
227
+ def forward(self, x, emb, batch_size=None):
228
+ """
229
+ Apply the block to a Tensor, conditioned on a timestep embedding.
230
+ :param x: an [N x C x ...] Tensor of features.
231
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
232
+ :return: an [N x C x ...] Tensor of outputs.
233
+ """
234
+ input_tuple = (x, emb)
235
+ if batch_size:
236
+ forward_batchsize = partial(self._forward, batch_size=batch_size)
237
+ return checkpoint(forward_batchsize, input_tuple, self.parameters(), self.use_checkpoint)
238
+ return checkpoint(self._forward, input_tuple, self.parameters(), self.use_checkpoint)
239
+
240
+ def _forward(self, x, emb, batch_size=None):
241
+ if self.updown:
242
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
243
+ h = in_rest(x)
244
+ h = self.h_upd(h)
245
+ x = self.x_upd(x)
246
+ h = in_conv(h)
247
+ else:
248
+ h = self.in_layers(x)
249
+ emb_out = self.emb_layers(emb).type(h.dtype)
250
+ while len(emb_out.shape) < len(h.shape):
251
+ emb_out = emb_out[..., None]
252
+ if self.use_scale_shift_norm:
253
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
254
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
255
+ h = out_norm(h) * (1 + scale) + shift
256
+ h = out_rest(h)
257
+ else:
258
+ h = h + emb_out
259
+ h = self.out_layers(h)
260
+ h = self.skip_connection(x) + h
261
+
262
+ if self.use_temporal_conv and batch_size:
263
+ h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size)
264
+ h = self.temopral_conv(h)
265
+ h = rearrange(h, 'b c t h w -> (b t) c h w')
266
+ return h
267
+
268
+
269
+ class TemporalConvBlock(nn.Module):
270
+ """
271
+ Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py
272
+ """
273
+
274
+ def __init__(self, in_channels, out_channels=None, dropout=0.0, spatial_aware=False):
275
+ super(TemporalConvBlock, self).__init__()
276
+ if out_channels is None:
277
+ out_channels = in_channels
278
+ self.in_channels = in_channels
279
+ self.out_channels = out_channels
280
+ th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1)
281
+ th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0)
282
+ tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3)
283
+ tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1)
284
+
285
+ # conv layers
286
+ self.conv1 = nn.Sequential(
287
+ nn.GroupNorm(32, in_channels), nn.SiLU(),
288
+ nn.Conv3d(in_channels, out_channels, th_kernel_shape, padding=th_padding_shape))
289
+ self.conv2 = nn.Sequential(
290
+ nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
291
+ nn.Conv3d(out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape))
292
+ self.conv3 = nn.Sequential(
293
+ nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
294
+ nn.Conv3d(out_channels, in_channels, th_kernel_shape, padding=th_padding_shape))
295
+ self.conv4 = nn.Sequential(
296
+ nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
297
+ nn.Conv3d(out_channels, in_channels, tw_kernel_shape, padding=tw_padding_shape))
298
+
299
+ # zero out the last layer params,so the conv block is identity
300
+ nn.init.zeros_(self.conv4[-1].weight)
301
+ nn.init.zeros_(self.conv4[-1].bias)
302
+
303
+ def forward(self, x):
304
+ identity = x
305
+ x = self.conv1(x)
306
+ x = self.conv2(x)
307
+ x = self.conv3(x)
308
+ x = self.conv4(x)
309
+
310
+ return identity + x
311
+
312
+
313
+ class UNet3DModel(nn.Module, PyTorchModelHubMixin):
314
+ """
315
+ The full UNet model with attention and timestep embedding.
316
+ :param in_channels: in_channels in the input Tensor.
317
+ :param model_channels: base channel count for the model.
318
+ :param out_channels: channels in the output Tensor.
319
+ :param num_res_blocks: number of residual blocks per downsample.
320
+ :param attention_resolutions: a collection of downsample rates at which
321
+ attention will take place. May be a set, list, or tuple.
322
+ For example, if this contains 4, then at 4x downsampling, attention
323
+ will be used.
324
+ :param dropout: the dropout probability.
325
+ :param channel_mult: channel multiplier for each level of the UNet.
326
+ :param conv_resample: if True, use learned convolutions for upsampling and
327
+ downsampling.
328
+ :param dims: determines if the signal is 1D, 2D, or 3D.
329
+ :param num_classes: if specified (as an int), then this model will be
330
+ class-conditional with `num_classes` classes.
331
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
332
+ :param num_heads: the number of attention heads in each attention layer.
333
+ :param num_heads_channels: if specified, ignore num_heads and instead use
334
+ a fixed channel width per attention head.
335
+ :param num_heads_upsample: works with num_heads to set a different number
336
+ of heads for upsampling. Deprecated.
337
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
338
+ :param resblock_updown: use residual blocks for up/downsampling.
339
+ :param use_new_attention_order: use a different attention pattern for potentially
340
+ increased efficiency.
341
+ """
342
+
343
+ def __init__(self,
344
+ in_channels,
345
+ model_channels,
346
+ out_channels,
347
+ num_res_blocks,
348
+ attention_resolutions,
349
+ dropout=0.0,
350
+ channel_mult=(1, 2, 4, 8),
351
+ conv_resample=True,
352
+ dims=2,
353
+ context_dim=None,
354
+ use_scale_shift_norm=False,
355
+ resblock_updown=False,
356
+ num_heads=-1,
357
+ num_head_channels=-1,
358
+ transformer_depth=1,
359
+ use_linear=False,
360
+ temporal_conv=False,
361
+ tempspatial_aware=False,
362
+ temporal_attention=True,
363
+ use_relative_position=True,
364
+ use_causal_attention=False,
365
+ temporal_length=None,
366
+ addition_attention=False,
367
+ temporal_selfatt_only=True,
368
+ image_cross_attention=False,
369
+ image_cross_attention_scale_learnable=False,
370
+ default_fs=4,
371
+ fs_condition=False,
372
+ ):
373
+ super(UNet3DModel, self).__init__()
374
+ if num_heads == -1:
375
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
376
+ if num_head_channels == -1:
377
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
378
+
379
+ self.in_channels = in_channels
380
+ self.model_channels = model_channels
381
+ self.out_channels = out_channels
382
+ self.num_res_blocks = num_res_blocks
383
+ self.attention_resolutions = attention_resolutions
384
+ self.dropout = dropout
385
+ self.channel_mult = channel_mult
386
+ self.conv_resample = conv_resample
387
+ self.temporal_attention = temporal_attention
388
+ time_embed_dim = model_channels * 4
389
+ self.use_checkpoint = use_checkpoint = False # moved to self.enable_gradient_checkpointing()
390
+ temporal_self_att_only = True
391
+ self.addition_attention = addition_attention
392
+ self.temporal_length = temporal_length
393
+ self.image_cross_attention = image_cross_attention
394
+ self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable
395
+ self.default_fs = default_fs
396
+ self.fs_condition = fs_condition
397
+
398
+ ## Time embedding blocks
399
+ self.time_embed = nn.Sequential(
400
+ linear(model_channels, time_embed_dim),
401
+ nn.SiLU(),
402
+ linear(time_embed_dim, time_embed_dim),
403
+ )
404
+ if fs_condition:
405
+ self.fps_embedding = nn.Sequential(
406
+ linear(model_channels, time_embed_dim),
407
+ nn.SiLU(),
408
+ linear(time_embed_dim, time_embed_dim),
409
+ )
410
+ nn.init.zeros_(self.fps_embedding[-1].weight)
411
+ nn.init.zeros_(self.fps_embedding[-1].bias)
412
+ ## Input Block
413
+ self.input_blocks = nn.ModuleList(
414
+ [
415
+ TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))
416
+ ]
417
+ )
418
+ if self.addition_attention:
419
+ self.init_attn = TimestepEmbedSequential(
420
+ TemporalTransformer(
421
+ model_channels,
422
+ n_heads=8,
423
+ d_head=num_head_channels,
424
+ depth=transformer_depth,
425
+ context_dim=context_dim,
426
+ use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only,
427
+ causal_attention=False, relative_position=use_relative_position,
428
+ temporal_length=temporal_length))
429
+
430
+ input_block_chans = [model_channels]
431
+ ch = model_channels
432
+ ds = 1
433
+ for level, mult in enumerate(channel_mult):
434
+ for _ in range(num_res_blocks):
435
+ layers = [
436
+ ResBlock(ch, time_embed_dim, dropout,
437
+ out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint,
438
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
439
+ use_temporal_conv=temporal_conv
440
+ )
441
+ ]
442
+ ch = mult * model_channels
443
+ if ds in attention_resolutions:
444
+ if num_head_channels == -1:
445
+ dim_head = ch // num_heads
446
+ else:
447
+ num_heads = ch // num_head_channels
448
+ dim_head = num_head_channels
449
+ layers.append(
450
+ SpatialTransformer(ch, num_heads, dim_head,
451
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
452
+ use_checkpoint=use_checkpoint, disable_self_attn=False,
453
+ video_length=temporal_length,
454
+ image_cross_attention=self.image_cross_attention,
455
+ image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
456
+ )
457
+ )
458
+ if self.temporal_attention:
459
+ layers.append(
460
+ TemporalTransformer(ch, num_heads, dim_head,
461
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
462
+ use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
463
+ causal_attention=use_causal_attention,
464
+ relative_position=use_relative_position,
465
+ temporal_length=temporal_length
466
+ )
467
+ )
468
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
469
+ input_block_chans.append(ch)
470
+ if level != len(channel_mult) - 1:
471
+ out_ch = ch
472
+ self.input_blocks.append(
473
+ TimestepEmbedSequential(
474
+ ResBlock(ch, time_embed_dim, dropout,
475
+ out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint,
476
+ use_scale_shift_norm=use_scale_shift_norm,
477
+ down=True
478
+ )
479
+ if resblock_updown
480
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
481
+ )
482
+ )
483
+ ch = out_ch
484
+ input_block_chans.append(ch)
485
+ ds *= 2
486
+
487
+ if num_head_channels == -1:
488
+ dim_head = ch // num_heads
489
+ else:
490
+ num_heads = ch // num_head_channels
491
+ dim_head = num_head_channels
492
+ layers = [
493
+ ResBlock(ch, time_embed_dim, dropout,
494
+ dims=dims, use_checkpoint=use_checkpoint,
495
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
496
+ use_temporal_conv=temporal_conv
497
+ ),
498
+ SpatialTransformer(ch, num_heads, dim_head,
499
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
500
+ use_checkpoint=use_checkpoint, disable_self_attn=False, video_length=temporal_length,
501
+ image_cross_attention=self.image_cross_attention,
502
+ image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable
503
+ )
504
+ ]
505
+ if self.temporal_attention:
506
+ layers.append(
507
+ TemporalTransformer(ch, num_heads, dim_head,
508
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
509
+ use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
510
+ causal_attention=use_causal_attention, relative_position=use_relative_position,
511
+ temporal_length=temporal_length
512
+ )
513
+ )
514
+ layers.append(
515
+ ResBlock(ch, time_embed_dim, dropout,
516
+ dims=dims, use_checkpoint=use_checkpoint,
517
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
518
+ use_temporal_conv=temporal_conv
519
+ )
520
+ )
521
+
522
+ ## Middle Block
523
+ self.middle_block = TimestepEmbedSequential(*layers)
524
+
525
+ ## Output Block
526
+ self.output_blocks = nn.ModuleList([])
527
+ for level, mult in list(enumerate(channel_mult))[::-1]:
528
+ for i in range(num_res_blocks + 1):
529
+ ich = input_block_chans.pop()
530
+ layers = [
531
+ ResBlock(ch + ich, time_embed_dim, dropout,
532
+ out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint,
533
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
534
+ use_temporal_conv=temporal_conv
535
+ )
536
+ ]
537
+ ch = model_channels * mult
538
+ if ds in attention_resolutions:
539
+ if num_head_channels == -1:
540
+ dim_head = ch // num_heads
541
+ else:
542
+ num_heads = ch // num_head_channels
543
+ dim_head = num_head_channels
544
+ layers.append(
545
+ SpatialTransformer(ch, num_heads, dim_head,
546
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
547
+ use_checkpoint=use_checkpoint, disable_self_attn=False,
548
+ video_length=temporal_length,
549
+ image_cross_attention=self.image_cross_attention,
550
+ image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable
551
+ )
552
+ )
553
+ if self.temporal_attention:
554
+ layers.append(
555
+ TemporalTransformer(ch, num_heads, dim_head,
556
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
557
+ use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
558
+ causal_attention=use_causal_attention,
559
+ relative_position=use_relative_position,
560
+ temporal_length=temporal_length
561
+ )
562
+ )
563
+ if level and i == num_res_blocks:
564
+ out_ch = ch
565
+ layers.append(
566
+ ResBlock(ch, time_embed_dim, dropout,
567
+ out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint,
568
+ use_scale_shift_norm=use_scale_shift_norm,
569
+ up=True
570
+ )
571
+ if resblock_updown
572
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
573
+ )
574
+ ds //= 2
575
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
576
+
577
+ self.out = nn.Sequential(
578
+ normalization(ch),
579
+ nn.SiLU(),
580
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
581
+ )
582
+
583
+ @property
584
+ def device(self):
585
+ return next(self.parameters()).device
586
+
587
+ @property
588
+ def dtype(self):
589
+ return next(self.parameters()).dtype
590
+
591
+ def forward(self, x, timesteps, context_text=None, context_img=None, concat_cond=None, fs=None, **kwargs):
592
+ b, _, t, _, _ = x.shape
593
+
594
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).type(x.dtype)
595
+ emb = self.time_embed(t_emb)
596
+
597
+ context_text = context_text.repeat_interleave(repeats=t, dim=0)
598
+ context_img = rearrange(context_img, 'b t l c -> (b t) l c')
599
+
600
+ context = (context_text, context_img)
601
+
602
+ emb = emb.repeat_interleave(repeats=t, dim=0)
603
+
604
+ if concat_cond is not None:
605
+ x = torch.cat([x, concat_cond], dim=1)
606
+
607
+ ## always in shape (b t) c h w, except for temporal layer
608
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
609
+
610
+ ## combine emb
611
+ if self.fs_condition:
612
+ if fs is None:
613
+ fs = torch.tensor(
614
+ [self.default_fs] * b, dtype=torch.long, device=x.device)
615
+ fs_emb = timestep_embedding(fs, self.model_channels, repeat_only=False).type(x.dtype)
616
+
617
+ fs_embed = self.fps_embedding(fs_emb)
618
+ fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
619
+ emb = emb + fs_embed
620
+
621
+ h = x
622
+ hs = []
623
+ for id, module in enumerate(self.input_blocks):
624
+ h = module(h, emb, context=context, batch_size=b)
625
+ if id == 0 and self.addition_attention:
626
+ h = self.init_attn(h, emb, context=context, batch_size=b)
627
+ hs.append(h)
628
+
629
+ h = self.middle_block(h, emb, context=context, batch_size=b)
630
+
631
+ for module in self.output_blocks:
632
+ h = torch.cat([h, hs.pop()], dim=1)
633
+ h = module(h, emb, context=context, batch_size=b)
634
+ h = h.type(x.dtype)
635
+ y = self.out(h)
636
+
637
+ y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
638
+ return y
639
+
640
+ def enable_gradient_checkpointing(self, enable=True, verbose=False):
641
+ for k, v in self.named_modules():
642
+ if hasattr(v, 'checkpoint'):
643
+ v.checkpoint = enable
644
+ if verbose:
645
+ print(f'{k}.checkpoint = {enable}')
646
+ if hasattr(v, 'use_checkpoint'):
647
+ v.use_checkpoint = enable
648
+ if verbose:
649
+ print(f'{k}.use_checkpoint = {enable}')
650
+ return
diffusers_vdm/utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import einops
5
+ import torchvision
6
+
7
+
8
+ def resize_and_center_crop(image, target_width, target_height, interpolation=cv2.INTER_AREA):
9
+ original_height, original_width = image.shape[:2]
10
+ k = max(target_height / original_height, target_width / original_width)
11
+ new_width = int(round(original_width * k))
12
+ new_height = int(round(original_height * k))
13
+ resized_image = cv2.resize(image, (new_width, new_height), interpolation=interpolation)
14
+ x_start = (new_width - target_width) // 2
15
+ y_start = (new_height - target_height) // 2
16
+ cropped_image = resized_image[y_start:y_start + target_height, x_start:x_start + target_width]
17
+ return cropped_image
18
+
19
+
20
+ def save_bcthw_as_mp4(x, output_filename, fps=10):
21
+ b, c, t, h, w = x.shape
22
+
23
+ per_row = b
24
+ for p in [6, 5, 4, 3, 2]:
25
+ if b % p == 0:
26
+ per_row = p
27
+ break
28
+
29
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
30
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
31
+ x = x.detach().cpu().to(torch.uint8)
32
+ x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
33
+ torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '1'})
34
+ return x
35
+
36
+
37
+ def save_bcthw_as_png(x, output_filename):
38
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
39
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
40
+ x = x.detach().cpu().to(torch.uint8)
41
+ x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
42
+ torchvision.io.write_png(x, output_filename)
43
+ return output_filename
diffusers_vdm/vae.py ADDED
@@ -0,0 +1,826 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # video VAE with many components from lots of repos
2
+ # collected by lvmin
3
+
4
+
5
+ import torch
6
+ import xformers.ops
7
+ import torch.nn as nn
8
+
9
+ from einops import rearrange, repeat
10
+ from diffusers_vdm.basics import default, exists, zero_module, conv_nd, linear, normalization
11
+ from diffusers_vdm.unet import Upsample, Downsample
12
+ from huggingface_hub import PyTorchModelHubMixin
13
+
14
+
15
+ def chunked_attention(q, k, v, batch_chunk=0):
16
+ # if batch_chunk > 0 and not torch.is_grad_enabled():
17
+ # batch_size = q.size(0)
18
+ # chunks = [slice(i, i + batch_chunk) for i in range(0, batch_size, batch_chunk)]
19
+ #
20
+ # out_chunks = []
21
+ # for chunk in chunks:
22
+ # q_chunk = q[chunk]
23
+ # k_chunk = k[chunk]
24
+ # v_chunk = v[chunk]
25
+ #
26
+ # out_chunk = torch.nn.functional.scaled_dot_product_attention(
27
+ # q_chunk, k_chunk, v_chunk, attn_mask=None
28
+ # )
29
+ # out_chunks.append(out_chunk)
30
+ #
31
+ # out = torch.cat(out_chunks, dim=0)
32
+ # else:
33
+ # out = torch.nn.functional.scaled_dot_product_attention(
34
+ # q, k, v, attn_mask=None
35
+ # )
36
+ out = xformers.ops.memory_efficient_attention(q, k, v)
37
+ return out
38
+
39
+
40
+ def nonlinearity(x):
41
+ return x * torch.sigmoid(x)
42
+
43
+
44
+ def GroupNorm(in_channels, num_groups=32):
45
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
46
+
47
+
48
+ class DiagonalGaussianDistribution:
49
+ def __init__(self, parameters, deterministic=False):
50
+ self.parameters = parameters
51
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
52
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
53
+ self.deterministic = deterministic
54
+ self.std = torch.exp(0.5 * self.logvar)
55
+ self.var = torch.exp(self.logvar)
56
+ if self.deterministic:
57
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
58
+
59
+ def sample(self, noise=None):
60
+ if noise is None:
61
+ noise = torch.randn(self.mean.shape)
62
+
63
+ x = self.mean + self.std * noise.to(device=self.parameters.device)
64
+ return x
65
+
66
+ def mode(self):
67
+ return self.mean
68
+
69
+
70
+ class EncoderDownSampleBlock(nn.Module):
71
+ def __init__(self, in_channels, with_conv):
72
+ super().__init__()
73
+ self.with_conv = with_conv
74
+ self.in_channels = in_channels
75
+ if self.with_conv:
76
+ self.conv = torch.nn.Conv2d(in_channels,
77
+ in_channels,
78
+ kernel_size=3,
79
+ stride=2,
80
+ padding=0)
81
+
82
+ def forward(self, x):
83
+ if self.with_conv:
84
+ pad = (0, 1, 0, 1)
85
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
86
+ x = self.conv(x)
87
+ else:
88
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
89
+ return x
90
+
91
+
92
+ class ResnetBlock(nn.Module):
93
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
94
+ dropout, temb_channels=512):
95
+ super().__init__()
96
+ self.in_channels = in_channels
97
+ out_channels = in_channels if out_channels is None else out_channels
98
+ self.out_channels = out_channels
99
+ self.use_conv_shortcut = conv_shortcut
100
+
101
+ self.norm1 = GroupNorm(in_channels)
102
+ self.conv1 = torch.nn.Conv2d(in_channels,
103
+ out_channels,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1)
107
+ if temb_channels > 0:
108
+ self.temb_proj = torch.nn.Linear(temb_channels,
109
+ out_channels)
110
+ self.norm2 = GroupNorm(out_channels)
111
+ self.dropout = torch.nn.Dropout(dropout)
112
+ self.conv2 = torch.nn.Conv2d(out_channels,
113
+ out_channels,
114
+ kernel_size=3,
115
+ stride=1,
116
+ padding=1)
117
+ if self.in_channels != self.out_channels:
118
+ if self.use_conv_shortcut:
119
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
120
+ out_channels,
121
+ kernel_size=3,
122
+ stride=1,
123
+ padding=1)
124
+ else:
125
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
126
+ out_channels,
127
+ kernel_size=1,
128
+ stride=1,
129
+ padding=0)
130
+
131
+ def forward(self, x, temb):
132
+ h = x
133
+ h = self.norm1(h)
134
+ h = nonlinearity(h)
135
+ h = self.conv1(h)
136
+
137
+ if temb is not None:
138
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
139
+
140
+ h = self.norm2(h)
141
+ h = nonlinearity(h)
142
+ h = self.dropout(h)
143
+ h = self.conv2(h)
144
+
145
+ if self.in_channels != self.out_channels:
146
+ if self.use_conv_shortcut:
147
+ x = self.conv_shortcut(x)
148
+ else:
149
+ x = self.nin_shortcut(x)
150
+
151
+ return x + h
152
+
153
+
154
+ class Encoder(nn.Module):
155
+ def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
156
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
157
+ resolution, z_channels, double_z=True, **kwargs):
158
+ super().__init__()
159
+ self.ch = ch
160
+ self.temb_ch = 0
161
+ self.num_resolutions = len(ch_mult)
162
+ self.num_res_blocks = num_res_blocks
163
+ self.resolution = resolution
164
+ self.in_channels = in_channels
165
+
166
+ # downsampling
167
+ self.conv_in = torch.nn.Conv2d(in_channels,
168
+ self.ch,
169
+ kernel_size=3,
170
+ stride=1,
171
+ padding=1)
172
+
173
+ curr_res = resolution
174
+ in_ch_mult = (1,) + tuple(ch_mult)
175
+ self.in_ch_mult = in_ch_mult
176
+ self.down = nn.ModuleList()
177
+ for i_level in range(self.num_resolutions):
178
+ block = nn.ModuleList()
179
+ attn = nn.ModuleList()
180
+ block_in = ch * in_ch_mult[i_level]
181
+ block_out = ch * ch_mult[i_level]
182
+ for i_block in range(self.num_res_blocks):
183
+ block.append(ResnetBlock(in_channels=block_in,
184
+ out_channels=block_out,
185
+ temb_channels=self.temb_ch,
186
+ dropout=dropout))
187
+ block_in = block_out
188
+ if curr_res in attn_resolutions:
189
+ attn.append(Attention(block_in))
190
+ down = nn.Module()
191
+ down.block = block
192
+ down.attn = attn
193
+ if i_level != self.num_resolutions - 1:
194
+ down.downsample = EncoderDownSampleBlock(block_in, resamp_with_conv)
195
+ curr_res = curr_res // 2
196
+ self.down.append(down)
197
+
198
+ # middle
199
+ self.mid = nn.Module()
200
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
201
+ out_channels=block_in,
202
+ temb_channels=self.temb_ch,
203
+ dropout=dropout)
204
+ self.mid.attn_1 = Attention(block_in)
205
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
206
+ out_channels=block_in,
207
+ temb_channels=self.temb_ch,
208
+ dropout=dropout)
209
+
210
+ # end
211
+ self.norm_out = GroupNorm(block_in)
212
+ self.conv_out = torch.nn.Conv2d(block_in,
213
+ 2 * z_channels if double_z else z_channels,
214
+ kernel_size=3,
215
+ stride=1,
216
+ padding=1)
217
+
218
+ def forward(self, x, return_hidden_states=False):
219
+ # timestep embedding
220
+ temb = None
221
+
222
+ # print(f'encoder-input={x.shape}')
223
+ # downsampling
224
+ hs = [self.conv_in(x)]
225
+
226
+ ## if we return hidden states for decoder usage, we will store them in a list
227
+ if return_hidden_states:
228
+ hidden_states = []
229
+ # print(f'encoder-conv in feat={hs[0].shape}')
230
+ for i_level in range(self.num_resolutions):
231
+ for i_block in range(self.num_res_blocks):
232
+ h = self.down[i_level].block[i_block](hs[-1], temb)
233
+ # print(f'encoder-down feat={h.shape}')
234
+ if len(self.down[i_level].attn) > 0:
235
+ h = self.down[i_level].attn[i_block](h)
236
+ hs.append(h)
237
+ if return_hidden_states:
238
+ hidden_states.append(h)
239
+ if i_level != self.num_resolutions - 1:
240
+ # print(f'encoder-downsample (input)={hs[-1].shape}')
241
+ hs.append(self.down[i_level].downsample(hs[-1]))
242
+ # print(f'encoder-downsample (output)={hs[-1].shape}')
243
+ if return_hidden_states:
244
+ hidden_states.append(hs[0])
245
+ # middle
246
+ h = hs[-1]
247
+ h = self.mid.block_1(h, temb)
248
+ # print(f'encoder-mid1 feat={h.shape}')
249
+ h = self.mid.attn_1(h)
250
+ h = self.mid.block_2(h, temb)
251
+ # print(f'encoder-mid2 feat={h.shape}')
252
+
253
+ # end
254
+ h = self.norm_out(h)
255
+ h = nonlinearity(h)
256
+ h = self.conv_out(h)
257
+ # print(f'end feat={h.shape}')
258
+ if return_hidden_states:
259
+ return h, hidden_states
260
+ else:
261
+ return h
262
+
263
+
264
+ class ConvCombiner(nn.Module):
265
+ def __init__(self, ch):
266
+ super().__init__()
267
+ self.conv = nn.Conv2d(ch, ch, 1, padding=0)
268
+
269
+ nn.init.zeros_(self.conv.weight)
270
+ nn.init.zeros_(self.conv.bias)
271
+
272
+ def forward(self, x, context):
273
+ ## x: b c h w, context: b c 2 h w
274
+ b, c, l, h, w = context.shape
275
+ bt, c, h, w = x.shape
276
+ context = rearrange(context, "b c l h w -> (b l) c h w")
277
+ context = self.conv(context)
278
+ context = rearrange(context, "(b l) c h w -> b c l h w", l=l)
279
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=bt // b)
280
+ x[:, :, 0] = x[:, :, 0] + context[:, :, 0]
281
+ x[:, :, -1] = x[:, :, -1] + context[:, :, -1]
282
+ x = rearrange(x, "b c t h w -> (b t) c h w")
283
+ return x
284
+
285
+
286
+ class AttentionCombiner(nn.Module):
287
+ def __init__(
288
+ self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
289
+ ):
290
+ super().__init__()
291
+
292
+ inner_dim = dim_head * heads
293
+ context_dim = default(context_dim, query_dim)
294
+
295
+ self.heads = heads
296
+ self.dim_head = dim_head
297
+
298
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
299
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
300
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
301
+
302
+ self.to_out = nn.Sequential(
303
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
304
+ )
305
+ self.attention_op = None
306
+
307
+ self.norm = GroupNorm(query_dim)
308
+ nn.init.zeros_(self.to_out[0].weight)
309
+ nn.init.zeros_(self.to_out[0].bias)
310
+
311
+ def forward(
312
+ self,
313
+ x,
314
+ context=None,
315
+ mask=None,
316
+ ):
317
+ bt, c, h, w = x.shape
318
+ h_ = self.norm(x)
319
+ h_ = rearrange(h_, "b c h w -> b (h w) c")
320
+ q = self.to_q(h_)
321
+
322
+ b, c, l, h, w = context.shape
323
+ context = rearrange(context, "b c l h w -> (b l) (h w) c")
324
+ k = self.to_k(context)
325
+ v = self.to_v(context)
326
+
327
+ t = bt // b
328
+ k = repeat(k, "(b l) d c -> (b t) (l d) c", l=l, t=t)
329
+ v = repeat(v, "(b l) d c -> (b t) (l d) c", l=l, t=t)
330
+
331
+ b, _, _ = q.shape
332
+ q, k, v = map(
333
+ lambda t: t.unsqueeze(3)
334
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
335
+ .permute(0, 2, 1, 3)
336
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
337
+ .contiguous(),
338
+ (q, k, v),
339
+ )
340
+
341
+ out = chunked_attention(
342
+ q, k, v, batch_chunk=1
343
+ )
344
+
345
+ if exists(mask):
346
+ raise NotImplementedError
347
+
348
+ out = (
349
+ out.unsqueeze(0)
350
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
351
+ .permute(0, 2, 1, 3)
352
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
353
+ )
354
+ out = self.to_out(out)
355
+ out = rearrange(out, "bt (h w) c -> bt c h w", h=h, w=w, c=c)
356
+ return x + out
357
+
358
+
359
+ class Attention(nn.Module):
360
+ def __init__(self, in_channels):
361
+ super().__init__()
362
+ self.in_channels = in_channels
363
+
364
+ self.norm = GroupNorm(in_channels)
365
+ self.q = torch.nn.Conv2d(
366
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
367
+ )
368
+ self.k = torch.nn.Conv2d(
369
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
370
+ )
371
+ self.v = torch.nn.Conv2d(
372
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
373
+ )
374
+ self.proj_out = torch.nn.Conv2d(
375
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
376
+ )
377
+
378
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
379
+ h_ = self.norm(h_)
380
+ q = self.q(h_)
381
+ k = self.k(h_)
382
+ v = self.v(h_)
383
+
384
+ # compute attention
385
+ B, C, H, W = q.shape
386
+ q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
387
+
388
+ q, k, v = map(
389
+ lambda t: t.unsqueeze(3)
390
+ .reshape(B, t.shape[1], 1, C)
391
+ .permute(0, 2, 1, 3)
392
+ .reshape(B * 1, t.shape[1], C)
393
+ .contiguous(),
394
+ (q, k, v),
395
+ )
396
+
397
+ out = chunked_attention(
398
+ q, k, v, batch_chunk=1
399
+ )
400
+
401
+ out = (
402
+ out.unsqueeze(0)
403
+ .reshape(B, 1, out.shape[1], C)
404
+ .permute(0, 2, 1, 3)
405
+ .reshape(B, out.shape[1], C)
406
+ )
407
+ return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
408
+
409
+ def forward(self, x, **kwargs):
410
+ h_ = x
411
+ h_ = self.attention(h_)
412
+ h_ = self.proj_out(h_)
413
+ return x + h_
414
+
415
+
416
+ class VideoDecoder(nn.Module):
417
+ def __init__(
418
+ self,
419
+ *,
420
+ ch,
421
+ out_ch,
422
+ ch_mult=(1, 2, 4, 8),
423
+ num_res_blocks,
424
+ attn_resolutions,
425
+ dropout=0.0,
426
+ resamp_with_conv=True,
427
+ in_channels,
428
+ resolution,
429
+ z_channels,
430
+ give_pre_end=False,
431
+ tanh_out=False,
432
+ use_linear_attn=False,
433
+ attn_level=[2, 3],
434
+ video_kernel_size=[3, 1, 1],
435
+ alpha: float = 0.0,
436
+ merge_strategy: str = "learned",
437
+ **kwargs,
438
+ ):
439
+ super().__init__()
440
+ self.video_kernel_size = video_kernel_size
441
+ self.alpha = alpha
442
+ self.merge_strategy = merge_strategy
443
+ self.ch = ch
444
+ self.temb_ch = 0
445
+ self.num_resolutions = len(ch_mult)
446
+ self.num_res_blocks = num_res_blocks
447
+ self.resolution = resolution
448
+ self.in_channels = in_channels
449
+ self.give_pre_end = give_pre_end
450
+ self.tanh_out = tanh_out
451
+ self.attn_level = attn_level
452
+ # compute in_ch_mult, block_in and curr_res at lowest res
453
+ in_ch_mult = (1,) + tuple(ch_mult)
454
+ block_in = ch * ch_mult[self.num_resolutions - 1]
455
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
456
+ self.z_shape = (1, z_channels, curr_res, curr_res)
457
+
458
+ # z to block_in
459
+ self.conv_in = torch.nn.Conv2d(
460
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
461
+ )
462
+
463
+ # middle
464
+ self.mid = nn.Module()
465
+ self.mid.block_1 = VideoResBlock(
466
+ in_channels=block_in,
467
+ out_channels=block_in,
468
+ temb_channels=self.temb_ch,
469
+ dropout=dropout,
470
+ video_kernel_size=self.video_kernel_size,
471
+ alpha=self.alpha,
472
+ merge_strategy=self.merge_strategy,
473
+ )
474
+ self.mid.attn_1 = Attention(block_in)
475
+ self.mid.block_2 = VideoResBlock(
476
+ in_channels=block_in,
477
+ out_channels=block_in,
478
+ temb_channels=self.temb_ch,
479
+ dropout=dropout,
480
+ video_kernel_size=self.video_kernel_size,
481
+ alpha=self.alpha,
482
+ merge_strategy=self.merge_strategy,
483
+ )
484
+
485
+ # upsampling
486
+ self.up = nn.ModuleList()
487
+ self.attn_refinement = nn.ModuleList()
488
+ for i_level in reversed(range(self.num_resolutions)):
489
+ block = nn.ModuleList()
490
+ attn = nn.ModuleList()
491
+ block_out = ch * ch_mult[i_level]
492
+ for i_block in range(self.num_res_blocks + 1):
493
+ block.append(
494
+ VideoResBlock(
495
+ in_channels=block_in,
496
+ out_channels=block_out,
497
+ temb_channels=self.temb_ch,
498
+ dropout=dropout,
499
+ video_kernel_size=self.video_kernel_size,
500
+ alpha=self.alpha,
501
+ merge_strategy=self.merge_strategy,
502
+ )
503
+ )
504
+ block_in = block_out
505
+ if curr_res in attn_resolutions:
506
+ attn.append(Attention(block_in))
507
+ up = nn.Module()
508
+ up.block = block
509
+ up.attn = attn
510
+ if i_level != 0:
511
+ up.upsample = Upsample(block_in, resamp_with_conv)
512
+ curr_res = curr_res * 2
513
+ self.up.insert(0, up) # prepend to get consistent order
514
+
515
+ if i_level in self.attn_level:
516
+ self.attn_refinement.insert(0, AttentionCombiner(block_in))
517
+ else:
518
+ self.attn_refinement.insert(0, ConvCombiner(block_in))
519
+ # end
520
+ self.norm_out = GroupNorm(block_in)
521
+ self.attn_refinement.append(ConvCombiner(block_in))
522
+ self.conv_out = DecoderConv3D(
523
+ block_in, out_ch, kernel_size=3, stride=1, padding=1, video_kernel_size=self.video_kernel_size
524
+ )
525
+
526
+ def forward(self, z, ref_context=None, **kwargs):
527
+ ## ref_context: b c 2 h w, 2 means starting and ending frame
528
+ # assert z.shape[1:] == self.z_shape[1:]
529
+ self.last_z_shape = z.shape
530
+ # timestep embedding
531
+ temb = None
532
+
533
+ # z to block_in
534
+ h = self.conv_in(z)
535
+
536
+ # middle
537
+ h = self.mid.block_1(h, temb, **kwargs)
538
+ h = self.mid.attn_1(h, **kwargs)
539
+ h = self.mid.block_2(h, temb, **kwargs)
540
+
541
+ # upsampling
542
+ for i_level in reversed(range(self.num_resolutions)):
543
+ for i_block in range(self.num_res_blocks + 1):
544
+ h = self.up[i_level].block[i_block](h, temb, **kwargs)
545
+ if len(self.up[i_level].attn) > 0:
546
+ h = self.up[i_level].attn[i_block](h, **kwargs)
547
+ if ref_context:
548
+ h = self.attn_refinement[i_level](x=h, context=ref_context[i_level])
549
+ if i_level != 0:
550
+ h = self.up[i_level].upsample(h)
551
+
552
+ # end
553
+ if self.give_pre_end:
554
+ return h
555
+
556
+ h = self.norm_out(h)
557
+ h = nonlinearity(h)
558
+ if ref_context:
559
+ # print(h.shape, ref_context[i_level].shape) #torch.Size([8, 128, 256, 256]) torch.Size([1, 128, 2, 256, 256])
560
+ h = self.attn_refinement[-1](x=h, context=ref_context[-1])
561
+ h = self.conv_out(h, **kwargs)
562
+ if self.tanh_out:
563
+ h = torch.tanh(h)
564
+ return h
565
+
566
+
567
+ class TimeStackBlock(torch.nn.Module):
568
+ def __init__(
569
+ self,
570
+ channels: int,
571
+ emb_channels: int,
572
+ dropout: float,
573
+ out_channels: int = None,
574
+ use_conv: bool = False,
575
+ use_scale_shift_norm: bool = False,
576
+ dims: int = 2,
577
+ use_checkpoint: bool = False,
578
+ up: bool = False,
579
+ down: bool = False,
580
+ kernel_size: int = 3,
581
+ exchange_temb_dims: bool = False,
582
+ skip_t_emb: bool = False,
583
+ ):
584
+ super().__init__()
585
+ self.channels = channels
586
+ self.emb_channels = emb_channels
587
+ self.dropout = dropout
588
+ self.out_channels = out_channels or channels
589
+ self.use_conv = use_conv
590
+ self.use_checkpoint = use_checkpoint
591
+ self.use_scale_shift_norm = use_scale_shift_norm
592
+ self.exchange_temb_dims = exchange_temb_dims
593
+
594
+ if isinstance(kernel_size, list):
595
+ padding = [k // 2 for k in kernel_size]
596
+ else:
597
+ padding = kernel_size // 2
598
+
599
+ self.in_layers = nn.Sequential(
600
+ normalization(channels),
601
+ nn.SiLU(),
602
+ conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
603
+ )
604
+
605
+ self.updown = up or down
606
+
607
+ if up:
608
+ self.h_upd = Upsample(channels, False, dims)
609
+ self.x_upd = Upsample(channels, False, dims)
610
+ elif down:
611
+ self.h_upd = Downsample(channels, False, dims)
612
+ self.x_upd = Downsample(channels, False, dims)
613
+ else:
614
+ self.h_upd = self.x_upd = nn.Identity()
615
+
616
+ self.skip_t_emb = skip_t_emb
617
+ self.emb_out_channels = (
618
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels
619
+ )
620
+ if self.skip_t_emb:
621
+ # print(f"Skipping timestep embedding in {self.__class__.__name__}")
622
+ assert not self.use_scale_shift_norm
623
+ self.emb_layers = None
624
+ self.exchange_temb_dims = False
625
+ else:
626
+ self.emb_layers = nn.Sequential(
627
+ nn.SiLU(),
628
+ linear(
629
+ emb_channels,
630
+ self.emb_out_channels,
631
+ ),
632
+ )
633
+
634
+ self.out_layers = nn.Sequential(
635
+ normalization(self.out_channels),
636
+ nn.SiLU(),
637
+ nn.Dropout(p=dropout),
638
+ zero_module(
639
+ conv_nd(
640
+ dims,
641
+ self.out_channels,
642
+ self.out_channels,
643
+ kernel_size,
644
+ padding=padding,
645
+ )
646
+ ),
647
+ )
648
+
649
+ if self.out_channels == channels:
650
+ self.skip_connection = nn.Identity()
651
+ elif use_conv:
652
+ self.skip_connection = conv_nd(
653
+ dims, channels, self.out_channels, kernel_size, padding=padding
654
+ )
655
+ else:
656
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
657
+
658
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
659
+ if self.updown:
660
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
661
+ h = in_rest(x)
662
+ h = self.h_upd(h)
663
+ x = self.x_upd(x)
664
+ h = in_conv(h)
665
+ else:
666
+ h = self.in_layers(x)
667
+
668
+ if self.skip_t_emb:
669
+ emb_out = torch.zeros_like(h)
670
+ else:
671
+ emb_out = self.emb_layers(emb).type(h.dtype)
672
+ while len(emb_out.shape) < len(h.shape):
673
+ emb_out = emb_out[..., None]
674
+ if self.use_scale_shift_norm:
675
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
676
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
677
+ h = out_norm(h) * (1 + scale) + shift
678
+ h = out_rest(h)
679
+ else:
680
+ if self.exchange_temb_dims:
681
+ emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
682
+ h = h + emb_out
683
+ h = self.out_layers(h)
684
+ return self.skip_connection(x) + h
685
+
686
+
687
+ class VideoResBlock(ResnetBlock):
688
+ def __init__(
689
+ self,
690
+ out_channels,
691
+ *args,
692
+ dropout=0.0,
693
+ video_kernel_size=3,
694
+ alpha=0.0,
695
+ merge_strategy="learned",
696
+ **kwargs,
697
+ ):
698
+ super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
699
+ if video_kernel_size is None:
700
+ video_kernel_size = [3, 1, 1]
701
+ self.time_stack = TimeStackBlock(
702
+ channels=out_channels,
703
+ emb_channels=0,
704
+ dropout=dropout,
705
+ dims=3,
706
+ use_scale_shift_norm=False,
707
+ use_conv=False,
708
+ up=False,
709
+ down=False,
710
+ kernel_size=video_kernel_size,
711
+ use_checkpoint=True,
712
+ skip_t_emb=True,
713
+ )
714
+
715
+ self.merge_strategy = merge_strategy
716
+ if self.merge_strategy == "fixed":
717
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
718
+ elif self.merge_strategy == "learned":
719
+ self.register_parameter(
720
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
721
+ )
722
+ else:
723
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
724
+
725
+ def get_alpha(self, bs):
726
+ if self.merge_strategy == "fixed":
727
+ return self.mix_factor
728
+ elif self.merge_strategy == "learned":
729
+ return torch.sigmoid(self.mix_factor)
730
+ else:
731
+ raise NotImplementedError()
732
+
733
+ def forward(self, x, temb, skip_video=False, timesteps=None):
734
+ assert isinstance(timesteps, int)
735
+
736
+ b, c, h, w = x.shape
737
+
738
+ x = super().forward(x, temb)
739
+
740
+ if not skip_video:
741
+ x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
742
+
743
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
744
+
745
+ x = self.time_stack(x, temb)
746
+
747
+ alpha = self.get_alpha(bs=b // timesteps)
748
+ x = alpha * x + (1.0 - alpha) * x_mix
749
+
750
+ x = rearrange(x, "b c t h w -> (b t) c h w")
751
+ return x
752
+
753
+
754
+ class DecoderConv3D(torch.nn.Conv2d):
755
+ def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
756
+ super().__init__(in_channels, out_channels, *args, **kwargs)
757
+ if isinstance(video_kernel_size, list):
758
+ padding = [int(k // 2) for k in video_kernel_size]
759
+ else:
760
+ padding = int(video_kernel_size // 2)
761
+
762
+ self.time_mix_conv = torch.nn.Conv3d(
763
+ in_channels=out_channels,
764
+ out_channels=out_channels,
765
+ kernel_size=video_kernel_size,
766
+ padding=padding,
767
+ )
768
+
769
+ def forward(self, input, timesteps, skip_video=False):
770
+ x = super().forward(input)
771
+ if skip_video:
772
+ return x
773
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
774
+ x = self.time_mix_conv(x)
775
+ return rearrange(x, "b c t h w -> (b t) c h w")
776
+
777
+
778
+ class VideoAutoencoderKL(torch.nn.Module, PyTorchModelHubMixin):
779
+ def __init__(self,
780
+ double_z=True,
781
+ z_channels=4,
782
+ resolution=256,
783
+ in_channels=3,
784
+ out_ch=3,
785
+ ch=128,
786
+ ch_mult=[],
787
+ num_res_blocks=2,
788
+ attn_resolutions=[],
789
+ dropout=0.0,
790
+ ):
791
+ super().__init__()
792
+ self.encoder = Encoder(double_z=double_z, z_channels=z_channels, resolution=resolution, in_channels=in_channels,
793
+ out_ch=out_ch, ch=ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks,
794
+ attn_resolutions=attn_resolutions, dropout=dropout)
795
+ self.decoder = VideoDecoder(double_z=double_z, z_channels=z_channels, resolution=resolution,
796
+ in_channels=in_channels, out_ch=out_ch, ch=ch, ch_mult=ch_mult,
797
+ num_res_blocks=num_res_blocks, attn_resolutions=attn_resolutions, dropout=dropout)
798
+ self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1)
799
+ self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
800
+ self.scale_factor = 0.18215
801
+
802
+ def encode(self, x, return_hidden_states=False, **kwargs):
803
+ if return_hidden_states:
804
+ h, hidden = self.encoder(x, return_hidden_states)
805
+ moments = self.quant_conv(h)
806
+ posterior = DiagonalGaussianDistribution(moments)
807
+ return posterior, hidden
808
+ else:
809
+ h = self.encoder(x)
810
+ moments = self.quant_conv(h)
811
+ posterior = DiagonalGaussianDistribution(moments)
812
+ return posterior, None
813
+
814
+ def decode(self, z, **kwargs):
815
+ if len(kwargs) == 0:
816
+ z = self.post_quant_conv(z)
817
+ dec = self.decoder(z, **kwargs)
818
+ return dec
819
+
820
+ @property
821
+ def device(self):
822
+ return next(self.parameters()).device
823
+
824
+ @property
825
+ def dtype(self):
826
+ return next(self.parameters()).dtype
imgs/1.jpg ADDED
imgs/2.jpg ADDED
imgs/3.jpg ADDED
memory_management.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from contextlib import contextmanager
3
+
4
+
5
+ high_vram = False
6
+ gpu = torch.device('cuda')
7
+ cpu = torch.device('cpu')
8
+
9
+ torch.zeros((1, 1)).to(gpu, torch.float32)
10
+ torch.cuda.empty_cache()
11
+
12
+ models_in_gpu = []
13
+
14
+
15
+ @contextmanager
16
+ def movable_bnb_model(m):
17
+ if hasattr(m, 'quantization_method'):
18
+ m.quantization_method_backup = m.quantization_method
19
+ del m.quantization_method
20
+ try:
21
+ yield None
22
+ finally:
23
+ if hasattr(m, 'quantization_method_backup'):
24
+ m.quantization_method = m.quantization_method_backup
25
+ del m.quantization_method_backup
26
+ return
27
+
28
+
29
+ def load_models_to_gpu(models):
30
+ global models_in_gpu
31
+
32
+ if not isinstance(models, (tuple, list)):
33
+ models = [models]
34
+
35
+ models_to_remain = [m for m in set(models) if m in models_in_gpu]
36
+ models_to_load = [m for m in set(models) if m not in models_in_gpu]
37
+ models_to_unload = [m for m in set(models_in_gpu) if m not in models_to_remain]
38
+
39
+ if not high_vram:
40
+ for m in models_to_unload:
41
+ with movable_bnb_model(m):
42
+ m.to(cpu)
43
+ print('Unload to CPU:', m.__class__.__name__)
44
+ models_in_gpu = models_to_remain
45
+
46
+ for m in models_to_load:
47
+ with movable_bnb_model(m):
48
+ m.to(gpu)
49
+ print('Load to GPU:', m.__class__.__name__)
50
+
51
+ models_in_gpu = list(set(models_in_gpu + models))
52
+ torch.cuda.empty_cache()
53
+ return
54
+
55
+
56
+ def unload_all_models(extra_models=None):
57
+ global models_in_gpu
58
+
59
+ if extra_models is None:
60
+ extra_models = []
61
+
62
+ if not isinstance(extra_models, (tuple, list)):
63
+ extra_models = [extra_models]
64
+
65
+ models_in_gpu = list(set(models_in_gpu + extra_models))
66
+
67
+ return load_models_to_gpu([])
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.28.0
2
+ transformers==4.41.1
3
+ gradio==4.31.5
4
+ bitsandbytes==0.43.1
5
+ accelerate==0.30.1
6
+ protobuf==3.20
7
+ opencv-python
8
+ tensorboardX
9
+ safetensors
10
+ pillow
11
+ einops
12
+ peft
13
+ xformers
14
+ onnxruntime
15
+ av
16
+ torchvision
wd14tagger.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags
2
+
3
+
4
+ import os
5
+ import csv
6
+ import numpy as np
7
+ import onnxruntime as ort
8
+
9
+ from PIL import Image
10
+ from onnxruntime import InferenceSession
11
+ from torch.hub import download_url_to_file
12
+
13
+
14
+ global_model = None
15
+ global_csv = None
16
+
17
+
18
+ def download_model(url, local_path):
19
+ if os.path.exists(local_path):
20
+ return local_path
21
+
22
+ temp_path = local_path + '.tmp'
23
+ download_url_to_file(url=url, dst=temp_path)
24
+ os.rename(temp_path, local_path)
25
+ return local_path
26
+
27
+
28
+ def default_interrogator(image, threshold=0.35, character_threshold=0.85, exclude_tags=""):
29
+ global global_model, global_csv
30
+
31
+ model_name = "wd-v1-4-moat-tagger-v2"
32
+
33
+ model_onnx_filename = download_model(
34
+ url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.onnx',
35
+ local_path=f'./{model_name}.onnx',
36
+ )
37
+
38
+ model_csv_filename = download_model(
39
+ url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.csv',
40
+ local_path=f'./{model_name}.csv',
41
+ )
42
+
43
+ if global_model is not None:
44
+ model = global_model
45
+ else:
46
+ # assert 'CUDAExecutionProvider' in ort.get_available_providers(), 'CUDA Install Failed!'
47
+ # model = InferenceSession(model_onnx_filename, providers=['CUDAExecutionProvider'])
48
+ model = InferenceSession(model_onnx_filename, providers=['CPUExecutionProvider'])
49
+ global_model = model
50
+
51
+ input = model.get_inputs()[0]
52
+ height = input.shape[1]
53
+
54
+ if isinstance(image, str):
55
+ image = Image.open(image) # RGB
56
+ elif isinstance(image, np.ndarray):
57
+ image = Image.fromarray(image)
58
+ else:
59
+ image = image
60
+
61
+ ratio = float(height) / max(image.size)
62
+ new_size = tuple([int(x*ratio) for x in image.size])
63
+ image = image.resize(new_size, Image.LANCZOS)
64
+ square = Image.new("RGB", (height, height), (255, 255, 255))
65
+ square.paste(image, ((height-new_size[0])//2, (height-new_size[1])//2))
66
+
67
+ image = np.array(square).astype(np.float32)
68
+ image = image[:, :, ::-1] # RGB -> BGR
69
+ image = np.expand_dims(image, 0)
70
+
71
+ if global_csv is not None:
72
+ csv_lines = global_csv
73
+ else:
74
+ csv_lines = []
75
+ with open(model_csv_filename) as f:
76
+ reader = csv.reader(f)
77
+ next(reader)
78
+ for row in reader:
79
+ csv_lines.append(row)
80
+ global_csv = csv_lines
81
+
82
+ tags = []
83
+ general_index = None
84
+ character_index = None
85
+ for line_num, row in enumerate(csv_lines):
86
+ if general_index is None and row[2] == "0":
87
+ general_index = line_num
88
+ elif character_index is None and row[2] == "4":
89
+ character_index = line_num
90
+ tags.append(row[1])
91
+
92
+ label_name = model.get_outputs()[0].name
93
+ probs = model.run([label_name], {input.name: image})[0]
94
+
95
+ result = list(zip(tags, probs[0]))
96
+
97
+ general = [item for item in result[general_index:character_index] if item[1] > threshold]
98
+ character = [item for item in result[character_index:] if item[1] > character_threshold]
99
+
100
+ all = character + general
101
+ remove = [s.strip() for s in exclude_tags.lower().split(",")]
102
+ all = [tag for tag in all if tag[0] not in remove]
103
+
104
+ res = ", ".join((item[0].replace("(", "\\(").replace(")", "\\)") for item in all)).replace('_', ' ')
105
+ return res