Spaces:
Paused
Paused
upload git code base
Browse files- .gitattributes +10 -0
- .gitignore +6 -0
- LICENSE +21 -0
- README.md +154 -13
- assets/.DS_Store +0 -0
- assets/comparison.jpg +3 -0
- assets/embeddings_sd_1.4/cat.pt +3 -0
- assets/embeddings_sd_1.4/dog.pt +3 -0
- assets/embeddings_sd_1.4/horse.pt +3 -0
- assets/embeddings_sd_1.4/zebra.pt +3 -0
- assets/grid_cat2dog.jpg +3 -0
- assets/grid_dog2cat.jpg +3 -0
- assets/grid_horse2zebra.jpg +3 -0
- assets/grid_tree2fall.jpg +3 -0
- assets/grid_zebra2horse.jpg +3 -0
- assets/main.gif +3 -0
- assets/method.jpeg +3 -0
- assets/results_real.jpg +3 -0
- assets/results_syn.jpg +3 -0
- assets/results_teaser.jpg +0 -0
- assets/test_images/cats/cat_1.png +0 -0
- assets/test_images/cats/cat_2.png +0 -0
- assets/test_images/cats/cat_3.png +0 -0
- assets/test_images/cats/cat_4.png +0 -0
- assets/test_images/cats/cat_5.png +0 -0
- assets/test_images/cats/cat_6.png +0 -0
- assets/test_images/cats/cat_7.png +0 -0
- assets/test_images/cats/cat_8.png +0 -0
- assets/test_images/cats/cat_9.png +0 -0
- assets/test_images/dogs/dog_1.png +0 -0
- assets/test_images/dogs/dog_2.png +0 -0
- assets/test_images/dogs/dog_3.png +0 -0
- assets/test_images/dogs/dog_4.png +0 -0
- assets/test_images/dogs/dog_5.png +0 -0
- assets/test_images/dogs/dog_6.png +0 -0
- assets/test_images/dogs/dog_7.png +0 -0
- assets/test_images/dogs/dog_8.png +0 -0
- assets/test_images/dogs/dog_9.png +0 -0
- environment.yml +23 -0
- src/edit_real.py +65 -0
- src/edit_synthetic.py +52 -0
- src/inversion.py +64 -0
- src/make_edit_direction.py +61 -0
- src/utils/base_pipeline.py +322 -0
- src/utils/cross_attention.py +57 -0
- src/utils/ddim_inv.py +140 -0
- src/utils/edit_directions.py +29 -0
- src/utils/edit_pipeline.py +174 -0
- src/utils/scheduler.py +289 -0
.gitattributes
CHANGED
@@ -32,3 +32,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
assets/comparison.jpg filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/grid_cat2dog.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/grid_dog2cat.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/grid_horse2zebra.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/grid_tree2fall.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/grid_zebra2horse.jpg filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/main.gif filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/method.jpeg filter=lfs diff=lfs merge=lfs -text
|
43 |
+
assets/results_real.jpg filter=lfs diff=lfs merge=lfs -text
|
44 |
+
assets/results_syn.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
output
|
2 |
+
scripts
|
3 |
+
src/folder_*.py
|
4 |
+
src/ig_*.py
|
5 |
+
assets/edit_sentences
|
6 |
+
src/utils/edit_pipeline_spatial.py
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 pix2pixzero
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,13 +1,154 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pix2pix-zero
|
2 |
+
|
3 |
+
## [**[website]**](https://pix2pixzero.github.io/)
|
4 |
+
|
5 |
+
|
6 |
+
This is author's reimplementation of "Zero-shot Image-to-Image Translation" using the diffusers library. <br>
|
7 |
+
The results in the paper are based on the [CompVis](https://github.com/CompVis/stable-diffusion) library, which will be released later.
|
8 |
+
|
9 |
+
**[New!]** Code for editing real and synthetic images released!
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
<br>
|
14 |
+
<div class="gif">
|
15 |
+
<p align="center">
|
16 |
+
<img src='assets/main.gif' align="center">
|
17 |
+
</p>
|
18 |
+
</div>
|
19 |
+
|
20 |
+
|
21 |
+
We propose pix2pix-zero, a diffusion-based image-to-image approach that allows users to specify the edit direction on-the-fly (e.g., cat to dog). Our method can directly use pre-trained [Stable Diffusion](https://github.com/CompVis/stable-diffusion), for editing real and synthetic images while preserving the input image's structure. Our method is training-free and prompt-free, as it requires neither manual text prompting for each input image nor costly fine-tuning for each task.
|
22 |
+
|
23 |
+
**TL;DR**: no finetuning required, no text input needed, input structure preserved.
|
24 |
+
|
25 |
+
## Results
|
26 |
+
All our results are based on [stable-diffusion-v1-4](https://github.com/CompVis/stable-diffusion) model. Please the website for more results.
|
27 |
+
|
28 |
+
<div>
|
29 |
+
<p align="center">
|
30 |
+
<img src='assets/results_teaser.jpg' align="center" width=800px>
|
31 |
+
</p>
|
32 |
+
</div>
|
33 |
+
<hr>
|
34 |
+
|
35 |
+
The top row for each of the results below show editing of real images, and the bottom row shows synthetic image editing.
|
36 |
+
<div>
|
37 |
+
<p align="center">
|
38 |
+
<img src='assets/grid_dog2cat.jpg' align="center" width=800px>
|
39 |
+
</p>
|
40 |
+
<p align="center">
|
41 |
+
<img src='assets/grid_zebra2horse.jpg' align="center" width=800px>
|
42 |
+
</p>
|
43 |
+
<p align="center">
|
44 |
+
<img src='assets/grid_cat2dog.jpg' align="center" width=800px>
|
45 |
+
</p>
|
46 |
+
<p align="center">
|
47 |
+
<img src='assets/grid_horse2zebra.jpg' align="center" width=800px>
|
48 |
+
</p>
|
49 |
+
<p align="center">
|
50 |
+
<img src='assets/grid_tree2fall.jpg' align="center" width=800px>
|
51 |
+
</p>
|
52 |
+
</div>
|
53 |
+
|
54 |
+
## Real Image Editing
|
55 |
+
<div>
|
56 |
+
<p align="center">
|
57 |
+
<img src='assets/results_real.jpg' align="center" width=800px>
|
58 |
+
</p>
|
59 |
+
</div>
|
60 |
+
|
61 |
+
## Synthetic Image Editing
|
62 |
+
<div>
|
63 |
+
<p align="center">
|
64 |
+
<img src='assets/results_syn.jpg' align="center" width=800px>
|
65 |
+
</p>
|
66 |
+
</div>
|
67 |
+
|
68 |
+
## Method Details
|
69 |
+
|
70 |
+
Given an input image, we first generate text captions using [BLIP](https://github.com/salesforce/LAVIS) and apply regularized DDIM inversion to obtain our inverted noise map.
|
71 |
+
Then, we obtain reference cross-attention maps that correspoind to the structure of the input image by denoising, guided with the CLIP embeddings
|
72 |
+
of our generated text (c). Next, we denoise with edited text embeddings, while enforcing a loss to match current cross-attention maps with the
|
73 |
+
reference cross-attention maps.
|
74 |
+
|
75 |
+
<div>
|
76 |
+
<p align="center">
|
77 |
+
<img src='assets/method.jpeg' align="center" width=900>
|
78 |
+
</p>
|
79 |
+
</div>
|
80 |
+
|
81 |
+
|
82 |
+
## Getting Started
|
83 |
+
|
84 |
+
**Environment Setup**
|
85 |
+
- We provide a [conda env file](environment.yml) that contains all the required dependencies
|
86 |
+
```
|
87 |
+
conda env create -f environment.yml
|
88 |
+
```
|
89 |
+
- Following this, you can activate the conda environment with the command below.
|
90 |
+
```
|
91 |
+
conda activate pix2pix-zero
|
92 |
+
```
|
93 |
+
|
94 |
+
**Real Image Translation**
|
95 |
+
- First, run the inversion command below to obtain the input noise that reconstructs the image.
|
96 |
+
The command below will save the inversion in the results folder as `output/test_cat/inversion/cat_1.pt`
|
97 |
+
and the BLIP-generated prompt as `output/test_cat/prompt/cat_1.txt`
|
98 |
+
```
|
99 |
+
python src/inversion.py \
|
100 |
+
--input_image "assets/test_images/cats/cat_1.png" \
|
101 |
+
--results_folder "output/test_cat"
|
102 |
+
```
|
103 |
+
- Next, we can perform image editing with the editing direction as shown below.
|
104 |
+
The command below will save the edited image as `output/test_cat/edit/cat_1.png`
|
105 |
+
```
|
106 |
+
python src/edit_real.py \
|
107 |
+
--inversion "output/test_cat/inversion/cat_1.pt" \
|
108 |
+
--prompt "output/test_cat/prompt/cat_1.txt" \
|
109 |
+
--task_name "cat2dog" \
|
110 |
+
--results_folder "output/test_cat/"
|
111 |
+
```
|
112 |
+
|
113 |
+
**Editing Synthetic Images**
|
114 |
+
- Similarly, we can edit the synthetic images generated by Stable Diffusion with the following command.
|
115 |
+
```
|
116 |
+
python src/edit_synthetic.py \
|
117 |
+
--results_folder "output/synth_editing" \
|
118 |
+
--prompt_str "a high resolution painting of a cat in the style of van gough" \
|
119 |
+
--task "cat2dog"
|
120 |
+
```
|
121 |
+
|
122 |
+
### **Tips and Debugging**
|
123 |
+
- **Controlling the Image Structure:**<br>
|
124 |
+
The `--xa_guidance` flag controls the amount of cross-attention guidance to be applied when performing the edit. If the output edited image does not retain the structure from the input, increasing the value will typically address the issue. We recommend changing the value in increments of 0.05.
|
125 |
+
|
126 |
+
- **Improving Image Quality:**<br>
|
127 |
+
If the output image quality is low or has some artifacts, using more steps for both the inversion and editing would be helpful.
|
128 |
+
This can be controlled with the `--num_ddim_steps` flag.
|
129 |
+
|
130 |
+
- **Reducing the VRAM Requirements:**<br>
|
131 |
+
We can reduce the VRAM requirements using lower precision and setting the flag `--use_float_16`.
|
132 |
+
|
133 |
+
<br>
|
134 |
+
|
135 |
+
**Finding Custom Edit Directions**<br>
|
136 |
+
- We provide some pre-computed directions in the assets [folder](assets/embeddings_sd_1.4).
|
137 |
+
To generate new edit directions, users can first generate two files containing a large number of sentences (~1000) and then run the command as shown below.
|
138 |
+
```
|
139 |
+
python src/make_edit_direction.py \
|
140 |
+
--file_source_sentences sentences/apple.txt \
|
141 |
+
--file_target_sentences sentences/orange.txt \
|
142 |
+
--output_folder assets/embeddings_sd_1.4
|
143 |
+
```
|
144 |
+
- After running the above command, you can set the flag `--task apple2orange` for the new edit.
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
## Comparison
|
149 |
+
Comparisons with different baselines, including, SDEdit + word swap, DDIM + word swap, and prompt-to-propmt. Our method successfully applies the edit, while preserving the structure of the input image.
|
150 |
+
<div>
|
151 |
+
<p align="center">
|
152 |
+
<img src='assets/comparison.jpg' align="center" width=900>
|
153 |
+
</p>
|
154 |
+
</div>
|
assets/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
assets/comparison.jpg
ADDED
Git LFS Details
|
assets/embeddings_sd_1.4/cat.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aa9441dc014d5e86567c5ef165e10b50d2a7b3a68d90686d0cd1006792adf334
|
3 |
+
size 237300
|
assets/embeddings_sd_1.4/dog.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:becf079d61d7f35727bcc0d8506ddcdcddb61e62d611840ff3d18eca7fb6338c
|
3 |
+
size 237300
|
assets/embeddings_sd_1.4/horse.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c5d499299544d11371f84674761292b0512055ef45776c700c0b0da164cbf6c7
|
3 |
+
size 118949
|
assets/embeddings_sd_1.4/zebra.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a29f6a11d91f3a276e27326b7623fae9d61a3d253ad430bb868bd40fb7e02fec
|
3 |
+
size 118949
|
assets/grid_cat2dog.jpg
ADDED
Git LFS Details
|
assets/grid_dog2cat.jpg
ADDED
Git LFS Details
|
assets/grid_horse2zebra.jpg
ADDED
Git LFS Details
|
assets/grid_tree2fall.jpg
ADDED
Git LFS Details
|
assets/grid_zebra2horse.jpg
ADDED
Git LFS Details
|
assets/main.gif
ADDED
Git LFS Details
|
assets/method.jpeg
ADDED
Git LFS Details
|
assets/results_real.jpg
ADDED
Git LFS Details
|
assets/results_syn.jpg
ADDED
Git LFS Details
|
assets/results_teaser.jpg
ADDED
assets/test_images/cats/cat_1.png
ADDED
assets/test_images/cats/cat_2.png
ADDED
assets/test_images/cats/cat_3.png
ADDED
assets/test_images/cats/cat_4.png
ADDED
assets/test_images/cats/cat_5.png
ADDED
assets/test_images/cats/cat_6.png
ADDED
assets/test_images/cats/cat_7.png
ADDED
assets/test_images/cats/cat_8.png
ADDED
assets/test_images/cats/cat_9.png
ADDED
assets/test_images/dogs/dog_1.png
ADDED
assets/test_images/dogs/dog_2.png
ADDED
assets/test_images/dogs/dog_3.png
ADDED
assets/test_images/dogs/dog_4.png
ADDED
assets/test_images/dogs/dog_5.png
ADDED
assets/test_images/dogs/dog_6.png
ADDED
assets/test_images/dogs/dog_7.png
ADDED
assets/test_images/dogs/dog_8.png
ADDED
assets/test_images/dogs/dog_9.png
ADDED
environment.yml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: pix2pix-zero
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- pip
|
8 |
+
- pytorch-cuda=11.6
|
9 |
+
- torchvision
|
10 |
+
- pytorch
|
11 |
+
- pip:
|
12 |
+
- accelerate
|
13 |
+
- diffusers
|
14 |
+
- einops
|
15 |
+
- gradio
|
16 |
+
- ipython
|
17 |
+
- numpy
|
18 |
+
- opencv-python-headless
|
19 |
+
- pillow
|
20 |
+
- psutil
|
21 |
+
- tqdm
|
22 |
+
- transformers
|
23 |
+
- salesforce-lavis
|
src/edit_real.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, pdb
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from diffusers import DDIMScheduler
|
10 |
+
from utils.ddim_inv import DDIMInversion
|
11 |
+
from utils.edit_directions import construct_direction
|
12 |
+
from utils.edit_pipeline import EditingPipeline
|
13 |
+
|
14 |
+
|
15 |
+
if __name__=="__main__":
|
16 |
+
parser = argparse.ArgumentParser()
|
17 |
+
parser.add_argument('--inversion', required=True)
|
18 |
+
parser.add_argument('--prompt', type=str, required=True)
|
19 |
+
parser.add_argument('--task_name', type=str, default='cat2dog')
|
20 |
+
parser.add_argument('--results_folder', type=str, default='output/test_cat')
|
21 |
+
parser.add_argument('--num_ddim_steps', type=int, default=50)
|
22 |
+
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
|
23 |
+
parser.add_argument('--xa_guidance', default=0.1, type=float)
|
24 |
+
parser.add_argument('--negative_guidance_scale', default=5.0, type=float)
|
25 |
+
parser.add_argument('--use_float_16', action='store_true')
|
26 |
+
|
27 |
+
args = parser.parse_args()
|
28 |
+
|
29 |
+
os.makedirs(os.path.join(args.results_folder, "edit"), exist_ok=True)
|
30 |
+
os.makedirs(os.path.join(args.results_folder, "reconstruction"), exist_ok=True)
|
31 |
+
|
32 |
+
if args.use_float_16:
|
33 |
+
torch_dtype = torch.float16
|
34 |
+
else:
|
35 |
+
torch_dtype = torch.float32
|
36 |
+
|
37 |
+
# if the inversion is a folder, the prompt should also be a folder
|
38 |
+
assert (os.path.isdir(args.inversion)==os.path.isdir(args.prompt)), "If the inversion is a folder, the prompt should also be a folder"
|
39 |
+
if os.path.isdir(args.inversion):
|
40 |
+
l_inv_paths = sorted(glob(os.path.join(args.inversion, "*.pt")))
|
41 |
+
l_bnames = [os.path.basename(x) for x in l_inv_paths]
|
42 |
+
l_prompt_paths = [os.path.join(args.prompt, x.replace(".pt",".txt")) for x in l_bnames]
|
43 |
+
else:
|
44 |
+
l_inv_paths = [args.inversion]
|
45 |
+
l_prompt_paths = [args.prompt]
|
46 |
+
|
47 |
+
# Make the editing pipeline
|
48 |
+
pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
|
49 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
50 |
+
|
51 |
+
|
52 |
+
for inv_path, prompt_path in zip(l_inv_paths, l_prompt_paths):
|
53 |
+
prompt_str = open(prompt_path).read().strip()
|
54 |
+
rec_pil, edit_pil = pipe(prompt_str,
|
55 |
+
num_inference_steps=args.num_ddim_steps,
|
56 |
+
x_in=torch.load(inv_path).unsqueeze(0),
|
57 |
+
edit_dir=construct_direction(args.task_name),
|
58 |
+
guidance_amount=args.xa_guidance,
|
59 |
+
guidance_scale=args.negative_guidance_scale,
|
60 |
+
negative_prompt=prompt_str # use the unedited prompt for the negative prompt
|
61 |
+
)
|
62 |
+
|
63 |
+
bname = os.path.basename(args.inversion).split(".")[0]
|
64 |
+
edit_pil[0].save(os.path.join(args.results_folder, f"edit/{bname}.png"))
|
65 |
+
rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction/{bname}.png"))
|
src/edit_synthetic.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, pdb
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from diffusers import DDIMScheduler
|
10 |
+
from utils.edit_directions import construct_direction
|
11 |
+
from utils.edit_pipeline import EditingPipeline
|
12 |
+
|
13 |
+
|
14 |
+
if __name__=="__main__":
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument('--prompt_str', type=str, required=True)
|
17 |
+
parser.add_argument('--random_seed', default=0)
|
18 |
+
parser.add_argument('--task_name', type=str, default='cat2dog')
|
19 |
+
parser.add_argument('--results_folder', type=str, default='output/test_cat')
|
20 |
+
parser.add_argument('--num_ddim_steps', type=int, default=50)
|
21 |
+
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
|
22 |
+
parser.add_argument('--xa_guidance', default=0.15, type=float)
|
23 |
+
parser.add_argument('--negative_guidance_scale', default=5.0, type=float)
|
24 |
+
parser.add_argument('--use_float_16', action='store_true')
|
25 |
+
args = parser.parse_args()
|
26 |
+
|
27 |
+
os.makedirs(args.results_folder, exist_ok=True)
|
28 |
+
|
29 |
+
if args.use_float_16:
|
30 |
+
torch_dtype = torch.float16
|
31 |
+
else:
|
32 |
+
torch_dtype = torch.float32
|
33 |
+
|
34 |
+
# make the input noise map
|
35 |
+
torch.cuda.manual_seed(args.random_seed)
|
36 |
+
x = torch.randn((1,4,64,64), device="cuda")
|
37 |
+
|
38 |
+
# Make the editing pipeline
|
39 |
+
pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
|
40 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
41 |
+
|
42 |
+
rec_pil, edit_pil = pipe(args.prompt_str,
|
43 |
+
num_inference_steps=args.num_ddim_steps,
|
44 |
+
x_in=x,
|
45 |
+
edit_dir=construct_direction(args.task_name),
|
46 |
+
guidance_amount=args.xa_guidance,
|
47 |
+
guidance_scale=args.negative_guidance_scale,
|
48 |
+
negative_prompt="" # use the empty string for the negative prompt
|
49 |
+
)
|
50 |
+
|
51 |
+
edit_pil[0].save(os.path.join(args.results_folder, f"edit.png"))
|
52 |
+
rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction.png"))
|
src/inversion.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, pdb
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from lavis.models import load_model_and_preprocess
|
10 |
+
|
11 |
+
from utils.ddim_inv import DDIMInversion
|
12 |
+
from utils.scheduler import DDIMInverseScheduler
|
13 |
+
|
14 |
+
if __name__=="__main__":
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument('--input_image', type=str, default='assets/test_images/cat_a.png')
|
17 |
+
parser.add_argument('--results_folder', type=str, default='output/test_cat')
|
18 |
+
parser.add_argument('--num_ddim_steps', type=int, default=50)
|
19 |
+
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
|
20 |
+
parser.add_argument('--use_float_16', action='store_true')
|
21 |
+
args = parser.parse_args()
|
22 |
+
|
23 |
+
# make the output folders
|
24 |
+
os.makedirs(os.path.join(args.results_folder, "inversion"), exist_ok=True)
|
25 |
+
os.makedirs(os.path.join(args.results_folder, "prompt"), exist_ok=True)
|
26 |
+
|
27 |
+
if args.use_float_16:
|
28 |
+
torch_dtype = torch.float16
|
29 |
+
else:
|
30 |
+
torch_dtype = torch.float32
|
31 |
+
|
32 |
+
|
33 |
+
# load the BLIP model
|
34 |
+
model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device("cuda"))
|
35 |
+
# make the DDIM inversion pipeline
|
36 |
+
pipe = DDIMInversion.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
|
37 |
+
pipe.scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
|
38 |
+
|
39 |
+
|
40 |
+
# if the input is a folder, collect all the images as a list
|
41 |
+
if os.path.isdir(args.input_image):
|
42 |
+
l_img_paths = sorted(glob(os.path.join(args.input_image, "*.png")))
|
43 |
+
else:
|
44 |
+
l_img_paths = [args.input_image]
|
45 |
+
|
46 |
+
|
47 |
+
for img_path in l_img_paths:
|
48 |
+
bname = os.path.basename(args.input_image).split(".")[0]
|
49 |
+
img = Image.open(args.input_image).resize((512,512), Image.Resampling.LANCZOS)
|
50 |
+
# generate the caption
|
51 |
+
_image = vis_processors["eval"](img).unsqueeze(0).cuda()
|
52 |
+
prompt_str = model_blip.generate({"image": _image})[0]
|
53 |
+
x_inv, x_inv_image, x_dec_img = pipe(
|
54 |
+
prompt_str,
|
55 |
+
guidance_scale=1,
|
56 |
+
num_inversion_steps=args.num_ddim_steps,
|
57 |
+
img=img,
|
58 |
+
torch_dtype=torch_dtype
|
59 |
+
)
|
60 |
+
# save the inversion
|
61 |
+
torch.save(x_inv[0], os.path.join(args.results_folder, f"inversion/{bname}.pt"))
|
62 |
+
# save the prompt string
|
63 |
+
with open(os.path.join(args.results_folder, f"prompt/{bname}.txt"), "w") as f:
|
64 |
+
f.write(prompt_str)
|
src/make_edit_direction.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, pdb
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from diffusers import DDIMScheduler
|
10 |
+
from utils.edit_pipeline import EditingPipeline
|
11 |
+
|
12 |
+
|
13 |
+
## convert sentences to sentence embeddings
|
14 |
+
def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"):
|
15 |
+
with torch.no_grad():
|
16 |
+
l_embeddings = []
|
17 |
+
for sent in l_sentences:
|
18 |
+
text_inputs = tokenizer(
|
19 |
+
sent,
|
20 |
+
padding="max_length",
|
21 |
+
max_length=tokenizer.model_max_length,
|
22 |
+
truncation=True,
|
23 |
+
return_tensors="pt",
|
24 |
+
)
|
25 |
+
text_input_ids = text_inputs.input_ids
|
26 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
|
27 |
+
l_embeddings.append(prompt_embeds)
|
28 |
+
return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0)
|
29 |
+
|
30 |
+
|
31 |
+
if __name__=="__main__":
|
32 |
+
parser = argparse.ArgumentParser()
|
33 |
+
parser.add_argument('--file_source_sentences', required=True)
|
34 |
+
parser.add_argument('--file_target_sentences', required=True)
|
35 |
+
parser.add_argument('--output_folder', required=True)
|
36 |
+
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
|
37 |
+
args = parser.parse_args()
|
38 |
+
|
39 |
+
# load the model
|
40 |
+
pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch.float16).to("cuda")
|
41 |
+
bname_src = os.path.basename(args.file_source_sentences).strip(".txt")
|
42 |
+
outf_src = os.path.join(args.output_folder, bname_src+".pt")
|
43 |
+
if os.path.exists(outf_src):
|
44 |
+
print(f"Skipping source file {outf_src} as it already exists")
|
45 |
+
else:
|
46 |
+
with open(args.file_source_sentences, "r") as f:
|
47 |
+
l_sents = [x.strip() for x in f.readlines()]
|
48 |
+
mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda")
|
49 |
+
print(mean_emb.shape)
|
50 |
+
torch.save(mean_emb, outf_src)
|
51 |
+
|
52 |
+
bname_tgt = os.path.basename(args.file_target_sentences).strip(".txt")
|
53 |
+
outf_tgt = os.path.join(args.output_folder, bname_tgt+".pt")
|
54 |
+
if os.path.exists(outf_tgt):
|
55 |
+
print(f"Skipping target file {outf_tgt} as it already exists")
|
56 |
+
else:
|
57 |
+
with open(args.file_target_sentences, "r") as f:
|
58 |
+
l_sents = [x.strip() for x in f.readlines()]
|
59 |
+
mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda")
|
60 |
+
print(mean_emb.shape)
|
61 |
+
torch.save(mean_emb, outf_tgt)
|
src/utils/base_pipeline.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import inspect
|
4 |
+
from packaging import version
|
5 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
6 |
+
|
7 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
8 |
+
from diffusers import DiffusionPipeline
|
9 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
10 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
11 |
+
from diffusers.utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring
|
12 |
+
from diffusers import StableDiffusionPipeline
|
13 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class BasePipeline(DiffusionPipeline):
|
18 |
+
_optional_components = ["safety_checker", "feature_extractor"]
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
vae: AutoencoderKL,
|
22 |
+
text_encoder: CLIPTextModel,
|
23 |
+
tokenizer: CLIPTokenizer,
|
24 |
+
unet: UNet2DConditionModel,
|
25 |
+
scheduler: KarrasDiffusionSchedulers,
|
26 |
+
safety_checker: StableDiffusionSafetyChecker,
|
27 |
+
feature_extractor: CLIPFeatureExtractor,
|
28 |
+
requires_safety_checker: bool = True,
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
33 |
+
deprecation_message = (
|
34 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
35 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
36 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
37 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
38 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
39 |
+
" file"
|
40 |
+
)
|
41 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
42 |
+
new_config = dict(scheduler.config)
|
43 |
+
new_config["steps_offset"] = 1
|
44 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
45 |
+
|
46 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
47 |
+
deprecation_message = (
|
48 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
49 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
50 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
51 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
52 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
53 |
+
)
|
54 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
55 |
+
new_config = dict(scheduler.config)
|
56 |
+
new_config["clip_sample"] = False
|
57 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
58 |
+
|
59 |
+
if safety_checker is None and requires_safety_checker:
|
60 |
+
logger.warning(
|
61 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
62 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
63 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
64 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
65 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
66 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
67 |
+
)
|
68 |
+
|
69 |
+
if safety_checker is not None and feature_extractor is None:
|
70 |
+
raise ValueError(
|
71 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
72 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
73 |
+
)
|
74 |
+
|
75 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
76 |
+
version.parse(unet.config._diffusers_version).base_version
|
77 |
+
) < version.parse("0.9.0.dev0")
|
78 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
79 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
80 |
+
deprecation_message = (
|
81 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
82 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
83 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
84 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
85 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
86 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
87 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
88 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
89 |
+
" the `unet/config.json` file"
|
90 |
+
)
|
91 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
92 |
+
new_config = dict(unet.config)
|
93 |
+
new_config["sample_size"] = 64
|
94 |
+
unet._internal_dict = FrozenDict(new_config)
|
95 |
+
|
96 |
+
self.register_modules(
|
97 |
+
vae=vae,
|
98 |
+
text_encoder=text_encoder,
|
99 |
+
tokenizer=tokenizer,
|
100 |
+
unet=unet,
|
101 |
+
scheduler=scheduler,
|
102 |
+
safety_checker=safety_checker,
|
103 |
+
feature_extractor=feature_extractor,
|
104 |
+
)
|
105 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
106 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
107 |
+
|
108 |
+
@property
|
109 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
110 |
+
def _execution_device(self):
|
111 |
+
r"""
|
112 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
113 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
114 |
+
hooks.
|
115 |
+
"""
|
116 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
117 |
+
return self.device
|
118 |
+
for module in self.unet.modules():
|
119 |
+
if (
|
120 |
+
hasattr(module, "_hf_hook")
|
121 |
+
and hasattr(module._hf_hook, "execution_device")
|
122 |
+
and module._hf_hook.execution_device is not None
|
123 |
+
):
|
124 |
+
return torch.device(module._hf_hook.execution_device)
|
125 |
+
return self.device
|
126 |
+
|
127 |
+
|
128 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
129 |
+
def _encode_prompt(
|
130 |
+
self,
|
131 |
+
prompt,
|
132 |
+
device,
|
133 |
+
num_images_per_prompt,
|
134 |
+
do_classifier_free_guidance,
|
135 |
+
negative_prompt=None,
|
136 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
137 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
138 |
+
):
|
139 |
+
r"""
|
140 |
+
Encodes the prompt into text encoder hidden states.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
prompt (`str` or `List[str]`, *optional*):
|
144 |
+
prompt to be encoded
|
145 |
+
device: (`torch.device`):
|
146 |
+
torch device
|
147 |
+
num_images_per_prompt (`int`):
|
148 |
+
number of images that should be generated per prompt
|
149 |
+
do_classifier_free_guidance (`bool`):
|
150 |
+
whether to use classifier free guidance or not
|
151 |
+
negative_ prompt (`str` or `List[str]`, *optional*):
|
152 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
153 |
+
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
154 |
+
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
155 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
156 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
157 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
158 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
159 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
160 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
161 |
+
argument.
|
162 |
+
"""
|
163 |
+
if prompt is not None and isinstance(prompt, str):
|
164 |
+
batch_size = 1
|
165 |
+
elif prompt is not None and isinstance(prompt, list):
|
166 |
+
batch_size = len(prompt)
|
167 |
+
else:
|
168 |
+
batch_size = prompt_embeds.shape[0]
|
169 |
+
|
170 |
+
if prompt_embeds is None:
|
171 |
+
text_inputs = self.tokenizer(
|
172 |
+
prompt,
|
173 |
+
padding="max_length",
|
174 |
+
max_length=self.tokenizer.model_max_length,
|
175 |
+
truncation=True,
|
176 |
+
return_tensors="pt",
|
177 |
+
)
|
178 |
+
text_input_ids = text_inputs.input_ids
|
179 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
180 |
+
|
181 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
182 |
+
text_input_ids, untruncated_ids
|
183 |
+
):
|
184 |
+
removed_text = self.tokenizer.batch_decode(
|
185 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
186 |
+
)
|
187 |
+
logger.warning(
|
188 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
189 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
190 |
+
)
|
191 |
+
|
192 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
193 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
194 |
+
else:
|
195 |
+
attention_mask = None
|
196 |
+
|
197 |
+
prompt_embeds = self.text_encoder(
|
198 |
+
text_input_ids.to(device),
|
199 |
+
attention_mask=attention_mask,
|
200 |
+
)
|
201 |
+
prompt_embeds = prompt_embeds[0]
|
202 |
+
|
203 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
204 |
+
|
205 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
206 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
207 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
208 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
209 |
+
|
210 |
+
# get unconditional embeddings for classifier free guidance
|
211 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
212 |
+
uncond_tokens: List[str]
|
213 |
+
if negative_prompt is None:
|
214 |
+
uncond_tokens = [""] * batch_size
|
215 |
+
elif type(prompt) is not type(negative_prompt):
|
216 |
+
raise TypeError(
|
217 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
218 |
+
f" {type(prompt)}."
|
219 |
+
)
|
220 |
+
elif isinstance(negative_prompt, str):
|
221 |
+
uncond_tokens = [negative_prompt]
|
222 |
+
elif batch_size != len(negative_prompt):
|
223 |
+
raise ValueError(
|
224 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
225 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
226 |
+
" the batch size of `prompt`."
|
227 |
+
)
|
228 |
+
else:
|
229 |
+
uncond_tokens = negative_prompt
|
230 |
+
|
231 |
+
max_length = prompt_embeds.shape[1]
|
232 |
+
uncond_input = self.tokenizer(
|
233 |
+
uncond_tokens,
|
234 |
+
padding="max_length",
|
235 |
+
max_length=max_length,
|
236 |
+
truncation=True,
|
237 |
+
return_tensors="pt",
|
238 |
+
)
|
239 |
+
|
240 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
241 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
242 |
+
else:
|
243 |
+
attention_mask = None
|
244 |
+
|
245 |
+
negative_prompt_embeds = self.text_encoder(
|
246 |
+
uncond_input.input_ids.to(device),
|
247 |
+
attention_mask=attention_mask,
|
248 |
+
)
|
249 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
250 |
+
|
251 |
+
if do_classifier_free_guidance:
|
252 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
253 |
+
seq_len = negative_prompt_embeds.shape[1]
|
254 |
+
|
255 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
256 |
+
|
257 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
258 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
259 |
+
|
260 |
+
# For classifier free guidance, we need to do two forward passes.
|
261 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
262 |
+
# to avoid doing two forward passes
|
263 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
264 |
+
|
265 |
+
return prompt_embeds
|
266 |
+
|
267 |
+
|
268 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
269 |
+
def decode_latents(self, latents):
|
270 |
+
latents = 1 / 0.18215 * latents
|
271 |
+
image = self.vae.decode(latents).sample
|
272 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
273 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
274 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
275 |
+
return image
|
276 |
+
|
277 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
278 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
279 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
280 |
+
raise ValueError(
|
281 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
282 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
283 |
+
)
|
284 |
+
|
285 |
+
if latents is None:
|
286 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
287 |
+
else:
|
288 |
+
latents = latents.to(device)
|
289 |
+
|
290 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
291 |
+
latents = latents * self.scheduler.init_noise_sigma
|
292 |
+
return latents
|
293 |
+
|
294 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
295 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
296 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
297 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
298 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
299 |
+
# and should be between [0, 1]
|
300 |
+
|
301 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
302 |
+
extra_step_kwargs = {}
|
303 |
+
if accepts_eta:
|
304 |
+
extra_step_kwargs["eta"] = eta
|
305 |
+
|
306 |
+
# check if the scheduler accepts generator
|
307 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
308 |
+
if accepts_generator:
|
309 |
+
extra_step_kwargs["generator"] = generator
|
310 |
+
return extra_step_kwargs
|
311 |
+
|
312 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
313 |
+
def run_safety_checker(self, image, device, dtype):
|
314 |
+
if self.safety_checker is not None:
|
315 |
+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
316 |
+
image, has_nsfw_concept = self.safety_checker(
|
317 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
318 |
+
)
|
319 |
+
else:
|
320 |
+
has_nsfw_concept = None
|
321 |
+
return image, has_nsfw_concept
|
322 |
+
|
src/utils/cross_attention.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers.models.attention import CrossAttention
|
3 |
+
|
4 |
+
class MyCrossAttnProcessor:
|
5 |
+
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
6 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
7 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
8 |
+
|
9 |
+
query = attn.to_q(hidden_states)
|
10 |
+
|
11 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
12 |
+
key = attn.to_k(encoder_hidden_states)
|
13 |
+
value = attn.to_v(encoder_hidden_states)
|
14 |
+
|
15 |
+
query = attn.head_to_batch_dim(query)
|
16 |
+
key = attn.head_to_batch_dim(key)
|
17 |
+
value = attn.head_to_batch_dim(value)
|
18 |
+
|
19 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
20 |
+
# new bookkeeping to save the attn probs
|
21 |
+
attn.attn_probs = attention_probs
|
22 |
+
|
23 |
+
hidden_states = torch.bmm(attention_probs, value)
|
24 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
25 |
+
|
26 |
+
# linear proj
|
27 |
+
hidden_states = attn.to_out[0](hidden_states)
|
28 |
+
# dropout
|
29 |
+
hidden_states = attn.to_out[1](hidden_states)
|
30 |
+
|
31 |
+
return hidden_states
|
32 |
+
|
33 |
+
|
34 |
+
"""
|
35 |
+
A function that prepares a U-Net model for training by enabling gradient computation
|
36 |
+
for a specified set of parameters and setting the forward pass to be performed by a
|
37 |
+
custom cross attention processor.
|
38 |
+
|
39 |
+
Parameters:
|
40 |
+
unet: A U-Net model.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
unet: The prepared U-Net model.
|
44 |
+
"""
|
45 |
+
def prep_unet(unet):
|
46 |
+
# set the gradients for XA maps to be true
|
47 |
+
for name, params in unet.named_parameters():
|
48 |
+
if 'attn2' in name:
|
49 |
+
params.requires_grad = True
|
50 |
+
else:
|
51 |
+
params.requires_grad = False
|
52 |
+
# replace the fwd function
|
53 |
+
for name, module in unet.named_modules():
|
54 |
+
module_name = type(module).__name__
|
55 |
+
if module_name == "CrossAttention":
|
56 |
+
module.set_processor(MyCrossAttnProcessor())
|
57 |
+
return unet
|
src/utils/ddim_inv.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from random import randrange
|
6 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
7 |
+
from diffusers import DDIMScheduler
|
8 |
+
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
|
9 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
10 |
+
sys.path.insert(0, "src/utils")
|
11 |
+
from base_pipeline import BasePipeline
|
12 |
+
from cross_attention import prep_unet
|
13 |
+
|
14 |
+
|
15 |
+
class DDIMInversion(BasePipeline):
|
16 |
+
|
17 |
+
def auto_corr_loss(self, x, random_shift=True):
|
18 |
+
B,C,H,W = x.shape
|
19 |
+
assert B==1
|
20 |
+
x = x.squeeze(0)
|
21 |
+
# x must be shape [C,H,W] now
|
22 |
+
reg_loss = 0.0
|
23 |
+
for ch_idx in range(x.shape[0]):
|
24 |
+
noise = x[ch_idx][None, None,:,:]
|
25 |
+
while True:
|
26 |
+
if random_shift: roll_amount = randrange(noise.shape[2]//2)
|
27 |
+
else: roll_amount = 1
|
28 |
+
reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=2)).mean()**2
|
29 |
+
reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=3)).mean()**2
|
30 |
+
if noise.shape[2] <= 8:
|
31 |
+
break
|
32 |
+
noise = F.avg_pool2d(noise, kernel_size=2)
|
33 |
+
return reg_loss
|
34 |
+
|
35 |
+
def kl_divergence(self, x):
|
36 |
+
_mu = x.mean()
|
37 |
+
_var = x.var()
|
38 |
+
return _var + _mu**2 - 1 - torch.log(_var+1e-7)
|
39 |
+
|
40 |
+
|
41 |
+
def __call__(
|
42 |
+
self,
|
43 |
+
prompt: Union[str, List[str]] = None,
|
44 |
+
num_inversion_steps: int = 50,
|
45 |
+
guidance_scale: float = 7.5,
|
46 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
47 |
+
num_images_per_prompt: Optional[int] = 1,
|
48 |
+
eta: float = 0.0,
|
49 |
+
output_type: Optional[str] = "pil",
|
50 |
+
return_dict: bool = True,
|
51 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
52 |
+
img=None, # the input image as a PIL image
|
53 |
+
torch_dtype=torch.float32,
|
54 |
+
|
55 |
+
# inversion regularization parameters
|
56 |
+
lambda_ac: float = 20.0,
|
57 |
+
lambda_kl: float = 20.0,
|
58 |
+
num_reg_steps: int = 5,
|
59 |
+
num_ac_rolls: int = 5,
|
60 |
+
):
|
61 |
+
|
62 |
+
# 0. modify the unet to be useful :D
|
63 |
+
self.unet = prep_unet(self.unet)
|
64 |
+
|
65 |
+
# set the scheduler to be the Inverse DDIM scheduler
|
66 |
+
# self.scheduler = MyDDIMScheduler.from_config(self.scheduler.config)
|
67 |
+
|
68 |
+
device = self._execution_device
|
69 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
70 |
+
self.scheduler.set_timesteps(num_inversion_steps, device=device)
|
71 |
+
timesteps = self.scheduler.timesteps
|
72 |
+
|
73 |
+
# Encode the input image with the first stage model
|
74 |
+
x0 = np.array(img)/255
|
75 |
+
x0 = torch.from_numpy(x0).type(torch_dtype).permute(2, 0, 1).unsqueeze(dim=0).repeat(1, 1, 1, 1).cuda()
|
76 |
+
x0 = (x0 - 0.5) * 2.
|
77 |
+
with torch.no_grad():
|
78 |
+
x0_enc = self.vae.encode(x0).latent_dist.sample().to(device, torch_dtype)
|
79 |
+
latents = x0_enc = 0.18215 * x0_enc
|
80 |
+
|
81 |
+
# Decode and return the image
|
82 |
+
with torch.no_grad():
|
83 |
+
x0_dec = self.decode_latents(x0_enc.detach())
|
84 |
+
image_x0_dec = self.numpy_to_pil(x0_dec)
|
85 |
+
|
86 |
+
with torch.no_grad():
|
87 |
+
prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt).to(device)
|
88 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(None, eta)
|
89 |
+
|
90 |
+
# Do the inversion
|
91 |
+
num_warmup_steps = len(timesteps) - num_inversion_steps * self.scheduler.order # should be 0?
|
92 |
+
with self.progress_bar(total=num_inversion_steps) as progress_bar:
|
93 |
+
for i, t in enumerate(timesteps.flip(0)[1:-1]):
|
94 |
+
# expand the latents if we are doing classifier free guidance
|
95 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
96 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
97 |
+
|
98 |
+
# predict the noise residual
|
99 |
+
with torch.no_grad():
|
100 |
+
noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample
|
101 |
+
|
102 |
+
# perform guidance
|
103 |
+
if do_classifier_free_guidance:
|
104 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
105 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
106 |
+
|
107 |
+
# regularization of the noise prediction
|
108 |
+
e_t = noise_pred
|
109 |
+
for _outer in range(num_reg_steps):
|
110 |
+
if lambda_ac>0:
|
111 |
+
for _inner in range(num_ac_rolls):
|
112 |
+
_var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
|
113 |
+
l_ac = self.auto_corr_loss(_var)
|
114 |
+
l_ac.backward()
|
115 |
+
_grad = _var.grad.detach()/num_ac_rolls
|
116 |
+
e_t = e_t - lambda_ac*_grad
|
117 |
+
if lambda_kl>0:
|
118 |
+
_var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
|
119 |
+
l_kld = self.kl_divergence(_var)
|
120 |
+
l_kld.backward()
|
121 |
+
_grad = _var.grad.detach()
|
122 |
+
e_t = e_t - lambda_kl*_grad
|
123 |
+
e_t = e_t.detach()
|
124 |
+
noise_pred = e_t
|
125 |
+
|
126 |
+
# compute the previous noisy sample x_t -> x_t-1
|
127 |
+
latents = self.scheduler.step(noise_pred, t, latents, reverse=True, **extra_step_kwargs).prev_sample
|
128 |
+
|
129 |
+
# call the callback, if provided
|
130 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
131 |
+
progress_bar.update()
|
132 |
+
|
133 |
+
|
134 |
+
x_inv = latents.detach().clone()
|
135 |
+
# reconstruct the image
|
136 |
+
|
137 |
+
# 8. Post-processing
|
138 |
+
image = self.decode_latents(latents.detach())
|
139 |
+
image = self.numpy_to_pil(image)
|
140 |
+
return x_inv, image, image_x0_dec
|
src/utils/edit_directions.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
"""
|
6 |
+
This function takes in a task name and returns the direction in the embedding space that transforms class A to class B for the given task.
|
7 |
+
|
8 |
+
Parameters:
|
9 |
+
task_name (str): name of the task for which direction is to be constructed.
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
torch.Tensor: A tensor representing the direction in the embedding space that transforms class A to class B.
|
13 |
+
|
14 |
+
Examples:
|
15 |
+
>>> construct_direction("cat2dog")
|
16 |
+
"""
|
17 |
+
def construct_direction(task_name):
|
18 |
+
if task_name=="cat2dog":
|
19 |
+
emb_dir = f"assets/embeddings_sd_1.4"
|
20 |
+
embs_a = torch.load(os.path.join(emb_dir, f"cat.pt"))
|
21 |
+
embs_b = torch.load(os.path.join(emb_dir, f"dog.pt"))
|
22 |
+
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
|
23 |
+
elif task_name=="dog2cat":
|
24 |
+
emb_dir = f"assets/embeddings_sd_1.4"
|
25 |
+
embs_a = torch.load(os.path.join(emb_dir, f"dog.pt"))
|
26 |
+
embs_b = torch.load(os.path.join(emb_dir, f"cat.pt"))
|
27 |
+
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
|
28 |
+
else:
|
29 |
+
raise NotImplementedError
|
src/utils/edit_pipeline.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb, sys
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
6 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
7 |
+
sys.path.insert(0, "src/utils")
|
8 |
+
from base_pipeline import BasePipeline
|
9 |
+
from cross_attention import prep_unet
|
10 |
+
|
11 |
+
|
12 |
+
class EditingPipeline(BasePipeline):
|
13 |
+
def __call__(
|
14 |
+
self,
|
15 |
+
prompt: Union[str, List[str]] = None,
|
16 |
+
height: Optional[int] = None,
|
17 |
+
width: Optional[int] = None,
|
18 |
+
num_inference_steps: int = 50,
|
19 |
+
guidance_scale: float = 7.5,
|
20 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
21 |
+
num_images_per_prompt: Optional[int] = 1,
|
22 |
+
eta: float = 0.0,
|
23 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
24 |
+
latents: Optional[torch.FloatTensor] = None,
|
25 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
26 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
27 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
28 |
+
|
29 |
+
# pix2pix parameters
|
30 |
+
guidance_amount=0.1,
|
31 |
+
edit_dir=None,
|
32 |
+
x_in=None,
|
33 |
+
|
34 |
+
):
|
35 |
+
|
36 |
+
x_in.to(dtype=self.unet.dtype, device=self._execution_device)
|
37 |
+
|
38 |
+
# 0. modify the unet to be useful :D
|
39 |
+
self.unet = prep_unet(self.unet)
|
40 |
+
|
41 |
+
# 1. setup all caching objects
|
42 |
+
d_ref_t2attn = {} # reference cross attention maps
|
43 |
+
|
44 |
+
# 2. Default height and width to unet
|
45 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
46 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
47 |
+
|
48 |
+
# TODO: add the input checker function
|
49 |
+
# self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds )
|
50 |
+
|
51 |
+
# 2. Define call parameters
|
52 |
+
if prompt is not None and isinstance(prompt, str):
|
53 |
+
batch_size = 1
|
54 |
+
elif prompt is not None and isinstance(prompt, list):
|
55 |
+
batch_size = len(prompt)
|
56 |
+
else:
|
57 |
+
batch_size = prompt_embeds.shape[0]
|
58 |
+
|
59 |
+
device = self._execution_device
|
60 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
61 |
+
x_in = x_in.to(dtype=self.unet.dtype, device=self._execution_device)
|
62 |
+
# 3. Encode input prompt = 2x77x1024
|
63 |
+
prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,)
|
64 |
+
|
65 |
+
# 4. Prepare timesteps
|
66 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
67 |
+
timesteps = self.scheduler.timesteps
|
68 |
+
|
69 |
+
# 5. Prepare latent variables
|
70 |
+
num_channels_latents = self.unet.in_channels
|
71 |
+
|
72 |
+
# randomly sample a latent code if not provided
|
73 |
+
latents = self.prepare_latents(batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, x_in,)
|
74 |
+
|
75 |
+
latents_init = latents.clone()
|
76 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
77 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
78 |
+
|
79 |
+
# 7. First Denoising loop for getting the reference cross attention maps
|
80 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
81 |
+
with torch.no_grad():
|
82 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
83 |
+
for i, t in enumerate(timesteps):
|
84 |
+
# expand the latents if we are doing classifier free guidance
|
85 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
86 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
87 |
+
|
88 |
+
# predict the noise residual
|
89 |
+
noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample
|
90 |
+
|
91 |
+
# add the cross attention map to the dictionary
|
92 |
+
d_ref_t2attn[t.item()] = {}
|
93 |
+
for name, module in self.unet.named_modules():
|
94 |
+
module_name = type(module).__name__
|
95 |
+
if module_name == "CrossAttention" and 'attn2' in name:
|
96 |
+
attn_mask = module.attn_probs # size is num_channel,s*s,77
|
97 |
+
d_ref_t2attn[t.item()][name] = attn_mask.detach().cpu()
|
98 |
+
|
99 |
+
# perform guidance
|
100 |
+
if do_classifier_free_guidance:
|
101 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
102 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
103 |
+
|
104 |
+
# compute the previous noisy sample x_t -> x_t-1
|
105 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
106 |
+
|
107 |
+
# call the callback, if provided
|
108 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
109 |
+
progress_bar.update()
|
110 |
+
|
111 |
+
# make the reference image (reconstruction)
|
112 |
+
image_rec = self.numpy_to_pil(self.decode_latents(latents.detach()))
|
113 |
+
|
114 |
+
prompt_embeds_edit = prompt_embeds.clone()
|
115 |
+
#add the edit only to the second prompt, idx 0 is the negative prompt
|
116 |
+
prompt_embeds_edit[1:2] += edit_dir
|
117 |
+
|
118 |
+
latents = latents_init
|
119 |
+
# Second denoising loop for editing the text prompt
|
120 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
121 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
122 |
+
for i, t in enumerate(timesteps):
|
123 |
+
# expand the latents if we are doing classifier free guidance
|
124 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
125 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
126 |
+
|
127 |
+
x_in = latent_model_input.detach().clone()
|
128 |
+
x_in.requires_grad = True
|
129 |
+
|
130 |
+
opt = torch.optim.SGD([x_in], lr=guidance_amount)
|
131 |
+
|
132 |
+
# predict the noise residual
|
133 |
+
noise_pred = self.unet(x_in,t,encoder_hidden_states=prompt_embeds_edit.detach(),cross_attention_kwargs=cross_attention_kwargs,).sample
|
134 |
+
|
135 |
+
loss = 0.0
|
136 |
+
for name, module in self.unet.named_modules():
|
137 |
+
module_name = type(module).__name__
|
138 |
+
if module_name == "CrossAttention" and 'attn2' in name:
|
139 |
+
curr = module.attn_probs # size is num_channel,s*s,77
|
140 |
+
ref = d_ref_t2attn[t.item()][name].detach().cuda()
|
141 |
+
loss += ((curr-ref)**2).sum((1,2)).mean(0)
|
142 |
+
loss.backward(retain_graph=False)
|
143 |
+
opt.step()
|
144 |
+
|
145 |
+
# recompute the noise
|
146 |
+
with torch.no_grad():
|
147 |
+
noise_pred = self.unet(x_in.detach(),t,encoder_hidden_states=prompt_embeds_edit,cross_attention_kwargs=cross_attention_kwargs,).sample
|
148 |
+
|
149 |
+
latents = x_in.detach().chunk(2)[0]
|
150 |
+
|
151 |
+
# perform guidance
|
152 |
+
if do_classifier_free_guidance:
|
153 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
154 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
155 |
+
|
156 |
+
# compute the previous noisy sample x_t -> x_t-1
|
157 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
158 |
+
|
159 |
+
# call the callback, if provided
|
160 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
161 |
+
progress_bar.update()
|
162 |
+
|
163 |
+
|
164 |
+
# 8. Post-processing
|
165 |
+
image = self.decode_latents(latents.detach())
|
166 |
+
|
167 |
+
# 9. Run safety checker
|
168 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
169 |
+
|
170 |
+
# 10. Convert to PIL
|
171 |
+
image_edit = self.numpy_to_pil(image)
|
172 |
+
|
173 |
+
|
174 |
+
return image_rec, image_edit
|
src/utils/scheduler.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
16 |
+
# and https://github.com/hojonathanho/diffusion
|
17 |
+
import os, sys, pdb
|
18 |
+
import math
|
19 |
+
from dataclasses import dataclass
|
20 |
+
from typing import List, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
|
25 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
26 |
+
from diffusers.utils import BaseOutput, randn_tensor
|
27 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
|
32 |
+
class DDIMSchedulerOutput(BaseOutput):
|
33 |
+
"""
|
34 |
+
Output class for the scheduler's step function output.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
38 |
+
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
39 |
+
denoising loop.
|
40 |
+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
41 |
+
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
|
42 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
43 |
+
"""
|
44 |
+
|
45 |
+
prev_sample: torch.FloatTensor
|
46 |
+
pred_original_sample: Optional[torch.FloatTensor] = None
|
47 |
+
|
48 |
+
|
49 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
|
50 |
+
"""
|
51 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
52 |
+
(1-beta) over time from t = [0,1].
|
53 |
+
|
54 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
55 |
+
to that part of the diffusion process.
|
56 |
+
|
57 |
+
|
58 |
+
Args:
|
59 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
60 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
61 |
+
prevent singularities.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
65 |
+
"""
|
66 |
+
|
67 |
+
def alpha_bar(time_step):
|
68 |
+
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
69 |
+
|
70 |
+
betas = []
|
71 |
+
for i in range(num_diffusion_timesteps):
|
72 |
+
t1 = i / num_diffusion_timesteps
|
73 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
74 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
75 |
+
return torch.tensor(betas)
|
76 |
+
|
77 |
+
|
78 |
+
class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
79 |
+
"""
|
80 |
+
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
|
81 |
+
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
|
82 |
+
|
83 |
+
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
84 |
+
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
85 |
+
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
86 |
+
[`~SchedulerMixin.from_pretrained`] functions.
|
87 |
+
|
88 |
+
For more details, see the original paper: https://arxiv.org/abs/2010.02502
|
89 |
+
|
90 |
+
Args:
|
91 |
+
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
92 |
+
beta_start (`float`): the starting `beta` value of inference.
|
93 |
+
beta_end (`float`): the final `beta` value.
|
94 |
+
beta_schedule (`str`):
|
95 |
+
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
96 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
97 |
+
trained_betas (`np.ndarray`, optional):
|
98 |
+
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
99 |
+
clip_sample (`bool`, default `True`):
|
100 |
+
option to clip predicted sample between -1 and 1 for numerical stability.
|
101 |
+
set_alpha_to_one (`bool`, default `True`):
|
102 |
+
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
|
103 |
+
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
104 |
+
otherwise it uses the value of alpha at step 0.
|
105 |
+
steps_offset (`int`, default `0`):
|
106 |
+
an offset added to the inference steps. You can use a combination of `offset=1` and
|
107 |
+
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
108 |
+
stable diffusion.
|
109 |
+
prediction_type (`str`, default `epsilon`, optional):
|
110 |
+
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
111 |
+
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
112 |
+
https://imagen.research.google/video/paper.pdf)
|
113 |
+
"""
|
114 |
+
|
115 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
116 |
+
order = 1
|
117 |
+
|
118 |
+
@register_to_config
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
num_train_timesteps: int = 1000,
|
122 |
+
beta_start: float = 0.0001,
|
123 |
+
beta_end: float = 0.02,
|
124 |
+
beta_schedule: str = "linear",
|
125 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
126 |
+
clip_sample: bool = True,
|
127 |
+
set_alpha_to_one: bool = True,
|
128 |
+
steps_offset: int = 0,
|
129 |
+
prediction_type: str = "epsilon",
|
130 |
+
):
|
131 |
+
if trained_betas is not None:
|
132 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
133 |
+
elif beta_schedule == "linear":
|
134 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
135 |
+
elif beta_schedule == "scaled_linear":
|
136 |
+
# this schedule is very specific to the latent diffusion model.
|
137 |
+
self.betas = (
|
138 |
+
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
139 |
+
)
|
140 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
141 |
+
# Glide cosine schedule
|
142 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
143 |
+
else:
|
144 |
+
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
145 |
+
|
146 |
+
self.alphas = 1.0 - self.betas
|
147 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
148 |
+
|
149 |
+
# At every step in ddim, we are looking into the previous alphas_cumprod
|
150 |
+
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
151 |
+
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
152 |
+
# whether we use the final alpha of the "non-previous" one.
|
153 |
+
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
154 |
+
|
155 |
+
# standard deviation of the initial noise distribution
|
156 |
+
self.init_noise_sigma = 1.0
|
157 |
+
|
158 |
+
# setable values
|
159 |
+
self.num_inference_steps = None
|
160 |
+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
161 |
+
|
162 |
+
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
163 |
+
"""
|
164 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
165 |
+
current timestep.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
sample (`torch.FloatTensor`): input sample
|
169 |
+
timestep (`int`, optional): current timestep
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
`torch.FloatTensor`: scaled input sample
|
173 |
+
"""
|
174 |
+
return sample
|
175 |
+
|
176 |
+
def _get_variance(self, timestep, prev_timestep):
|
177 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
178 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
179 |
+
beta_prod_t = 1 - alpha_prod_t
|
180 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
181 |
+
|
182 |
+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
183 |
+
|
184 |
+
return variance
|
185 |
+
|
186 |
+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
187 |
+
"""
|
188 |
+
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
num_inference_steps (`int`):
|
192 |
+
the number of diffusion steps used when generating samples with a pre-trained model.
|
193 |
+
"""
|
194 |
+
|
195 |
+
if num_inference_steps > self.config.num_train_timesteps:
|
196 |
+
raise ValueError(
|
197 |
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
198 |
+
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
199 |
+
f" maximal {self.config.num_train_timesteps} timesteps."
|
200 |
+
)
|
201 |
+
|
202 |
+
self.num_inference_steps = num_inference_steps
|
203 |
+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
204 |
+
# creates integer timesteps by multiplying by ratio
|
205 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
206 |
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
207 |
+
self.timesteps = torch.from_numpy(timesteps).to(device)
|
208 |
+
self.timesteps += self.config.steps_offset
|
209 |
+
|
210 |
+
def step(
|
211 |
+
self,
|
212 |
+
model_output: torch.FloatTensor,
|
213 |
+
timestep: int,
|
214 |
+
sample: torch.FloatTensor,
|
215 |
+
eta: float = 0.0,
|
216 |
+
use_clipped_model_output: bool = False,
|
217 |
+
generator=None,
|
218 |
+
variance_noise: Optional[torch.FloatTensor] = None,
|
219 |
+
return_dict: bool = True,
|
220 |
+
reverse=False
|
221 |
+
) -> Union[DDIMSchedulerOutput, Tuple]:
|
222 |
+
|
223 |
+
|
224 |
+
e_t = model_output
|
225 |
+
|
226 |
+
x = sample
|
227 |
+
prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
|
228 |
+
# print(timestep, prev_timestep)
|
229 |
+
a_t = alpha_prod_t = self.alphas_cumprod[timestep-1]
|
230 |
+
a_prev = alpha_t_prev = self.alphas_cumprod[prev_timestep-1] if prev_timestep >= 0 else self.final_alpha_cumprod
|
231 |
+
beta_prod_t = 1 - alpha_prod_t
|
232 |
+
|
233 |
+
pred_x0 = (x - (1-a_t)**0.5 * e_t) / a_t.sqrt()
|
234 |
+
# direction pointing to x_t
|
235 |
+
dir_xt = (1. - a_prev).sqrt() * e_t
|
236 |
+
x = a_prev.sqrt()*pred_x0 + dir_xt
|
237 |
+
if not return_dict:
|
238 |
+
return (x,)
|
239 |
+
return DDIMSchedulerOutput(prev_sample=x, pred_original_sample=pred_x0)
|
240 |
+
|
241 |
+
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
def add_noise(
|
246 |
+
self,
|
247 |
+
original_samples: torch.FloatTensor,
|
248 |
+
noise: torch.FloatTensor,
|
249 |
+
timesteps: torch.IntTensor,
|
250 |
+
) -> torch.FloatTensor:
|
251 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
252 |
+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
253 |
+
timesteps = timesteps.to(original_samples.device)
|
254 |
+
|
255 |
+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
256 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
257 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
258 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
259 |
+
|
260 |
+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
261 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
262 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
263 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
264 |
+
|
265 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
266 |
+
return noisy_samples
|
267 |
+
|
268 |
+
def get_velocity(
|
269 |
+
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
270 |
+
) -> torch.FloatTensor:
|
271 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
272 |
+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
|
273 |
+
timesteps = timesteps.to(sample.device)
|
274 |
+
|
275 |
+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
276 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
277 |
+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
278 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
279 |
+
|
280 |
+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
281 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
282 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
283 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
284 |
+
|
285 |
+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
286 |
+
return velocity
|
287 |
+
|
288 |
+
def __len__(self):
|
289 |
+
return self.config.num_train_timesteps
|