Spaces:
Runtime error
Runtime error
initial commit
Browse files- README.md +5 -5
- app.py +316 -0
- requirements.txt +19 -0
- share_btn.py +72 -0
- showone/models/__init__.py +15 -0
- showone/models/transformer_temporal.py +179 -0
- showone/models/unet_3d_blocks.py +1619 -0
- showone/models/unet_3d_condition.py +985 -0
- showone/pipelines/__init__.py +37 -0
- showone/pipelines/pipeline_t2v_base_pixel.py +775 -0
- showone/pipelines/pipeline_t2v_interp_pixel.py +798 -0
- showone/pipelines/pipeline_t2v_sr_pixel.py +877 -0
- showone/pipelines/pipeline_t2v_sr_pixel_cond.py +890 -0
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
-
title: Show
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
1 |
---
|
2 |
+
title: Show-1
|
3 |
+
emoji: 🎬
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: gray
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.39.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
app.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from share_btn import community_icon_html, loading_icon_html, share_js
|
3 |
+
import torch
|
4 |
+
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
5 |
+
from diffusers.utils import export_to_video
|
6 |
+
|
7 |
+
import os
|
8 |
+
import imageio
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from diffusers import IFSuperResolutionPipeline, VideoToVideoSDPipeline
|
14 |
+
from diffusers.utils import export_to_video
|
15 |
+
from diffusers.utils.torch_utils import randn_tensor
|
16 |
+
|
17 |
+
from showone.pipelines import TextToVideoIFPipeline, TextToVideoIFInterpPipeline, TextToVideoIFSuperResolutionPipeline
|
18 |
+
from showone.pipelines.pipeline_t2v_base_pixel import tensor2vid
|
19 |
+
from showone.pipelines.pipeline_t2v_sr_pixel_cond import TextToVideoIFSuperResolutionPipeline_Cond
|
20 |
+
|
21 |
+
|
22 |
+
# Base Model
|
23 |
+
pretrained_model_path = "showlab/show-1-base"
|
24 |
+
pipe_base = TextToVideoIFPipeline.from_pretrained(
|
25 |
+
pretrained_model_path,
|
26 |
+
torch_dtype=torch.float16,
|
27 |
+
variant="fp16"
|
28 |
+
)
|
29 |
+
pipe_base.enable_model_cpu_offload()
|
30 |
+
|
31 |
+
# Interpolation Model
|
32 |
+
pretrained_model_path = "showlab/show-1-interpolation"
|
33 |
+
pipe_interp_1 = TextToVideoIFInterpPipeline.from_pretrained(
|
34 |
+
pretrained_model_path,
|
35 |
+
torch_dtype=torch.float16,
|
36 |
+
variant="fp16"
|
37 |
+
)
|
38 |
+
pipe_interp_1.enable_model_cpu_offload()
|
39 |
+
|
40 |
+
# Super-Resolution Model 1
|
41 |
+
# Image super-resolution model from DeepFloyd https://huggingface.co/DeepFloyd/IF-II-L-v1.0
|
42 |
+
pretrained_model_path = "DeepFloyd/IF-II-L-v1.0"
|
43 |
+
pipe_sr_1_image = IFSuperResolutionPipeline.from_pretrained(
|
44 |
+
pretrained_model_path,
|
45 |
+
text_encoder=None,
|
46 |
+
torch_dtype=torch.float16,
|
47 |
+
variant="fp16"
|
48 |
+
)
|
49 |
+
pipe_sr_1_image.enable_model_cpu_offload()
|
50 |
+
|
51 |
+
pretrained_model_path = "showlab/show-1-sr1"
|
52 |
+
pipe_sr_1_cond = TextToVideoIFSuperResolutionPipeline_Cond.from_pretrained(
|
53 |
+
pretrained_model_path,
|
54 |
+
torch_dtype=torch.float16
|
55 |
+
)
|
56 |
+
pipe_sr_1_cond.enable_model_cpu_offload()
|
57 |
+
|
58 |
+
# Super-Resolution Model 2
|
59 |
+
pretrained_model_path = "showlab/show-1-sr2"
|
60 |
+
pipe_sr_2 = VideoToVideoSDPipeline.from_pretrained(
|
61 |
+
pretrained_model_path,
|
62 |
+
torch_dtype=torch.float16
|
63 |
+
)
|
64 |
+
pipe_sr_2.enable_model_cpu_offload()
|
65 |
+
pipe_sr_2.enable_vae_slicing()
|
66 |
+
|
67 |
+
def infer(prompt):
|
68 |
+
print(prompt)
|
69 |
+
negative_prompt = "low resolution, blur"
|
70 |
+
|
71 |
+
# Text embeds
|
72 |
+
prompt_embeds, negative_embeds = pipe_base.encode_prompt(prompt)
|
73 |
+
|
74 |
+
# Keyframes generation (8x64x40, 2fps)
|
75 |
+
video_frames = pipe_base(
|
76 |
+
prompt_embeds=prompt_embeds,
|
77 |
+
negative_prompt_embeds=negative_embeds,
|
78 |
+
num_frames=8,
|
79 |
+
height=40,
|
80 |
+
width=64,
|
81 |
+
num_inference_steps=75,
|
82 |
+
guidance_scale=9.0,
|
83 |
+
output_type="pt"
|
84 |
+
).frames
|
85 |
+
|
86 |
+
# Frame interpolation (8x64x40, 2fps -> 29x64x40, 7.5fps)
|
87 |
+
bsz, channel, num_frames, height, width = video_frames.shape
|
88 |
+
new_num_frames = 3 * (num_frames - 1) + num_frames
|
89 |
+
new_video_frames = torch.zeros((bsz, channel, new_num_frames, height, width),
|
90 |
+
dtype=video_frames.dtype, device=video_frames.device)
|
91 |
+
new_video_frames[:, :, torch.arange(0, new_num_frames, 4), ...] = video_frames
|
92 |
+
init_noise = randn_tensor((bsz, channel, 5, height, width), dtype=video_frames.dtype,
|
93 |
+
device=video_frames.device)
|
94 |
+
|
95 |
+
for i in range(num_frames - 1):
|
96 |
+
batch_i = torch.zeros((bsz, channel, 5, height, width), dtype=video_frames.dtype, device=video_frames.device)
|
97 |
+
batch_i[:, :, 0, ...] = video_frames[:, :, i, ...]
|
98 |
+
batch_i[:, :, -1, ...] = video_frames[:, :, i + 1, ...]
|
99 |
+
batch_i = pipe_interp_1(
|
100 |
+
pixel_values=batch_i,
|
101 |
+
prompt_embeds=prompt_embeds,
|
102 |
+
negative_prompt_embeds=negative_embeds,
|
103 |
+
num_frames=batch_i.shape[2],
|
104 |
+
height=40,
|
105 |
+
width=64,
|
106 |
+
num_inference_steps=50,
|
107 |
+
guidance_scale=4.0,
|
108 |
+
output_type="pt",
|
109 |
+
init_noise=init_noise,
|
110 |
+
cond_interpolation=True,
|
111 |
+
).frames
|
112 |
+
|
113 |
+
new_video_frames[:, :, i * 4:i * 4 + 5, ...] = batch_i
|
114 |
+
|
115 |
+
video_frames = new_video_frames
|
116 |
+
|
117 |
+
# Super-resolution 1 (29x64x40 -> 29x256x160)
|
118 |
+
bsz, channel, num_frames, height, width = video_frames.shape
|
119 |
+
window_size, stride = 8, 7
|
120 |
+
new_video_frames = torch.zeros(
|
121 |
+
(bsz, channel, num_frames, height * 4, width * 4),
|
122 |
+
dtype=video_frames.dtype,
|
123 |
+
device=video_frames.device)
|
124 |
+
for i in range(0, num_frames - window_size + 1, stride):
|
125 |
+
batch_i = video_frames[:, :, i:i + window_size, ...]
|
126 |
+
|
127 |
+
if i == 0:
|
128 |
+
first_frame_cond = pipe_sr_1_image(
|
129 |
+
image=video_frames[:, :, 0, ...],
|
130 |
+
prompt_embeds=prompt_embeds,
|
131 |
+
negative_prompt_embeds=negative_embeds,
|
132 |
+
height=height * 4,
|
133 |
+
width=width * 4,
|
134 |
+
num_inference_steps=50,
|
135 |
+
guidance_scale=4.0,
|
136 |
+
noise_level=150,
|
137 |
+
output_type="pt"
|
138 |
+
).images
|
139 |
+
first_frame_cond = first_frame_cond.unsqueeze(2)
|
140 |
+
else:
|
141 |
+
first_frame_cond = new_video_frames[:, :, i:i + 1, ...]
|
142 |
+
|
143 |
+
batch_i = pipe_sr_1_cond(
|
144 |
+
image=batch_i,
|
145 |
+
prompt_embeds=prompt_embeds,
|
146 |
+
negative_prompt_embeds=negative_embeds,
|
147 |
+
first_frame_cond=first_frame_cond,
|
148 |
+
height=height * 4,
|
149 |
+
width=width * 4,
|
150 |
+
num_inference_steps=50,
|
151 |
+
guidance_scale=7.0,
|
152 |
+
noise_level=250,
|
153 |
+
output_type="pt"
|
154 |
+
).frames
|
155 |
+
new_video_frames[:, :, i:i + window_size, ...] = batch_i
|
156 |
+
|
157 |
+
video_frames = new_video_frames
|
158 |
+
|
159 |
+
# Super-resolution 2 (29x256x160 -> 29x576x320)
|
160 |
+
video_frames = [Image.fromarray(frame).resize((576, 320)) for frame in tensor2vid(video_frames.clone())]
|
161 |
+
video_frames = pipe_sr_2(
|
162 |
+
prompt,
|
163 |
+
negative_prompt=negative_prompt,
|
164 |
+
video=video_frames,
|
165 |
+
strength=0.8,
|
166 |
+
num_inference_steps=50,
|
167 |
+
).frames
|
168 |
+
|
169 |
+
video_path = export_to_video(video_frames)
|
170 |
+
print(video_path)
|
171 |
+
return video_path, gr.Group.update(visible=True)
|
172 |
+
|
173 |
+
css = """
|
174 |
+
#col-container {max-width: 510px; margin-left: auto; margin-right: auto;}
|
175 |
+
a {text-decoration-line: underline; font-weight: 600;}
|
176 |
+
.animate-spin {
|
177 |
+
animation: spin 1s linear infinite;
|
178 |
+
}
|
179 |
+
|
180 |
+
@keyframes spin {
|
181 |
+
from {
|
182 |
+
transform: rotate(0deg);
|
183 |
+
}
|
184 |
+
to {
|
185 |
+
transform: rotate(360deg);
|
186 |
+
}
|
187 |
+
}
|
188 |
+
|
189 |
+
#share-btn-container {
|
190 |
+
display: flex;
|
191 |
+
padding-left: 0.5rem !important;
|
192 |
+
padding-right: 0.5rem !important;
|
193 |
+
background-color: #000000;
|
194 |
+
justify-content: center;
|
195 |
+
align-items: center;
|
196 |
+
border-radius: 9999px !important;
|
197 |
+
max-width: 15rem;
|
198 |
+
height: 36px;
|
199 |
+
}
|
200 |
+
|
201 |
+
div#share-btn-container > div {
|
202 |
+
flex-direction: row;
|
203 |
+
background: black;
|
204 |
+
align-items: center;
|
205 |
+
}
|
206 |
+
|
207 |
+
#share-btn-container:hover {
|
208 |
+
background-color: #060606;
|
209 |
+
}
|
210 |
+
|
211 |
+
#share-btn {
|
212 |
+
all: initial;
|
213 |
+
color: #ffffff;
|
214 |
+
font-weight: 600;
|
215 |
+
cursor:pointer;
|
216 |
+
font-family: 'IBM Plex Sans', sans-serif;
|
217 |
+
margin-left: 0.5rem !important;
|
218 |
+
padding-top: 0.5rem !important;
|
219 |
+
padding-bottom: 0.5rem !important;
|
220 |
+
right:0;
|
221 |
+
}
|
222 |
+
|
223 |
+
#share-btn * {
|
224 |
+
all: unset;
|
225 |
+
}
|
226 |
+
|
227 |
+
#share-btn-container div:nth-child(-n+2){
|
228 |
+
width: auto !important;
|
229 |
+
min-height: 0px !important;
|
230 |
+
}
|
231 |
+
|
232 |
+
#share-btn-container .wrap {
|
233 |
+
display: none !important;
|
234 |
+
}
|
235 |
+
|
236 |
+
#share-btn-container.hidden {
|
237 |
+
display: none!important;
|
238 |
+
}
|
239 |
+
img[src*='#center'] {
|
240 |
+
display: inline-block;
|
241 |
+
margin: unset;
|
242 |
+
}
|
243 |
+
|
244 |
+
.footer {
|
245 |
+
margin-bottom: 45px;
|
246 |
+
margin-top: 10px;
|
247 |
+
text-align: center;
|
248 |
+
border-bottom: 1px solid #e5e5e5;
|
249 |
+
}
|
250 |
+
.footer>p {
|
251 |
+
font-size: .8rem;
|
252 |
+
display: inline-block;
|
253 |
+
padding: 0 10px;
|
254 |
+
transform: translateY(10px);
|
255 |
+
background: white;
|
256 |
+
}
|
257 |
+
.dark .footer {
|
258 |
+
border-color: #303030;
|
259 |
+
}
|
260 |
+
.dark .footer>p {
|
261 |
+
background: #0b0f19;
|
262 |
+
}
|
263 |
+
"""
|
264 |
+
|
265 |
+
with gr.Blocks(css=css) as demo:
|
266 |
+
with gr.Column(elem_id="col-container"):
|
267 |
+
gr.Markdown(
|
268 |
+
"""
|
269 |
+
<h1 style="text-align: center;">Show-1 Text-to-Video</h1>
|
270 |
+
<p style="text-align: center;">
|
271 |
+
A text-to-video generation model that marries the strength and alleviates the weakness of pixel-based and latent-based VDMs. <br />
|
272 |
+
</p>
|
273 |
+
|
274 |
+
<p style="text-align: center;">
|
275 |
+
<a href="https://arxiv.org/abs/2309.15818" target="_blank">Paper</a> |
|
276 |
+
<a href="https://showlab.github.io/Show-1" target="_blank">Project Page</a> |
|
277 |
+
<a href="https://github.com/showlab/Show-1" target="_blank">Github</a>
|
278 |
+
</p>
|
279 |
+
|
280 |
+
"""
|
281 |
+
)
|
282 |
+
|
283 |
+
prompt_in = gr.Textbox(label="Prompt", placeholder="A panda taking a selfie", elem_id="prompt-in")
|
284 |
+
#neg_prompt = gr.Textbox(label="Negative prompt", value="text, watermark, copyright, blurry, nsfw", elem_id="neg-prompt-in")
|
285 |
+
#inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=40, interactive=False)
|
286 |
+
submit_btn = gr.Button("Submit")
|
287 |
+
video_result = gr.Video(label="Video Output", elem_id="video-output")
|
288 |
+
|
289 |
+
with gr.Row():
|
290 |
+
with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
|
291 |
+
community_icon = gr.HTML(community_icon_html)
|
292 |
+
loading_icon = gr.HTML(loading_icon_html)
|
293 |
+
share_button = gr.Button("Share with Community", elem_id="share-btn")
|
294 |
+
|
295 |
+
gr.Markdown("""
|
296 |
+
[![Duplicate this Space](https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-lg.svg#center)](https://huggingface.co/spaces/showlab/Show-1?duplicate=true)
|
297 |
+
""")
|
298 |
+
|
299 |
+
gr.HTML("""
|
300 |
+
<div class="footer">
|
301 |
+
<p>
|
302 |
+
Demo adapted from <a href="https://huggingface.co/spaces/fffiloni/zeroscope" target="_blank">zeroscope</a>
|
303 |
+
by 🤗 <a href="https://twitter.com/fffiloni" target="_blank">Sylvain Filoni</a>
|
304 |
+
</p>
|
305 |
+
</div>
|
306 |
+
""")
|
307 |
+
|
308 |
+
submit_btn.click(fn=infer,
|
309 |
+
inputs=[prompt_in],
|
310 |
+
outputs=[video_result, share_group],
|
311 |
+
api_name="show-1")
|
312 |
+
|
313 |
+
share_button.click(None, [], [], _js=share_js)
|
314 |
+
|
315 |
+
demo.queue(max_size=12).launch(show_api=True, share=True)
|
316 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.19.3
|
2 |
+
bitsandbytes==0.35.4
|
3 |
+
decord==0.6.0
|
4 |
+
transformers==4.29.1
|
5 |
+
accelerate==0.18.0
|
6 |
+
imageio==2.14.1
|
7 |
+
torch==2.0.0
|
8 |
+
torchvision==0.15.0
|
9 |
+
beautifulsoup4
|
10 |
+
tensorboard
|
11 |
+
sentencepiece
|
12 |
+
safetensors
|
13 |
+
modelcards
|
14 |
+
omegaconf
|
15 |
+
pandas
|
16 |
+
einops
|
17 |
+
ftfy
|
18 |
+
opencv-python
|
19 |
+
|
share_btn.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
community_icon_html = """<svg id="share-btn-share-icon" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32">
|
2 |
+
<path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path>
|
3 |
+
<path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path>
|
4 |
+
</svg>"""
|
5 |
+
|
6 |
+
loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" class="animate-spin"
|
7 |
+
style="color: #ffffff;
|
8 |
+
"
|
9 |
+
xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
|
10 |
+
|
11 |
+
share_js = """async () => {
|
12 |
+
async function uploadFile(file){
|
13 |
+
const UPLOAD_URL = 'https://huggingface.co/uploads';
|
14 |
+
const response = await fetch(UPLOAD_URL, {
|
15 |
+
method: 'POST',
|
16 |
+
headers: {
|
17 |
+
'Content-Type': file.type,
|
18 |
+
'X-Requested-With': 'XMLHttpRequest',
|
19 |
+
},
|
20 |
+
body: file, /// <- File inherits from Blob
|
21 |
+
});
|
22 |
+
const url = await response.text();
|
23 |
+
return url;
|
24 |
+
}
|
25 |
+
|
26 |
+
async function getVideoBlobFile(videoEL){
|
27 |
+
const res = await fetch(videoEL.src);
|
28 |
+
const blob = await res.blob();
|
29 |
+
const videoId = Date.now() % 200;
|
30 |
+
const fileName = `vid-show1-${{videoId}}.mp4`;
|
31 |
+
const videoBlob = new File([blob], fileName, { type: 'video/mp4' });
|
32 |
+
console.log(videoBlob);
|
33 |
+
return videoBlob;
|
34 |
+
}
|
35 |
+
|
36 |
+
const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
|
37 |
+
const captionTxt = gradioEl.querySelector('#prompt-in textarea').value;
|
38 |
+
const outputVideo = gradioEl.querySelector('#video-output video');
|
39 |
+
|
40 |
+
|
41 |
+
const shareBtnEl = gradioEl.querySelector('#share-btn');
|
42 |
+
const shareIconEl = gradioEl.querySelector('#share-btn-share-icon');
|
43 |
+
const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon');
|
44 |
+
if(!outputVideo){
|
45 |
+
return;
|
46 |
+
};
|
47 |
+
shareBtnEl.style.pointerEvents = 'none';
|
48 |
+
shareIconEl.style.display = 'none';
|
49 |
+
loadingIconEl.style.removeProperty('display');
|
50 |
+
|
51 |
+
|
52 |
+
const videoOutFile = await getVideoBlobFile(outputVideo);
|
53 |
+
const dataOutputVid = await uploadFile(videoOutFile);
|
54 |
+
|
55 |
+
const descriptionMd = `
|
56 |
+
#### Prompt:
|
57 |
+
${captionTxt}
|
58 |
+
|
59 |
+
#### Show-1 video result:
|
60 |
+
${dataOutputVid}
|
61 |
+
|
62 |
+
`;
|
63 |
+
const params = new URLSearchParams({
|
64 |
+
title: captionTxt,
|
65 |
+
description: descriptionMd,
|
66 |
+
});
|
67 |
+
const paramsStr = params.toString();
|
68 |
+
window.open(`https://huggingface.co/spaces/showlab/Show-1/discussions/new?${paramsStr}`, '_blank');
|
69 |
+
shareBtnEl.style.removeProperty('pointer-events');
|
70 |
+
shareIconEl.style.removeProperty('display');
|
71 |
+
loadingIconEl.style.display = 'none';
|
72 |
+
}"""
|
showone/models/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .unet_3d_condition import UNet3DConditionModel
|
showone/models/transformer_temporal.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
21 |
+
from diffusers.utils import BaseOutput
|
22 |
+
from diffusers.models.attention import BasicTransformerBlock
|
23 |
+
from diffusers.models.modeling_utils import ModelMixin
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class TransformerTemporalModelOutput(BaseOutput):
|
28 |
+
"""
|
29 |
+
The output of [`TransformerTemporalModel`].
|
30 |
+
|
31 |
+
Args:
|
32 |
+
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
|
33 |
+
The hidden states output conditioned on `encoder_hidden_states` input.
|
34 |
+
"""
|
35 |
+
|
36 |
+
sample: torch.FloatTensor
|
37 |
+
|
38 |
+
|
39 |
+
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
40 |
+
"""
|
41 |
+
A Transformer model for video-like data.
|
42 |
+
|
43 |
+
Parameters:
|
44 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
45 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
46 |
+
in_channels (`int`, *optional*):
|
47 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
48 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
49 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
50 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
51 |
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
52 |
+
This is fixed during training since it is used to learn a number of position embeddings.
|
53 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
54 |
+
attention_bias (`bool`, *optional*):
|
55 |
+
Configure if the `TransformerBlock` attention should contain a bias parameter.
|
56 |
+
double_self_attention (`bool`, *optional*):
|
57 |
+
Configure if each `TransformerBlock` should contain two self-attention layers.
|
58 |
+
"""
|
59 |
+
|
60 |
+
@register_to_config
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
num_attention_heads: int = 16,
|
64 |
+
attention_head_dim: int = 88,
|
65 |
+
in_channels: Optional[int] = None,
|
66 |
+
out_channels: Optional[int] = None,
|
67 |
+
num_layers: int = 1,
|
68 |
+
dropout: float = 0.0,
|
69 |
+
norm_num_groups: int = 32,
|
70 |
+
cross_attention_dim: Optional[int] = None,
|
71 |
+
attention_bias: bool = False,
|
72 |
+
sample_size: Optional[int] = None,
|
73 |
+
activation_fn: str = "geglu",
|
74 |
+
norm_elementwise_affine: bool = True,
|
75 |
+
double_self_attention: bool = True,
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
self.num_attention_heads = num_attention_heads
|
79 |
+
self.attention_head_dim = attention_head_dim
|
80 |
+
inner_dim = num_attention_heads * attention_head_dim
|
81 |
+
|
82 |
+
self.in_channels = in_channels
|
83 |
+
|
84 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
85 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
86 |
+
|
87 |
+
# 3. Define transformers blocks
|
88 |
+
self.transformer_blocks = nn.ModuleList(
|
89 |
+
[
|
90 |
+
BasicTransformerBlock(
|
91 |
+
inner_dim,
|
92 |
+
num_attention_heads,
|
93 |
+
attention_head_dim,
|
94 |
+
dropout=dropout,
|
95 |
+
cross_attention_dim=cross_attention_dim,
|
96 |
+
activation_fn=activation_fn,
|
97 |
+
attention_bias=attention_bias,
|
98 |
+
double_self_attention=double_self_attention,
|
99 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
100 |
+
)
|
101 |
+
for d in range(num_layers)
|
102 |
+
]
|
103 |
+
)
|
104 |
+
|
105 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
106 |
+
|
107 |
+
def forward(
|
108 |
+
self,
|
109 |
+
hidden_states,
|
110 |
+
encoder_hidden_states=None,
|
111 |
+
timestep=None,
|
112 |
+
class_labels=None,
|
113 |
+
num_frames=1,
|
114 |
+
cross_attention_kwargs=None,
|
115 |
+
return_dict: bool = True,
|
116 |
+
):
|
117 |
+
"""
|
118 |
+
The [`TransformerTemporal`] forward method.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
122 |
+
Input hidden_states.
|
123 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
124 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
125 |
+
self-attention.
|
126 |
+
timestep ( `torch.long`, *optional*):
|
127 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
128 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
129 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
130 |
+
`AdaLayerZeroNorm`.
|
131 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
132 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
133 |
+
tuple.
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
137 |
+
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
138 |
+
returned, otherwise a `tuple` where the first element is the sample tensor.
|
139 |
+
"""
|
140 |
+
# 1. Input
|
141 |
+
batch_frames, channel, height, width = hidden_states.shape
|
142 |
+
batch_size = batch_frames // num_frames
|
143 |
+
|
144 |
+
residual = hidden_states
|
145 |
+
|
146 |
+
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
|
147 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
148 |
+
|
149 |
+
hidden_states = self.norm(hidden_states)
|
150 |
+
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
|
151 |
+
|
152 |
+
hidden_states = self.proj_in(hidden_states)
|
153 |
+
|
154 |
+
# 2. Blocks
|
155 |
+
for block in self.transformer_blocks:
|
156 |
+
hidden_states = block(
|
157 |
+
hidden_states,
|
158 |
+
encoder_hidden_states=encoder_hidden_states,
|
159 |
+
timestep=timestep,
|
160 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
161 |
+
class_labels=class_labels,
|
162 |
+
)
|
163 |
+
|
164 |
+
# 3. Output
|
165 |
+
hidden_states = self.proj_out(hidden_states)
|
166 |
+
hidden_states = (
|
167 |
+
hidden_states[None, None, :]
|
168 |
+
.reshape(batch_size, height, width, channel, num_frames)
|
169 |
+
.permute(0, 3, 4, 1, 2)
|
170 |
+
.contiguous()
|
171 |
+
)
|
172 |
+
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
|
173 |
+
|
174 |
+
output = hidden_states + residual
|
175 |
+
|
176 |
+
if not return_dict:
|
177 |
+
return (output,)
|
178 |
+
|
179 |
+
return TransformerTemporalModelOutput(sample=output)
|
showone/models/unet_3d_blocks.py
ADDED
@@ -0,0 +1,1619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Any, Dict, Optional, Tuple
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from diffusers.utils import is_torch_version, logging
|
22 |
+
from diffusers.models.attention import AdaGroupNorm
|
23 |
+
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
24 |
+
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
|
25 |
+
from diffusers.models.transformer_2d import Transformer2DModel
|
26 |
+
from diffusers.models.transformer_temporal import TransformerTemporalModel
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
29 |
+
|
30 |
+
|
31 |
+
def get_down_block(
|
32 |
+
down_block_type,
|
33 |
+
num_layers,
|
34 |
+
in_channels,
|
35 |
+
out_channels,
|
36 |
+
temb_channels,
|
37 |
+
add_downsample,
|
38 |
+
resnet_eps,
|
39 |
+
resnet_act_fn,
|
40 |
+
transformer_layers_per_block=1,
|
41 |
+
num_attention_heads=None,
|
42 |
+
resnet_groups=None,
|
43 |
+
cross_attention_dim=None,
|
44 |
+
downsample_padding=None,
|
45 |
+
dual_cross_attention=False,
|
46 |
+
use_linear_projection=False,
|
47 |
+
only_cross_attention=False,
|
48 |
+
upcast_attention=False,
|
49 |
+
resnet_time_scale_shift="default",
|
50 |
+
resnet_skip_time_act=False,
|
51 |
+
resnet_out_scale_factor=1.0,
|
52 |
+
cross_attention_norm=None,
|
53 |
+
attention_head_dim=None,
|
54 |
+
downsample_type=None,
|
55 |
+
):
|
56 |
+
# If attn head dim is not defined, we default it to the number of heads
|
57 |
+
if attention_head_dim is None:
|
58 |
+
logger.warn(
|
59 |
+
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
60 |
+
)
|
61 |
+
attention_head_dim = num_attention_heads
|
62 |
+
|
63 |
+
if down_block_type == "DownBlock3D":
|
64 |
+
return DownBlock3D(
|
65 |
+
num_layers=num_layers,
|
66 |
+
in_channels=in_channels,
|
67 |
+
out_channels=out_channels,
|
68 |
+
temb_channels=temb_channels,
|
69 |
+
add_downsample=add_downsample,
|
70 |
+
resnet_eps=resnet_eps,
|
71 |
+
resnet_act_fn=resnet_act_fn,
|
72 |
+
resnet_groups=resnet_groups,
|
73 |
+
downsample_padding=downsample_padding,
|
74 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
75 |
+
)
|
76 |
+
elif down_block_type == "CrossAttnDownBlock3D":
|
77 |
+
if cross_attention_dim is None:
|
78 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
79 |
+
return CrossAttnDownBlock3D(
|
80 |
+
num_layers=num_layers,
|
81 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
82 |
+
in_channels=in_channels,
|
83 |
+
out_channels=out_channels,
|
84 |
+
temb_channels=temb_channels,
|
85 |
+
add_downsample=add_downsample,
|
86 |
+
resnet_eps=resnet_eps,
|
87 |
+
resnet_act_fn=resnet_act_fn,
|
88 |
+
resnet_groups=resnet_groups,
|
89 |
+
downsample_padding=downsample_padding,
|
90 |
+
cross_attention_dim=cross_attention_dim,
|
91 |
+
num_attention_heads=num_attention_heads,
|
92 |
+
dual_cross_attention=dual_cross_attention,
|
93 |
+
use_linear_projection=use_linear_projection,
|
94 |
+
only_cross_attention=only_cross_attention,
|
95 |
+
upcast_attention=upcast_attention,
|
96 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
97 |
+
)
|
98 |
+
elif down_block_type == "SimpleCrossAttnDownBlock3D":
|
99 |
+
if cross_attention_dim is None:
|
100 |
+
raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock3D")
|
101 |
+
return SimpleCrossAttnDownBlock3D(
|
102 |
+
num_layers=num_layers,
|
103 |
+
in_channels=in_channels,
|
104 |
+
out_channels=out_channels,
|
105 |
+
temb_channels=temb_channels,
|
106 |
+
add_downsample=add_downsample,
|
107 |
+
resnet_eps=resnet_eps,
|
108 |
+
resnet_act_fn=resnet_act_fn,
|
109 |
+
resnet_groups=resnet_groups,
|
110 |
+
cross_attention_dim=cross_attention_dim,
|
111 |
+
attention_head_dim=attention_head_dim,
|
112 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
113 |
+
skip_time_act=resnet_skip_time_act,
|
114 |
+
output_scale_factor=resnet_out_scale_factor,
|
115 |
+
only_cross_attention=only_cross_attention,
|
116 |
+
cross_attention_norm=cross_attention_norm,
|
117 |
+
)
|
118 |
+
elif down_block_type == "ResnetDownsampleBlock3D":
|
119 |
+
return ResnetDownsampleBlock3D(
|
120 |
+
num_layers=num_layers,
|
121 |
+
in_channels=in_channels,
|
122 |
+
out_channels=out_channels,
|
123 |
+
temb_channels=temb_channels,
|
124 |
+
add_downsample=add_downsample,
|
125 |
+
resnet_eps=resnet_eps,
|
126 |
+
resnet_act_fn=resnet_act_fn,
|
127 |
+
resnet_groups=resnet_groups,
|
128 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
129 |
+
skip_time_act=resnet_skip_time_act,
|
130 |
+
output_scale_factor=resnet_out_scale_factor,
|
131 |
+
)
|
132 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
133 |
+
|
134 |
+
|
135 |
+
def get_up_block(
|
136 |
+
up_block_type,
|
137 |
+
num_layers,
|
138 |
+
in_channels,
|
139 |
+
out_channels,
|
140 |
+
prev_output_channel,
|
141 |
+
temb_channels,
|
142 |
+
add_upsample,
|
143 |
+
resnet_eps,
|
144 |
+
resnet_act_fn,
|
145 |
+
transformer_layers_per_block=1,
|
146 |
+
num_attention_heads=None,
|
147 |
+
resnet_groups=None,
|
148 |
+
cross_attention_dim=None,
|
149 |
+
dual_cross_attention=False,
|
150 |
+
use_linear_projection=False,
|
151 |
+
only_cross_attention=False,
|
152 |
+
upcast_attention=False,
|
153 |
+
resnet_time_scale_shift="default",
|
154 |
+
resnet_skip_time_act=False,
|
155 |
+
resnet_out_scale_factor=1.0,
|
156 |
+
cross_attention_norm=None,
|
157 |
+
attention_head_dim=None,
|
158 |
+
upsample_type=None,
|
159 |
+
):
|
160 |
+
# If attn head dim is not defined, we default it to the number of heads
|
161 |
+
if attention_head_dim is None:
|
162 |
+
logger.warn(
|
163 |
+
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
164 |
+
)
|
165 |
+
attention_head_dim = num_attention_heads
|
166 |
+
|
167 |
+
if up_block_type == "UpBlock3D":
|
168 |
+
return UpBlock3D(
|
169 |
+
num_layers=num_layers,
|
170 |
+
in_channels=in_channels,
|
171 |
+
out_channels=out_channels,
|
172 |
+
prev_output_channel=prev_output_channel,
|
173 |
+
temb_channels=temb_channels,
|
174 |
+
add_upsample=add_upsample,
|
175 |
+
resnet_eps=resnet_eps,
|
176 |
+
resnet_act_fn=resnet_act_fn,
|
177 |
+
resnet_groups=resnet_groups,
|
178 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
179 |
+
)
|
180 |
+
elif up_block_type == "CrossAttnUpBlock3D":
|
181 |
+
if cross_attention_dim is None:
|
182 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
183 |
+
return CrossAttnUpBlock3D(
|
184 |
+
num_layers=num_layers,
|
185 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
186 |
+
in_channels=in_channels,
|
187 |
+
out_channels=out_channels,
|
188 |
+
prev_output_channel=prev_output_channel,
|
189 |
+
temb_channels=temb_channels,
|
190 |
+
add_upsample=add_upsample,
|
191 |
+
resnet_eps=resnet_eps,
|
192 |
+
resnet_act_fn=resnet_act_fn,
|
193 |
+
resnet_groups=resnet_groups,
|
194 |
+
cross_attention_dim=cross_attention_dim,
|
195 |
+
num_attention_heads=num_attention_heads,
|
196 |
+
dual_cross_attention=dual_cross_attention,
|
197 |
+
use_linear_projection=use_linear_projection,
|
198 |
+
only_cross_attention=only_cross_attention,
|
199 |
+
upcast_attention=upcast_attention,
|
200 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
201 |
+
)
|
202 |
+
elif up_block_type == "SimpleCrossAttnUpBlock3D":
|
203 |
+
if cross_attention_dim is None:
|
204 |
+
raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock3D")
|
205 |
+
return SimpleCrossAttnUpBlock3D(
|
206 |
+
num_layers=num_layers,
|
207 |
+
in_channels=in_channels,
|
208 |
+
out_channels=out_channels,
|
209 |
+
prev_output_channel=prev_output_channel,
|
210 |
+
temb_channels=temb_channels,
|
211 |
+
add_upsample=add_upsample,
|
212 |
+
resnet_eps=resnet_eps,
|
213 |
+
resnet_act_fn=resnet_act_fn,
|
214 |
+
resnet_groups=resnet_groups,
|
215 |
+
cross_attention_dim=cross_attention_dim,
|
216 |
+
attention_head_dim=attention_head_dim,
|
217 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
218 |
+
skip_time_act=resnet_skip_time_act,
|
219 |
+
output_scale_factor=resnet_out_scale_factor,
|
220 |
+
only_cross_attention=only_cross_attention,
|
221 |
+
cross_attention_norm=cross_attention_norm,
|
222 |
+
)
|
223 |
+
elif up_block_type == "ResnetUpsampleBlock3D":
|
224 |
+
return ResnetUpsampleBlock3D(
|
225 |
+
num_layers=num_layers,
|
226 |
+
in_channels=in_channels,
|
227 |
+
out_channels=out_channels,
|
228 |
+
prev_output_channel=prev_output_channel,
|
229 |
+
temb_channels=temb_channels,
|
230 |
+
add_upsample=add_upsample,
|
231 |
+
resnet_eps=resnet_eps,
|
232 |
+
resnet_act_fn=resnet_act_fn,
|
233 |
+
resnet_groups=resnet_groups,
|
234 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
235 |
+
skip_time_act=resnet_skip_time_act,
|
236 |
+
output_scale_factor=resnet_out_scale_factor,
|
237 |
+
)
|
238 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
239 |
+
|
240 |
+
|
241 |
+
class UNetMidBlock3DCrossAttn(nn.Module):
|
242 |
+
def __init__(
|
243 |
+
self,
|
244 |
+
in_channels: int,
|
245 |
+
temb_channels: int,
|
246 |
+
dropout: float = 0.0,
|
247 |
+
num_layers: int = 1,
|
248 |
+
transformer_layers_per_block: int = 1,
|
249 |
+
resnet_eps: float = 1e-6,
|
250 |
+
resnet_time_scale_shift: str = "default",
|
251 |
+
resnet_act_fn: str = "swish",
|
252 |
+
resnet_groups: int = 32,
|
253 |
+
resnet_pre_norm: bool = True,
|
254 |
+
num_attention_heads=1,
|
255 |
+
output_scale_factor=1.0,
|
256 |
+
cross_attention_dim=1280,
|
257 |
+
dual_cross_attention=False,
|
258 |
+
use_linear_projection=False,
|
259 |
+
upcast_attention=False,
|
260 |
+
):
|
261 |
+
super().__init__()
|
262 |
+
|
263 |
+
self.has_cross_attention = True
|
264 |
+
self.num_attention_heads = num_attention_heads
|
265 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
266 |
+
|
267 |
+
# there is always at least one resnet
|
268 |
+
resnets = [
|
269 |
+
ResnetBlock2D(
|
270 |
+
in_channels=in_channels,
|
271 |
+
out_channels=in_channels,
|
272 |
+
temb_channels=temb_channels,
|
273 |
+
eps=resnet_eps,
|
274 |
+
groups=resnet_groups,
|
275 |
+
dropout=dropout,
|
276 |
+
time_embedding_norm=resnet_time_scale_shift,
|
277 |
+
non_linearity=resnet_act_fn,
|
278 |
+
output_scale_factor=output_scale_factor,
|
279 |
+
pre_norm=resnet_pre_norm,
|
280 |
+
)
|
281 |
+
]
|
282 |
+
temp_convs = [
|
283 |
+
TemporalConvLayer(
|
284 |
+
in_channels,
|
285 |
+
in_channels,
|
286 |
+
dropout=0.1,
|
287 |
+
)
|
288 |
+
]
|
289 |
+
attentions = []
|
290 |
+
temp_attentions = []
|
291 |
+
|
292 |
+
for _ in range(num_layers):
|
293 |
+
attentions.append(
|
294 |
+
Transformer2DModel(
|
295 |
+
num_attention_heads,
|
296 |
+
in_channels // num_attention_heads,
|
297 |
+
in_channels=in_channels,
|
298 |
+
num_layers=transformer_layers_per_block,
|
299 |
+
cross_attention_dim=cross_attention_dim,
|
300 |
+
norm_num_groups=resnet_groups,
|
301 |
+
use_linear_projection=use_linear_projection,
|
302 |
+
upcast_attention=upcast_attention,
|
303 |
+
)
|
304 |
+
)
|
305 |
+
temp_attentions.append(
|
306 |
+
TransformerTemporalModel(
|
307 |
+
num_attention_heads,
|
308 |
+
in_channels // num_attention_heads,
|
309 |
+
in_channels=in_channels,
|
310 |
+
num_layers=1, #todo: transformer_layers_per_block?
|
311 |
+
cross_attention_dim=cross_attention_dim,
|
312 |
+
norm_num_groups=resnet_groups,
|
313 |
+
)
|
314 |
+
)
|
315 |
+
resnets.append(
|
316 |
+
ResnetBlock2D(
|
317 |
+
in_channels=in_channels,
|
318 |
+
out_channels=in_channels,
|
319 |
+
temb_channels=temb_channels,
|
320 |
+
eps=resnet_eps,
|
321 |
+
groups=resnet_groups,
|
322 |
+
dropout=dropout,
|
323 |
+
time_embedding_norm=resnet_time_scale_shift,
|
324 |
+
non_linearity=resnet_act_fn,
|
325 |
+
output_scale_factor=output_scale_factor,
|
326 |
+
pre_norm=resnet_pre_norm,
|
327 |
+
)
|
328 |
+
)
|
329 |
+
temp_convs.append(
|
330 |
+
TemporalConvLayer(
|
331 |
+
in_channels,
|
332 |
+
in_channels,
|
333 |
+
dropout=0.1,
|
334 |
+
)
|
335 |
+
)
|
336 |
+
|
337 |
+
self.resnets = nn.ModuleList(resnets)
|
338 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
339 |
+
self.attentions = nn.ModuleList(attentions)
|
340 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
341 |
+
|
342 |
+
def forward(
|
343 |
+
self,
|
344 |
+
hidden_states: torch.FloatTensor,
|
345 |
+
temb: Optional[torch.FloatTensor] = None,
|
346 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
347 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
348 |
+
num_frames: int = 1,
|
349 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
350 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
351 |
+
) -> torch.FloatTensor:
|
352 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
353 |
+
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
|
354 |
+
for attn, temp_attn, resnet, temp_conv in zip(
|
355 |
+
self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
|
356 |
+
):
|
357 |
+
hidden_states = attn(
|
358 |
+
hidden_states,
|
359 |
+
encoder_hidden_states=encoder_hidden_states,
|
360 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
361 |
+
attention_mask=attention_mask,
|
362 |
+
encoder_attention_mask=encoder_attention_mask,
|
363 |
+
return_dict=False,
|
364 |
+
)[0]
|
365 |
+
hidden_states = temp_attn(
|
366 |
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
367 |
+
).sample
|
368 |
+
hidden_states = resnet(hidden_states, temb)
|
369 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
370 |
+
|
371 |
+
return hidden_states
|
372 |
+
|
373 |
+
|
374 |
+
class UNetMidBlock3DSimpleCrossAttn(nn.Module):
|
375 |
+
def __init__(
|
376 |
+
self,
|
377 |
+
in_channels: int,
|
378 |
+
temb_channels: int,
|
379 |
+
dropout: float = 0.0,
|
380 |
+
num_layers: int = 1,
|
381 |
+
resnet_eps: float = 1e-6,
|
382 |
+
resnet_time_scale_shift: str = "default",
|
383 |
+
resnet_act_fn: str = "swish",
|
384 |
+
resnet_groups: int = 32,
|
385 |
+
resnet_pre_norm: bool = True,
|
386 |
+
attention_head_dim=1,
|
387 |
+
output_scale_factor=1.0,
|
388 |
+
cross_attention_dim=1280,
|
389 |
+
skip_time_act=False,
|
390 |
+
only_cross_attention=False,
|
391 |
+
cross_attention_norm=None,
|
392 |
+
):
|
393 |
+
super().__init__()
|
394 |
+
|
395 |
+
self.has_cross_attention = True
|
396 |
+
|
397 |
+
self.attention_head_dim = attention_head_dim
|
398 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
399 |
+
|
400 |
+
self.num_heads = in_channels // self.attention_head_dim
|
401 |
+
|
402 |
+
# there is always at least one resnet
|
403 |
+
resnets = [
|
404 |
+
ResnetBlock2D(
|
405 |
+
in_channels=in_channels,
|
406 |
+
out_channels=in_channels,
|
407 |
+
temb_channels=temb_channels,
|
408 |
+
eps=resnet_eps,
|
409 |
+
groups=resnet_groups,
|
410 |
+
dropout=dropout,
|
411 |
+
time_embedding_norm=resnet_time_scale_shift,
|
412 |
+
non_linearity=resnet_act_fn,
|
413 |
+
output_scale_factor=output_scale_factor,
|
414 |
+
pre_norm=resnet_pre_norm,
|
415 |
+
skip_time_act=skip_time_act,
|
416 |
+
)
|
417 |
+
]
|
418 |
+
temp_convs = [
|
419 |
+
TemporalConvLayer(
|
420 |
+
in_channels,
|
421 |
+
in_channels,
|
422 |
+
dropout=0.1,
|
423 |
+
)
|
424 |
+
]
|
425 |
+
attentions = []
|
426 |
+
temp_attentions = []
|
427 |
+
|
428 |
+
for _ in range(num_layers):
|
429 |
+
processor = (
|
430 |
+
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
|
431 |
+
)
|
432 |
+
|
433 |
+
attentions.append(
|
434 |
+
Attention(
|
435 |
+
query_dim=in_channels,
|
436 |
+
cross_attention_dim=in_channels,
|
437 |
+
heads=self.num_heads,
|
438 |
+
dim_head=self.attention_head_dim,
|
439 |
+
added_kv_proj_dim=cross_attention_dim,
|
440 |
+
norm_num_groups=resnet_groups,
|
441 |
+
bias=True,
|
442 |
+
upcast_softmax=True,
|
443 |
+
only_cross_attention=only_cross_attention,
|
444 |
+
cross_attention_norm=cross_attention_norm,
|
445 |
+
processor=processor,
|
446 |
+
)
|
447 |
+
)
|
448 |
+
temp_attentions.append(
|
449 |
+
TransformerTemporalModel(
|
450 |
+
self.attention_head_dim,
|
451 |
+
in_channels // self.attention_head_dim,
|
452 |
+
in_channels=in_channels,
|
453 |
+
num_layers=1,
|
454 |
+
cross_attention_dim=cross_attention_dim,
|
455 |
+
norm_num_groups=resnet_groups,
|
456 |
+
)
|
457 |
+
)
|
458 |
+
resnets.append(
|
459 |
+
ResnetBlock2D(
|
460 |
+
in_channels=in_channels,
|
461 |
+
out_channels=in_channels,
|
462 |
+
temb_channels=temb_channels,
|
463 |
+
eps=resnet_eps,
|
464 |
+
groups=resnet_groups,
|
465 |
+
dropout=dropout,
|
466 |
+
time_embedding_norm=resnet_time_scale_shift,
|
467 |
+
non_linearity=resnet_act_fn,
|
468 |
+
output_scale_factor=output_scale_factor,
|
469 |
+
pre_norm=resnet_pre_norm,
|
470 |
+
skip_time_act=skip_time_act,
|
471 |
+
)
|
472 |
+
)
|
473 |
+
temp_convs.append(
|
474 |
+
TemporalConvLayer(
|
475 |
+
in_channels,
|
476 |
+
in_channels,
|
477 |
+
dropout=0.1,
|
478 |
+
)
|
479 |
+
)
|
480 |
+
|
481 |
+
self.resnets = nn.ModuleList(resnets)
|
482 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
483 |
+
self.attentions = nn.ModuleList(attentions)
|
484 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
485 |
+
|
486 |
+
def forward(
|
487 |
+
self,
|
488 |
+
hidden_states: torch.FloatTensor,
|
489 |
+
temb: Optional[torch.FloatTensor] = None,
|
490 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
491 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
492 |
+
num_frames: int = 1,
|
493 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
494 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
495 |
+
):
|
496 |
+
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
497 |
+
|
498 |
+
if attention_mask is None:
|
499 |
+
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
500 |
+
mask = None if encoder_hidden_states is None else encoder_attention_mask
|
501 |
+
else:
|
502 |
+
# when attention_mask is defined: we don't even check for encoder_attention_mask.
|
503 |
+
# this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
|
504 |
+
# TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
|
505 |
+
# then we can simplify this whole if/else block to:
|
506 |
+
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
|
507 |
+
mask = attention_mask
|
508 |
+
|
509 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
510 |
+
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
|
511 |
+
for attn, temp_attn, resnet, temp_conv in zip(
|
512 |
+
self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
|
513 |
+
):
|
514 |
+
hidden_states = attn(
|
515 |
+
hidden_states,
|
516 |
+
encoder_hidden_states=encoder_hidden_states,
|
517 |
+
attention_mask=mask,
|
518 |
+
**cross_attention_kwargs,
|
519 |
+
)
|
520 |
+
hidden_states = temp_attn(
|
521 |
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
522 |
+
).sample
|
523 |
+
hidden_states = resnet(hidden_states, temb)
|
524 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
525 |
+
|
526 |
+
return hidden_states
|
527 |
+
|
528 |
+
|
529 |
+
class CrossAttnDownBlock3D(nn.Module):
|
530 |
+
def __init__(
|
531 |
+
self,
|
532 |
+
in_channels: int,
|
533 |
+
out_channels: int,
|
534 |
+
temb_channels: int,
|
535 |
+
dropout: float = 0.0,
|
536 |
+
num_layers: int = 1,
|
537 |
+
transformer_layers_per_block: int = 1,
|
538 |
+
resnet_eps: float = 1e-6,
|
539 |
+
resnet_time_scale_shift: str = "default",
|
540 |
+
resnet_act_fn: str = "swish",
|
541 |
+
resnet_groups: int = 32,
|
542 |
+
resnet_pre_norm: bool = True,
|
543 |
+
num_attention_heads=1,
|
544 |
+
cross_attention_dim=1280,
|
545 |
+
output_scale_factor=1.0,
|
546 |
+
downsample_padding=1,
|
547 |
+
add_downsample=True,
|
548 |
+
dual_cross_attention=False,
|
549 |
+
use_linear_projection=False,
|
550 |
+
only_cross_attention=False,
|
551 |
+
upcast_attention=False,
|
552 |
+
):
|
553 |
+
super().__init__()
|
554 |
+
resnets = []
|
555 |
+
attentions = []
|
556 |
+
temp_attentions = []
|
557 |
+
temp_convs = []
|
558 |
+
|
559 |
+
self.has_cross_attention = True
|
560 |
+
self.num_attention_heads = num_attention_heads
|
561 |
+
|
562 |
+
for i in range(num_layers):
|
563 |
+
in_channels = in_channels if i == 0 else out_channels
|
564 |
+
resnets.append(
|
565 |
+
ResnetBlock2D(
|
566 |
+
in_channels=in_channels,
|
567 |
+
out_channels=out_channels,
|
568 |
+
temb_channels=temb_channels,
|
569 |
+
eps=resnet_eps,
|
570 |
+
groups=resnet_groups,
|
571 |
+
dropout=dropout,
|
572 |
+
time_embedding_norm=resnet_time_scale_shift,
|
573 |
+
non_linearity=resnet_act_fn,
|
574 |
+
output_scale_factor=output_scale_factor,
|
575 |
+
pre_norm=resnet_pre_norm,
|
576 |
+
)
|
577 |
+
)
|
578 |
+
temp_convs.append(
|
579 |
+
TemporalConvLayer(
|
580 |
+
out_channels,
|
581 |
+
out_channels,
|
582 |
+
dropout=0.1,
|
583 |
+
)
|
584 |
+
)
|
585 |
+
attentions.append(
|
586 |
+
Transformer2DModel(
|
587 |
+
num_attention_heads,
|
588 |
+
out_channels // num_attention_heads,
|
589 |
+
in_channels=out_channels,
|
590 |
+
num_layers=transformer_layers_per_block,
|
591 |
+
cross_attention_dim=cross_attention_dim,
|
592 |
+
norm_num_groups=resnet_groups,
|
593 |
+
use_linear_projection=use_linear_projection,
|
594 |
+
only_cross_attention=only_cross_attention,
|
595 |
+
upcast_attention=upcast_attention,
|
596 |
+
)
|
597 |
+
)
|
598 |
+
temp_attentions.append(
|
599 |
+
TransformerTemporalModel(
|
600 |
+
num_attention_heads,
|
601 |
+
out_channels // num_attention_heads,
|
602 |
+
in_channels=out_channels,
|
603 |
+
num_layers=1,
|
604 |
+
cross_attention_dim=cross_attention_dim,
|
605 |
+
norm_num_groups=resnet_groups,
|
606 |
+
)
|
607 |
+
)
|
608 |
+
self.resnets = nn.ModuleList(resnets)
|
609 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
610 |
+
self.attentions = nn.ModuleList(attentions)
|
611 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
612 |
+
|
613 |
+
if add_downsample:
|
614 |
+
self.downsamplers = nn.ModuleList(
|
615 |
+
[
|
616 |
+
Downsample2D(
|
617 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
618 |
+
)
|
619 |
+
]
|
620 |
+
)
|
621 |
+
else:
|
622 |
+
self.downsamplers = None
|
623 |
+
|
624 |
+
self.gradient_checkpointing = False
|
625 |
+
|
626 |
+
def forward(
|
627 |
+
self,
|
628 |
+
hidden_states: torch.FloatTensor,
|
629 |
+
temb: Optional[torch.FloatTensor] = None,
|
630 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
631 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
632 |
+
num_frames: int = 1,
|
633 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
634 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
635 |
+
):
|
636 |
+
output_states = ()
|
637 |
+
|
638 |
+
for resnet, temp_conv, attn, temp_attn in zip(
|
639 |
+
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
|
640 |
+
):
|
641 |
+
if self.training and self.gradient_checkpointing:
|
642 |
+
|
643 |
+
def create_custom_forward(module, return_dict=None):
|
644 |
+
def custom_forward(*inputs):
|
645 |
+
if return_dict is not None:
|
646 |
+
return module(*inputs, return_dict=return_dict)
|
647 |
+
else:
|
648 |
+
return module(*inputs)
|
649 |
+
|
650 |
+
return custom_forward
|
651 |
+
|
652 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
653 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs,)
|
654 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames, **ckpt_kwargs,)
|
655 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
656 |
+
create_custom_forward(attn, return_dict=False),
|
657 |
+
hidden_states,
|
658 |
+
encoder_hidden_states,
|
659 |
+
None, # timestep
|
660 |
+
None, # class_labels
|
661 |
+
cross_attention_kwargs,
|
662 |
+
attention_mask,
|
663 |
+
encoder_attention_mask,
|
664 |
+
**ckpt_kwargs,
|
665 |
+
)[0]
|
666 |
+
hidden_states = temp_attn(
|
667 |
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, **ckpt_kwargs,
|
668 |
+
).sample
|
669 |
+
else:
|
670 |
+
hidden_states = resnet(hidden_states, temb)
|
671 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
672 |
+
hidden_states = attn(
|
673 |
+
hidden_states,
|
674 |
+
encoder_hidden_states=encoder_hidden_states,
|
675 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
676 |
+
attention_mask=attention_mask,
|
677 |
+
encoder_attention_mask=encoder_attention_mask,
|
678 |
+
return_dict=False,
|
679 |
+
)[0]
|
680 |
+
hidden_states = temp_attn(
|
681 |
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
682 |
+
).sample
|
683 |
+
|
684 |
+
output_states = output_states + (hidden_states,)
|
685 |
+
|
686 |
+
if self.downsamplers is not None:
|
687 |
+
for downsampler in self.downsamplers:
|
688 |
+
hidden_states = downsampler(hidden_states)
|
689 |
+
|
690 |
+
output_states = output_states + (hidden_states,)
|
691 |
+
|
692 |
+
return hidden_states, output_states
|
693 |
+
|
694 |
+
|
695 |
+
class DownBlock3D(nn.Module):
|
696 |
+
def __init__(
|
697 |
+
self,
|
698 |
+
in_channels: int,
|
699 |
+
out_channels: int,
|
700 |
+
temb_channels: int,
|
701 |
+
dropout: float = 0.0,
|
702 |
+
num_layers: int = 1,
|
703 |
+
resnet_eps: float = 1e-6,
|
704 |
+
resnet_time_scale_shift: str = "default",
|
705 |
+
resnet_act_fn: str = "swish",
|
706 |
+
resnet_groups: int = 32,
|
707 |
+
resnet_pre_norm: bool = True,
|
708 |
+
output_scale_factor=1.0,
|
709 |
+
add_downsample=True,
|
710 |
+
downsample_padding=1,
|
711 |
+
):
|
712 |
+
super().__init__()
|
713 |
+
resnets = []
|
714 |
+
temp_convs = []
|
715 |
+
|
716 |
+
for i in range(num_layers):
|
717 |
+
in_channels = in_channels if i == 0 else out_channels
|
718 |
+
resnets.append(
|
719 |
+
ResnetBlock2D(
|
720 |
+
in_channels=in_channels,
|
721 |
+
out_channels=out_channels,
|
722 |
+
temb_channels=temb_channels,
|
723 |
+
eps=resnet_eps,
|
724 |
+
groups=resnet_groups,
|
725 |
+
dropout=dropout,
|
726 |
+
time_embedding_norm=resnet_time_scale_shift,
|
727 |
+
non_linearity=resnet_act_fn,
|
728 |
+
output_scale_factor=output_scale_factor,
|
729 |
+
pre_norm=resnet_pre_norm,
|
730 |
+
)
|
731 |
+
)
|
732 |
+
temp_convs.append(
|
733 |
+
TemporalConvLayer(
|
734 |
+
out_channels,
|
735 |
+
out_channels,
|
736 |
+
dropout=0.1,
|
737 |
+
)
|
738 |
+
)
|
739 |
+
|
740 |
+
self.resnets = nn.ModuleList(resnets)
|
741 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
742 |
+
|
743 |
+
if add_downsample:
|
744 |
+
self.downsamplers = nn.ModuleList(
|
745 |
+
[
|
746 |
+
Downsample2D(
|
747 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
748 |
+
)
|
749 |
+
]
|
750 |
+
)
|
751 |
+
else:
|
752 |
+
self.downsamplers = None
|
753 |
+
|
754 |
+
self.gradient_checkpointing = False
|
755 |
+
|
756 |
+
def forward(self, hidden_states, temb=None, num_frames=1):
|
757 |
+
output_states = ()
|
758 |
+
|
759 |
+
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
760 |
+
if self.training and self.gradient_checkpointing:
|
761 |
+
|
762 |
+
def create_custom_forward(module):
|
763 |
+
def custom_forward(*inputs):
|
764 |
+
return module(*inputs)
|
765 |
+
|
766 |
+
return custom_forward
|
767 |
+
|
768 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False)
|
769 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames, use_reentrant=False)
|
770 |
+
else:
|
771 |
+
hidden_states = resnet(hidden_states, temb)
|
772 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
773 |
+
|
774 |
+
output_states = output_states + (hidden_states,)
|
775 |
+
|
776 |
+
if self.downsamplers is not None:
|
777 |
+
for downsampler in self.downsamplers:
|
778 |
+
hidden_states = downsampler(hidden_states)
|
779 |
+
|
780 |
+
output_states = output_states + (hidden_states,)
|
781 |
+
|
782 |
+
return hidden_states, output_states
|
783 |
+
|
784 |
+
|
785 |
+
class ResnetDownsampleBlock3D(nn.Module):
|
786 |
+
def __init__(
|
787 |
+
self,
|
788 |
+
in_channels: int,
|
789 |
+
out_channels: int,
|
790 |
+
temb_channels: int,
|
791 |
+
dropout: float = 0.0,
|
792 |
+
num_layers: int = 1,
|
793 |
+
resnet_eps: float = 1e-6,
|
794 |
+
resnet_time_scale_shift: str = "default",
|
795 |
+
resnet_act_fn: str = "swish",
|
796 |
+
resnet_groups: int = 32,
|
797 |
+
resnet_pre_norm: bool = True,
|
798 |
+
output_scale_factor=1.0,
|
799 |
+
add_downsample=True,
|
800 |
+
skip_time_act=False,
|
801 |
+
):
|
802 |
+
super().__init__()
|
803 |
+
resnets = []
|
804 |
+
temp_convs = []
|
805 |
+
|
806 |
+
for i in range(num_layers):
|
807 |
+
in_channels = in_channels if i == 0 else out_channels
|
808 |
+
resnets.append(
|
809 |
+
ResnetBlock2D(
|
810 |
+
in_channels=in_channels,
|
811 |
+
out_channels=out_channels,
|
812 |
+
temb_channels=temb_channels,
|
813 |
+
eps=resnet_eps,
|
814 |
+
groups=resnet_groups,
|
815 |
+
dropout=dropout,
|
816 |
+
time_embedding_norm=resnet_time_scale_shift,
|
817 |
+
non_linearity=resnet_act_fn,
|
818 |
+
output_scale_factor=output_scale_factor,
|
819 |
+
pre_norm=resnet_pre_norm,
|
820 |
+
skip_time_act=skip_time_act,
|
821 |
+
)
|
822 |
+
)
|
823 |
+
temp_convs.append(
|
824 |
+
TemporalConvLayer(
|
825 |
+
out_channels,
|
826 |
+
out_channels,
|
827 |
+
dropout=0.1,
|
828 |
+
)
|
829 |
+
)
|
830 |
+
|
831 |
+
self.resnets = nn.ModuleList(resnets)
|
832 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
833 |
+
|
834 |
+
if add_downsample:
|
835 |
+
self.downsamplers = nn.ModuleList(
|
836 |
+
[
|
837 |
+
ResnetBlock2D(
|
838 |
+
in_channels=out_channels,
|
839 |
+
out_channels=out_channels,
|
840 |
+
temb_channels=temb_channels,
|
841 |
+
eps=resnet_eps,
|
842 |
+
groups=resnet_groups,
|
843 |
+
dropout=dropout,
|
844 |
+
time_embedding_norm=resnet_time_scale_shift,
|
845 |
+
non_linearity=resnet_act_fn,
|
846 |
+
output_scale_factor=output_scale_factor,
|
847 |
+
pre_norm=resnet_pre_norm,
|
848 |
+
skip_time_act=skip_time_act,
|
849 |
+
down=True,
|
850 |
+
)
|
851 |
+
]
|
852 |
+
)
|
853 |
+
else:
|
854 |
+
self.downsamplers = None
|
855 |
+
|
856 |
+
self.gradient_checkpointing = False
|
857 |
+
|
858 |
+
def forward(self, hidden_states, temb=None, num_frames=1):
|
859 |
+
output_states = ()
|
860 |
+
|
861 |
+
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
862 |
+
if self.training and self.gradient_checkpointing:
|
863 |
+
|
864 |
+
def create_custom_forward(module):
|
865 |
+
def custom_forward(*inputs):
|
866 |
+
return module(*inputs)
|
867 |
+
|
868 |
+
return custom_forward
|
869 |
+
|
870 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False)
|
871 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames, use_reentrant=False)
|
872 |
+
else:
|
873 |
+
hidden_states = resnet(hidden_states, temb)
|
874 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
875 |
+
|
876 |
+
output_states = output_states + (hidden_states,)
|
877 |
+
|
878 |
+
if self.downsamplers is not None:
|
879 |
+
for downsampler in self.downsamplers:
|
880 |
+
hidden_states = downsampler(hidden_states, temb)
|
881 |
+
|
882 |
+
output_states = output_states + (hidden_states,)
|
883 |
+
|
884 |
+
return hidden_states, output_states
|
885 |
+
|
886 |
+
|
887 |
+
class SimpleCrossAttnDownBlock3D(nn.Module):
|
888 |
+
def __init__(
|
889 |
+
self,
|
890 |
+
in_channels: int,
|
891 |
+
out_channels: int,
|
892 |
+
temb_channels: int,
|
893 |
+
dropout: float = 0.0,
|
894 |
+
num_layers: int = 1,
|
895 |
+
resnet_eps: float = 1e-6,
|
896 |
+
resnet_time_scale_shift: str = "default",
|
897 |
+
resnet_act_fn: str = "swish",
|
898 |
+
resnet_groups: int = 32,
|
899 |
+
resnet_pre_norm: bool = True,
|
900 |
+
attention_head_dim=1,
|
901 |
+
cross_attention_dim=1280,
|
902 |
+
output_scale_factor=1.0,
|
903 |
+
add_downsample=True,
|
904 |
+
skip_time_act=False,
|
905 |
+
only_cross_attention=False,
|
906 |
+
cross_attention_norm=None,
|
907 |
+
):
|
908 |
+
super().__init__()
|
909 |
+
|
910 |
+
self.has_cross_attention = True
|
911 |
+
|
912 |
+
resnets = []
|
913 |
+
attentions = []
|
914 |
+
temp_attentions = []
|
915 |
+
temp_convs = []
|
916 |
+
|
917 |
+
self.attention_head_dim = attention_head_dim
|
918 |
+
self.num_heads = out_channels // self.attention_head_dim
|
919 |
+
|
920 |
+
for i in range(num_layers):
|
921 |
+
in_channels = in_channels if i == 0 else out_channels
|
922 |
+
resnets.append(
|
923 |
+
ResnetBlock2D(
|
924 |
+
in_channels=in_channels,
|
925 |
+
out_channels=out_channels,
|
926 |
+
temb_channels=temb_channels,
|
927 |
+
eps=resnet_eps,
|
928 |
+
groups=resnet_groups,
|
929 |
+
dropout=dropout,
|
930 |
+
time_embedding_norm=resnet_time_scale_shift,
|
931 |
+
non_linearity=resnet_act_fn,
|
932 |
+
output_scale_factor=output_scale_factor,
|
933 |
+
pre_norm=resnet_pre_norm,
|
934 |
+
skip_time_act=skip_time_act,
|
935 |
+
)
|
936 |
+
)
|
937 |
+
temp_convs.append(
|
938 |
+
TemporalConvLayer(
|
939 |
+
out_channels,
|
940 |
+
out_channels,
|
941 |
+
dropout=0.1,
|
942 |
+
)
|
943 |
+
)
|
944 |
+
processor = (
|
945 |
+
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
|
946 |
+
)
|
947 |
+
|
948 |
+
attentions.append(
|
949 |
+
Attention(
|
950 |
+
query_dim=out_channels,
|
951 |
+
cross_attention_dim=out_channels,
|
952 |
+
heads=self.num_heads,
|
953 |
+
dim_head=attention_head_dim,
|
954 |
+
added_kv_proj_dim=cross_attention_dim,
|
955 |
+
norm_num_groups=resnet_groups,
|
956 |
+
bias=True,
|
957 |
+
upcast_softmax=True,
|
958 |
+
only_cross_attention=only_cross_attention,
|
959 |
+
cross_attention_norm=cross_attention_norm,
|
960 |
+
processor=processor,
|
961 |
+
)
|
962 |
+
)
|
963 |
+
temp_attentions.append(
|
964 |
+
TransformerTemporalModel(
|
965 |
+
attention_head_dim,
|
966 |
+
out_channels // attention_head_dim,
|
967 |
+
in_channels=out_channels,
|
968 |
+
num_layers=1,
|
969 |
+
cross_attention_dim=cross_attention_dim,
|
970 |
+
norm_num_groups=resnet_groups,
|
971 |
+
)
|
972 |
+
)
|
973 |
+
self.resnets = nn.ModuleList(resnets)
|
974 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
975 |
+
self.attentions = nn.ModuleList(attentions)
|
976 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
977 |
+
|
978 |
+
if add_downsample:
|
979 |
+
self.downsamplers = nn.ModuleList(
|
980 |
+
[
|
981 |
+
ResnetBlock2D(
|
982 |
+
in_channels=out_channels,
|
983 |
+
out_channels=out_channels,
|
984 |
+
temb_channels=temb_channels,
|
985 |
+
eps=resnet_eps,
|
986 |
+
groups=resnet_groups,
|
987 |
+
dropout=dropout,
|
988 |
+
time_embedding_norm=resnet_time_scale_shift,
|
989 |
+
non_linearity=resnet_act_fn,
|
990 |
+
output_scale_factor=output_scale_factor,
|
991 |
+
pre_norm=resnet_pre_norm,
|
992 |
+
skip_time_act=skip_time_act,
|
993 |
+
down=True,
|
994 |
+
)
|
995 |
+
]
|
996 |
+
)
|
997 |
+
else:
|
998 |
+
self.downsamplers = None
|
999 |
+
|
1000 |
+
self.gradient_checkpointing = False
|
1001 |
+
|
1002 |
+
def forward(
|
1003 |
+
self,
|
1004 |
+
hidden_states: torch.FloatTensor,
|
1005 |
+
temb: Optional[torch.FloatTensor] = None,
|
1006 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1007 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1008 |
+
num_frames: int = 1,
|
1009 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1010 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1011 |
+
):
|
1012 |
+
output_states = ()
|
1013 |
+
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
1014 |
+
|
1015 |
+
if attention_mask is None:
|
1016 |
+
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
1017 |
+
mask = None if encoder_hidden_states is None else encoder_attention_mask
|
1018 |
+
else:
|
1019 |
+
# when attention_mask is defined: we don't even check for encoder_attention_mask.
|
1020 |
+
# this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
|
1021 |
+
# TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
|
1022 |
+
# then we can simplify this whole if/else block to:
|
1023 |
+
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
|
1024 |
+
mask = attention_mask
|
1025 |
+
|
1026 |
+
for resnet, temp_conv, attn, temp_attn in zip(
|
1027 |
+
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
|
1028 |
+
):
|
1029 |
+
if self.training and self.gradient_checkpointing:
|
1030 |
+
|
1031 |
+
def create_custom_forward(module, return_dict=None):
|
1032 |
+
def custom_forward(*inputs):
|
1033 |
+
if return_dict is not None:
|
1034 |
+
return module(*inputs, return_dict=return_dict)
|
1035 |
+
else:
|
1036 |
+
return module(*inputs)
|
1037 |
+
|
1038 |
+
return custom_forward
|
1039 |
+
|
1040 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1041 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames)
|
1042 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1043 |
+
create_custom_forward(attn, return_dict=False),
|
1044 |
+
hidden_states,
|
1045 |
+
encoder_hidden_states,
|
1046 |
+
mask,
|
1047 |
+
cross_attention_kwargs,
|
1048 |
+
)[0]
|
1049 |
+
hidden_states = temp_attn(
|
1050 |
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
1051 |
+
).sample
|
1052 |
+
else:
|
1053 |
+
hidden_states = resnet(hidden_states, temb)
|
1054 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
1055 |
+
hidden_states = attn(
|
1056 |
+
hidden_states,
|
1057 |
+
encoder_hidden_states=encoder_hidden_states,
|
1058 |
+
attention_mask=mask,
|
1059 |
+
**cross_attention_kwargs,
|
1060 |
+
)
|
1061 |
+
hidden_states = temp_attn(
|
1062 |
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
1063 |
+
).sample
|
1064 |
+
|
1065 |
+
output_states = output_states + (hidden_states,)
|
1066 |
+
|
1067 |
+
if self.downsamplers is not None:
|
1068 |
+
for downsampler in self.downsamplers:
|
1069 |
+
hidden_states = downsampler(hidden_states, temb)
|
1070 |
+
|
1071 |
+
output_states = output_states + (hidden_states,)
|
1072 |
+
|
1073 |
+
return hidden_states, output_states
|
1074 |
+
|
1075 |
+
|
1076 |
+
class CrossAttnUpBlock3D(nn.Module):
|
1077 |
+
def __init__(
|
1078 |
+
self,
|
1079 |
+
in_channels: int,
|
1080 |
+
out_channels: int,
|
1081 |
+
prev_output_channel: int,
|
1082 |
+
temb_channels: int,
|
1083 |
+
dropout: float = 0.0,
|
1084 |
+
num_layers: int = 1,
|
1085 |
+
transformer_layers_per_block: int = 1,
|
1086 |
+
resnet_eps: float = 1e-6,
|
1087 |
+
resnet_time_scale_shift: str = "default",
|
1088 |
+
resnet_act_fn: str = "swish",
|
1089 |
+
resnet_groups: int = 32,
|
1090 |
+
resnet_pre_norm: bool = True,
|
1091 |
+
num_attention_heads=1,
|
1092 |
+
cross_attention_dim=1280,
|
1093 |
+
output_scale_factor=1.0,
|
1094 |
+
add_upsample=True,
|
1095 |
+
dual_cross_attention=False,
|
1096 |
+
use_linear_projection=False,
|
1097 |
+
only_cross_attention=False,
|
1098 |
+
upcast_attention=False,
|
1099 |
+
):
|
1100 |
+
super().__init__()
|
1101 |
+
resnets = []
|
1102 |
+
temp_convs = []
|
1103 |
+
attentions = []
|
1104 |
+
temp_attentions = []
|
1105 |
+
|
1106 |
+
self.has_cross_attention = True
|
1107 |
+
self.num_attention_heads = num_attention_heads
|
1108 |
+
|
1109 |
+
for i in range(num_layers):
|
1110 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1111 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1112 |
+
|
1113 |
+
resnets.append(
|
1114 |
+
ResnetBlock2D(
|
1115 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1116 |
+
out_channels=out_channels,
|
1117 |
+
temb_channels=temb_channels,
|
1118 |
+
eps=resnet_eps,
|
1119 |
+
groups=resnet_groups,
|
1120 |
+
dropout=dropout,
|
1121 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1122 |
+
non_linearity=resnet_act_fn,
|
1123 |
+
output_scale_factor=output_scale_factor,
|
1124 |
+
pre_norm=resnet_pre_norm,
|
1125 |
+
)
|
1126 |
+
)
|
1127 |
+
temp_convs.append(
|
1128 |
+
TemporalConvLayer(
|
1129 |
+
out_channels,
|
1130 |
+
out_channels,
|
1131 |
+
dropout=0.1,
|
1132 |
+
)
|
1133 |
+
)
|
1134 |
+
attentions.append(
|
1135 |
+
Transformer2DModel(
|
1136 |
+
num_attention_heads,
|
1137 |
+
out_channels // num_attention_heads,
|
1138 |
+
in_channels=out_channels,
|
1139 |
+
num_layers=transformer_layers_per_block,
|
1140 |
+
cross_attention_dim=cross_attention_dim,
|
1141 |
+
norm_num_groups=resnet_groups,
|
1142 |
+
use_linear_projection=use_linear_projection,
|
1143 |
+
only_cross_attention=only_cross_attention,
|
1144 |
+
upcast_attention=upcast_attention,
|
1145 |
+
)
|
1146 |
+
)
|
1147 |
+
temp_attentions.append(
|
1148 |
+
TransformerTemporalModel(
|
1149 |
+
num_attention_heads,
|
1150 |
+
out_channels // num_attention_heads,
|
1151 |
+
in_channels=out_channels,
|
1152 |
+
num_layers=1,
|
1153 |
+
cross_attention_dim=cross_attention_dim,
|
1154 |
+
norm_num_groups=resnet_groups,
|
1155 |
+
)
|
1156 |
+
)
|
1157 |
+
self.resnets = nn.ModuleList(resnets)
|
1158 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
1159 |
+
self.attentions = nn.ModuleList(attentions)
|
1160 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
1161 |
+
|
1162 |
+
if add_upsample:
|
1163 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1164 |
+
else:
|
1165 |
+
self.upsamplers = None
|
1166 |
+
|
1167 |
+
self.gradient_checkpointing = False
|
1168 |
+
|
1169 |
+
def forward(
|
1170 |
+
self,
|
1171 |
+
hidden_states: torch.FloatTensor,
|
1172 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
1173 |
+
temb: Optional[torch.FloatTensor] = None,
|
1174 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1175 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1176 |
+
upsample_size: Optional[int] = None,
|
1177 |
+
num_frames: int = 1,
|
1178 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1179 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1180 |
+
):
|
1181 |
+
for resnet, temp_conv, attn, temp_attn in zip(
|
1182 |
+
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
|
1183 |
+
):
|
1184 |
+
# pop res hidden states
|
1185 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1186 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1187 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1188 |
+
|
1189 |
+
if self.training and self.gradient_checkpointing:
|
1190 |
+
|
1191 |
+
def create_custom_forward(module, return_dict=None):
|
1192 |
+
def custom_forward(*inputs):
|
1193 |
+
if return_dict is not None:
|
1194 |
+
return module(*inputs, return_dict=return_dict)
|
1195 |
+
else:
|
1196 |
+
return module(*inputs)
|
1197 |
+
|
1198 |
+
return custom_forward
|
1199 |
+
|
1200 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1201 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs,)
|
1202 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames, **ckpt_kwargs,)
|
1203 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1204 |
+
create_custom_forward(attn, return_dict=False),
|
1205 |
+
hidden_states,
|
1206 |
+
encoder_hidden_states,
|
1207 |
+
None, # timestep
|
1208 |
+
None, # class_labels
|
1209 |
+
cross_attention_kwargs,
|
1210 |
+
attention_mask,
|
1211 |
+
encoder_attention_mask,
|
1212 |
+
**ckpt_kwargs,
|
1213 |
+
)[0]
|
1214 |
+
hidden_states = temp_attn(
|
1215 |
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
1216 |
+
).sample
|
1217 |
+
else:
|
1218 |
+
hidden_states = resnet(hidden_states, temb)
|
1219 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
1220 |
+
hidden_states = attn(
|
1221 |
+
hidden_states,
|
1222 |
+
encoder_hidden_states=encoder_hidden_states,
|
1223 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1224 |
+
attention_mask=attention_mask,
|
1225 |
+
encoder_attention_mask=encoder_attention_mask,
|
1226 |
+
return_dict=False,
|
1227 |
+
)[0]
|
1228 |
+
hidden_states = temp_attn(
|
1229 |
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
1230 |
+
).sample
|
1231 |
+
|
1232 |
+
if self.upsamplers is not None:
|
1233 |
+
for upsampler in self.upsamplers:
|
1234 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1235 |
+
|
1236 |
+
return hidden_states
|
1237 |
+
|
1238 |
+
|
1239 |
+
class UpBlock3D(nn.Module):
|
1240 |
+
def __init__(
|
1241 |
+
self,
|
1242 |
+
in_channels: int,
|
1243 |
+
prev_output_channel: int,
|
1244 |
+
out_channels: int,
|
1245 |
+
temb_channels: int,
|
1246 |
+
dropout: float = 0.0,
|
1247 |
+
num_layers: int = 1,
|
1248 |
+
resnet_eps: float = 1e-6,
|
1249 |
+
resnet_time_scale_shift: str = "default",
|
1250 |
+
resnet_act_fn: str = "swish",
|
1251 |
+
resnet_groups: int = 32,
|
1252 |
+
resnet_pre_norm: bool = True,
|
1253 |
+
output_scale_factor=1.0,
|
1254 |
+
add_upsample=True,
|
1255 |
+
):
|
1256 |
+
super().__init__()
|
1257 |
+
resnets = []
|
1258 |
+
temp_convs = []
|
1259 |
+
|
1260 |
+
for i in range(num_layers):
|
1261 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1262 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1263 |
+
|
1264 |
+
resnets.append(
|
1265 |
+
ResnetBlock2D(
|
1266 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1267 |
+
out_channels=out_channels,
|
1268 |
+
temb_channels=temb_channels,
|
1269 |
+
eps=resnet_eps,
|
1270 |
+
groups=resnet_groups,
|
1271 |
+
dropout=dropout,
|
1272 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1273 |
+
non_linearity=resnet_act_fn,
|
1274 |
+
output_scale_factor=output_scale_factor,
|
1275 |
+
pre_norm=resnet_pre_norm,
|
1276 |
+
)
|
1277 |
+
)
|
1278 |
+
temp_convs.append(
|
1279 |
+
TemporalConvLayer(
|
1280 |
+
out_channels,
|
1281 |
+
out_channels,
|
1282 |
+
dropout=0.1,
|
1283 |
+
)
|
1284 |
+
)
|
1285 |
+
|
1286 |
+
self.resnets = nn.ModuleList(resnets)
|
1287 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
1288 |
+
|
1289 |
+
if add_upsample:
|
1290 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1291 |
+
else:
|
1292 |
+
self.upsamplers = None
|
1293 |
+
|
1294 |
+
self.gradient_checkpointing = False
|
1295 |
+
|
1296 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
|
1297 |
+
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
1298 |
+
# pop res hidden states
|
1299 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1300 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1301 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1302 |
+
|
1303 |
+
if self.training and self.gradient_checkpointing:
|
1304 |
+
|
1305 |
+
def create_custom_forward(module):
|
1306 |
+
def custom_forward(*inputs):
|
1307 |
+
return module(*inputs)
|
1308 |
+
|
1309 |
+
return custom_forward
|
1310 |
+
|
1311 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False)
|
1312 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames, use_reentrant=False)
|
1313 |
+
else:
|
1314 |
+
hidden_states = resnet(hidden_states, temb)
|
1315 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
1316 |
+
|
1317 |
+
if self.upsamplers is not None:
|
1318 |
+
for upsampler in self.upsamplers:
|
1319 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1320 |
+
|
1321 |
+
return hidden_states
|
1322 |
+
|
1323 |
+
|
1324 |
+
class ResnetUpsampleBlock3D(nn.Module):
|
1325 |
+
def __init__(
|
1326 |
+
self,
|
1327 |
+
in_channels: int,
|
1328 |
+
prev_output_channel: int,
|
1329 |
+
out_channels: int,
|
1330 |
+
temb_channels: int,
|
1331 |
+
dropout: float = 0.0,
|
1332 |
+
num_layers: int = 1,
|
1333 |
+
resnet_eps: float = 1e-6,
|
1334 |
+
resnet_time_scale_shift: str = "default",
|
1335 |
+
resnet_act_fn: str = "swish",
|
1336 |
+
resnet_groups: int = 32,
|
1337 |
+
resnet_pre_norm: bool = True,
|
1338 |
+
output_scale_factor=1.0,
|
1339 |
+
add_upsample=True,
|
1340 |
+
skip_time_act=False,
|
1341 |
+
):
|
1342 |
+
super().__init__()
|
1343 |
+
resnets = []
|
1344 |
+
temp_convs = []
|
1345 |
+
|
1346 |
+
for i in range(num_layers):
|
1347 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1348 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1349 |
+
|
1350 |
+
resnets.append(
|
1351 |
+
ResnetBlock2D(
|
1352 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1353 |
+
out_channels=out_channels,
|
1354 |
+
temb_channels=temb_channels,
|
1355 |
+
eps=resnet_eps,
|
1356 |
+
groups=resnet_groups,
|
1357 |
+
dropout=dropout,
|
1358 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1359 |
+
non_linearity=resnet_act_fn,
|
1360 |
+
output_scale_factor=output_scale_factor,
|
1361 |
+
pre_norm=resnet_pre_norm,
|
1362 |
+
skip_time_act=skip_time_act,
|
1363 |
+
)
|
1364 |
+
)
|
1365 |
+
temp_convs.append(
|
1366 |
+
TemporalConvLayer(
|
1367 |
+
out_channels,
|
1368 |
+
out_channels,
|
1369 |
+
dropout=0.1,
|
1370 |
+
)
|
1371 |
+
)
|
1372 |
+
|
1373 |
+
self.resnets = nn.ModuleList(resnets)
|
1374 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
1375 |
+
|
1376 |
+
if add_upsample:
|
1377 |
+
self.upsamplers = nn.ModuleList(
|
1378 |
+
[
|
1379 |
+
ResnetBlock2D(
|
1380 |
+
in_channels=out_channels,
|
1381 |
+
out_channels=out_channels,
|
1382 |
+
temb_channels=temb_channels,
|
1383 |
+
eps=resnet_eps,
|
1384 |
+
groups=resnet_groups,
|
1385 |
+
dropout=dropout,
|
1386 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1387 |
+
non_linearity=resnet_act_fn,
|
1388 |
+
output_scale_factor=output_scale_factor,
|
1389 |
+
pre_norm=resnet_pre_norm,
|
1390 |
+
skip_time_act=skip_time_act,
|
1391 |
+
up=True,
|
1392 |
+
)
|
1393 |
+
]
|
1394 |
+
)
|
1395 |
+
else:
|
1396 |
+
self.upsamplers = None
|
1397 |
+
|
1398 |
+
self.gradient_checkpointing = False
|
1399 |
+
|
1400 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
|
1401 |
+
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
1402 |
+
# pop res hidden states
|
1403 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1404 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1405 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1406 |
+
|
1407 |
+
if self.training and self.gradient_checkpointing:
|
1408 |
+
|
1409 |
+
def create_custom_forward(module):
|
1410 |
+
def custom_forward(*inputs):
|
1411 |
+
return module(*inputs)
|
1412 |
+
|
1413 |
+
return custom_forward
|
1414 |
+
|
1415 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False)
|
1416 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames, use_reentrant=False)
|
1417 |
+
else:
|
1418 |
+
hidden_states = resnet(hidden_states, temb)
|
1419 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
1420 |
+
|
1421 |
+
if self.upsamplers is not None:
|
1422 |
+
for upsampler in self.upsamplers:
|
1423 |
+
hidden_states = upsampler(hidden_states, temb)
|
1424 |
+
|
1425 |
+
return hidden_states
|
1426 |
+
|
1427 |
+
|
1428 |
+
class SimpleCrossAttnUpBlock3D(nn.Module):
|
1429 |
+
def __init__(
|
1430 |
+
self,
|
1431 |
+
in_channels: int,
|
1432 |
+
out_channels: int,
|
1433 |
+
prev_output_channel: int,
|
1434 |
+
temb_channels: int,
|
1435 |
+
dropout: float = 0.0,
|
1436 |
+
num_layers: int = 1,
|
1437 |
+
resnet_eps: float = 1e-6,
|
1438 |
+
resnet_time_scale_shift: str = "default",
|
1439 |
+
resnet_act_fn: str = "swish",
|
1440 |
+
resnet_groups: int = 32,
|
1441 |
+
resnet_pre_norm: bool = True,
|
1442 |
+
attention_head_dim=1,
|
1443 |
+
cross_attention_dim=1280,
|
1444 |
+
output_scale_factor=1.0,
|
1445 |
+
add_upsample=True,
|
1446 |
+
skip_time_act=False,
|
1447 |
+
only_cross_attention=False,
|
1448 |
+
cross_attention_norm=None,
|
1449 |
+
):
|
1450 |
+
super().__init__()
|
1451 |
+
resnets = []
|
1452 |
+
temp_convs = []
|
1453 |
+
attentions = []
|
1454 |
+
temp_attentions = []
|
1455 |
+
|
1456 |
+
self.has_cross_attention = True
|
1457 |
+
self.attention_head_dim = attention_head_dim
|
1458 |
+
|
1459 |
+
self.num_heads = out_channels // self.attention_head_dim
|
1460 |
+
|
1461 |
+
for i in range(num_layers):
|
1462 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1463 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1464 |
+
|
1465 |
+
resnets.append(
|
1466 |
+
ResnetBlock2D(
|
1467 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1468 |
+
out_channels=out_channels,
|
1469 |
+
temb_channels=temb_channels,
|
1470 |
+
eps=resnet_eps,
|
1471 |
+
groups=resnet_groups,
|
1472 |
+
dropout=dropout,
|
1473 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1474 |
+
non_linearity=resnet_act_fn,
|
1475 |
+
output_scale_factor=output_scale_factor,
|
1476 |
+
pre_norm=resnet_pre_norm,
|
1477 |
+
skip_time_act=skip_time_act,
|
1478 |
+
)
|
1479 |
+
)
|
1480 |
+
temp_convs.append(
|
1481 |
+
TemporalConvLayer(
|
1482 |
+
out_channels,
|
1483 |
+
out_channels,
|
1484 |
+
dropout=0.1,
|
1485 |
+
)
|
1486 |
+
)
|
1487 |
+
|
1488 |
+
processor = (
|
1489 |
+
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
|
1490 |
+
)
|
1491 |
+
|
1492 |
+
attentions.append(
|
1493 |
+
Attention(
|
1494 |
+
query_dim=out_channels,
|
1495 |
+
cross_attention_dim=out_channels,
|
1496 |
+
heads=self.num_heads,
|
1497 |
+
dim_head=self.attention_head_dim,
|
1498 |
+
added_kv_proj_dim=cross_attention_dim,
|
1499 |
+
norm_num_groups=resnet_groups,
|
1500 |
+
bias=True,
|
1501 |
+
upcast_softmax=True,
|
1502 |
+
only_cross_attention=only_cross_attention,
|
1503 |
+
cross_attention_norm=cross_attention_norm,
|
1504 |
+
processor=processor,
|
1505 |
+
)
|
1506 |
+
)
|
1507 |
+
temp_attentions.append(
|
1508 |
+
TransformerTemporalModel(
|
1509 |
+
attention_head_dim,
|
1510 |
+
out_channels // attention_head_dim,
|
1511 |
+
in_channels=out_channels,
|
1512 |
+
num_layers=1,
|
1513 |
+
cross_attention_dim=cross_attention_dim,
|
1514 |
+
norm_num_groups=resnet_groups,
|
1515 |
+
)
|
1516 |
+
)
|
1517 |
+
self.resnets = nn.ModuleList(resnets)
|
1518 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
1519 |
+
self.attentions = nn.ModuleList(attentions)
|
1520 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
1521 |
+
|
1522 |
+
if add_upsample:
|
1523 |
+
self.upsamplers = nn.ModuleList(
|
1524 |
+
[
|
1525 |
+
ResnetBlock2D(
|
1526 |
+
in_channels=out_channels,
|
1527 |
+
out_channels=out_channels,
|
1528 |
+
temb_channels=temb_channels,
|
1529 |
+
eps=resnet_eps,
|
1530 |
+
groups=resnet_groups,
|
1531 |
+
dropout=dropout,
|
1532 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1533 |
+
non_linearity=resnet_act_fn,
|
1534 |
+
output_scale_factor=output_scale_factor,
|
1535 |
+
pre_norm=resnet_pre_norm,
|
1536 |
+
skip_time_act=skip_time_act,
|
1537 |
+
up=True,
|
1538 |
+
)
|
1539 |
+
]
|
1540 |
+
)
|
1541 |
+
else:
|
1542 |
+
self.upsamplers = None
|
1543 |
+
|
1544 |
+
self.gradient_checkpointing = False
|
1545 |
+
|
1546 |
+
def forward(
|
1547 |
+
self,
|
1548 |
+
hidden_states: torch.FloatTensor,
|
1549 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
1550 |
+
temb: Optional[torch.FloatTensor] = None,
|
1551 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1552 |
+
upsample_size: Optional[int] = None,
|
1553 |
+
num_frames: int = 1,
|
1554 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1555 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1556 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1557 |
+
):
|
1558 |
+
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
1559 |
+
|
1560 |
+
if attention_mask is None:
|
1561 |
+
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
1562 |
+
mask = None if encoder_hidden_states is None else encoder_attention_mask
|
1563 |
+
else:
|
1564 |
+
# when attention_mask is defined: we don't even check for encoder_attention_mask.
|
1565 |
+
# this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
|
1566 |
+
# TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
|
1567 |
+
# then we can simplify this whole if/else block to:
|
1568 |
+
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
|
1569 |
+
mask = attention_mask
|
1570 |
+
|
1571 |
+
for resnet, temp_conv, attn, temp_attn in zip(
|
1572 |
+
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
|
1573 |
+
):
|
1574 |
+
# pop res hidden states
|
1575 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1576 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1577 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1578 |
+
|
1579 |
+
if self.training and self.gradient_checkpointing:
|
1580 |
+
|
1581 |
+
def create_custom_forward(module, return_dict=None):
|
1582 |
+
def custom_forward(*inputs):
|
1583 |
+
if return_dict is not None:
|
1584 |
+
return module(*inputs, return_dict=return_dict)
|
1585 |
+
else:
|
1586 |
+
return module(*inputs)
|
1587 |
+
|
1588 |
+
return custom_forward
|
1589 |
+
|
1590 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1591 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames)
|
1592 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1593 |
+
create_custom_forward(attn, return_dict=False),
|
1594 |
+
hidden_states,
|
1595 |
+
encoder_hidden_states,
|
1596 |
+
mask,
|
1597 |
+
cross_attention_kwargs,
|
1598 |
+
)[0]
|
1599 |
+
hidden_states = temp_attn(
|
1600 |
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
1601 |
+
).sample
|
1602 |
+
else:
|
1603 |
+
hidden_states = resnet(hidden_states, temb)
|
1604 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
1605 |
+
hidden_states = attn(
|
1606 |
+
hidden_states,
|
1607 |
+
encoder_hidden_states=encoder_hidden_states,
|
1608 |
+
attention_mask=mask,
|
1609 |
+
**cross_attention_kwargs,
|
1610 |
+
)
|
1611 |
+
hidden_states = temp_attn(
|
1612 |
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
1613 |
+
).sample
|
1614 |
+
|
1615 |
+
if self.upsamplers is not None:
|
1616 |
+
for upsampler in self.upsamplers:
|
1617 |
+
hidden_states = upsampler(hidden_states, temb)
|
1618 |
+
|
1619 |
+
return hidden_states
|
showone/models/unet_3d_condition.py
ADDED
@@ -0,0 +1,985 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
|
2 |
+
# Copyright 2023 The ModelScope Team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torch.utils.checkpoint
|
22 |
+
|
23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
24 |
+
from diffusers.loaders import UNet2DConditionLoadersMixin
|
25 |
+
from diffusers.utils import BaseOutput, logging
|
26 |
+
from diffusers.models.activations import get_activation
|
27 |
+
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
28 |
+
from diffusers.models.embeddings import (
|
29 |
+
GaussianFourierProjection,
|
30 |
+
ImageHintTimeEmbedding,
|
31 |
+
ImageProjection,
|
32 |
+
ImageTimeEmbedding,
|
33 |
+
TextImageProjection,
|
34 |
+
TextImageTimeEmbedding,
|
35 |
+
TextTimeEmbedding,
|
36 |
+
TimestepEmbedding,
|
37 |
+
Timesteps,
|
38 |
+
)
|
39 |
+
from diffusers.models.modeling_utils import ModelMixin
|
40 |
+
# from diffusers.models.transformer_temporal import TransformerTemporalModel
|
41 |
+
from .transformer_temporal import TransformerTemporalModel
|
42 |
+
from .unet_3d_blocks import (
|
43 |
+
CrossAttnDownBlock3D,
|
44 |
+
CrossAttnUpBlock3D,
|
45 |
+
DownBlock3D,
|
46 |
+
UNetMidBlock3DCrossAttn,
|
47 |
+
UNetMidBlock3DSimpleCrossAttn,
|
48 |
+
UpBlock3D,
|
49 |
+
get_down_block,
|
50 |
+
get_up_block,
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
55 |
+
|
56 |
+
|
57 |
+
@dataclass
|
58 |
+
class UNet3DConditionOutput(BaseOutput):
|
59 |
+
"""
|
60 |
+
Args:
|
61 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
|
62 |
+
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
63 |
+
"""
|
64 |
+
|
65 |
+
sample: torch.FloatTensor
|
66 |
+
|
67 |
+
|
68 |
+
class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
69 |
+
r"""
|
70 |
+
UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
|
71 |
+
and returns sample shaped output.
|
72 |
+
|
73 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
74 |
+
implements for all the models (such as downloading or saving, etc.)
|
75 |
+
|
76 |
+
Parameters:
|
77 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
78 |
+
Height and width of input/output sample.
|
79 |
+
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
80 |
+
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
81 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
82 |
+
The tuple of downsample blocks to use.
|
83 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
|
84 |
+
The tuple of upsample blocks to use.
|
85 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
86 |
+
The tuple of output channels for each block.
|
87 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
88 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
89 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
90 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
91 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
92 |
+
If `None`, it will skip the normalization and activation layers in post-processing
|
93 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
94 |
+
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
|
95 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
96 |
+
"""
|
97 |
+
|
98 |
+
_supports_gradient_checkpointing = True
|
99 |
+
|
100 |
+
@register_to_config
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
sample_size: Optional[int] = None,
|
104 |
+
in_channels: int = 4,
|
105 |
+
out_channels: int = 4,
|
106 |
+
center_input_sample: bool = False,
|
107 |
+
flip_sin_to_cos: bool = True,
|
108 |
+
freq_shift: int = 0,
|
109 |
+
down_block_types: Tuple[str] = (
|
110 |
+
"CrossAttnDownBlock3D",
|
111 |
+
"CrossAttnDownBlock3D",
|
112 |
+
"CrossAttnDownBlock3D",
|
113 |
+
"DownBlock3D",
|
114 |
+
),
|
115 |
+
mid_block_type: Optional[str] = "UNetMidBlock3DCrossAttn",
|
116 |
+
up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
|
117 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
118 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
119 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
120 |
+
downsample_padding: int = 1,
|
121 |
+
mid_block_scale_factor: float = 1,
|
122 |
+
act_fn: str = "silu",
|
123 |
+
norm_num_groups: Optional[int] = 32,
|
124 |
+
norm_eps: float = 1e-5,
|
125 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
126 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
127 |
+
encoder_hid_dim: Optional[int] = None,
|
128 |
+
encoder_hid_dim_type: Optional[str] = None,
|
129 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
130 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
131 |
+
dual_cross_attention: bool = False,
|
132 |
+
use_linear_projection: bool = False,
|
133 |
+
class_embed_type: Optional[str] = None,
|
134 |
+
addition_embed_type: Optional[str] = None,
|
135 |
+
addition_time_embed_dim: Optional[int] = None,
|
136 |
+
num_class_embeds: Optional[int] = None,
|
137 |
+
upcast_attention: bool = False,
|
138 |
+
resnet_time_scale_shift: str = "default",
|
139 |
+
resnet_skip_time_act: bool = False,
|
140 |
+
resnet_out_scale_factor: int = 1.0,
|
141 |
+
time_embedding_type: str = "positional",
|
142 |
+
time_embedding_dim: Optional[int] = None,
|
143 |
+
time_embedding_act_fn: Optional[str] = None,
|
144 |
+
timestep_post_act: Optional[str] = None,
|
145 |
+
time_cond_proj_dim: Optional[int] = None,
|
146 |
+
conv_in_kernel: int = 3,
|
147 |
+
conv_out_kernel: int = 3,
|
148 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
149 |
+
class_embeddings_concat: bool = False,
|
150 |
+
mid_block_only_cross_attention: Optional[bool] = None,
|
151 |
+
cross_attention_norm: Optional[str] = None,
|
152 |
+
addition_embed_type_num_heads=64,
|
153 |
+
transfromer_in_opt: bool =False,
|
154 |
+
):
|
155 |
+
super().__init__()
|
156 |
+
|
157 |
+
self.sample_size = sample_size
|
158 |
+
self.transformer_in_opt = transfromer_in_opt
|
159 |
+
|
160 |
+
if num_attention_heads is not None:
|
161 |
+
raise ValueError(
|
162 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
163 |
+
)
|
164 |
+
|
165 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
166 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
167 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
168 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
169 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
170 |
+
# which is why we correct for the naming here.
|
171 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
172 |
+
|
173 |
+
# Check inputs
|
174 |
+
if len(down_block_types) != len(up_block_types):
|
175 |
+
raise ValueError(
|
176 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
177 |
+
)
|
178 |
+
|
179 |
+
if len(block_out_channels) != len(down_block_types):
|
180 |
+
raise ValueError(
|
181 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
182 |
+
)
|
183 |
+
|
184 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
185 |
+
raise ValueError(
|
186 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
187 |
+
)
|
188 |
+
|
189 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
190 |
+
raise ValueError(
|
191 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
192 |
+
)
|
193 |
+
|
194 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
195 |
+
raise ValueError(
|
196 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
197 |
+
)
|
198 |
+
|
199 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
200 |
+
raise ValueError(
|
201 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
202 |
+
)
|
203 |
+
|
204 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
205 |
+
raise ValueError(
|
206 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
207 |
+
)
|
208 |
+
|
209 |
+
# input
|
210 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
211 |
+
self.conv_in = nn.Conv2d(
|
212 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
213 |
+
)
|
214 |
+
|
215 |
+
if self.transformer_in_opt:
|
216 |
+
self.transformer_in = TransformerTemporalModel(
|
217 |
+
num_attention_heads=8,
|
218 |
+
attention_head_dim=64,
|
219 |
+
in_channels=block_out_channels[0],
|
220 |
+
num_layers=1,
|
221 |
+
)
|
222 |
+
|
223 |
+
|
224 |
+
# time
|
225 |
+
if time_embedding_type == "fourier":
|
226 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
227 |
+
if time_embed_dim % 2 != 0:
|
228 |
+
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
229 |
+
self.time_proj = GaussianFourierProjection(
|
230 |
+
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
231 |
+
)
|
232 |
+
timestep_input_dim = time_embed_dim
|
233 |
+
elif time_embedding_type == "positional":
|
234 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
235 |
+
|
236 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
237 |
+
timestep_input_dim = block_out_channels[0]
|
238 |
+
else:
|
239 |
+
raise ValueError(
|
240 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
241 |
+
)
|
242 |
+
|
243 |
+
self.time_embedding = TimestepEmbedding(
|
244 |
+
timestep_input_dim,
|
245 |
+
time_embed_dim,
|
246 |
+
act_fn=act_fn,
|
247 |
+
post_act_fn=timestep_post_act,
|
248 |
+
cond_proj_dim=time_cond_proj_dim,
|
249 |
+
)
|
250 |
+
|
251 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
252 |
+
encoder_hid_dim_type = "text_proj"
|
253 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
254 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
255 |
+
|
256 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
257 |
+
raise ValueError(
|
258 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
259 |
+
)
|
260 |
+
|
261 |
+
if encoder_hid_dim_type == "text_proj":
|
262 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
263 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
264 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
265 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
266 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
267 |
+
self.encoder_hid_proj = TextImageProjection(
|
268 |
+
text_embed_dim=encoder_hid_dim,
|
269 |
+
image_embed_dim=cross_attention_dim,
|
270 |
+
cross_attention_dim=cross_attention_dim,
|
271 |
+
)
|
272 |
+
elif encoder_hid_dim_type == "image_proj":
|
273 |
+
# Kandinsky 2.2
|
274 |
+
self.encoder_hid_proj = ImageProjection(
|
275 |
+
image_embed_dim=encoder_hid_dim,
|
276 |
+
cross_attention_dim=cross_attention_dim,
|
277 |
+
)
|
278 |
+
elif encoder_hid_dim_type is not None:
|
279 |
+
raise ValueError(
|
280 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
281 |
+
)
|
282 |
+
else:
|
283 |
+
self.encoder_hid_proj = None
|
284 |
+
|
285 |
+
# class embedding
|
286 |
+
if class_embed_type is None and num_class_embeds is not None:
|
287 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
288 |
+
elif class_embed_type == "timestep":
|
289 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
290 |
+
elif class_embed_type == "identity":
|
291 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
292 |
+
elif class_embed_type == "projection":
|
293 |
+
if projection_class_embeddings_input_dim is None:
|
294 |
+
raise ValueError(
|
295 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
296 |
+
)
|
297 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
298 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
299 |
+
# 2. it projects from an arbitrary input dimension.
|
300 |
+
#
|
301 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
302 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
303 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
304 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
305 |
+
elif class_embed_type == "simple_projection":
|
306 |
+
if projection_class_embeddings_input_dim is None:
|
307 |
+
raise ValueError(
|
308 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
309 |
+
)
|
310 |
+
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
|
311 |
+
else:
|
312 |
+
self.class_embedding = None
|
313 |
+
|
314 |
+
if addition_embed_type == "text":
|
315 |
+
if encoder_hid_dim is not None:
|
316 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
317 |
+
else:
|
318 |
+
text_time_embedding_from_dim = cross_attention_dim
|
319 |
+
|
320 |
+
self.add_embedding = TextTimeEmbedding(
|
321 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
322 |
+
)
|
323 |
+
elif addition_embed_type == "text_image":
|
324 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
325 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
326 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
327 |
+
self.add_embedding = TextImageTimeEmbedding(
|
328 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
329 |
+
)
|
330 |
+
elif addition_embed_type == "text_time":
|
331 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
332 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
333 |
+
elif addition_embed_type == "image":
|
334 |
+
# Kandinsky 2.2
|
335 |
+
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
336 |
+
elif addition_embed_type == "image_hint":
|
337 |
+
# Kandinsky 2.2 ControlNet
|
338 |
+
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
339 |
+
elif addition_embed_type is not None:
|
340 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
341 |
+
|
342 |
+
if time_embedding_act_fn is None:
|
343 |
+
self.time_embed_act = None
|
344 |
+
else:
|
345 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
346 |
+
|
347 |
+
self.down_blocks = nn.ModuleList([])
|
348 |
+
self.up_blocks = nn.ModuleList([])
|
349 |
+
|
350 |
+
if isinstance(only_cross_attention, bool):
|
351 |
+
if mid_block_only_cross_attention is None:
|
352 |
+
mid_block_only_cross_attention = only_cross_attention
|
353 |
+
|
354 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
355 |
+
|
356 |
+
if mid_block_only_cross_attention is None:
|
357 |
+
mid_block_only_cross_attention = False
|
358 |
+
|
359 |
+
if isinstance(num_attention_heads, int):
|
360 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
361 |
+
|
362 |
+
if isinstance(attention_head_dim, int):
|
363 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
364 |
+
|
365 |
+
if isinstance(cross_attention_dim, int):
|
366 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
367 |
+
|
368 |
+
if isinstance(layers_per_block, int):
|
369 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
370 |
+
|
371 |
+
if isinstance(transformer_layers_per_block, int):
|
372 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
373 |
+
|
374 |
+
if class_embeddings_concat:
|
375 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
376 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
377 |
+
# regular time embeddings
|
378 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
379 |
+
else:
|
380 |
+
blocks_time_embed_dim = time_embed_dim
|
381 |
+
|
382 |
+
# down
|
383 |
+
output_channel = block_out_channels[0]
|
384 |
+
for i, down_block_type in enumerate(down_block_types):
|
385 |
+
input_channel = output_channel
|
386 |
+
output_channel = block_out_channels[i]
|
387 |
+
is_final_block = i == len(block_out_channels) - 1
|
388 |
+
|
389 |
+
down_block = get_down_block(
|
390 |
+
down_block_type,
|
391 |
+
num_layers=layers_per_block[i],
|
392 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
393 |
+
in_channels=input_channel,
|
394 |
+
out_channels=output_channel,
|
395 |
+
temb_channels=blocks_time_embed_dim,
|
396 |
+
add_downsample=not is_final_block,
|
397 |
+
resnet_eps=norm_eps,
|
398 |
+
resnet_act_fn=act_fn,
|
399 |
+
resnet_groups=norm_num_groups,
|
400 |
+
cross_attention_dim=cross_attention_dim[i],
|
401 |
+
num_attention_heads=num_attention_heads[i],
|
402 |
+
downsample_padding=downsample_padding,
|
403 |
+
dual_cross_attention=dual_cross_attention,
|
404 |
+
use_linear_projection=use_linear_projection,
|
405 |
+
only_cross_attention=only_cross_attention[i],
|
406 |
+
upcast_attention=upcast_attention,
|
407 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
408 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
409 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
410 |
+
cross_attention_norm=cross_attention_norm,
|
411 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
412 |
+
)
|
413 |
+
self.down_blocks.append(down_block)
|
414 |
+
|
415 |
+
# mid
|
416 |
+
if mid_block_type == "UNetMidBlock3DCrossAttn":
|
417 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
418 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
419 |
+
in_channels=block_out_channels[-1],
|
420 |
+
temb_channels=blocks_time_embed_dim,
|
421 |
+
resnet_eps=norm_eps,
|
422 |
+
resnet_act_fn=act_fn,
|
423 |
+
output_scale_factor=mid_block_scale_factor,
|
424 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
425 |
+
cross_attention_dim=cross_attention_dim[-1],
|
426 |
+
num_attention_heads=num_attention_heads[-1],
|
427 |
+
resnet_groups=norm_num_groups,
|
428 |
+
dual_cross_attention=dual_cross_attention,
|
429 |
+
use_linear_projection=use_linear_projection,
|
430 |
+
upcast_attention=upcast_attention,
|
431 |
+
)
|
432 |
+
elif mid_block_type == "UNetMidBlock3DSimpleCrossAttn":
|
433 |
+
self.mid_block = UNetMidBlock3DSimpleCrossAttn(
|
434 |
+
in_channels=block_out_channels[-1],
|
435 |
+
temb_channels=blocks_time_embed_dim,
|
436 |
+
resnet_eps=norm_eps,
|
437 |
+
resnet_act_fn=act_fn,
|
438 |
+
output_scale_factor=mid_block_scale_factor,
|
439 |
+
cross_attention_dim=cross_attention_dim[-1],
|
440 |
+
attention_head_dim=attention_head_dim[-1],
|
441 |
+
resnet_groups=norm_num_groups,
|
442 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
443 |
+
skip_time_act=resnet_skip_time_act,
|
444 |
+
only_cross_attention=mid_block_only_cross_attention,
|
445 |
+
cross_attention_norm=cross_attention_norm,
|
446 |
+
)
|
447 |
+
elif mid_block_type is None:
|
448 |
+
self.mid_block = None
|
449 |
+
else:
|
450 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
451 |
+
|
452 |
+
# count how many layers upsample the images
|
453 |
+
self.num_upsamplers = 0
|
454 |
+
|
455 |
+
# up
|
456 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
457 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
458 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
459 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
460 |
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
461 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
462 |
+
|
463 |
+
output_channel = reversed_block_out_channels[0]
|
464 |
+
for i, up_block_type in enumerate(up_block_types):
|
465 |
+
is_final_block = i == len(block_out_channels) - 1
|
466 |
+
|
467 |
+
prev_output_channel = output_channel
|
468 |
+
output_channel = reversed_block_out_channels[i]
|
469 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
470 |
+
|
471 |
+
# add upsample block for all BUT final layer
|
472 |
+
if not is_final_block:
|
473 |
+
add_upsample = True
|
474 |
+
self.num_upsamplers += 1
|
475 |
+
else:
|
476 |
+
add_upsample = False
|
477 |
+
|
478 |
+
up_block = get_up_block(
|
479 |
+
up_block_type,
|
480 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
481 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
482 |
+
in_channels=input_channel,
|
483 |
+
out_channels=output_channel,
|
484 |
+
prev_output_channel=prev_output_channel,
|
485 |
+
temb_channels=blocks_time_embed_dim,
|
486 |
+
add_upsample=add_upsample,
|
487 |
+
resnet_eps=norm_eps,
|
488 |
+
resnet_act_fn=act_fn,
|
489 |
+
resnet_groups=norm_num_groups,
|
490 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
491 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
492 |
+
dual_cross_attention=dual_cross_attention,
|
493 |
+
use_linear_projection=use_linear_projection,
|
494 |
+
only_cross_attention=only_cross_attention[i],
|
495 |
+
upcast_attention=upcast_attention,
|
496 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
497 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
498 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
499 |
+
cross_attention_norm=cross_attention_norm,
|
500 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
501 |
+
)
|
502 |
+
self.up_blocks.append(up_block)
|
503 |
+
prev_output_channel = output_channel
|
504 |
+
|
505 |
+
# out
|
506 |
+
if norm_num_groups is not None:
|
507 |
+
self.conv_norm_out = nn.GroupNorm(
|
508 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
509 |
+
)
|
510 |
+
|
511 |
+
self.conv_act = get_activation(act_fn)
|
512 |
+
|
513 |
+
else:
|
514 |
+
self.conv_norm_out = None
|
515 |
+
self.conv_act = None
|
516 |
+
|
517 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
518 |
+
self.conv_out = nn.Conv2d(
|
519 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
520 |
+
)
|
521 |
+
|
522 |
+
@property
|
523 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
524 |
+
r"""
|
525 |
+
Returns:
|
526 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
527 |
+
indexed by its weight name.
|
528 |
+
"""
|
529 |
+
# set recursively
|
530 |
+
processors = {}
|
531 |
+
|
532 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
533 |
+
if hasattr(module, "set_processor"):
|
534 |
+
processors[f"{name}.processor"] = module.processor
|
535 |
+
|
536 |
+
for sub_name, child in module.named_children():
|
537 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
538 |
+
|
539 |
+
return processors
|
540 |
+
|
541 |
+
for name, module in self.named_children():
|
542 |
+
fn_recursive_add_processors(name, module, processors)
|
543 |
+
|
544 |
+
return processors
|
545 |
+
|
546 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
547 |
+
r"""
|
548 |
+
Sets the attention processor to use to compute attention.
|
549 |
+
|
550 |
+
Parameters:
|
551 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
552 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
553 |
+
for **all** `Attention` layers.
|
554 |
+
|
555 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
556 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
557 |
+
|
558 |
+
"""
|
559 |
+
# count = len(self.attn_processors.keys())
|
560 |
+
# ignore temporal attention
|
561 |
+
count = len({k: v for k, v in self.attn_processors.items() if "temp_" not in k}.keys())
|
562 |
+
|
563 |
+
if isinstance(processor, dict) and len(processor) != count:
|
564 |
+
raise ValueError(
|
565 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
566 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
567 |
+
)
|
568 |
+
|
569 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
570 |
+
if hasattr(module, "set_processor") and "temp_" not in name:
|
571 |
+
if not isinstance(processor, dict):
|
572 |
+
module.set_processor(processor)
|
573 |
+
else:
|
574 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
575 |
+
|
576 |
+
for sub_name, child in module.named_children():
|
577 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
578 |
+
|
579 |
+
for name, module in self.named_children():
|
580 |
+
fn_recursive_attn_processor(name, module, processor)
|
581 |
+
|
582 |
+
def set_default_attn_processor(self):
|
583 |
+
"""
|
584 |
+
Disables custom attention processors and sets the default attention implementation.
|
585 |
+
"""
|
586 |
+
self.set_attn_processor(AttnProcessor())
|
587 |
+
|
588 |
+
def set_attention_slice(self, slice_size):
|
589 |
+
r"""
|
590 |
+
Enable sliced attention computation.
|
591 |
+
|
592 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
593 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
594 |
+
|
595 |
+
Args:
|
596 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
597 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
598 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
599 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
600 |
+
must be a multiple of `slice_size`.
|
601 |
+
"""
|
602 |
+
sliceable_head_dims = []
|
603 |
+
|
604 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
605 |
+
if hasattr(module, "set_attention_slice"):
|
606 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
607 |
+
|
608 |
+
for child in module.children():
|
609 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
610 |
+
|
611 |
+
# retrieve number of attention layers
|
612 |
+
for module in self.children():
|
613 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
614 |
+
|
615 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
616 |
+
|
617 |
+
if slice_size == "auto":
|
618 |
+
# half the attention head size is usually a good trade-off between
|
619 |
+
# speed and memory
|
620 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
621 |
+
elif slice_size == "max":
|
622 |
+
# make smallest slice possible
|
623 |
+
slice_size = num_sliceable_layers * [1]
|
624 |
+
|
625 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
626 |
+
|
627 |
+
if len(slice_size) != len(sliceable_head_dims):
|
628 |
+
raise ValueError(
|
629 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
630 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
631 |
+
)
|
632 |
+
|
633 |
+
for i in range(len(slice_size)):
|
634 |
+
size = slice_size[i]
|
635 |
+
dim = sliceable_head_dims[i]
|
636 |
+
if size is not None and size > dim:
|
637 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
638 |
+
|
639 |
+
# Recursively walk through all the children.
|
640 |
+
# Any children which exposes the set_attention_slice method
|
641 |
+
# gets the message
|
642 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
643 |
+
if hasattr(module, "set_attention_slice"):
|
644 |
+
module.set_attention_slice(slice_size.pop())
|
645 |
+
|
646 |
+
for child in module.children():
|
647 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
648 |
+
|
649 |
+
reversed_slice_size = list(reversed(slice_size))
|
650 |
+
for module in self.children():
|
651 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
652 |
+
|
653 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
654 |
+
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
655 |
+
module.gradient_checkpointing = value
|
656 |
+
|
657 |
+
def forward(
|
658 |
+
self,
|
659 |
+
sample: torch.FloatTensor,
|
660 |
+
timestep: Union[torch.Tensor, float, int],
|
661 |
+
encoder_hidden_states: torch.Tensor,
|
662 |
+
class_labels: Optional[torch.Tensor] = None,
|
663 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
664 |
+
attention_mask: Optional[torch.Tensor] = None,
|
665 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
666 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
667 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
668 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
669 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
670 |
+
return_dict: bool = True,
|
671 |
+
) -> Union[UNet3DConditionOutput, Tuple]:
|
672 |
+
r"""
|
673 |
+
Args:
|
674 |
+
sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor
|
675 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
676 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
677 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
678 |
+
Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple.
|
679 |
+
cross_attention_kwargs (`dict`, *optional*):
|
680 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
681 |
+
`self.processor` in
|
682 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
683 |
+
|
684 |
+
Returns:
|
685 |
+
[`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`:
|
686 |
+
[`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
687 |
+
returning a tuple, the first element is the sample tensor.
|
688 |
+
"""
|
689 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
690 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
691 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
692 |
+
# on the fly if necessary.
|
693 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
694 |
+
|
695 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
696 |
+
forward_upsample_size = False
|
697 |
+
upsample_size = None
|
698 |
+
|
699 |
+
|
700 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
701 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
702 |
+
forward_upsample_size = True
|
703 |
+
|
704 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
705 |
+
# expects mask of shape:
|
706 |
+
# [batch, key_tokens]
|
707 |
+
# adds singleton query_tokens dimension:
|
708 |
+
# [batch, 1, key_tokens]
|
709 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
710 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
711 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
712 |
+
if attention_mask is not None:
|
713 |
+
# assume that mask is expressed as:
|
714 |
+
# (1 = keep, 0 = discard)
|
715 |
+
# convert mask into a bias that can be added to attention scores:
|
716 |
+
# (keep = +0, discard = -10000.0)
|
717 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
718 |
+
attention_mask = attention_mask.unsqueeze(1)
|
719 |
+
|
720 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
721 |
+
if encoder_attention_mask is not None:
|
722 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
723 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
724 |
+
|
725 |
+
# 0. center input if necessary
|
726 |
+
if self.config.center_input_sample:
|
727 |
+
sample = 2 * sample - 1.0
|
728 |
+
|
729 |
+
# 1. time
|
730 |
+
timesteps = timestep
|
731 |
+
if not torch.is_tensor(timesteps):
|
732 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
733 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
734 |
+
is_mps = sample.device.type == "mps"
|
735 |
+
if isinstance(timestep, float):
|
736 |
+
dtype = torch.float32 if is_mps else torch.float64
|
737 |
+
else:
|
738 |
+
dtype = torch.int32 if is_mps else torch.int64
|
739 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
740 |
+
elif len(timesteps.shape) == 0:
|
741 |
+
timesteps = timesteps[None].to(sample.device)
|
742 |
+
|
743 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
744 |
+
num_frames = sample.shape[2]
|
745 |
+
timesteps = timesteps.expand(sample.shape[0])
|
746 |
+
|
747 |
+
t_emb = self.time_proj(timesteps)
|
748 |
+
|
749 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
750 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
751 |
+
# there might be better ways to encapsulate this.
|
752 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
753 |
+
|
754 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
755 |
+
aug_emb = None
|
756 |
+
|
757 |
+
if self.class_embedding is not None:
|
758 |
+
if class_labels is None:
|
759 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
760 |
+
|
761 |
+
if self.config.class_embed_type == "timestep":
|
762 |
+
class_labels = self.time_proj(class_labels)
|
763 |
+
|
764 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
765 |
+
# there might be better ways to encapsulate this.
|
766 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
767 |
+
|
768 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
769 |
+
|
770 |
+
if self.config.class_embeddings_concat:
|
771 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
772 |
+
else:
|
773 |
+
emb = emb + class_emb
|
774 |
+
|
775 |
+
if self.config.addition_embed_type == "text":
|
776 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
777 |
+
elif self.config.addition_embed_type == "text_image":
|
778 |
+
# Kandinsky 2.1 - style
|
779 |
+
if "image_embeds" not in added_cond_kwargs:
|
780 |
+
raise ValueError(
|
781 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
782 |
+
)
|
783 |
+
|
784 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
785 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
786 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
787 |
+
elif self.config.addition_embed_type == "text_time":
|
788 |
+
if "text_embeds" not in added_cond_kwargs:
|
789 |
+
raise ValueError(
|
790 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
791 |
+
)
|
792 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
793 |
+
if "time_ids" not in added_cond_kwargs:
|
794 |
+
raise ValueError(
|
795 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
796 |
+
)
|
797 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
798 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
799 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
800 |
+
|
801 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
802 |
+
add_embeds = add_embeds.to(emb.dtype)
|
803 |
+
aug_emb = self.add_embedding(add_embeds)
|
804 |
+
elif self.config.addition_embed_type == "image":
|
805 |
+
# Kandinsky 2.2 - style
|
806 |
+
if "image_embeds" not in added_cond_kwargs:
|
807 |
+
raise ValueError(
|
808 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
809 |
+
)
|
810 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
811 |
+
aug_emb = self.add_embedding(image_embs)
|
812 |
+
elif self.config.addition_embed_type == "image_hint":
|
813 |
+
# Kandinsky 2.2 - style
|
814 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
815 |
+
raise ValueError(
|
816 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
817 |
+
)
|
818 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
819 |
+
hint = added_cond_kwargs.get("hint")
|
820 |
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
821 |
+
sample = torch.cat([sample, hint], dim=1)
|
822 |
+
|
823 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
824 |
+
|
825 |
+
if self.time_embed_act is not None:
|
826 |
+
emb = self.time_embed_act(emb)
|
827 |
+
|
828 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
829 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
830 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
831 |
+
# Kadinsky 2.1 - style
|
832 |
+
if "image_embeds" not in added_cond_kwargs:
|
833 |
+
raise ValueError(
|
834 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
835 |
+
)
|
836 |
+
|
837 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
838 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
839 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
840 |
+
# Kandinsky 2.2 - style
|
841 |
+
if "image_embeds" not in added_cond_kwargs:
|
842 |
+
raise ValueError(
|
843 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
844 |
+
)
|
845 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
846 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
847 |
+
|
848 |
+
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
849 |
+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
850 |
+
|
851 |
+
# 2. pre-process
|
852 |
+
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
|
853 |
+
sample = self.conv_in(sample)
|
854 |
+
|
855 |
+
if self.transformer_in_opt:
|
856 |
+
|
857 |
+
sample = self.transformer_in(
|
858 |
+
sample,
|
859 |
+
num_frames=num_frames,
|
860 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
861 |
+
return_dict=False,
|
862 |
+
)[0]
|
863 |
+
|
864 |
+
# 3. down
|
865 |
+
down_block_res_samples = (sample,)
|
866 |
+
for downsample_block in self.down_blocks:
|
867 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
868 |
+
sample, res_samples = downsample_block(
|
869 |
+
hidden_states=sample,
|
870 |
+
temb=emb,
|
871 |
+
encoder_hidden_states=encoder_hidden_states,
|
872 |
+
attention_mask=attention_mask,
|
873 |
+
num_frames=num_frames,
|
874 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
875 |
+
encoder_attention_mask=encoder_attention_mask,
|
876 |
+
)
|
877 |
+
else:
|
878 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
|
879 |
+
|
880 |
+
down_block_res_samples += res_samples
|
881 |
+
|
882 |
+
if down_block_additional_residuals is not None:
|
883 |
+
new_down_block_res_samples = ()
|
884 |
+
|
885 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
886 |
+
down_block_res_samples, down_block_additional_residuals
|
887 |
+
):
|
888 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
889 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
890 |
+
|
891 |
+
down_block_res_samples = new_down_block_res_samples
|
892 |
+
|
893 |
+
# 4. mid
|
894 |
+
if self.mid_block is not None:
|
895 |
+
sample = self.mid_block(
|
896 |
+
sample,
|
897 |
+
emb,
|
898 |
+
encoder_hidden_states=encoder_hidden_states,
|
899 |
+
attention_mask=attention_mask,
|
900 |
+
num_frames=num_frames,
|
901 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
902 |
+
encoder_attention_mask=encoder_attention_mask,
|
903 |
+
)
|
904 |
+
|
905 |
+
if mid_block_additional_residual is not None:
|
906 |
+
sample = sample + mid_block_additional_residual
|
907 |
+
|
908 |
+
# 5. up
|
909 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
910 |
+
is_final_block = i == len(self.up_blocks) - 1
|
911 |
+
|
912 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
913 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
914 |
+
|
915 |
+
# if we have not reached the final block and need to forward the
|
916 |
+
# upsample size, we do it here
|
917 |
+
if not is_final_block and forward_upsample_size:
|
918 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
919 |
+
|
920 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
921 |
+
sample = upsample_block(
|
922 |
+
hidden_states=sample,
|
923 |
+
temb=emb,
|
924 |
+
res_hidden_states_tuple=res_samples,
|
925 |
+
encoder_hidden_states=encoder_hidden_states,
|
926 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
927 |
+
upsample_size=upsample_size,
|
928 |
+
attention_mask=attention_mask,
|
929 |
+
num_frames=num_frames,
|
930 |
+
encoder_attention_mask=encoder_attention_mask,
|
931 |
+
)
|
932 |
+
else:
|
933 |
+
sample = upsample_block(
|
934 |
+
hidden_states=sample,
|
935 |
+
temb=emb,
|
936 |
+
res_hidden_states_tuple=res_samples,
|
937 |
+
upsample_size=upsample_size,
|
938 |
+
num_frames=num_frames,
|
939 |
+
)
|
940 |
+
|
941 |
+
# 6. post-process
|
942 |
+
if self.conv_norm_out:
|
943 |
+
sample = self.conv_norm_out(sample)
|
944 |
+
sample = self.conv_act(sample)
|
945 |
+
sample = self.conv_out(sample)
|
946 |
+
|
947 |
+
# reshape to (batch, channel, framerate, width, height)
|
948 |
+
sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
|
949 |
+
|
950 |
+
if not return_dict:
|
951 |
+
return (sample,)
|
952 |
+
|
953 |
+
return UNet3DConditionOutput(sample=sample)
|
954 |
+
|
955 |
+
@classmethod
|
956 |
+
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None):
|
957 |
+
import os, json
|
958 |
+
if subfolder is not None:
|
959 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
960 |
+
|
961 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
962 |
+
if not os.path.isfile(config_file):
|
963 |
+
raise RuntimeError(f"{config_file} does not exist")
|
964 |
+
with open(config_file, "r") as f:
|
965 |
+
config = json.load(f)
|
966 |
+
config["_class_name"] = cls.__name__
|
967 |
+
|
968 |
+
config["down_block_types"] = [x.replace("2D", "3D") for x in config["down_block_types"]]
|
969 |
+
if "mid_block_type" in config.keys():
|
970 |
+
config["mid_block_type"] = config["mid_block_type"].replace("2D", "3D")
|
971 |
+
config["up_block_types"] = [x.replace("2D", "3D") for x in config["up_block_types"]]
|
972 |
+
|
973 |
+
from diffusers.utils import WEIGHTS_NAME
|
974 |
+
model = cls.from_config(config)
|
975 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
976 |
+
if not os.path.isfile(model_file):
|
977 |
+
raise RuntimeError(f"{model_file} does not exist")
|
978 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
979 |
+
for k, v in model.state_dict().items():
|
980 |
+
if k not in state_dict:
|
981 |
+
|
982 |
+
state_dict.update({k: v})
|
983 |
+
model.load_state_dict(state_dict)
|
984 |
+
|
985 |
+
return model
|
showone/pipelines/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from diffusers.utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class TextToVideoPipelineOutput(BaseOutput):
|
12 |
+
"""
|
13 |
+
Output class for text to video pipelines.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
frames (`List[np.ndarray]` or `torch.FloatTensor`)
|
17 |
+
List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
|
18 |
+
a `torch` tensor. NumPy array present the denoised images of the diffusion pipeline. The length of the list
|
19 |
+
denotes the video length i.e., the number of frames.
|
20 |
+
"""
|
21 |
+
|
22 |
+
frames: Union[List[np.ndarray], torch.FloatTensor]
|
23 |
+
|
24 |
+
|
25 |
+
try:
|
26 |
+
if not (is_transformers_available() and is_torch_available()):
|
27 |
+
raise OptionalDependencyNotAvailable()
|
28 |
+
except OptionalDependencyNotAvailable:
|
29 |
+
from diffusers.utils.dummy_torch_and_transformers_objects import * # noqa F403
|
30 |
+
else:
|
31 |
+
# from .pipeline_t2v_base_latent import TextToVideoSDPipeline # noqa: F401
|
32 |
+
# from .pipeline_t2v_base_latent_sdxl import TextToVideoSDXLPipeline
|
33 |
+
from .pipeline_t2v_base_pixel import TextToVideoIFPipeline
|
34 |
+
from .pipeline_t2v_interp_pixel import TextToVideoIFInterpPipeline
|
35 |
+
# from .pipeline_t2v_sr_latent import TextToVideoSDSuperResolutionPipeline
|
36 |
+
from .pipeline_t2v_sr_pixel import TextToVideoIFSuperResolutionPipeline
|
37 |
+
# from .pipeline_t2v_base_latent_controlnet import TextToVideoSDControlNetPipeline
|
showone/pipelines/pipeline_t2v_base_pixel.py
ADDED
@@ -0,0 +1,775 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import html
|
2 |
+
import inspect
|
3 |
+
import re
|
4 |
+
import urllib.parse as ul
|
5 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
10 |
+
|
11 |
+
from diffusers.loaders import LoraLoaderMixin
|
12 |
+
from diffusers.schedulers import DDPMScheduler
|
13 |
+
from diffusers.utils import (
|
14 |
+
BACKENDS_MAPPING,
|
15 |
+
is_accelerate_available,
|
16 |
+
is_accelerate_version,
|
17 |
+
is_bs4_available,
|
18 |
+
is_ftfy_available,
|
19 |
+
logging,
|
20 |
+
randn_tensor,
|
21 |
+
)
|
22 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
23 |
+
|
24 |
+
from ..models import UNet3DConditionModel
|
25 |
+
from . import TextToVideoPipelineOutput
|
26 |
+
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
29 |
+
|
30 |
+
if is_bs4_available():
|
31 |
+
from bs4 import BeautifulSoup
|
32 |
+
|
33 |
+
if is_ftfy_available():
|
34 |
+
import ftfy
|
35 |
+
|
36 |
+
|
37 |
+
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
|
38 |
+
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
39 |
+
# reshape to ncfhw
|
40 |
+
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
41 |
+
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
42 |
+
# unnormalize back to [0,1]
|
43 |
+
video = video.mul_(std).add_(mean)
|
44 |
+
video.clamp_(0, 1)
|
45 |
+
# prepare the final outputs
|
46 |
+
i, c, f, h, w = video.shape
|
47 |
+
images = video.permute(2, 3, 0, 4, 1).reshape(
|
48 |
+
f, h, i * w, c
|
49 |
+
) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
|
50 |
+
images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
|
51 |
+
images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
|
52 |
+
return images
|
53 |
+
|
54 |
+
|
55 |
+
class TextToVideoIFPipeline(DiffusionPipeline, LoraLoaderMixin):
|
56 |
+
tokenizer: T5Tokenizer
|
57 |
+
text_encoder: T5EncoderModel
|
58 |
+
|
59 |
+
unet: UNet3DConditionModel
|
60 |
+
scheduler: DDPMScheduler
|
61 |
+
|
62 |
+
feature_extractor: Optional[CLIPImageProcessor]
|
63 |
+
# safety_checker: Optional[IFSafetyChecker]
|
64 |
+
|
65 |
+
# watermarker: Optional[IFWatermarker]
|
66 |
+
|
67 |
+
bad_punct_regex = re.compile(
|
68 |
+
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
69 |
+
) # noqa
|
70 |
+
|
71 |
+
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
tokenizer: T5Tokenizer,
|
76 |
+
text_encoder: T5EncoderModel,
|
77 |
+
unet: UNet3DConditionModel,
|
78 |
+
scheduler: DDPMScheduler,
|
79 |
+
feature_extractor: Optional[CLIPImageProcessor],
|
80 |
+
):
|
81 |
+
super().__init__()
|
82 |
+
|
83 |
+
self.register_modules(
|
84 |
+
tokenizer=tokenizer,
|
85 |
+
text_encoder=text_encoder,
|
86 |
+
unet=unet,
|
87 |
+
scheduler=scheduler,
|
88 |
+
feature_extractor=feature_extractor,
|
89 |
+
)
|
90 |
+
self.safety_checker = None
|
91 |
+
|
92 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
93 |
+
r"""
|
94 |
+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
95 |
+
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
96 |
+
when their specific submodule has its `forward` method called.
|
97 |
+
"""
|
98 |
+
if is_accelerate_available():
|
99 |
+
from accelerate import cpu_offload
|
100 |
+
else:
|
101 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
102 |
+
|
103 |
+
device = torch.device(f"cuda:{gpu_id}")
|
104 |
+
|
105 |
+
models = [
|
106 |
+
self.text_encoder,
|
107 |
+
self.unet,
|
108 |
+
]
|
109 |
+
for cpu_offloaded_model in models:
|
110 |
+
if cpu_offloaded_model is not None:
|
111 |
+
cpu_offload(cpu_offloaded_model, device)
|
112 |
+
|
113 |
+
if self.safety_checker is not None:
|
114 |
+
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
115 |
+
|
116 |
+
def enable_model_cpu_offload(self, gpu_id=0):
|
117 |
+
r"""
|
118 |
+
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
119 |
+
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
120 |
+
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
121 |
+
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
122 |
+
"""
|
123 |
+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
124 |
+
from accelerate import cpu_offload_with_hook
|
125 |
+
else:
|
126 |
+
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
127 |
+
|
128 |
+
device = torch.device(f"cuda:{gpu_id}")
|
129 |
+
|
130 |
+
self.unet.train()
|
131 |
+
|
132 |
+
if self.device.type != "cpu":
|
133 |
+
self.to("cpu", silence_dtype_warnings=True)
|
134 |
+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
135 |
+
|
136 |
+
hook = None
|
137 |
+
|
138 |
+
if self.text_encoder is not None:
|
139 |
+
_, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
|
140 |
+
|
141 |
+
# Accelerate will move the next model to the device _before_ calling the offload hook of the
|
142 |
+
# previous model. This will cause both models to be present on the device at the same time.
|
143 |
+
# IF uses T5 for its text encoder which is really large. We can manually call the offload
|
144 |
+
# hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
|
145 |
+
# the GPU.
|
146 |
+
self.text_encoder_offload_hook = hook
|
147 |
+
|
148 |
+
_, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
|
149 |
+
|
150 |
+
# if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
|
151 |
+
self.unet_offload_hook = hook
|
152 |
+
|
153 |
+
if self.safety_checker is not None:
|
154 |
+
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
155 |
+
|
156 |
+
# We'll offload the last model manually.
|
157 |
+
self.final_offload_hook = hook
|
158 |
+
|
159 |
+
def remove_all_hooks(self):
|
160 |
+
if is_accelerate_available():
|
161 |
+
from accelerate.hooks import remove_hook_from_module
|
162 |
+
else:
|
163 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
164 |
+
|
165 |
+
for model in [self.text_encoder, self.unet, self.safety_checker]:
|
166 |
+
if model is not None:
|
167 |
+
remove_hook_from_module(model, recurse=True)
|
168 |
+
|
169 |
+
self.unet_offload_hook = None
|
170 |
+
self.text_encoder_offload_hook = None
|
171 |
+
self.final_offload_hook = None
|
172 |
+
|
173 |
+
@property
|
174 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
175 |
+
def _execution_device(self):
|
176 |
+
r"""
|
177 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
178 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
179 |
+
hooks.
|
180 |
+
"""
|
181 |
+
if not hasattr(self.unet, "_hf_hook"):
|
182 |
+
return self.device
|
183 |
+
for module in self.unet.modules():
|
184 |
+
if (
|
185 |
+
hasattr(module, "_hf_hook")
|
186 |
+
and hasattr(module._hf_hook, "execution_device")
|
187 |
+
and module._hf_hook.execution_device is not None
|
188 |
+
):
|
189 |
+
return torch.device(module._hf_hook.execution_device)
|
190 |
+
return self.device
|
191 |
+
|
192 |
+
@torch.no_grad()
|
193 |
+
def encode_prompt(
|
194 |
+
self,
|
195 |
+
prompt,
|
196 |
+
do_classifier_free_guidance=True,
|
197 |
+
num_images_per_prompt=1,
|
198 |
+
device=None,
|
199 |
+
negative_prompt=None,
|
200 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
201 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
202 |
+
clean_caption: bool = False,
|
203 |
+
):
|
204 |
+
r"""
|
205 |
+
Encodes the prompt into text encoder hidden states.
|
206 |
+
|
207 |
+
Args:
|
208 |
+
prompt (`str` or `List[str]`, *optional*):
|
209 |
+
prompt to be encoded
|
210 |
+
device: (`torch.device`, *optional*):
|
211 |
+
torch device to place the resulting embeddings on
|
212 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
213 |
+
number of images that should be generated per prompt
|
214 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
215 |
+
whether to use classifier free guidance or not
|
216 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
217 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
218 |
+
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
219 |
+
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
220 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
221 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
222 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
223 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
224 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
225 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
226 |
+
argument.
|
227 |
+
"""
|
228 |
+
if prompt is not None and negative_prompt is not None:
|
229 |
+
if type(prompt) is not type(negative_prompt):
|
230 |
+
raise TypeError(
|
231 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
232 |
+
f" {type(prompt)}."
|
233 |
+
)
|
234 |
+
|
235 |
+
if device is None:
|
236 |
+
device = self._execution_device
|
237 |
+
|
238 |
+
if prompt is not None and isinstance(prompt, str):
|
239 |
+
batch_size = 1
|
240 |
+
elif prompt is not None and isinstance(prompt, list):
|
241 |
+
batch_size = len(prompt)
|
242 |
+
else:
|
243 |
+
batch_size = prompt_embeds.shape[0]
|
244 |
+
|
245 |
+
# while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
|
246 |
+
max_length = 77
|
247 |
+
|
248 |
+
if prompt_embeds is None:
|
249 |
+
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
250 |
+
text_inputs = self.tokenizer(
|
251 |
+
prompt,
|
252 |
+
padding="max_length",
|
253 |
+
max_length=max_length,
|
254 |
+
truncation=True,
|
255 |
+
add_special_tokens=True,
|
256 |
+
return_tensors="pt",
|
257 |
+
)
|
258 |
+
text_input_ids = text_inputs.input_ids
|
259 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
260 |
+
|
261 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
262 |
+
text_input_ids, untruncated_ids
|
263 |
+
):
|
264 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
265 |
+
logger.warning(
|
266 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
267 |
+
f" {max_length} tokens: {removed_text}"
|
268 |
+
)
|
269 |
+
|
270 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
271 |
+
|
272 |
+
prompt_embeds = self.text_encoder(
|
273 |
+
text_input_ids.to(device),
|
274 |
+
attention_mask=attention_mask,
|
275 |
+
)
|
276 |
+
prompt_embeds = prompt_embeds[0]
|
277 |
+
|
278 |
+
if self.text_encoder is not None:
|
279 |
+
dtype = self.text_encoder.dtype
|
280 |
+
elif self.unet is not None:
|
281 |
+
dtype = self.unet.dtype
|
282 |
+
else:
|
283 |
+
dtype = None
|
284 |
+
|
285 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
286 |
+
|
287 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
288 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
289 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
290 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
291 |
+
|
292 |
+
# get unconditional embeddings for classifier free guidance
|
293 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
294 |
+
uncond_tokens: List[str]
|
295 |
+
if negative_prompt is None:
|
296 |
+
uncond_tokens = [""] * batch_size
|
297 |
+
elif isinstance(negative_prompt, str):
|
298 |
+
uncond_tokens = [negative_prompt]
|
299 |
+
elif batch_size != len(negative_prompt):
|
300 |
+
raise ValueError(
|
301 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
302 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
303 |
+
" the batch size of `prompt`."
|
304 |
+
)
|
305 |
+
else:
|
306 |
+
uncond_tokens = negative_prompt
|
307 |
+
|
308 |
+
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
309 |
+
max_length = prompt_embeds.shape[1]
|
310 |
+
uncond_input = self.tokenizer(
|
311 |
+
uncond_tokens,
|
312 |
+
padding="max_length",
|
313 |
+
max_length=max_length,
|
314 |
+
truncation=True,
|
315 |
+
return_attention_mask=True,
|
316 |
+
add_special_tokens=True,
|
317 |
+
return_tensors="pt",
|
318 |
+
)
|
319 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
320 |
+
|
321 |
+
negative_prompt_embeds = self.text_encoder(
|
322 |
+
uncond_input.input_ids.to(device),
|
323 |
+
attention_mask=attention_mask,
|
324 |
+
)
|
325 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
326 |
+
|
327 |
+
if do_classifier_free_guidance:
|
328 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
329 |
+
seq_len = negative_prompt_embeds.shape[1]
|
330 |
+
|
331 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
332 |
+
|
333 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
334 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
335 |
+
|
336 |
+
# For classifier free guidance, we need to do two forward passes.
|
337 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
338 |
+
# to avoid doing two forward passes
|
339 |
+
else:
|
340 |
+
negative_prompt_embeds = None
|
341 |
+
|
342 |
+
return prompt_embeds, negative_prompt_embeds
|
343 |
+
|
344 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
345 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
346 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
347 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
348 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
349 |
+
# and should be between [0, 1]
|
350 |
+
|
351 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
352 |
+
extra_step_kwargs = {}
|
353 |
+
if accepts_eta:
|
354 |
+
extra_step_kwargs["eta"] = eta
|
355 |
+
|
356 |
+
# check if the scheduler accepts generator
|
357 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
358 |
+
if accepts_generator:
|
359 |
+
extra_step_kwargs["generator"] = generator
|
360 |
+
return extra_step_kwargs
|
361 |
+
|
362 |
+
def check_inputs(
|
363 |
+
self,
|
364 |
+
prompt,
|
365 |
+
callback_steps,
|
366 |
+
negative_prompt=None,
|
367 |
+
prompt_embeds=None,
|
368 |
+
negative_prompt_embeds=None,
|
369 |
+
):
|
370 |
+
if (callback_steps is None) or (
|
371 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
372 |
+
):
|
373 |
+
raise ValueError(
|
374 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
375 |
+
f" {type(callback_steps)}."
|
376 |
+
)
|
377 |
+
|
378 |
+
if prompt is not None and prompt_embeds is not None:
|
379 |
+
raise ValueError(
|
380 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
381 |
+
" only forward one of the two."
|
382 |
+
)
|
383 |
+
elif prompt is None and prompt_embeds is None:
|
384 |
+
raise ValueError(
|
385 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
386 |
+
)
|
387 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
388 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
389 |
+
|
390 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
391 |
+
raise ValueError(
|
392 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
393 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
394 |
+
)
|
395 |
+
|
396 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
397 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
398 |
+
raise ValueError(
|
399 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
400 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
401 |
+
f" {negative_prompt_embeds.shape}."
|
402 |
+
)
|
403 |
+
|
404 |
+
def prepare_intermediate_images(self, batch_size, num_channels, num_frames, height, width, dtype, device, generator):
|
405 |
+
shape = (batch_size, num_channels, num_frames, height, width)
|
406 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
407 |
+
raise ValueError(
|
408 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
409 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
410 |
+
)
|
411 |
+
|
412 |
+
intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
413 |
+
|
414 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
415 |
+
intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
|
416 |
+
return intermediate_images
|
417 |
+
|
418 |
+
def _text_preprocessing(self, text, clean_caption=False):
|
419 |
+
if clean_caption and not is_bs4_available():
|
420 |
+
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
421 |
+
logger.warn("Setting `clean_caption` to False...")
|
422 |
+
clean_caption = False
|
423 |
+
|
424 |
+
if clean_caption and not is_ftfy_available():
|
425 |
+
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
426 |
+
logger.warn("Setting `clean_caption` to False...")
|
427 |
+
clean_caption = False
|
428 |
+
|
429 |
+
if not isinstance(text, (tuple, list)):
|
430 |
+
text = [text]
|
431 |
+
|
432 |
+
def process(text: str):
|
433 |
+
if clean_caption:
|
434 |
+
text = self._clean_caption(text)
|
435 |
+
text = self._clean_caption(text)
|
436 |
+
else:
|
437 |
+
text = text.lower().strip()
|
438 |
+
return text
|
439 |
+
|
440 |
+
return [process(t) for t in text]
|
441 |
+
|
442 |
+
def _clean_caption(self, caption):
|
443 |
+
caption = str(caption)
|
444 |
+
caption = ul.unquote_plus(caption)
|
445 |
+
caption = caption.strip().lower()
|
446 |
+
caption = re.sub("<person>", "person", caption)
|
447 |
+
# urls:
|
448 |
+
caption = re.sub(
|
449 |
+
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
450 |
+
"",
|
451 |
+
caption,
|
452 |
+
) # regex for urls
|
453 |
+
caption = re.sub(
|
454 |
+
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
455 |
+
"",
|
456 |
+
caption,
|
457 |
+
) # regex for urls
|
458 |
+
# html:
|
459 |
+
caption = BeautifulSoup(caption, features="html.parser").text
|
460 |
+
|
461 |
+
# @<nickname>
|
462 |
+
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
463 |
+
|
464 |
+
# 31C0—31EF CJK Strokes
|
465 |
+
# 31F0—31FF Katakana Phonetic Extensions
|
466 |
+
# 3200—32FF Enclosed CJK Letters and Months
|
467 |
+
# 3300—33FF CJK Compatibility
|
468 |
+
# 3400—4DBF CJK Unified Ideographs Extension A
|
469 |
+
# 4DC0—4DFF Yijing Hexagram Symbols
|
470 |
+
# 4E00—9FFF CJK Unified Ideographs
|
471 |
+
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
472 |
+
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
473 |
+
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
474 |
+
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
475 |
+
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
476 |
+
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
477 |
+
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
478 |
+
#######################################################
|
479 |
+
|
480 |
+
# все виды тире / all types of dash --> "-"
|
481 |
+
caption = re.sub(
|
482 |
+
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
483 |
+
"-",
|
484 |
+
caption,
|
485 |
+
)
|
486 |
+
|
487 |
+
# кавычки к одному стандарту
|
488 |
+
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
489 |
+
caption = re.sub(r"[‘’]", "'", caption)
|
490 |
+
|
491 |
+
# "
|
492 |
+
caption = re.sub(r""?", "", caption)
|
493 |
+
# &
|
494 |
+
caption = re.sub(r"&", "", caption)
|
495 |
+
|
496 |
+
# ip adresses:
|
497 |
+
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
498 |
+
|
499 |
+
# article ids:
|
500 |
+
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
501 |
+
|
502 |
+
# \n
|
503 |
+
caption = re.sub(r"\\n", " ", caption)
|
504 |
+
|
505 |
+
# "#123"
|
506 |
+
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
507 |
+
# "#12345.."
|
508 |
+
caption = re.sub(r"#\d{5,}\b", "", caption)
|
509 |
+
# "123456.."
|
510 |
+
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
511 |
+
# filenames:
|
512 |
+
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
513 |
+
|
514 |
+
#
|
515 |
+
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
516 |
+
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
517 |
+
|
518 |
+
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
519 |
+
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
520 |
+
|
521 |
+
# this-is-my-cute-cat / this_is_my_cute_cat
|
522 |
+
regex2 = re.compile(r"(?:\-|\_)")
|
523 |
+
if len(re.findall(regex2, caption)) > 3:
|
524 |
+
caption = re.sub(regex2, " ", caption)
|
525 |
+
|
526 |
+
caption = ftfy.fix_text(caption)
|
527 |
+
caption = html.unescape(html.unescape(caption))
|
528 |
+
|
529 |
+
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
530 |
+
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
531 |
+
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
532 |
+
|
533 |
+
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
534 |
+
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
535 |
+
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
536 |
+
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
537 |
+
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
538 |
+
|
539 |
+
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
540 |
+
|
541 |
+
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
542 |
+
|
543 |
+
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
544 |
+
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
545 |
+
caption = re.sub(r"\s+", " ", caption)
|
546 |
+
|
547 |
+
caption.strip()
|
548 |
+
|
549 |
+
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
550 |
+
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
551 |
+
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
552 |
+
caption = re.sub(r"^\.\S+$", "", caption)
|
553 |
+
|
554 |
+
return caption.strip()
|
555 |
+
|
556 |
+
@torch.no_grad()
|
557 |
+
def __call__(
|
558 |
+
self,
|
559 |
+
prompt: Union[str, List[str]] = None,
|
560 |
+
num_inference_steps: int = 100,
|
561 |
+
timesteps: List[int] = None,
|
562 |
+
guidance_scale: float = 7.0,
|
563 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
564 |
+
num_images_per_prompt: Optional[int] = 1,
|
565 |
+
height: Optional[int] = None,
|
566 |
+
width: Optional[int] = None,
|
567 |
+
num_frames: int = 16,
|
568 |
+
eta: float = 0.0,
|
569 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
570 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
571 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
572 |
+
output_type: Optional[str] = "np",
|
573 |
+
return_dict: bool = True,
|
574 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
575 |
+
callback_steps: int = 1,
|
576 |
+
clean_caption: bool = True,
|
577 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
578 |
+
):
|
579 |
+
"""
|
580 |
+
Function invoked when calling the pipeline for generation.
|
581 |
+
|
582 |
+
Args:
|
583 |
+
prompt (`str` or `List[str]`, *optional*):
|
584 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
585 |
+
instead.
|
586 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
587 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
588 |
+
expense of slower inference.
|
589 |
+
timesteps (`List[int]`, *optional*):
|
590 |
+
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
591 |
+
timesteps are used. Must be in descending order.
|
592 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
593 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
594 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
595 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
596 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
597 |
+
usually at the expense of lower image quality.
|
598 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
599 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
600 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
601 |
+
less than `1`).
|
602 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
603 |
+
The number of images to generate per prompt.
|
604 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
605 |
+
The height in pixels of the generated image.
|
606 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
607 |
+
The width in pixels of the generated image.
|
608 |
+
eta (`float`, *optional*, defaults to 0.0):
|
609 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
610 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
611 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
612 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
613 |
+
to make generation deterministic.
|
614 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
615 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
616 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
617 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
618 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
619 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
620 |
+
argument.
|
621 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
622 |
+
The output format of the generate image. Choose between
|
623 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
624 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
625 |
+
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
626 |
+
callback (`Callable`, *optional*):
|
627 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
628 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
629 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
630 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
631 |
+
called at every step.
|
632 |
+
clean_caption (`bool`, *optional*, defaults to `True`):
|
633 |
+
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
634 |
+
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
635 |
+
prompt.
|
636 |
+
cross_attention_kwargs (`dict`, *optional*):
|
637 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
638 |
+
`self.processor` in
|
639 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
640 |
+
|
641 |
+
Examples:
|
642 |
+
|
643 |
+
Returns:
|
644 |
+
[`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
|
645 |
+
[`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
|
646 |
+
returning a tuple, the first element is a list with the generated images, and the second element is a list
|
647 |
+
of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
|
648 |
+
or watermarked content, according to the `safety_checker`.
|
649 |
+
"""
|
650 |
+
# 1. Check inputs. Raise error if not correct
|
651 |
+
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
652 |
+
|
653 |
+
# 2. Define call parameters
|
654 |
+
height = height or self.unet.config.sample_size
|
655 |
+
width = width or self.unet.config.sample_size
|
656 |
+
|
657 |
+
if prompt is not None and isinstance(prompt, str):
|
658 |
+
batch_size = 1
|
659 |
+
elif prompt is not None and isinstance(prompt, list):
|
660 |
+
batch_size = len(prompt)
|
661 |
+
else:
|
662 |
+
batch_size = prompt_embeds.shape[0]
|
663 |
+
|
664 |
+
device = self._execution_device
|
665 |
+
|
666 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
667 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
668 |
+
# corresponds to doing no classifier free guidance.
|
669 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
670 |
+
|
671 |
+
# 3. Encode input prompt
|
672 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
673 |
+
prompt,
|
674 |
+
do_classifier_free_guidance,
|
675 |
+
num_images_per_prompt=num_images_per_prompt,
|
676 |
+
device=device,
|
677 |
+
negative_prompt=negative_prompt,
|
678 |
+
prompt_embeds=prompt_embeds,
|
679 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
680 |
+
clean_caption=clean_caption,
|
681 |
+
)
|
682 |
+
|
683 |
+
if do_classifier_free_guidance:
|
684 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
685 |
+
|
686 |
+
# 4. Prepare timesteps
|
687 |
+
if timesteps is not None:
|
688 |
+
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
|
689 |
+
timesteps = self.scheduler.timesteps
|
690 |
+
num_inference_steps = len(timesteps)
|
691 |
+
else:
|
692 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
693 |
+
timesteps = self.scheduler.timesteps
|
694 |
+
|
695 |
+
# 5. Prepare intermediate images
|
696 |
+
intermediate_images = self.prepare_intermediate_images(
|
697 |
+
batch_size * num_images_per_prompt,
|
698 |
+
self.unet.config.in_channels,
|
699 |
+
num_frames,
|
700 |
+
height,
|
701 |
+
width,
|
702 |
+
prompt_embeds.dtype,
|
703 |
+
device,
|
704 |
+
generator,
|
705 |
+
)
|
706 |
+
|
707 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
708 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
709 |
+
|
710 |
+
# HACK: see comment in `enable_model_cpu_offload`
|
711 |
+
if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
|
712 |
+
self.text_encoder_offload_hook.offload()
|
713 |
+
|
714 |
+
# 7. Denoising loop
|
715 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
716 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
717 |
+
for i, t in enumerate(timesteps):
|
718 |
+
model_input = (
|
719 |
+
torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images
|
720 |
+
)
|
721 |
+
model_input = self.scheduler.scale_model_input(model_input, t)
|
722 |
+
|
723 |
+
# predict the noise residual
|
724 |
+
noise_pred = self.unet(
|
725 |
+
model_input,
|
726 |
+
t,
|
727 |
+
encoder_hidden_states=prompt_embeds,
|
728 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
729 |
+
).sample
|
730 |
+
|
731 |
+
# perform guidance
|
732 |
+
if do_classifier_free_guidance:
|
733 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
734 |
+
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
|
735 |
+
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
|
736 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
737 |
+
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
738 |
+
|
739 |
+
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
|
740 |
+
noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1)
|
741 |
+
|
742 |
+
# reshape latents
|
743 |
+
bsz, channel, frames, height, width = intermediate_images.shape
|
744 |
+
intermediate_images = intermediate_images.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, height, width)
|
745 |
+
noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, -1, height, width)
|
746 |
+
|
747 |
+
# compute the previous noisy sample x_t -> x_t-1
|
748 |
+
intermediate_images = self.scheduler.step(
|
749 |
+
noise_pred, t, intermediate_images, **extra_step_kwargs
|
750 |
+
).prev_sample
|
751 |
+
|
752 |
+
# reshape latents back
|
753 |
+
intermediate_images = intermediate_images[None, :].reshape(bsz, frames, channel, height, width).permute(0, 2, 1, 3, 4)
|
754 |
+
|
755 |
+
# call the callback, if provided
|
756 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
757 |
+
progress_bar.update()
|
758 |
+
if callback is not None and i % callback_steps == 0:
|
759 |
+
callback(i, t, intermediate_images)
|
760 |
+
|
761 |
+
video_tensor = intermediate_images
|
762 |
+
|
763 |
+
if output_type == "pt":
|
764 |
+
video = video_tensor
|
765 |
+
else:
|
766 |
+
video = tensor2vid(video_tensor)
|
767 |
+
|
768 |
+
# Offload last model to CPU
|
769 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
770 |
+
self.final_offload_hook.offload()
|
771 |
+
|
772 |
+
if not return_dict:
|
773 |
+
return (video,)
|
774 |
+
|
775 |
+
return TextToVideoPipelineOutput(frames=video)
|
showone/pipelines/pipeline_t2v_interp_pixel.py
ADDED
@@ -0,0 +1,798 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import html
|
2 |
+
import inspect
|
3 |
+
import re
|
4 |
+
import urllib.parse as ul
|
5 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
10 |
+
|
11 |
+
from diffusers.schedulers import DDPMScheduler
|
12 |
+
from diffusers.utils import (
|
13 |
+
BACKENDS_MAPPING,
|
14 |
+
is_accelerate_available,
|
15 |
+
is_accelerate_version,
|
16 |
+
is_bs4_available,
|
17 |
+
is_ftfy_available,
|
18 |
+
logging,
|
19 |
+
randn_tensor,
|
20 |
+
replace_example_docstring,
|
21 |
+
)
|
22 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
23 |
+
|
24 |
+
from ..models import UNet3DConditionModel
|
25 |
+
from . import TextToVideoPipelineOutput
|
26 |
+
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
29 |
+
|
30 |
+
if is_bs4_available():
|
31 |
+
from bs4 import BeautifulSoup
|
32 |
+
|
33 |
+
if is_ftfy_available():
|
34 |
+
import ftfy
|
35 |
+
|
36 |
+
|
37 |
+
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
|
38 |
+
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
39 |
+
# reshape to ncfhw
|
40 |
+
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
41 |
+
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
42 |
+
# unnormalize back to [0,1]
|
43 |
+
video = video.mul_(std).add_(mean)
|
44 |
+
video.clamp_(0, 1)
|
45 |
+
# prepare the final outputs
|
46 |
+
i, c, f, h, w = video.shape
|
47 |
+
images = video.permute(2, 3, 0, 4, 1).reshape(
|
48 |
+
f, h, i * w, c
|
49 |
+
) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
|
50 |
+
images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
|
51 |
+
images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
|
52 |
+
return images
|
53 |
+
|
54 |
+
|
55 |
+
class TextToVideoIFInterpPipeline(DiffusionPipeline):
|
56 |
+
tokenizer: T5Tokenizer
|
57 |
+
text_encoder: T5EncoderModel
|
58 |
+
|
59 |
+
unet: UNet3DConditionModel
|
60 |
+
scheduler: DDPMScheduler
|
61 |
+
|
62 |
+
feature_extractor: Optional[CLIPImageProcessor]
|
63 |
+
# safety_checker: Optional[IFSafetyChecker]
|
64 |
+
|
65 |
+
# watermarker: Optional[IFWatermarker]
|
66 |
+
|
67 |
+
bad_punct_regex = re.compile(
|
68 |
+
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
69 |
+
) # noqa
|
70 |
+
|
71 |
+
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
tokenizer: T5Tokenizer,
|
76 |
+
text_encoder: T5EncoderModel,
|
77 |
+
unet: UNet3DConditionModel,
|
78 |
+
scheduler: DDPMScheduler,
|
79 |
+
feature_extractor: Optional[CLIPImageProcessor],
|
80 |
+
):
|
81 |
+
super().__init__()
|
82 |
+
|
83 |
+
self.register_modules(
|
84 |
+
tokenizer=tokenizer,
|
85 |
+
text_encoder=text_encoder,
|
86 |
+
unet=unet,
|
87 |
+
scheduler=scheduler,
|
88 |
+
feature_extractor=feature_extractor,
|
89 |
+
)
|
90 |
+
self.safety_checker = None
|
91 |
+
|
92 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
93 |
+
r"""
|
94 |
+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
95 |
+
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
96 |
+
when their specific submodule has its `forward` method called.
|
97 |
+
"""
|
98 |
+
if is_accelerate_available():
|
99 |
+
from accelerate import cpu_offload
|
100 |
+
else:
|
101 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
102 |
+
|
103 |
+
device = torch.device(f"cuda:{gpu_id}")
|
104 |
+
|
105 |
+
models = [
|
106 |
+
self.text_encoder,
|
107 |
+
self.unet,
|
108 |
+
]
|
109 |
+
for cpu_offloaded_model in models:
|
110 |
+
if cpu_offloaded_model is not None:
|
111 |
+
cpu_offload(cpu_offloaded_model, device)
|
112 |
+
|
113 |
+
if self.safety_checker is not None:
|
114 |
+
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
115 |
+
|
116 |
+
def enable_model_cpu_offload(self, gpu_id=0):
|
117 |
+
r"""
|
118 |
+
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
119 |
+
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
120 |
+
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
121 |
+
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
122 |
+
"""
|
123 |
+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
124 |
+
from accelerate import cpu_offload_with_hook
|
125 |
+
else:
|
126 |
+
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
127 |
+
|
128 |
+
device = torch.device(f"cuda:{gpu_id}")
|
129 |
+
|
130 |
+
|
131 |
+
if self.device.type != "cpu":
|
132 |
+
self.to("cpu", silence_dtype_warnings=True)
|
133 |
+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
134 |
+
|
135 |
+
hook = None
|
136 |
+
|
137 |
+
if self.text_encoder is not None:
|
138 |
+
_, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
|
139 |
+
|
140 |
+
# Accelerate will move the next model to the device _before_ calling the offload hook of the
|
141 |
+
# previous model. This will cause both models to be present on the device at the same time.
|
142 |
+
# IF uses T5 for its text encoder which is really large. We can manually call the offload
|
143 |
+
# hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
|
144 |
+
# the GPU.
|
145 |
+
self.text_encoder_offload_hook = hook
|
146 |
+
|
147 |
+
_, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
|
148 |
+
|
149 |
+
# if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
|
150 |
+
self.unet_offload_hook = hook
|
151 |
+
|
152 |
+
if self.safety_checker is not None:
|
153 |
+
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
154 |
+
|
155 |
+
# We'll offload the last model manually.
|
156 |
+
self.final_offload_hook = hook
|
157 |
+
|
158 |
+
def remove_all_hooks(self):
|
159 |
+
if is_accelerate_available():
|
160 |
+
from accelerate.hooks import remove_hook_from_module
|
161 |
+
else:
|
162 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
163 |
+
|
164 |
+
for model in [self.text_encoder, self.unet, self.safety_checker]:
|
165 |
+
if model is not None:
|
166 |
+
remove_hook_from_module(model, recurse=True)
|
167 |
+
|
168 |
+
self.unet_offload_hook = None
|
169 |
+
self.text_encoder_offload_hook = None
|
170 |
+
self.final_offload_hook = None
|
171 |
+
|
172 |
+
@property
|
173 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
174 |
+
def _execution_device(self):
|
175 |
+
r"""
|
176 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
177 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
178 |
+
hooks.
|
179 |
+
"""
|
180 |
+
if not hasattr(self.unet, "_hf_hook"):
|
181 |
+
return self.device
|
182 |
+
for module in self.unet.modules():
|
183 |
+
if (
|
184 |
+
hasattr(module, "_hf_hook")
|
185 |
+
and hasattr(module._hf_hook, "execution_device")
|
186 |
+
and module._hf_hook.execution_device is not None
|
187 |
+
):
|
188 |
+
return torch.device(module._hf_hook.execution_device)
|
189 |
+
return self.device
|
190 |
+
|
191 |
+
@torch.no_grad()
|
192 |
+
def encode_prompt(
|
193 |
+
self,
|
194 |
+
prompt,
|
195 |
+
do_classifier_free_guidance=True,
|
196 |
+
num_images_per_prompt=1,
|
197 |
+
device=None,
|
198 |
+
negative_prompt=None,
|
199 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
200 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
201 |
+
clean_caption: bool = False,
|
202 |
+
):
|
203 |
+
r"""
|
204 |
+
Encodes the prompt into text encoder hidden states.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
prompt (`str` or `List[str]`, *optional*):
|
208 |
+
prompt to be encoded
|
209 |
+
device: (`torch.device`, *optional*):
|
210 |
+
torch device to place the resulting embeddings on
|
211 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
212 |
+
number of images that should be generated per prompt
|
213 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
214 |
+
whether to use classifier free guidance or not
|
215 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
216 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
217 |
+
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
218 |
+
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
219 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
220 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
221 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
222 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
223 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
224 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
225 |
+
argument.
|
226 |
+
"""
|
227 |
+
if prompt is not None and negative_prompt is not None:
|
228 |
+
if type(prompt) is not type(negative_prompt):
|
229 |
+
raise TypeError(
|
230 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
231 |
+
f" {type(prompt)}."
|
232 |
+
)
|
233 |
+
|
234 |
+
if device is None:
|
235 |
+
device = self._execution_device
|
236 |
+
|
237 |
+
if prompt is not None and isinstance(prompt, str):
|
238 |
+
batch_size = 1
|
239 |
+
elif prompt is not None and isinstance(prompt, list):
|
240 |
+
batch_size = len(prompt)
|
241 |
+
else:
|
242 |
+
batch_size = prompt_embeds.shape[0]
|
243 |
+
|
244 |
+
# while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
|
245 |
+
max_length = 77
|
246 |
+
|
247 |
+
if prompt_embeds is None:
|
248 |
+
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
249 |
+
text_inputs = self.tokenizer(
|
250 |
+
prompt,
|
251 |
+
padding="max_length",
|
252 |
+
max_length=max_length,
|
253 |
+
truncation=True,
|
254 |
+
add_special_tokens=True,
|
255 |
+
return_tensors="pt",
|
256 |
+
)
|
257 |
+
text_input_ids = text_inputs.input_ids
|
258 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
259 |
+
|
260 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
261 |
+
text_input_ids, untruncated_ids
|
262 |
+
):
|
263 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
264 |
+
logger.warning(
|
265 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
266 |
+
f" {max_length} tokens: {removed_text}"
|
267 |
+
)
|
268 |
+
|
269 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
270 |
+
|
271 |
+
prompt_embeds = self.text_encoder(
|
272 |
+
text_input_ids.to(device),
|
273 |
+
attention_mask=attention_mask,
|
274 |
+
)
|
275 |
+
prompt_embeds = prompt_embeds[0]
|
276 |
+
|
277 |
+
if self.text_encoder is not None:
|
278 |
+
dtype = self.text_encoder.dtype
|
279 |
+
elif self.unet is not None:
|
280 |
+
dtype = self.unet.dtype
|
281 |
+
else:
|
282 |
+
dtype = None
|
283 |
+
|
284 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
285 |
+
|
286 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
287 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
288 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
289 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
290 |
+
|
291 |
+
# get unconditional embeddings for classifier free guidance
|
292 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
293 |
+
uncond_tokens: List[str]
|
294 |
+
if negative_prompt is None:
|
295 |
+
uncond_tokens = [""] * batch_size
|
296 |
+
elif isinstance(negative_prompt, str):
|
297 |
+
uncond_tokens = [negative_prompt]
|
298 |
+
elif batch_size != len(negative_prompt):
|
299 |
+
raise ValueError(
|
300 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
301 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
302 |
+
" the batch size of `prompt`."
|
303 |
+
)
|
304 |
+
else:
|
305 |
+
uncond_tokens = negative_prompt
|
306 |
+
|
307 |
+
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
308 |
+
max_length = prompt_embeds.shape[1]
|
309 |
+
uncond_input = self.tokenizer(
|
310 |
+
uncond_tokens,
|
311 |
+
padding="max_length",
|
312 |
+
max_length=max_length,
|
313 |
+
truncation=True,
|
314 |
+
return_attention_mask=True,
|
315 |
+
add_special_tokens=True,
|
316 |
+
return_tensors="pt",
|
317 |
+
)
|
318 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
319 |
+
|
320 |
+
negative_prompt_embeds = self.text_encoder(
|
321 |
+
uncond_input.input_ids.to(device),
|
322 |
+
attention_mask=attention_mask,
|
323 |
+
)
|
324 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
325 |
+
|
326 |
+
if do_classifier_free_guidance:
|
327 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
328 |
+
seq_len = negative_prompt_embeds.shape[1]
|
329 |
+
|
330 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
331 |
+
|
332 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
333 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
334 |
+
|
335 |
+
# For classifier free guidance, we need to do two forward passes.
|
336 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
337 |
+
# to avoid doing two forward passes
|
338 |
+
else:
|
339 |
+
negative_prompt_embeds = None
|
340 |
+
|
341 |
+
return prompt_embeds, negative_prompt_embeds
|
342 |
+
|
343 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
344 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
345 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
346 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
347 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
348 |
+
# and should be between [0, 1]
|
349 |
+
|
350 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
351 |
+
extra_step_kwargs = {}
|
352 |
+
if accepts_eta:
|
353 |
+
extra_step_kwargs["eta"] = eta
|
354 |
+
|
355 |
+
# check if the scheduler accepts generator
|
356 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
357 |
+
if accepts_generator:
|
358 |
+
extra_step_kwargs["generator"] = generator
|
359 |
+
return extra_step_kwargs
|
360 |
+
|
361 |
+
def check_inputs(
|
362 |
+
self,
|
363 |
+
prompt,
|
364 |
+
callback_steps,
|
365 |
+
negative_prompt=None,
|
366 |
+
prompt_embeds=None,
|
367 |
+
negative_prompt_embeds=None,
|
368 |
+
):
|
369 |
+
if (callback_steps is None) or (
|
370 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
371 |
+
):
|
372 |
+
raise ValueError(
|
373 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
374 |
+
f" {type(callback_steps)}."
|
375 |
+
)
|
376 |
+
|
377 |
+
if prompt is not None and prompt_embeds is not None:
|
378 |
+
raise ValueError(
|
379 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
380 |
+
" only forward one of the two."
|
381 |
+
)
|
382 |
+
elif prompt is None and prompt_embeds is None:
|
383 |
+
raise ValueError(
|
384 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
385 |
+
)
|
386 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
387 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
388 |
+
|
389 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
390 |
+
raise ValueError(
|
391 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
392 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
393 |
+
)
|
394 |
+
|
395 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
396 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
397 |
+
raise ValueError(
|
398 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
399 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
400 |
+
f" {negative_prompt_embeds.shape}."
|
401 |
+
)
|
402 |
+
|
403 |
+
def prepare_intermediate_images(self, batch_size, num_channels, num_frames, height, width, dtype, device, generator):
|
404 |
+
shape = (batch_size, num_channels, num_frames, height, width)
|
405 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
406 |
+
raise ValueError(
|
407 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
408 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
409 |
+
)
|
410 |
+
|
411 |
+
intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
412 |
+
|
413 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
414 |
+
intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
|
415 |
+
return intermediate_images
|
416 |
+
|
417 |
+
def _text_preprocessing(self, text, clean_caption=False):
|
418 |
+
if clean_caption and not is_bs4_available():
|
419 |
+
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
420 |
+
logger.warn("Setting `clean_caption` to False...")
|
421 |
+
clean_caption = False
|
422 |
+
|
423 |
+
if clean_caption and not is_ftfy_available():
|
424 |
+
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
425 |
+
logger.warn("Setting `clean_caption` to False...")
|
426 |
+
clean_caption = False
|
427 |
+
|
428 |
+
if not isinstance(text, (tuple, list)):
|
429 |
+
text = [text]
|
430 |
+
|
431 |
+
def process(text: str):
|
432 |
+
if clean_caption:
|
433 |
+
text = self._clean_caption(text)
|
434 |
+
text = self._clean_caption(text)
|
435 |
+
else:
|
436 |
+
text = text.lower().strip()
|
437 |
+
return text
|
438 |
+
|
439 |
+
return [process(t) for t in text]
|
440 |
+
|
441 |
+
def _clean_caption(self, caption):
|
442 |
+
caption = str(caption)
|
443 |
+
caption = ul.unquote_plus(caption)
|
444 |
+
caption = caption.strip().lower()
|
445 |
+
caption = re.sub("<person>", "person", caption)
|
446 |
+
# urls:
|
447 |
+
caption = re.sub(
|
448 |
+
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
449 |
+
"",
|
450 |
+
caption,
|
451 |
+
) # regex for urls
|
452 |
+
caption = re.sub(
|
453 |
+
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
454 |
+
"",
|
455 |
+
caption,
|
456 |
+
) # regex for urls
|
457 |
+
# html:
|
458 |
+
caption = BeautifulSoup(caption, features="html.parser").text
|
459 |
+
|
460 |
+
# @<nickname>
|
461 |
+
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
462 |
+
|
463 |
+
# 31C0—31EF CJK Strokes
|
464 |
+
# 31F0—31FF Katakana Phonetic Extensions
|
465 |
+
# 3200—32FF Enclosed CJK Letters and Months
|
466 |
+
# 3300—33FF CJK Compatibility
|
467 |
+
# 3400—4DBF CJK Unified Ideographs Extension A
|
468 |
+
# 4DC0—4DFF Yijing Hexagram Symbols
|
469 |
+
# 4E00—9FFF CJK Unified Ideographs
|
470 |
+
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
471 |
+
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
472 |
+
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
473 |
+
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
474 |
+
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
475 |
+
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
476 |
+
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
477 |
+
#######################################################
|
478 |
+
|
479 |
+
# все виды тире / all types of dash --> "-"
|
480 |
+
caption = re.sub(
|
481 |
+
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
482 |
+
"-",
|
483 |
+
caption,
|
484 |
+
)
|
485 |
+
|
486 |
+
# кавычки к одному стандарту
|
487 |
+
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
488 |
+
caption = re.sub(r"[‘’]", "'", caption)
|
489 |
+
|
490 |
+
# "
|
491 |
+
caption = re.sub(r""?", "", caption)
|
492 |
+
# &
|
493 |
+
caption = re.sub(r"&", "", caption)
|
494 |
+
|
495 |
+
# ip adresses:
|
496 |
+
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
497 |
+
|
498 |
+
# article ids:
|
499 |
+
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
500 |
+
|
501 |
+
# \n
|
502 |
+
caption = re.sub(r"\\n", " ", caption)
|
503 |
+
|
504 |
+
# "#123"
|
505 |
+
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
506 |
+
# "#12345.."
|
507 |
+
caption = re.sub(r"#\d{5,}\b", "", caption)
|
508 |
+
# "123456.."
|
509 |
+
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
510 |
+
# filenames:
|
511 |
+
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
512 |
+
|
513 |
+
#
|
514 |
+
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
515 |
+
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
516 |
+
|
517 |
+
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
518 |
+
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
519 |
+
|
520 |
+
# this-is-my-cute-cat / this_is_my_cute_cat
|
521 |
+
regex2 = re.compile(r"(?:\-|\_)")
|
522 |
+
if len(re.findall(regex2, caption)) > 3:
|
523 |
+
caption = re.sub(regex2, " ", caption)
|
524 |
+
|
525 |
+
caption = ftfy.fix_text(caption)
|
526 |
+
caption = html.unescape(html.unescape(caption))
|
527 |
+
|
528 |
+
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
529 |
+
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
530 |
+
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
531 |
+
|
532 |
+
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
533 |
+
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
534 |
+
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
535 |
+
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
536 |
+
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
537 |
+
|
538 |
+
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
539 |
+
|
540 |
+
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
541 |
+
|
542 |
+
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
543 |
+
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
544 |
+
caption = re.sub(r"\s+", " ", caption)
|
545 |
+
|
546 |
+
caption.strip()
|
547 |
+
|
548 |
+
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
549 |
+
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
550 |
+
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
551 |
+
caption = re.sub(r"^\.\S+$", "", caption)
|
552 |
+
|
553 |
+
return caption.strip()
|
554 |
+
|
555 |
+
@torch.no_grad()
|
556 |
+
def __call__(
|
557 |
+
self,
|
558 |
+
pixel_values,
|
559 |
+
prompt: Union[str, List[str]] = None,
|
560 |
+
num_inference_steps: int = 100,
|
561 |
+
timesteps: List[int] = None,
|
562 |
+
guidance_scale: float = 7.0,
|
563 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
564 |
+
num_images_per_prompt: Optional[int] = 1,
|
565 |
+
height: Optional[int] = None,
|
566 |
+
width: Optional[int] = None,
|
567 |
+
num_frames: int = 16,
|
568 |
+
eta: float = 0.0,
|
569 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
570 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
571 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
572 |
+
output_type: Optional[str] = "np",
|
573 |
+
return_dict: bool = True,
|
574 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
575 |
+
callback_steps: int = 1,
|
576 |
+
clean_caption: bool = True,
|
577 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
578 |
+
init_noise = None,
|
579 |
+
cond_interpolation = False,
|
580 |
+
):
|
581 |
+
"""
|
582 |
+
Function invoked when calling the pipeline for generation.
|
583 |
+
|
584 |
+
Args:
|
585 |
+
prompt (`str` or `List[str]`, *optional*):
|
586 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
587 |
+
instead.
|
588 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
589 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
590 |
+
expense of slower inference.
|
591 |
+
timesteps (`List[int]`, *optional*):
|
592 |
+
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
593 |
+
timesteps are used. Must be in descending order.
|
594 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
595 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
596 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
597 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
598 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
599 |
+
usually at the expense of lower image quality.
|
600 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
601 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
602 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
603 |
+
less than `1`).
|
604 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
605 |
+
The number of images to generate per prompt.
|
606 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
607 |
+
The height in pixels of the generated image.
|
608 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
609 |
+
The width in pixels of the generated image.
|
610 |
+
eta (`float`, *optional*, defaults to 0.0):
|
611 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
612 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
613 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
614 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
615 |
+
to make generation deterministic.
|
616 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
617 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
618 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
619 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
620 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
621 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
622 |
+
argument.
|
623 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
624 |
+
The output format of the generate image. Choose between
|
625 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
626 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
627 |
+
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
628 |
+
callback (`Callable`, *optional*):
|
629 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
630 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
631 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
632 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
633 |
+
called at every step.
|
634 |
+
clean_caption (`bool`, *optional*, defaults to `True`):
|
635 |
+
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
636 |
+
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
637 |
+
prompt.
|
638 |
+
cross_attention_kwargs (`dict`, *optional*):
|
639 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
640 |
+
`self.processor` in
|
641 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
642 |
+
|
643 |
+
Examples:
|
644 |
+
|
645 |
+
Returns:
|
646 |
+
[`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
|
647 |
+
[`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
|
648 |
+
returning a tuple, the first element is a list with the generated images, and the second element is a list
|
649 |
+
of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
|
650 |
+
or watermarked content, according to the `safety_checker`.
|
651 |
+
"""
|
652 |
+
# 1. Check inputs. Raise error if not correct
|
653 |
+
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
654 |
+
|
655 |
+
# 2. Define call parameters
|
656 |
+
height = height or self.unet.config.sample_size
|
657 |
+
width = width or self.unet.config.sample_size
|
658 |
+
|
659 |
+
if prompt is not None and isinstance(prompt, str):
|
660 |
+
batch_size = 1
|
661 |
+
elif prompt is not None and isinstance(prompt, list):
|
662 |
+
batch_size = len(prompt)
|
663 |
+
else:
|
664 |
+
batch_size = prompt_embeds.shape[0]
|
665 |
+
|
666 |
+
device = self._execution_device
|
667 |
+
|
668 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
669 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
670 |
+
# corresponds to doing no classifier free guidance.
|
671 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
672 |
+
|
673 |
+
# 3. Encode input prompt
|
674 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
675 |
+
prompt,
|
676 |
+
do_classifier_free_guidance,
|
677 |
+
num_images_per_prompt=num_images_per_prompt,
|
678 |
+
device=device,
|
679 |
+
negative_prompt=negative_prompt,
|
680 |
+
prompt_embeds=prompt_embeds,
|
681 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
682 |
+
clean_caption=clean_caption,
|
683 |
+
)
|
684 |
+
|
685 |
+
if do_classifier_free_guidance:
|
686 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
687 |
+
|
688 |
+
# 4. Prepare timesteps
|
689 |
+
if timesteps is not None:
|
690 |
+
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
|
691 |
+
timesteps = self.scheduler.timesteps
|
692 |
+
num_inference_steps = len(timesteps)
|
693 |
+
else:
|
694 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
695 |
+
timesteps = self.scheduler.timesteps
|
696 |
+
|
697 |
+
# 5. Prepare intermediate images
|
698 |
+
pixel_values = pixel_values.to(device)
|
699 |
+
if init_noise is not None:
|
700 |
+
intermediate_images = init_noise
|
701 |
+
else:
|
702 |
+
intermediate_images = self.prepare_intermediate_images(
|
703 |
+
batch_size * num_images_per_prompt,
|
704 |
+
# self.unet.config.in_channels, # mask not noise.
|
705 |
+
pixel_values.shape[1],
|
706 |
+
num_frames,
|
707 |
+
height,
|
708 |
+
width,
|
709 |
+
prompt_embeds.dtype,
|
710 |
+
device,
|
711 |
+
generator,
|
712 |
+
)
|
713 |
+
|
714 |
+
bsz = intermediate_images.shape[0]
|
715 |
+
interp_mask = torch.zeros(bsz, 1, *intermediate_images.shape[2:], device=device, dtype=intermediate_images.dtype)
|
716 |
+
interp_mask[:, :, 0, :, :] = 1
|
717 |
+
interp_mask[:, :, -1, :, :] = 1
|
718 |
+
|
719 |
+
if cond_interpolation:
|
720 |
+
import torch.nn.functional as F
|
721 |
+
pixel_values = F.interpolate(pixel_values[:, :, [0, -1], ...], pixel_values.shape[2:],
|
722 |
+
mode="trilinear", align_corners=True)
|
723 |
+
else:
|
724 |
+
raise Exception("apply mask to pixel_values")
|
725 |
+
|
726 |
+
# intermediate_images[:, :, 0, :, :] = pixel_values[:, :, 0, :, :]
|
727 |
+
# intermediate_images[:, :, -1, :, :] = pixel_values[:, :, -1, :, :]
|
728 |
+
pixel_values_condition = torch.cat((pixel_values, interp_mask), dim=1)
|
729 |
+
|
730 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
731 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
732 |
+
|
733 |
+
# HACK: see comment in `enable_model_cpu_offload`
|
734 |
+
if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
|
735 |
+
self.text_encoder_offload_hook.offload()
|
736 |
+
|
737 |
+
# 7. Denoising loop
|
738 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
739 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
740 |
+
for i, t in enumerate(timesteps):
|
741 |
+
intermediate_images_input = torch.cat((intermediate_images, pixel_values_condition), dim=1)
|
742 |
+
model_input = (
|
743 |
+
torch.cat([intermediate_images_input] * 2) if do_classifier_free_guidance else intermediate_images_input
|
744 |
+
)
|
745 |
+
model_input = self.scheduler.scale_model_input(model_input, t)
|
746 |
+
|
747 |
+
# predict the noise residual
|
748 |
+
noise_pred = self.unet(
|
749 |
+
model_input,
|
750 |
+
t,
|
751 |
+
encoder_hidden_states=prompt_embeds,
|
752 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
753 |
+
).sample
|
754 |
+
# perform guidance
|
755 |
+
if do_classifier_free_guidance:
|
756 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
757 |
+
noise_pred_uncond, _ = noise_pred_uncond.split(intermediate_images.shape[1], dim=1)
|
758 |
+
noise_pred_text, predicted_variance = noise_pred_text.split(intermediate_images.shape[1], dim=1)
|
759 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
760 |
+
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
761 |
+
|
762 |
+
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
|
763 |
+
noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
|
764 |
+
|
765 |
+
# reshape latents
|
766 |
+
bsz, channel, frames, width, height = intermediate_images.shape
|
767 |
+
intermediate_images = intermediate_images.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
|
768 |
+
noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, -1, width, height)
|
769 |
+
|
770 |
+
# compute the previous noisy sample x_t -> x_t-1
|
771 |
+
intermediate_images = self.scheduler.step(
|
772 |
+
noise_pred, t, intermediate_images, **extra_step_kwargs
|
773 |
+
).prev_sample
|
774 |
+
|
775 |
+
# reshape latents back
|
776 |
+
intermediate_images = intermediate_images[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4)
|
777 |
+
|
778 |
+
# call the callback, if provided
|
779 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
780 |
+
progress_bar.update()
|
781 |
+
if callback is not None and i % callback_steps == 0:
|
782 |
+
callback(i, t, intermediate_images)
|
783 |
+
|
784 |
+
video_tensor = intermediate_images
|
785 |
+
|
786 |
+
if output_type == "pt":
|
787 |
+
video = video_tensor
|
788 |
+
else:
|
789 |
+
video = tensor2vid(video_tensor)
|
790 |
+
|
791 |
+
# Offload last model to CPU
|
792 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
793 |
+
self.final_offload_hook.offload()
|
794 |
+
|
795 |
+
if not return_dict:
|
796 |
+
return (video,)
|
797 |
+
|
798 |
+
return TextToVideoPipelineOutput(frames=video)
|
showone/pipelines/pipeline_t2v_sr_pixel.py
ADDED
@@ -0,0 +1,877 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import html
|
2 |
+
import inspect
|
3 |
+
import re
|
4 |
+
import urllib.parse as ul
|
5 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from einops import rearrange
|
9 |
+
import PIL
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
13 |
+
|
14 |
+
from diffusers.loaders import LoraLoaderMixin
|
15 |
+
from diffusers.schedulers import DDPMScheduler
|
16 |
+
from diffusers.utils import (
|
17 |
+
BACKENDS_MAPPING,
|
18 |
+
is_accelerate_available,
|
19 |
+
is_accelerate_version,
|
20 |
+
is_bs4_available,
|
21 |
+
is_ftfy_available,
|
22 |
+
logging,
|
23 |
+
randn_tensor,
|
24 |
+
)
|
25 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
26 |
+
|
27 |
+
from ..models import UNet3DConditionModel
|
28 |
+
from . import TextToVideoPipelineOutput
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
32 |
+
|
33 |
+
if is_bs4_available():
|
34 |
+
from bs4 import BeautifulSoup
|
35 |
+
|
36 |
+
if is_ftfy_available():
|
37 |
+
import ftfy
|
38 |
+
|
39 |
+
|
40 |
+
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
|
41 |
+
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
42 |
+
# reshape to ncfhw
|
43 |
+
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
44 |
+
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
45 |
+
# unnormalize back to [0,1]
|
46 |
+
video = video.mul_(std).add_(mean)
|
47 |
+
video.clamp_(0, 1)
|
48 |
+
# prepare the final outputs
|
49 |
+
i, c, f, h, w = video.shape
|
50 |
+
images = video.permute(2, 3, 0, 4, 1).reshape(
|
51 |
+
f, h, i * w, c
|
52 |
+
) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
|
53 |
+
images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
|
54 |
+
images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
|
55 |
+
return images
|
56 |
+
|
57 |
+
|
58 |
+
class TextToVideoIFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
59 |
+
tokenizer: T5Tokenizer
|
60 |
+
text_encoder: T5EncoderModel
|
61 |
+
|
62 |
+
unet: UNet3DConditionModel
|
63 |
+
scheduler: DDPMScheduler
|
64 |
+
image_noising_scheduler: DDPMScheduler
|
65 |
+
|
66 |
+
feature_extractor: Optional[CLIPImageProcessor]
|
67 |
+
# safety_checker: Optional[IFSafetyChecker]
|
68 |
+
|
69 |
+
# watermarker: Optional[IFWatermarker]
|
70 |
+
|
71 |
+
bad_punct_regex = re.compile(
|
72 |
+
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
73 |
+
) # noqa
|
74 |
+
|
75 |
+
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
|
76 |
+
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
tokenizer: T5Tokenizer,
|
80 |
+
text_encoder: T5EncoderModel,
|
81 |
+
unet: UNet3DConditionModel,
|
82 |
+
scheduler: DDPMScheduler,
|
83 |
+
image_noising_scheduler: DDPMScheduler,
|
84 |
+
feature_extractor: Optional[CLIPImageProcessor],
|
85 |
+
):
|
86 |
+
super().__init__()
|
87 |
+
|
88 |
+
self.register_modules(
|
89 |
+
tokenizer=tokenizer,
|
90 |
+
text_encoder=text_encoder,
|
91 |
+
unet=unet,
|
92 |
+
scheduler=scheduler,
|
93 |
+
image_noising_scheduler=image_noising_scheduler,
|
94 |
+
feature_extractor=feature_extractor,
|
95 |
+
)
|
96 |
+
self.safety_checker = None
|
97 |
+
|
98 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
99 |
+
r"""
|
100 |
+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
101 |
+
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
102 |
+
when their specific submodule has its `forward` method called.
|
103 |
+
"""
|
104 |
+
if is_accelerate_available():
|
105 |
+
from accelerate import cpu_offload
|
106 |
+
else:
|
107 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
108 |
+
|
109 |
+
device = torch.device(f"cuda:{gpu_id}")
|
110 |
+
|
111 |
+
models = [
|
112 |
+
self.text_encoder,
|
113 |
+
self.unet,
|
114 |
+
]
|
115 |
+
for cpu_offloaded_model in models:
|
116 |
+
if cpu_offloaded_model is not None:
|
117 |
+
cpu_offload(cpu_offloaded_model, device)
|
118 |
+
|
119 |
+
if self.safety_checker is not None:
|
120 |
+
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
121 |
+
|
122 |
+
def enable_model_cpu_offload(self, gpu_id=0):
|
123 |
+
r"""
|
124 |
+
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
125 |
+
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
126 |
+
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
127 |
+
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
128 |
+
"""
|
129 |
+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
130 |
+
from accelerate import cpu_offload_with_hook
|
131 |
+
else:
|
132 |
+
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
133 |
+
|
134 |
+
device = torch.device(f"cuda:{gpu_id}")
|
135 |
+
|
136 |
+
if self.device.type != "cpu":
|
137 |
+
self.to("cpu", silence_dtype_warnings=True)
|
138 |
+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
139 |
+
|
140 |
+
hook = None
|
141 |
+
|
142 |
+
if self.text_encoder is not None:
|
143 |
+
_, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
|
144 |
+
|
145 |
+
# Accelerate will move the next model to the device _before_ calling the offload hook of the
|
146 |
+
# previous model. This will cause both models to be present on the device at the same time.
|
147 |
+
# IF uses T5 for its text encoder which is really large. We can manually call the offload
|
148 |
+
# hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
|
149 |
+
# the GPU.
|
150 |
+
self.text_encoder_offload_hook = hook
|
151 |
+
|
152 |
+
_, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
|
153 |
+
|
154 |
+
# if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
|
155 |
+
self.unet_offload_hook = hook
|
156 |
+
|
157 |
+
if self.safety_checker is not None:
|
158 |
+
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
159 |
+
|
160 |
+
# We'll offload the last model manually.
|
161 |
+
self.final_offload_hook = hook
|
162 |
+
|
163 |
+
def remove_all_hooks(self):
|
164 |
+
if is_accelerate_available():
|
165 |
+
from accelerate.hooks import remove_hook_from_module
|
166 |
+
else:
|
167 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
168 |
+
|
169 |
+
for model in [self.text_encoder, self.unet, self.safety_checker]:
|
170 |
+
if model is not None:
|
171 |
+
remove_hook_from_module(model, recurse=True)
|
172 |
+
|
173 |
+
self.unet_offload_hook = None
|
174 |
+
self.text_encoder_offload_hook = None
|
175 |
+
self.final_offload_hook = None
|
176 |
+
|
177 |
+
@property
|
178 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
179 |
+
def _execution_device(self):
|
180 |
+
r"""
|
181 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
182 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
183 |
+
hooks.
|
184 |
+
"""
|
185 |
+
if not hasattr(self.unet, "_hf_hook"):
|
186 |
+
return self.device
|
187 |
+
for module in self.unet.modules():
|
188 |
+
if (
|
189 |
+
hasattr(module, "_hf_hook")
|
190 |
+
and hasattr(module._hf_hook, "execution_device")
|
191 |
+
and module._hf_hook.execution_device is not None
|
192 |
+
):
|
193 |
+
return torch.device(module._hf_hook.execution_device)
|
194 |
+
return self.device
|
195 |
+
|
196 |
+
@torch.no_grad()
|
197 |
+
def encode_prompt(
|
198 |
+
self,
|
199 |
+
prompt,
|
200 |
+
do_classifier_free_guidance=True,
|
201 |
+
num_images_per_prompt=1,
|
202 |
+
device=None,
|
203 |
+
negative_prompt=None,
|
204 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
205 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
206 |
+
clean_caption: bool = False,
|
207 |
+
):
|
208 |
+
r"""
|
209 |
+
Encodes the prompt into text encoder hidden states.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
prompt (`str` or `List[str]`, *optional*):
|
213 |
+
prompt to be encoded
|
214 |
+
device: (`torch.device`, *optional*):
|
215 |
+
torch device to place the resulting embeddings on
|
216 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
217 |
+
number of images that should be generated per prompt
|
218 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
219 |
+
whether to use classifier free guidance or not
|
220 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
221 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
222 |
+
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
223 |
+
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
224 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
225 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
226 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
227 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
228 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
229 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
230 |
+
argument.
|
231 |
+
"""
|
232 |
+
if prompt is not None and negative_prompt is not None:
|
233 |
+
if type(prompt) is not type(negative_prompt):
|
234 |
+
raise TypeError(
|
235 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
236 |
+
f" {type(prompt)}."
|
237 |
+
)
|
238 |
+
|
239 |
+
if device is None:
|
240 |
+
device = self._execution_device
|
241 |
+
|
242 |
+
if prompt is not None and isinstance(prompt, str):
|
243 |
+
batch_size = 1
|
244 |
+
elif prompt is not None and isinstance(prompt, list):
|
245 |
+
batch_size = len(prompt)
|
246 |
+
else:
|
247 |
+
batch_size = prompt_embeds.shape[0]
|
248 |
+
|
249 |
+
# while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
|
250 |
+
max_length = 77
|
251 |
+
|
252 |
+
if prompt_embeds is None:
|
253 |
+
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
254 |
+
text_inputs = self.tokenizer(
|
255 |
+
prompt,
|
256 |
+
padding="max_length",
|
257 |
+
max_length=max_length,
|
258 |
+
truncation=True,
|
259 |
+
add_special_tokens=True,
|
260 |
+
return_tensors="pt",
|
261 |
+
)
|
262 |
+
text_input_ids = text_inputs.input_ids
|
263 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
264 |
+
|
265 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
266 |
+
text_input_ids, untruncated_ids
|
267 |
+
):
|
268 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
269 |
+
logger.warning(
|
270 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
271 |
+
f" {max_length} tokens: {removed_text}"
|
272 |
+
)
|
273 |
+
|
274 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
275 |
+
|
276 |
+
prompt_embeds = self.text_encoder(
|
277 |
+
text_input_ids.to(device),
|
278 |
+
attention_mask=attention_mask,
|
279 |
+
)
|
280 |
+
prompt_embeds = prompt_embeds[0]
|
281 |
+
|
282 |
+
if self.text_encoder is not None:
|
283 |
+
dtype = self.text_encoder.dtype
|
284 |
+
elif self.unet is not None:
|
285 |
+
dtype = self.unet.dtype
|
286 |
+
else:
|
287 |
+
dtype = None
|
288 |
+
|
289 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
290 |
+
|
291 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
292 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
293 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
294 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
295 |
+
|
296 |
+
# get unconditional embeddings for classifier free guidance
|
297 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
298 |
+
uncond_tokens: List[str]
|
299 |
+
if negative_prompt is None:
|
300 |
+
uncond_tokens = [""] * batch_size
|
301 |
+
elif isinstance(negative_prompt, str):
|
302 |
+
uncond_tokens = [negative_prompt]
|
303 |
+
elif batch_size != len(negative_prompt):
|
304 |
+
raise ValueError(
|
305 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
306 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
307 |
+
" the batch size of `prompt`."
|
308 |
+
)
|
309 |
+
else:
|
310 |
+
uncond_tokens = negative_prompt
|
311 |
+
|
312 |
+
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
313 |
+
max_length = prompt_embeds.shape[1]
|
314 |
+
uncond_input = self.tokenizer(
|
315 |
+
uncond_tokens,
|
316 |
+
padding="max_length",
|
317 |
+
max_length=max_length,
|
318 |
+
truncation=True,
|
319 |
+
return_attention_mask=True,
|
320 |
+
add_special_tokens=True,
|
321 |
+
return_tensors="pt",
|
322 |
+
)
|
323 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
324 |
+
|
325 |
+
negative_prompt_embeds = self.text_encoder(
|
326 |
+
uncond_input.input_ids.to(device),
|
327 |
+
attention_mask=attention_mask,
|
328 |
+
)
|
329 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
330 |
+
|
331 |
+
if do_classifier_free_guidance:
|
332 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
333 |
+
seq_len = negative_prompt_embeds.shape[1]
|
334 |
+
|
335 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
336 |
+
|
337 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
338 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
339 |
+
|
340 |
+
# For classifier free guidance, we need to do two forward passes.
|
341 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
342 |
+
# to avoid doing two forward passes
|
343 |
+
else:
|
344 |
+
negative_prompt_embeds = None
|
345 |
+
|
346 |
+
return prompt_embeds, negative_prompt_embeds
|
347 |
+
|
348 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
349 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
350 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
351 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
352 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
353 |
+
# and should be between [0, 1]
|
354 |
+
|
355 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
356 |
+
extra_step_kwargs = {}
|
357 |
+
if accepts_eta:
|
358 |
+
extra_step_kwargs["eta"] = eta
|
359 |
+
|
360 |
+
# check if the scheduler accepts generator
|
361 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
362 |
+
if accepts_generator:
|
363 |
+
extra_step_kwargs["generator"] = generator
|
364 |
+
return extra_step_kwargs
|
365 |
+
|
366 |
+
def check_inputs(
|
367 |
+
self,
|
368 |
+
prompt,
|
369 |
+
image,
|
370 |
+
batch_size,
|
371 |
+
noise_level,
|
372 |
+
callback_steps,
|
373 |
+
negative_prompt=None,
|
374 |
+
prompt_embeds=None,
|
375 |
+
negative_prompt_embeds=None,
|
376 |
+
):
|
377 |
+
if (callback_steps is None) or (
|
378 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
379 |
+
):
|
380 |
+
raise ValueError(
|
381 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
382 |
+
f" {type(callback_steps)}."
|
383 |
+
)
|
384 |
+
|
385 |
+
if prompt is not None and prompt_embeds is not None:
|
386 |
+
raise ValueError(
|
387 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
388 |
+
" only forward one of the two."
|
389 |
+
)
|
390 |
+
elif prompt is None and prompt_embeds is None:
|
391 |
+
raise ValueError(
|
392 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
393 |
+
)
|
394 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
395 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
396 |
+
|
397 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
398 |
+
raise ValueError(
|
399 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
400 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
401 |
+
)
|
402 |
+
|
403 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
404 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
405 |
+
raise ValueError(
|
406 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
407 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
408 |
+
f" {negative_prompt_embeds.shape}."
|
409 |
+
)
|
410 |
+
|
411 |
+
if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
|
412 |
+
raise ValueError(
|
413 |
+
f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})"
|
414 |
+
)
|
415 |
+
|
416 |
+
if isinstance(image, list):
|
417 |
+
check_image_type = image[0]
|
418 |
+
else:
|
419 |
+
check_image_type = image
|
420 |
+
|
421 |
+
if (
|
422 |
+
not isinstance(check_image_type, torch.Tensor)
|
423 |
+
and not isinstance(check_image_type, PIL.Image.Image)
|
424 |
+
and not isinstance(check_image_type, np.ndarray)
|
425 |
+
):
|
426 |
+
raise ValueError(
|
427 |
+
"`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
|
428 |
+
f" {type(check_image_type)}"
|
429 |
+
)
|
430 |
+
|
431 |
+
if isinstance(image, list):
|
432 |
+
image_batch_size = len(image)
|
433 |
+
elif isinstance(image, torch.Tensor):
|
434 |
+
image_batch_size = image.shape[0]
|
435 |
+
elif isinstance(image, PIL.Image.Image):
|
436 |
+
image_batch_size = 1
|
437 |
+
elif isinstance(image, np.ndarray):
|
438 |
+
image_batch_size = image.shape[0]
|
439 |
+
else:
|
440 |
+
assert False
|
441 |
+
|
442 |
+
if batch_size != image_batch_size:
|
443 |
+
raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}")
|
444 |
+
|
445 |
+
def prepare_intermediate_images(self, batch_size, num_channels, num_frames, height, width, dtype, device, generator):
|
446 |
+
shape = (batch_size, num_channels, num_frames, height, width)
|
447 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
448 |
+
raise ValueError(
|
449 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
450 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
451 |
+
)
|
452 |
+
|
453 |
+
intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
454 |
+
|
455 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
456 |
+
intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
|
457 |
+
return intermediate_images
|
458 |
+
|
459 |
+
def preprocess_image(self, image, num_images_per_prompt, device):
|
460 |
+
if not isinstance(image, torch.Tensor) and not isinstance(image, list):
|
461 |
+
image = [image]
|
462 |
+
|
463 |
+
if isinstance(image[0], PIL.Image.Image):
|
464 |
+
image = [np.array(i).astype(np.float32) / 255.0 for i in image]
|
465 |
+
|
466 |
+
image = np.stack(image, axis=0) # to np
|
467 |
+
torch.from_numpy(image.transpose(0, 3, 1, 2))
|
468 |
+
elif isinstance(image[0], np.ndarray):
|
469 |
+
image = np.stack(image, axis=0) # to np
|
470 |
+
if image.ndim == 5:
|
471 |
+
image = image[0]
|
472 |
+
|
473 |
+
image = torch.from_numpy(image.transpose(0, 3, 1, 2))
|
474 |
+
elif isinstance(image, list) and isinstance(image[0], torch.Tensor):
|
475 |
+
dims = image[0].ndim
|
476 |
+
|
477 |
+
if dims == 3:
|
478 |
+
image = torch.stack(image, dim=0)
|
479 |
+
elif dims == 4:
|
480 |
+
image = torch.concat(image, dim=0)
|
481 |
+
else:
|
482 |
+
raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}")
|
483 |
+
|
484 |
+
image = image.to(device=device, dtype=self.unet.dtype)
|
485 |
+
|
486 |
+
image = image.repeat_interleave(num_images_per_prompt, dim=0)
|
487 |
+
|
488 |
+
return image
|
489 |
+
|
490 |
+
def _text_preprocessing(self, text, clean_caption=False):
|
491 |
+
if clean_caption and not is_bs4_available():
|
492 |
+
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
493 |
+
logger.warn("Setting `clean_caption` to False...")
|
494 |
+
clean_caption = False
|
495 |
+
|
496 |
+
if clean_caption and not is_ftfy_available():
|
497 |
+
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
498 |
+
logger.warn("Setting `clean_caption` to False...")
|
499 |
+
clean_caption = False
|
500 |
+
|
501 |
+
if not isinstance(text, (tuple, list)):
|
502 |
+
text = [text]
|
503 |
+
|
504 |
+
def process(text: str):
|
505 |
+
if clean_caption:
|
506 |
+
text = self._clean_caption(text)
|
507 |
+
text = self._clean_caption(text)
|
508 |
+
else:
|
509 |
+
text = text.lower().strip()
|
510 |
+
return text
|
511 |
+
|
512 |
+
return [process(t) for t in text]
|
513 |
+
|
514 |
+
def _clean_caption(self, caption):
|
515 |
+
caption = str(caption)
|
516 |
+
caption = ul.unquote_plus(caption)
|
517 |
+
caption = caption.strip().lower()
|
518 |
+
caption = re.sub("<person>", "person", caption)
|
519 |
+
# urls:
|
520 |
+
caption = re.sub(
|
521 |
+
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
522 |
+
"",
|
523 |
+
caption,
|
524 |
+
) # regex for urls
|
525 |
+
caption = re.sub(
|
526 |
+
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
527 |
+
"",
|
528 |
+
caption,
|
529 |
+
) # regex for urls
|
530 |
+
# html:
|
531 |
+
caption = BeautifulSoup(caption, features="html.parser").text
|
532 |
+
|
533 |
+
# @<nickname>
|
534 |
+
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
535 |
+
|
536 |
+
# 31C0—31EF CJK Strokes
|
537 |
+
# 31F0—31FF Katakana Phonetic Extensions
|
538 |
+
# 3200—32FF Enclosed CJK Letters and Months
|
539 |
+
# 3300—33FF CJK Compatibility
|
540 |
+
# 3400—4DBF CJK Unified Ideographs Extension A
|
541 |
+
# 4DC0—4DFF Yijing Hexagram Symbols
|
542 |
+
# 4E00—9FFF CJK Unified Ideographs
|
543 |
+
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
544 |
+
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
545 |
+
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
546 |
+
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
547 |
+
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
548 |
+
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
549 |
+
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
550 |
+
#######################################################
|
551 |
+
|
552 |
+
# все виды тире / all types of dash --> "-"
|
553 |
+
caption = re.sub(
|
554 |
+
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
555 |
+
"-",
|
556 |
+
caption,
|
557 |
+
)
|
558 |
+
|
559 |
+
# кавычки к одному стандарту
|
560 |
+
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
561 |
+
caption = re.sub(r"[‘’]", "'", caption)
|
562 |
+
|
563 |
+
# "
|
564 |
+
caption = re.sub(r""?", "", caption)
|
565 |
+
# &
|
566 |
+
caption = re.sub(r"&", "", caption)
|
567 |
+
|
568 |
+
# ip adresses:
|
569 |
+
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
570 |
+
|
571 |
+
# article ids:
|
572 |
+
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
573 |
+
|
574 |
+
# \n
|
575 |
+
caption = re.sub(r"\\n", " ", caption)
|
576 |
+
|
577 |
+
# "#123"
|
578 |
+
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
579 |
+
# "#12345.."
|
580 |
+
caption = re.sub(r"#\d{5,}\b", "", caption)
|
581 |
+
# "123456.."
|
582 |
+
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
583 |
+
# filenames:
|
584 |
+
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
585 |
+
|
586 |
+
#
|
587 |
+
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
588 |
+
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
589 |
+
|
590 |
+
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
591 |
+
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
592 |
+
|
593 |
+
# this-is-my-cute-cat / this_is_my_cute_cat
|
594 |
+
regex2 = re.compile(r"(?:\-|\_)")
|
595 |
+
if len(re.findall(regex2, caption)) > 3:
|
596 |
+
caption = re.sub(regex2, " ", caption)
|
597 |
+
|
598 |
+
caption = ftfy.fix_text(caption)
|
599 |
+
caption = html.unescape(html.unescape(caption))
|
600 |
+
|
601 |
+
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
602 |
+
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
603 |
+
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
604 |
+
|
605 |
+
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
606 |
+
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
607 |
+
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
608 |
+
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
609 |
+
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
610 |
+
|
611 |
+
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
612 |
+
|
613 |
+
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
614 |
+
|
615 |
+
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
616 |
+
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
617 |
+
caption = re.sub(r"\s+", " ", caption)
|
618 |
+
|
619 |
+
caption.strip()
|
620 |
+
|
621 |
+
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
622 |
+
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
623 |
+
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
624 |
+
caption = re.sub(r"^\.\S+$", "", caption)
|
625 |
+
|
626 |
+
return caption.strip()
|
627 |
+
|
628 |
+
@torch.no_grad()
|
629 |
+
def __call__(
|
630 |
+
self,
|
631 |
+
prompt: Union[str, List[str]] = None,
|
632 |
+
height: Optional[int] = None,
|
633 |
+
width: Optional[int] = None,
|
634 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None,
|
635 |
+
num_inference_steps: int = 50,
|
636 |
+
timesteps: List[int] = None,
|
637 |
+
guidance_scale: float = 4.0,
|
638 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
639 |
+
num_images_per_prompt: Optional[int] = 1,
|
640 |
+
eta: float = 0.0,
|
641 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
642 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
643 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
644 |
+
output_type: Optional[str] = "np",
|
645 |
+
return_dict: bool = True,
|
646 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
647 |
+
callback_steps: int = 1,
|
648 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
649 |
+
noise_level: int = 20,
|
650 |
+
clean_caption: bool = True,
|
651 |
+
):
|
652 |
+
"""
|
653 |
+
Function invoked when calling the pipeline for generation.
|
654 |
+
|
655 |
+
Args:
|
656 |
+
prompt (`str` or `List[str]`, *optional*):
|
657 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
658 |
+
instead.
|
659 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
660 |
+
The height in pixels of the generated image.
|
661 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
662 |
+
The width in pixels of the generated image.
|
663 |
+
image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`):
|
664 |
+
The image to be upscaled.
|
665 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
666 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
667 |
+
expense of slower inference.
|
668 |
+
timesteps (`List[int]`, *optional*):
|
669 |
+
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
670 |
+
timesteps are used. Must be in descending order.
|
671 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
672 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
673 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
674 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
675 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
676 |
+
usually at the expense of lower image quality.
|
677 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
678 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
679 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
680 |
+
less than `1`).
|
681 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
682 |
+
The number of images to generate per prompt.
|
683 |
+
eta (`float`, *optional*, defaults to 0.0):
|
684 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
685 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
686 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
687 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
688 |
+
to make generation deterministic.
|
689 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
690 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
691 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
692 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
693 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
694 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
695 |
+
argument.
|
696 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
697 |
+
The output format of the generate image. Choose between
|
698 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
699 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
700 |
+
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
701 |
+
callback (`Callable`, *optional*):
|
702 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
703 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
704 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
705 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
706 |
+
called at every step.
|
707 |
+
cross_attention_kwargs (`dict`, *optional*):
|
708 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
709 |
+
`self.processor` in
|
710 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
711 |
+
noise_level (`int`, *optional*, defaults to 250):
|
712 |
+
The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)`
|
713 |
+
clean_caption (`bool`, *optional*, defaults to `True`):
|
714 |
+
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
715 |
+
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
716 |
+
prompt.
|
717 |
+
|
718 |
+
Examples:
|
719 |
+
|
720 |
+
Returns:
|
721 |
+
[`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
|
722 |
+
[`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
|
723 |
+
returning a tuple, the first element is a list with the generated images, and the second element is a list
|
724 |
+
of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
|
725 |
+
or watermarked content, according to the `safety_checker`.
|
726 |
+
"""
|
727 |
+
# 1. Check inputs. Raise error if not correct
|
728 |
+
|
729 |
+
if prompt is not None and isinstance(prompt, str):
|
730 |
+
batch_size = 1
|
731 |
+
elif prompt is not None and isinstance(prompt, list):
|
732 |
+
batch_size = len(prompt)
|
733 |
+
else:
|
734 |
+
batch_size = prompt_embeds.shape[0]
|
735 |
+
|
736 |
+
self.check_inputs(
|
737 |
+
prompt,
|
738 |
+
image,
|
739 |
+
batch_size,
|
740 |
+
noise_level,
|
741 |
+
callback_steps,
|
742 |
+
negative_prompt,
|
743 |
+
prompt_embeds,
|
744 |
+
negative_prompt_embeds,
|
745 |
+
)
|
746 |
+
|
747 |
+
# 2. Define call parameters
|
748 |
+
|
749 |
+
height = height or self.unet.config.sample_size
|
750 |
+
width = width or self.unet.config.sample_size
|
751 |
+
assert isinstance(image, torch.Tensor), f"{type(image)} is not supported."
|
752 |
+
num_frames = image.shape[2]
|
753 |
+
|
754 |
+
device = self._execution_device
|
755 |
+
|
756 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
757 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
758 |
+
# corresponds to doing no classifier free guidance.
|
759 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
760 |
+
|
761 |
+
# 3. Encode input prompt
|
762 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
763 |
+
prompt,
|
764 |
+
do_classifier_free_guidance,
|
765 |
+
num_images_per_prompt=num_images_per_prompt,
|
766 |
+
device=device,
|
767 |
+
negative_prompt=negative_prompt,
|
768 |
+
prompt_embeds=prompt_embeds,
|
769 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
770 |
+
clean_caption=clean_caption,
|
771 |
+
)
|
772 |
+
|
773 |
+
if do_classifier_free_guidance:
|
774 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
775 |
+
|
776 |
+
# 4. Prepare timesteps
|
777 |
+
if timesteps is not None:
|
778 |
+
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
|
779 |
+
timesteps = self.scheduler.timesteps
|
780 |
+
num_inference_steps = len(timesteps)
|
781 |
+
else:
|
782 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
783 |
+
timesteps = self.scheduler.timesteps
|
784 |
+
|
785 |
+
# 5. Prepare intermediate images
|
786 |
+
num_channels = self.unet.config.in_channels // 2
|
787 |
+
intermediate_images = self.prepare_intermediate_images(
|
788 |
+
batch_size * num_images_per_prompt,
|
789 |
+
num_channels,
|
790 |
+
num_frames,
|
791 |
+
height,
|
792 |
+
width,
|
793 |
+
prompt_embeds.dtype,
|
794 |
+
device,
|
795 |
+
generator,
|
796 |
+
)
|
797 |
+
|
798 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
799 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
800 |
+
|
801 |
+
# 7. Prepare upscaled image and noise level
|
802 |
+
image = self.preprocess_image(image, num_images_per_prompt, device)
|
803 |
+
upscaled = rearrange(image, "b c f h w -> (b f) c h w")
|
804 |
+
upscaled = F.interpolate(upscaled, (height, width), mode="bilinear", align_corners=True)
|
805 |
+
upscaled = rearrange(upscaled, "(b f) c h w -> b c f h w", f=image.shape[2])
|
806 |
+
|
807 |
+
noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
|
808 |
+
noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
|
809 |
+
upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)
|
810 |
+
|
811 |
+
if do_classifier_free_guidance:
|
812 |
+
noise_level = torch.cat([noise_level] * 2)
|
813 |
+
|
814 |
+
# HACK: see comment in `enable_model_cpu_offload`
|
815 |
+
if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
|
816 |
+
self.text_encoder_offload_hook.offload()
|
817 |
+
|
818 |
+
# 8. Denoising loop
|
819 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
820 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
821 |
+
for i, t in enumerate(timesteps):
|
822 |
+
model_input = torch.cat([intermediate_images, upscaled], dim=1)
|
823 |
+
|
824 |
+
model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input
|
825 |
+
model_input = self.scheduler.scale_model_input(model_input, t)
|
826 |
+
|
827 |
+
# predict the noise residual
|
828 |
+
noise_pred = self.unet(
|
829 |
+
model_input,
|
830 |
+
t,
|
831 |
+
encoder_hidden_states=prompt_embeds,
|
832 |
+
class_labels=noise_level,
|
833 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
834 |
+
).sample
|
835 |
+
|
836 |
+
# perform guidance
|
837 |
+
if do_classifier_free_guidance:
|
838 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
839 |
+
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
|
840 |
+
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
|
841 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
842 |
+
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
843 |
+
|
844 |
+
# reshape latents
|
845 |
+
bsz, channel, frames, height, width = intermediate_images.shape
|
846 |
+
intermediate_images = intermediate_images.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, height, width)
|
847 |
+
noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, -1, height, width)
|
848 |
+
|
849 |
+
# compute the previous noisy sample x_t -> x_t-1
|
850 |
+
intermediate_images = self.scheduler.step(
|
851 |
+
noise_pred, t, intermediate_images, **extra_step_kwargs
|
852 |
+
).prev_sample
|
853 |
+
|
854 |
+
# reshape latents back
|
855 |
+
intermediate_images = intermediate_images[None, :].reshape(bsz, frames, channel, height, width).permute(0, 2, 1, 3, 4)
|
856 |
+
|
857 |
+
# call the callback, if provided
|
858 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
859 |
+
progress_bar.update()
|
860 |
+
if callback is not None and i % callback_steps == 0:
|
861 |
+
callback(i, t, intermediate_images)
|
862 |
+
|
863 |
+
video_tensor = intermediate_images
|
864 |
+
|
865 |
+
if output_type == "pt":
|
866 |
+
video = video_tensor
|
867 |
+
else:
|
868 |
+
video = tensor2vid(video_tensor)
|
869 |
+
|
870 |
+
# Offload last model to CPU
|
871 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
872 |
+
self.final_offload_hook.offload()
|
873 |
+
|
874 |
+
if not return_dict:
|
875 |
+
return (video,)
|
876 |
+
|
877 |
+
return TextToVideoPipelineOutput(frames=video)
|
showone/pipelines/pipeline_t2v_sr_pixel_cond.py
ADDED
@@ -0,0 +1,890 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import html
|
2 |
+
import inspect
|
3 |
+
import re
|
4 |
+
import urllib.parse as ul
|
5 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import PIL
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
from diffusers.loaders import LoraLoaderMixin
|
15 |
+
from diffusers.schedulers import DDPMScheduler
|
16 |
+
from diffusers.utils import (
|
17 |
+
BACKENDS_MAPPING,
|
18 |
+
is_accelerate_available,
|
19 |
+
is_accelerate_version,
|
20 |
+
is_bs4_available,
|
21 |
+
is_ftfy_available,
|
22 |
+
logging,
|
23 |
+
randn_tensor,
|
24 |
+
replace_example_docstring,
|
25 |
+
)
|
26 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
27 |
+
|
28 |
+
from ..models import UNet3DConditionModel
|
29 |
+
from . import TextToVideoPipelineOutput
|
30 |
+
|
31 |
+
|
32 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
33 |
+
|
34 |
+
if is_bs4_available():
|
35 |
+
from bs4 import BeautifulSoup
|
36 |
+
|
37 |
+
if is_ftfy_available():
|
38 |
+
import ftfy
|
39 |
+
|
40 |
+
|
41 |
+
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
|
42 |
+
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
43 |
+
# reshape to ncfhw
|
44 |
+
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
45 |
+
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
46 |
+
# unnormalize back to [0,1]
|
47 |
+
video = video.mul_(std).add_(mean)
|
48 |
+
video.clamp_(0, 1)
|
49 |
+
# prepare the final outputs
|
50 |
+
i, c, f, h, w = video.shape
|
51 |
+
images = video.permute(2, 3, 0, 4, 1).reshape(
|
52 |
+
f, h, i * w, c
|
53 |
+
) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
|
54 |
+
images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
|
55 |
+
images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
|
56 |
+
return images
|
57 |
+
|
58 |
+
|
59 |
+
class TextToVideoIFSuperResolutionPipeline_Cond(DiffusionPipeline, LoraLoaderMixin):
|
60 |
+
tokenizer: T5Tokenizer
|
61 |
+
text_encoder: T5EncoderModel
|
62 |
+
|
63 |
+
unet: UNet3DConditionModel
|
64 |
+
scheduler: DDPMScheduler
|
65 |
+
image_noising_scheduler: DDPMScheduler
|
66 |
+
|
67 |
+
feature_extractor: Optional[CLIPImageProcessor]
|
68 |
+
# safety_checker: Optional[IFSafetyChecker]
|
69 |
+
|
70 |
+
# watermarker: Optional[IFWatermarker]
|
71 |
+
|
72 |
+
bad_punct_regex = re.compile(
|
73 |
+
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
74 |
+
) # noqa
|
75 |
+
|
76 |
+
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
|
77 |
+
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
tokenizer: T5Tokenizer,
|
81 |
+
text_encoder: T5EncoderModel,
|
82 |
+
unet: UNet3DConditionModel,
|
83 |
+
scheduler: DDPMScheduler,
|
84 |
+
image_noising_scheduler: DDPMScheduler,
|
85 |
+
feature_extractor: Optional[CLIPImageProcessor],
|
86 |
+
):
|
87 |
+
super().__init__()
|
88 |
+
|
89 |
+
self.register_modules(
|
90 |
+
tokenizer=tokenizer,
|
91 |
+
text_encoder=text_encoder,
|
92 |
+
unet=unet,
|
93 |
+
scheduler=scheduler,
|
94 |
+
image_noising_scheduler=image_noising_scheduler,
|
95 |
+
feature_extractor=feature_extractor,
|
96 |
+
)
|
97 |
+
self.safety_checker = None
|
98 |
+
|
99 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
100 |
+
r"""
|
101 |
+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
102 |
+
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
103 |
+
when their specific submodule has its `forward` method called.
|
104 |
+
"""
|
105 |
+
if is_accelerate_available():
|
106 |
+
from accelerate import cpu_offload
|
107 |
+
else:
|
108 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
109 |
+
|
110 |
+
device = torch.device(f"cuda:{gpu_id}")
|
111 |
+
|
112 |
+
models = [
|
113 |
+
self.text_encoder,
|
114 |
+
self.unet,
|
115 |
+
]
|
116 |
+
for cpu_offloaded_model in models:
|
117 |
+
if cpu_offloaded_model is not None:
|
118 |
+
cpu_offload(cpu_offloaded_model, device)
|
119 |
+
|
120 |
+
if self.safety_checker is not None:
|
121 |
+
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
122 |
+
|
123 |
+
def enable_model_cpu_offload(self, gpu_id=0):
|
124 |
+
r"""
|
125 |
+
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
126 |
+
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
127 |
+
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
128 |
+
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
129 |
+
"""
|
130 |
+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
131 |
+
from accelerate import cpu_offload_with_hook
|
132 |
+
else:
|
133 |
+
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
134 |
+
|
135 |
+
device = torch.device(f"cuda:{gpu_id}")
|
136 |
+
|
137 |
+
if self.device.type != "cpu":
|
138 |
+
self.to("cpu", silence_dtype_warnings=True)
|
139 |
+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
140 |
+
|
141 |
+
hook = None
|
142 |
+
|
143 |
+
if self.text_encoder is not None:
|
144 |
+
_, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
|
145 |
+
|
146 |
+
# Accelerate will move the next model to the device _before_ calling the offload hook of the
|
147 |
+
# previous model. This will cause both models to be present on the device at the same time.
|
148 |
+
# IF uses T5 for its text encoder which is really large. We can manually call the offload
|
149 |
+
# hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
|
150 |
+
# the GPU.
|
151 |
+
self.text_encoder_offload_hook = hook
|
152 |
+
|
153 |
+
_, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
|
154 |
+
|
155 |
+
# if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
|
156 |
+
self.unet_offload_hook = hook
|
157 |
+
|
158 |
+
if self.safety_checker is not None:
|
159 |
+
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
160 |
+
|
161 |
+
# We'll offload the last model manually.
|
162 |
+
self.final_offload_hook = hook
|
163 |
+
|
164 |
+
def remove_all_hooks(self):
|
165 |
+
if is_accelerate_available():
|
166 |
+
from accelerate.hooks import remove_hook_from_module
|
167 |
+
else:
|
168 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
169 |
+
|
170 |
+
for model in [self.text_encoder, self.unet, self.safety_checker]:
|
171 |
+
if model is not None:
|
172 |
+
remove_hook_from_module(model, recurse=True)
|
173 |
+
|
174 |
+
self.unet_offload_hook = None
|
175 |
+
self.text_encoder_offload_hook = None
|
176 |
+
self.final_offload_hook = None
|
177 |
+
|
178 |
+
@property
|
179 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
180 |
+
def _execution_device(self):
|
181 |
+
r"""
|
182 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
183 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
184 |
+
hooks.
|
185 |
+
"""
|
186 |
+
if not hasattr(self.unet, "_hf_hook"):
|
187 |
+
return self.device
|
188 |
+
for module in self.unet.modules():
|
189 |
+
if (
|
190 |
+
hasattr(module, "_hf_hook")
|
191 |
+
and hasattr(module._hf_hook, "execution_device")
|
192 |
+
and module._hf_hook.execution_device is not None
|
193 |
+
):
|
194 |
+
return torch.device(module._hf_hook.execution_device)
|
195 |
+
return self.device
|
196 |
+
|
197 |
+
@torch.no_grad()
|
198 |
+
def encode_prompt(
|
199 |
+
self,
|
200 |
+
prompt,
|
201 |
+
do_classifier_free_guidance=True,
|
202 |
+
num_images_per_prompt=1,
|
203 |
+
device=None,
|
204 |
+
negative_prompt=None,
|
205 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
206 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
207 |
+
clean_caption: bool = False,
|
208 |
+
):
|
209 |
+
r"""
|
210 |
+
Encodes the prompt into text encoder hidden states.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
prompt (`str` or `List[str]`, *optional*):
|
214 |
+
prompt to be encoded
|
215 |
+
device: (`torch.device`, *optional*):
|
216 |
+
torch device to place the resulting embeddings on
|
217 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
218 |
+
number of images that should be generated per prompt
|
219 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
220 |
+
whether to use classifier free guidance or not
|
221 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
222 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
223 |
+
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
224 |
+
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
225 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
226 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
227 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
228 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
229 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
230 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
231 |
+
argument.
|
232 |
+
"""
|
233 |
+
if prompt is not None and negative_prompt is not None:
|
234 |
+
if type(prompt) is not type(negative_prompt):
|
235 |
+
raise TypeError(
|
236 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
237 |
+
f" {type(prompt)}."
|
238 |
+
)
|
239 |
+
|
240 |
+
if device is None:
|
241 |
+
device = self._execution_device
|
242 |
+
|
243 |
+
if prompt is not None and isinstance(prompt, str):
|
244 |
+
batch_size = 1
|
245 |
+
elif prompt is not None and isinstance(prompt, list):
|
246 |
+
batch_size = len(prompt)
|
247 |
+
else:
|
248 |
+
batch_size = prompt_embeds.shape[0]
|
249 |
+
|
250 |
+
# while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
|
251 |
+
max_length = 77
|
252 |
+
|
253 |
+
if prompt_embeds is None:
|
254 |
+
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
255 |
+
text_inputs = self.tokenizer(
|
256 |
+
prompt,
|
257 |
+
padding="max_length",
|
258 |
+
max_length=max_length,
|
259 |
+
truncation=True,
|
260 |
+
add_special_tokens=True,
|
261 |
+
return_tensors="pt",
|
262 |
+
)
|
263 |
+
text_input_ids = text_inputs.input_ids
|
264 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
265 |
+
|
266 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
267 |
+
text_input_ids, untruncated_ids
|
268 |
+
):
|
269 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
270 |
+
logger.warning(
|
271 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
272 |
+
f" {max_length} tokens: {removed_text}"
|
273 |
+
)
|
274 |
+
|
275 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
276 |
+
|
277 |
+
prompt_embeds = self.text_encoder(
|
278 |
+
text_input_ids.to(device),
|
279 |
+
attention_mask=attention_mask,
|
280 |
+
)
|
281 |
+
prompt_embeds = prompt_embeds[0]
|
282 |
+
|
283 |
+
if self.text_encoder is not None:
|
284 |
+
dtype = self.text_encoder.dtype
|
285 |
+
elif self.unet is not None:
|
286 |
+
dtype = self.unet.dtype
|
287 |
+
else:
|
288 |
+
dtype = None
|
289 |
+
|
290 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
291 |
+
|
292 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
293 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
294 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
295 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
296 |
+
|
297 |
+
# get unconditional embeddings for classifier free guidance
|
298 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
299 |
+
uncond_tokens: List[str]
|
300 |
+
if negative_prompt is None:
|
301 |
+
uncond_tokens = [""] * batch_size
|
302 |
+
elif isinstance(negative_prompt, str):
|
303 |
+
uncond_tokens = [negative_prompt]
|
304 |
+
elif batch_size != len(negative_prompt):
|
305 |
+
raise ValueError(
|
306 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
307 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
308 |
+
" the batch size of `prompt`."
|
309 |
+
)
|
310 |
+
else:
|
311 |
+
uncond_tokens = negative_prompt
|
312 |
+
|
313 |
+
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
314 |
+
max_length = prompt_embeds.shape[1]
|
315 |
+
uncond_input = self.tokenizer(
|
316 |
+
uncond_tokens,
|
317 |
+
padding="max_length",
|
318 |
+
max_length=max_length,
|
319 |
+
truncation=True,
|
320 |
+
return_attention_mask=True,
|
321 |
+
add_special_tokens=True,
|
322 |
+
return_tensors="pt",
|
323 |
+
)
|
324 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
325 |
+
|
326 |
+
negative_prompt_embeds = self.text_encoder(
|
327 |
+
uncond_input.input_ids.to(device),
|
328 |
+
attention_mask=attention_mask,
|
329 |
+
)
|
330 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
331 |
+
|
332 |
+
if do_classifier_free_guidance:
|
333 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
334 |
+
seq_len = negative_prompt_embeds.shape[1]
|
335 |
+
|
336 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
337 |
+
|
338 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
339 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
340 |
+
|
341 |
+
# For classifier free guidance, we need to do two forward passes.
|
342 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
343 |
+
# to avoid doing two forward passes
|
344 |
+
else:
|
345 |
+
negative_prompt_embeds = None
|
346 |
+
|
347 |
+
return prompt_embeds, negative_prompt_embeds
|
348 |
+
|
349 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
350 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
351 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
352 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
353 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
354 |
+
# and should be between [0, 1]
|
355 |
+
|
356 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
357 |
+
extra_step_kwargs = {}
|
358 |
+
if accepts_eta:
|
359 |
+
extra_step_kwargs["eta"] = eta
|
360 |
+
|
361 |
+
# check if the scheduler accepts generator
|
362 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
363 |
+
if accepts_generator:
|
364 |
+
extra_step_kwargs["generator"] = generator
|
365 |
+
return extra_step_kwargs
|
366 |
+
|
367 |
+
def check_inputs(
|
368 |
+
self,
|
369 |
+
prompt,
|
370 |
+
image,
|
371 |
+
batch_size,
|
372 |
+
noise_level,
|
373 |
+
callback_steps,
|
374 |
+
negative_prompt=None,
|
375 |
+
prompt_embeds=None,
|
376 |
+
negative_prompt_embeds=None,
|
377 |
+
):
|
378 |
+
if (callback_steps is None) or (
|
379 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
380 |
+
):
|
381 |
+
raise ValueError(
|
382 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
383 |
+
f" {type(callback_steps)}."
|
384 |
+
)
|
385 |
+
|
386 |
+
if prompt is not None and prompt_embeds is not None:
|
387 |
+
raise ValueError(
|
388 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
389 |
+
" only forward one of the two."
|
390 |
+
)
|
391 |
+
elif prompt is None and prompt_embeds is None:
|
392 |
+
raise ValueError(
|
393 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
394 |
+
)
|
395 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
396 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
397 |
+
|
398 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
399 |
+
raise ValueError(
|
400 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
401 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
402 |
+
)
|
403 |
+
|
404 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
405 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
406 |
+
raise ValueError(
|
407 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
408 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
409 |
+
f" {negative_prompt_embeds.shape}."
|
410 |
+
)
|
411 |
+
|
412 |
+
if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
|
413 |
+
raise ValueError(
|
414 |
+
f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})"
|
415 |
+
)
|
416 |
+
|
417 |
+
if isinstance(image, list):
|
418 |
+
check_image_type = image[0]
|
419 |
+
else:
|
420 |
+
check_image_type = image
|
421 |
+
|
422 |
+
if (
|
423 |
+
not isinstance(check_image_type, torch.Tensor)
|
424 |
+
and not isinstance(check_image_type, PIL.Image.Image)
|
425 |
+
and not isinstance(check_image_type, np.ndarray)
|
426 |
+
):
|
427 |
+
raise ValueError(
|
428 |
+
"`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
|
429 |
+
f" {type(check_image_type)}"
|
430 |
+
)
|
431 |
+
|
432 |
+
if isinstance(image, list):
|
433 |
+
image_batch_size = len(image)
|
434 |
+
elif isinstance(image, torch.Tensor):
|
435 |
+
image_batch_size = image.shape[0]
|
436 |
+
elif isinstance(image, PIL.Image.Image):
|
437 |
+
image_batch_size = 1
|
438 |
+
elif isinstance(image, np.ndarray):
|
439 |
+
image_batch_size = image.shape[0]
|
440 |
+
else:
|
441 |
+
assert False
|
442 |
+
|
443 |
+
if batch_size != image_batch_size:
|
444 |
+
raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}")
|
445 |
+
|
446 |
+
def prepare_intermediate_images(self, batch_size, num_channels, num_frames, height, width, dtype, device, generator):
|
447 |
+
shape = (batch_size, num_channels, num_frames, height, width)
|
448 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
449 |
+
raise ValueError(
|
450 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
451 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
452 |
+
)
|
453 |
+
|
454 |
+
intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
455 |
+
|
456 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
457 |
+
intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
|
458 |
+
return intermediate_images
|
459 |
+
|
460 |
+
def preprocess_image(self, image, num_images_per_prompt, device):
|
461 |
+
if not isinstance(image, torch.Tensor) and not isinstance(image, list):
|
462 |
+
image = [image]
|
463 |
+
|
464 |
+
if isinstance(image[0], PIL.Image.Image):
|
465 |
+
image = [np.array(i).astype(np.float32) / 255.0 for i in image]
|
466 |
+
|
467 |
+
image = np.stack(image, axis=0) # to np
|
468 |
+
torch.from_numpy(image.transpose(0, 3, 1, 2))
|
469 |
+
elif isinstance(image[0], np.ndarray):
|
470 |
+
image = np.stack(image, axis=0) # to np
|
471 |
+
if image.ndim == 5:
|
472 |
+
image = image[0]
|
473 |
+
|
474 |
+
image = torch.from_numpy(image.transpose(0, 3, 1, 2))
|
475 |
+
elif isinstance(image, list) and isinstance(image[0], torch.Tensor):
|
476 |
+
dims = image[0].ndim
|
477 |
+
|
478 |
+
if dims == 3:
|
479 |
+
image = torch.stack(image, dim=0)
|
480 |
+
elif dims == 4:
|
481 |
+
image = torch.concat(image, dim=0)
|
482 |
+
else:
|
483 |
+
raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}")
|
484 |
+
|
485 |
+
image = image.to(device=device, dtype=self.unet.dtype)
|
486 |
+
|
487 |
+
image = image.repeat_interleave(num_images_per_prompt, dim=0)
|
488 |
+
|
489 |
+
return image
|
490 |
+
|
491 |
+
def _text_preprocessing(self, text, clean_caption=False):
|
492 |
+
if clean_caption and not is_bs4_available():
|
493 |
+
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
494 |
+
logger.warn("Setting `clean_caption` to False...")
|
495 |
+
clean_caption = False
|
496 |
+
|
497 |
+
if clean_caption and not is_ftfy_available():
|
498 |
+
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
499 |
+
logger.warn("Setting `clean_caption` to False...")
|
500 |
+
clean_caption = False
|
501 |
+
|
502 |
+
if not isinstance(text, (tuple, list)):
|
503 |
+
text = [text]
|
504 |
+
|
505 |
+
def process(text: str):
|
506 |
+
if clean_caption:
|
507 |
+
text = self._clean_caption(text)
|
508 |
+
text = self._clean_caption(text)
|
509 |
+
else:
|
510 |
+
text = text.lower().strip()
|
511 |
+
return text
|
512 |
+
|
513 |
+
return [process(t) for t in text]
|
514 |
+
|
515 |
+
def _clean_caption(self, caption):
|
516 |
+
caption = str(caption)
|
517 |
+
caption = ul.unquote_plus(caption)
|
518 |
+
caption = caption.strip().lower()
|
519 |
+
caption = re.sub("<person>", "person", caption)
|
520 |
+
# urls:
|
521 |
+
caption = re.sub(
|
522 |
+
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
523 |
+
"",
|
524 |
+
caption,
|
525 |
+
) # regex for urls
|
526 |
+
caption = re.sub(
|
527 |
+
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
528 |
+
"",
|
529 |
+
caption,
|
530 |
+
) # regex for urls
|
531 |
+
# html:
|
532 |
+
caption = BeautifulSoup(caption, features="html.parser").text
|
533 |
+
|
534 |
+
# @<nickname>
|
535 |
+
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
536 |
+
|
537 |
+
# 31C0—31EF CJK Strokes
|
538 |
+
# 31F0—31FF Katakana Phonetic Extensions
|
539 |
+
# 3200—32FF Enclosed CJK Letters and Months
|
540 |
+
# 3300—33FF CJK Compatibility
|
541 |
+
# 3400—4DBF CJK Unified Ideographs Extension A
|
542 |
+
# 4DC0—4DFF Yijing Hexagram Symbols
|
543 |
+
# 4E00—9FFF CJK Unified Ideographs
|
544 |
+
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
545 |
+
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
546 |
+
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
547 |
+
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
548 |
+
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
549 |
+
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
550 |
+
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
551 |
+
#######################################################
|
552 |
+
|
553 |
+
# все виды тире / all types of dash --> "-"
|
554 |
+
caption = re.sub(
|
555 |
+
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
556 |
+
"-",
|
557 |
+
caption,
|
558 |
+
)
|
559 |
+
|
560 |
+
# кавычки к одному стандарту
|
561 |
+
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
562 |
+
caption = re.sub(r"[‘’]", "'", caption)
|
563 |
+
|
564 |
+
# "
|
565 |
+
caption = re.sub(r""?", "", caption)
|
566 |
+
# &
|
567 |
+
caption = re.sub(r"&", "", caption)
|
568 |
+
|
569 |
+
# ip adresses:
|
570 |
+
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
571 |
+
|
572 |
+
# article ids:
|
573 |
+
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
574 |
+
|
575 |
+
# \n
|
576 |
+
caption = re.sub(r"\\n", " ", caption)
|
577 |
+
|
578 |
+
# "#123"
|
579 |
+
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
580 |
+
# "#12345.."
|
581 |
+
caption = re.sub(r"#\d{5,}\b", "", caption)
|
582 |
+
# "123456.."
|
583 |
+
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
584 |
+
# filenames:
|
585 |
+
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
586 |
+
|
587 |
+
#
|
588 |
+
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
589 |
+
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
590 |
+
|
591 |
+
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
592 |
+
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
593 |
+
|
594 |
+
# this-is-my-cute-cat / this_is_my_cute_cat
|
595 |
+
regex2 = re.compile(r"(?:\-|\_)")
|
596 |
+
if len(re.findall(regex2, caption)) > 3:
|
597 |
+
caption = re.sub(regex2, " ", caption)
|
598 |
+
|
599 |
+
caption = ftfy.fix_text(caption)
|
600 |
+
caption = html.unescape(html.unescape(caption))
|
601 |
+
|
602 |
+
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
603 |
+
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
604 |
+
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
605 |
+
|
606 |
+
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
607 |
+
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
608 |
+
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
609 |
+
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
610 |
+
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
611 |
+
|
612 |
+
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
613 |
+
|
614 |
+
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
615 |
+
|
616 |
+
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
617 |
+
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
618 |
+
caption = re.sub(r"\s+", " ", caption)
|
619 |
+
|
620 |
+
caption.strip()
|
621 |
+
|
622 |
+
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
623 |
+
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
624 |
+
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
625 |
+
caption = re.sub(r"^\.\S+$", "", caption)
|
626 |
+
|
627 |
+
return caption.strip()
|
628 |
+
|
629 |
+
@torch.no_grad()
|
630 |
+
def __call__(
|
631 |
+
self,
|
632 |
+
prompt: Union[str, List[str]] = None,
|
633 |
+
height: Optional[int] = None,
|
634 |
+
width: Optional[int] = None,
|
635 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None,
|
636 |
+
first_frame_cond: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None,
|
637 |
+
all_frame_cond: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None,
|
638 |
+
num_inference_steps: int = 50,
|
639 |
+
timesteps: List[int] = None,
|
640 |
+
guidance_scale: float = 4.0,
|
641 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
642 |
+
num_images_per_prompt: Optional[int] = 1,
|
643 |
+
eta: float = 0.0,
|
644 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
645 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
646 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
647 |
+
output_type: Optional[str] = "np",
|
648 |
+
return_dict: bool = True,
|
649 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
650 |
+
callback_steps: int = 1,
|
651 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
652 |
+
noise_level: int = 250,
|
653 |
+
clean_caption: bool = True,
|
654 |
+
):
|
655 |
+
"""
|
656 |
+
Function invoked when calling the pipeline for generation.
|
657 |
+
|
658 |
+
Args:
|
659 |
+
prompt (`str` or `List[str]`, *optional*):
|
660 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
661 |
+
instead.
|
662 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
663 |
+
The height in pixels of the generated image.
|
664 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
665 |
+
The width in pixels of the generated image.
|
666 |
+
image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`):
|
667 |
+
The image to be upscaled.
|
668 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
669 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
670 |
+
expense of slower inference.
|
671 |
+
timesteps (`List[int]`, *optional*):
|
672 |
+
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
673 |
+
timesteps are used. Must be in descending order.
|
674 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
675 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
676 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
677 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
678 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
679 |
+
usually at the expense of lower image quality.
|
680 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
681 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
682 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
683 |
+
less than `1`).
|
684 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
685 |
+
The number of images to generate per prompt.
|
686 |
+
eta (`float`, *optional*, defaults to 0.0):
|
687 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
688 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
689 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
690 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
691 |
+
to make generation deterministic.
|
692 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
693 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
694 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
695 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
696 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
697 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
698 |
+
argument.
|
699 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
700 |
+
The output format of the generate image. Choose between
|
701 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
702 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
703 |
+
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
704 |
+
callback (`Callable`, *optional*):
|
705 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
706 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
707 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
708 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
709 |
+
called at every step.
|
710 |
+
cross_attention_kwargs (`dict`, *optional*):
|
711 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
712 |
+
`self.processor` in
|
713 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
714 |
+
noise_level (`int`, *optional*, defaults to 250):
|
715 |
+
The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)`
|
716 |
+
clean_caption (`bool`, *optional*, defaults to `True`):
|
717 |
+
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
718 |
+
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
719 |
+
prompt.
|
720 |
+
|
721 |
+
Examples:
|
722 |
+
|
723 |
+
Returns:
|
724 |
+
[`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
|
725 |
+
[`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
|
726 |
+
returning a tuple, the first element is a list with the generated images, and the second element is a list
|
727 |
+
of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
|
728 |
+
or watermarked content, according to the `safety_checker`.
|
729 |
+
"""
|
730 |
+
# 1. Check inputs. Raise error if not correct
|
731 |
+
|
732 |
+
if prompt is not None and isinstance(prompt, str):
|
733 |
+
batch_size = 1
|
734 |
+
elif prompt is not None and isinstance(prompt, list):
|
735 |
+
batch_size = len(prompt)
|
736 |
+
else:
|
737 |
+
batch_size = prompt_embeds.shape[0]
|
738 |
+
|
739 |
+
self.check_inputs(
|
740 |
+
prompt,
|
741 |
+
image,
|
742 |
+
batch_size,
|
743 |
+
noise_level,
|
744 |
+
callback_steps,
|
745 |
+
negative_prompt,
|
746 |
+
prompt_embeds,
|
747 |
+
negative_prompt_embeds,
|
748 |
+
)
|
749 |
+
|
750 |
+
# 2. Define call parameters
|
751 |
+
|
752 |
+
height = height or self.unet.config.sample_size
|
753 |
+
width = width or self.unet.config.sample_size
|
754 |
+
assert isinstance(image, torch.Tensor), f"{type(image)} is not supported."
|
755 |
+
num_frames = image.shape[2]
|
756 |
+
|
757 |
+
device = self._execution_device
|
758 |
+
|
759 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
760 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
761 |
+
# corresponds to doing no classifier free guidance.
|
762 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
763 |
+
|
764 |
+
# 3. Encode input prompt
|
765 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
766 |
+
prompt,
|
767 |
+
do_classifier_free_guidance,
|
768 |
+
num_images_per_prompt=num_images_per_prompt,
|
769 |
+
device=device,
|
770 |
+
negative_prompt=negative_prompt,
|
771 |
+
prompt_embeds=prompt_embeds,
|
772 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
773 |
+
clean_caption=clean_caption,
|
774 |
+
)
|
775 |
+
|
776 |
+
if do_classifier_free_guidance:
|
777 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
778 |
+
|
779 |
+
# 4. Prepare timesteps
|
780 |
+
if timesteps is not None:
|
781 |
+
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
|
782 |
+
timesteps = self.scheduler.timesteps
|
783 |
+
num_inference_steps = len(timesteps)
|
784 |
+
else:
|
785 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
786 |
+
timesteps = self.scheduler.timesteps
|
787 |
+
|
788 |
+
# 5. Prepare intermediate images
|
789 |
+
num_channels = self.unet.config.in_channels // 2
|
790 |
+
intermediate_images = self.prepare_intermediate_images(
|
791 |
+
batch_size * num_images_per_prompt,
|
792 |
+
num_channels,
|
793 |
+
num_frames,
|
794 |
+
height,
|
795 |
+
width,
|
796 |
+
prompt_embeds.dtype,
|
797 |
+
device,
|
798 |
+
generator,
|
799 |
+
)
|
800 |
+
|
801 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
802 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
803 |
+
|
804 |
+
# 7. Prepare upscaled image and noise level
|
805 |
+
image = self.preprocess_image(image, num_images_per_prompt, device)
|
806 |
+
# upscaled = F.interpolate(image, (num_frames, height, width), mode="trilinear", align_corners=True)
|
807 |
+
if all_frame_cond is not None:
|
808 |
+
upscaled = all_frame_cond
|
809 |
+
else:
|
810 |
+
upscaled = rearrange(image, "b c f h w -> (b f) c h w")
|
811 |
+
upscaled = F.interpolate(upscaled, (height, width), mode="bilinear", align_corners=True)
|
812 |
+
upscaled = rearrange(upscaled, "(b f) c h w -> b c f h w", f=image.shape[2])
|
813 |
+
|
814 |
+
noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
|
815 |
+
noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
|
816 |
+
upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)
|
817 |
+
if first_frame_cond is not None:
|
818 |
+
first_frame_cond = first_frame_cond.to(device=device, dtype=self.unet.dtype)
|
819 |
+
upscaled[:,:,:1,:,:] = first_frame_cond
|
820 |
+
|
821 |
+
if do_classifier_free_guidance:
|
822 |
+
noise_level = torch.cat([noise_level] * 2)
|
823 |
+
|
824 |
+
# HACK: see comment in `enable_model_cpu_offload`
|
825 |
+
if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
|
826 |
+
self.text_encoder_offload_hook.offload()
|
827 |
+
|
828 |
+
# 8. Denoising loop
|
829 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
830 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
831 |
+
for i, t in enumerate(timesteps):
|
832 |
+
model_input = torch.cat([intermediate_images, upscaled], dim=1)
|
833 |
+
|
834 |
+
model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input
|
835 |
+
model_input = self.scheduler.scale_model_input(model_input, t)
|
836 |
+
|
837 |
+
# predict the noise residual
|
838 |
+
noise_pred = self.unet(
|
839 |
+
model_input,
|
840 |
+
t,
|
841 |
+
encoder_hidden_states=prompt_embeds,
|
842 |
+
class_labels=noise_level,
|
843 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
844 |
+
).sample
|
845 |
+
|
846 |
+
# perform guidance
|
847 |
+
if do_classifier_free_guidance:
|
848 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
849 |
+
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
|
850 |
+
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
|
851 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
852 |
+
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
853 |
+
|
854 |
+
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
|
855 |
+
noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
|
856 |
+
|
857 |
+
# reshape latents
|
858 |
+
bsz, channel, frames, height, width = intermediate_images.shape
|
859 |
+
intermediate_images = intermediate_images.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, height, width)
|
860 |
+
noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, -1, height, width)
|
861 |
+
|
862 |
+
# compute the previous noisy sample x_t -> x_t-1
|
863 |
+
intermediate_images = self.scheduler.step(
|
864 |
+
noise_pred, t, intermediate_images, **extra_step_kwargs
|
865 |
+
).prev_sample
|
866 |
+
|
867 |
+
# reshape latents back
|
868 |
+
intermediate_images = intermediate_images[None, :].reshape(bsz, frames, channel, height, width).permute(0, 2, 1, 3, 4)
|
869 |
+
|
870 |
+
# call the callback, if provided
|
871 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
872 |
+
progress_bar.update()
|
873 |
+
if callback is not None and i % callback_steps == 0:
|
874 |
+
callback(i, t, intermediate_images)
|
875 |
+
|
876 |
+
video_tensor = intermediate_images
|
877 |
+
|
878 |
+
if output_type == "pt":
|
879 |
+
video = video_tensor
|
880 |
+
else:
|
881 |
+
video = tensor2vid(video_tensor)
|
882 |
+
|
883 |
+
# Offload last model to CPU
|
884 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
885 |
+
self.final_offload_hook.offload()
|
886 |
+
|
887 |
+
if not return_dict:
|
888 |
+
return (video,)
|
889 |
+
|
890 |
+
return TextToVideoPipelineOutput(frames=video)
|