init
Browse files- .gitattributes +1 -0
- .gitignore +4 -0
- app.py +31 -0
- imgs/all-results.png +3 -0
- imgs/boromir.png +3 -0
- imgs/bread.jpg +0 -0
- imgs/cat.jpg +0 -0
- imgs/gigachad.jpg +0 -0
- imgs/shrek.jpg +0 -0
- imgs/wazovski.png +3 -0
- requirements.txt +10 -0
- research/README.md +37 -0
- research/imgs/inst-gigachad.png +3 -0
- run.py +403 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
.vscode
|
3 |
+
.venv
|
4 |
+
__pycache__
|
app.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from run import StableRemix, run_remixing
|
5 |
+
|
6 |
+
pipe = StableRemix.from_pretrained(
|
7 |
+
"stabilityai/stable-diffusion-2-1-unclip",
|
8 |
+
torch_dtype=torch.float16,
|
9 |
+
variant="fp16"
|
10 |
+
)
|
11 |
+
pipe = pipe.to('gpu')
|
12 |
+
pipe.enable_attention_slicing()
|
13 |
+
|
14 |
+
print('pipe loaded')
|
15 |
+
|
16 |
+
|
17 |
+
def remix(image1):
|
18 |
+
# style_img = Image.open(args.style_img).convert('RGB')
|
19 |
+
|
20 |
+
# images = run_remixing(pipe, image1, image1, [0.6, 0.65, 0.7])
|
21 |
+
images = run_remixing(pipe, image1, image1, [0.6])
|
22 |
+
return images[0]
|
23 |
+
for idx, image in enumerate(images):
|
24 |
+
path = args.save_dir / f'remix_{idx}.png'
|
25 |
+
print('Saving remix to', path)
|
26 |
+
image.save(path)
|
27 |
+
|
28 |
+
|
29 |
+
demo = gr.Interface(
|
30 |
+
fn=remix, inputs=[gr.Image(image_mode='RGB')], outputs="image")
|
31 |
+
demo.launch()
|
imgs/all-results.png
ADDED
Git LFS Details
|
imgs/boromir.png
ADDED
Git LFS Details
|
imgs/bread.jpg
ADDED
imgs/cat.jpg
ADDED
imgs/gigachad.jpg
ADDED
imgs/shrek.jpg
ADDED
imgs/wazovski.png
ADDED
Git LFS Details
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==3.36.1
|
2 |
+
torch==2.0.1
|
3 |
+
diffusers==0.17.1
|
4 |
+
Pillow==10.0.0
|
5 |
+
tokenizers==0.13.3
|
6 |
+
transformers==4.30.2
|
7 |
+
accelerate==0.20.3
|
8 |
+
scipy==1.11.1
|
9 |
+
safetensors==0.3.1
|
10 |
+
# xformers
|
research/README.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Stable Diffusion Remix research
|
2 |
+
|
3 |
+
__This document is in early draft form.__
|
4 |
+
|
5 |
+
This directory contains research and experiments related to the Stable Diffusion Remix.
|
6 |
+
|
7 |
+
## Something to read
|
8 |
+
|
9 |
+
- [Midjourney Remix docs](https://docs.midjourney.com/docs/remix)
|
10 |
+
- [Great discussion on Midjourney Remix reproduction](https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/4595)
|
11 |
+
- [InST paper](https://arxiv.org/abs/2211.13203)
|
12 |
+
|
13 |
+
## Early experiments
|
14 |
+
|
15 |
+
### InST
|
16 |
+
|
17 |
+
![InST](imgs/inst-gigachad.png)
|
18 |
+
|
19 |
+
- [Code](https://github.com/zyxElsa/InST)
|
20 |
+
- [Paper](https://arxiv.org/abs/2211.13203)
|
21 |
+
|
22 |
+
This method does pretty much what we want but it requires training converter from image embedding space to text embedding space. For now I tried to stick to no-training approaches.
|
23 |
+
|
24 |
+
### CLIP Interrogator
|
25 |
+
|
26 |
+
This idea came right from github discussion mentioned above. Straight forward solution is to use vanilla Stable Diffusion. For example, we can somehow get textual representation of content and style images and use them as prompts. This is what [CLIP Interrogator](https://huggingface.co/spaces/fffiloni/CLIP-Interrogator-2) does. It's a great tool but it's not reliable and adds up time and memory requirements.
|
27 |
+
|
28 |
+
- Gigachad prompt: "a man with a beard and a beard is smiling and looking at the camera, octane trending on cgsociety, chris evans as a bodybuilder, monochrome 3 d model, 40 years old women
|
29 |
+
"
|
30 |
+
- Shrek prompt: "a close up of a cartoonish looking green trolly trolly trolly trolly trolly troll troll, elon musk as jabba the hutt, full og shrek, 3d cg, 240p
|
31 |
+
"
|
32 |
+
|
33 |
+
## Image embedding interpolation
|
34 |
+
|
35 |
+
|
36 |
+
TODO
|
37 |
+
|
research/imgs/inst-gigachad.png
ADDED
Git LFS Details
|
run.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentParser
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import PIL
|
7 |
+
from PIL import Image
|
8 |
+
from typing import Optional, Union, List, Callable, Dict, Any
|
9 |
+
|
10 |
+
from diffusers import StableUnCLIPImg2ImgPipeline, ImagePipelineOutput
|
11 |
+
from diffusers.image_processor import VaeImageProcessor
|
12 |
+
from diffusers.utils import randn_tensor, PIL_INTERPOLATION
|
13 |
+
|
14 |
+
|
15 |
+
def center_resize_crop(image, size=224):
|
16 |
+
w, h = image.size
|
17 |
+
if h < w:
|
18 |
+
h, w = size, size * w // h
|
19 |
+
else:
|
20 |
+
h, w = size * h // w, size
|
21 |
+
|
22 |
+
image = image.resize((w, h))
|
23 |
+
|
24 |
+
box = ((w - size) // 2, (h - size) // 2, (w + size) // 2, (h + size) // 2)
|
25 |
+
return image.crop(box)
|
26 |
+
|
27 |
+
|
28 |
+
def encode_image(image, pipe):
|
29 |
+
device = pipe._execution_device
|
30 |
+
dtype = next(pipe.image_encoder.parameters()).dtype
|
31 |
+
|
32 |
+
if not isinstance(image, torch.Tensor):
|
33 |
+
image = pipe.feature_extractor(
|
34 |
+
images=image, return_tensors="pt").pixel_values
|
35 |
+
|
36 |
+
image = image.to(device=device, dtype=dtype)
|
37 |
+
image_embeds = pipe.image_encoder(image).image_embeds
|
38 |
+
|
39 |
+
return image_embeds
|
40 |
+
|
41 |
+
|
42 |
+
def generate_latents(pipe):
|
43 |
+
shape = (1, pipe.unet.in_channels, pipe.unet.config.sample_size,
|
44 |
+
pipe.unet.config.sample_size)
|
45 |
+
device = pipe._execution_device
|
46 |
+
dtype = next(pipe.image_encoder.parameters()).dtype
|
47 |
+
|
48 |
+
return torch.randn(shape, device=device, dtype=dtype)
|
49 |
+
|
50 |
+
|
51 |
+
# https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/4
|
52 |
+
def slerp(val, low, high):
|
53 |
+
low_norm = low/torch.norm(low, dim=1, keepdim=True)
|
54 |
+
high_norm = high/torch.norm(high, dim=1, keepdim=True)
|
55 |
+
omega = torch.acos((low_norm*high_norm).sum(1))
|
56 |
+
so = torch.sin(omega)
|
57 |
+
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1) * \
|
58 |
+
low + (torch.sin(val*omega)/so).unsqueeze(1) * high
|
59 |
+
return res
|
60 |
+
|
61 |
+
|
62 |
+
class StableRemixImageProcessor(VaeImageProcessor):
|
63 |
+
def __init__(self, w, h):
|
64 |
+
super().__init__()
|
65 |
+
self.w = w
|
66 |
+
self.h = h
|
67 |
+
|
68 |
+
def resize(self, image):
|
69 |
+
image = center_resize_crop(image, self.w)
|
70 |
+
return image
|
71 |
+
|
72 |
+
def preprocess(self, image):
|
73 |
+
image = super().preprocess(image)
|
74 |
+
# image = randomize_color(image)
|
75 |
+
|
76 |
+
return image
|
77 |
+
|
78 |
+
|
79 |
+
class StableRemix(StableUnCLIPImg2ImgPipeline):
|
80 |
+
# pipeline_stable_diffusion_img2img.py
|
81 |
+
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, noise=None):
|
82 |
+
if not isinstance(image, (torch.Tensor, Image.Image, list)):
|
83 |
+
raise ValueError(
|
84 |
+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
85 |
+
)
|
86 |
+
|
87 |
+
image = image.to(device=device, dtype=dtype)
|
88 |
+
|
89 |
+
batch_size = batch_size * num_images_per_prompt
|
90 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
91 |
+
raise ValueError(
|
92 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
93 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
94 |
+
)
|
95 |
+
|
96 |
+
if isinstance(generator, list):
|
97 |
+
init_latents = [
|
98 |
+
self.vae.encode(image[i: i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
99 |
+
]
|
100 |
+
init_latents = torch.cat(init_latents, dim=0)
|
101 |
+
else:
|
102 |
+
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
103 |
+
|
104 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
105 |
+
|
106 |
+
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
107 |
+
# expand init_latents for batch_size
|
108 |
+
deprecation_message = (
|
109 |
+
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
110 |
+
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
|
111 |
+
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
112 |
+
" your script to pass as many initial images as text prompts to suppress this warning."
|
113 |
+
)
|
114 |
+
deprecate("len(prompt) != len(image)", "1.0.0",
|
115 |
+
deprecation_message, standard_warn=False)
|
116 |
+
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
117 |
+
init_latents = torch.cat(
|
118 |
+
[init_latents] * additional_image_per_prompt, dim=0)
|
119 |
+
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
120 |
+
raise ValueError(
|
121 |
+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
122 |
+
)
|
123 |
+
else:
|
124 |
+
init_latents = torch.cat([init_latents], dim=0)
|
125 |
+
|
126 |
+
shape = init_latents.shape
|
127 |
+
if noise is None:
|
128 |
+
noise = randn_tensor(shape, generator=generator,
|
129 |
+
device=device, dtype=dtype)
|
130 |
+
|
131 |
+
# get latents
|
132 |
+
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
133 |
+
latents = init_latents
|
134 |
+
|
135 |
+
return latents
|
136 |
+
|
137 |
+
# Original method has bug. This one is fixed
|
138 |
+
def _encode_image(
|
139 |
+
self,
|
140 |
+
image,
|
141 |
+
device,
|
142 |
+
batch_size,
|
143 |
+
num_images_per_prompt,
|
144 |
+
do_classifier_free_guidance,
|
145 |
+
noise_level,
|
146 |
+
generator,
|
147 |
+
image_embeds,
|
148 |
+
):
|
149 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
150 |
+
|
151 |
+
if isinstance(image, PIL.Image.Image):
|
152 |
+
# the image embedding should repeated so it matches the total batch size of the prompt
|
153 |
+
repeat_by = batch_size
|
154 |
+
else:
|
155 |
+
# assume the image input is already properly batched and just needs to be repeated so
|
156 |
+
# it matches the num_images_per_prompt.
|
157 |
+
#
|
158 |
+
# NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched
|
159 |
+
# `image_embeds`. If those happen to be common use cases, let's think harder about
|
160 |
+
# what the expected dimensions of inputs should be and how we handle the encoding.
|
161 |
+
repeat_by = num_images_per_prompt
|
162 |
+
|
163 |
+
if image_embeds is None:
|
164 |
+
if not isinstance(image, torch.Tensor):
|
165 |
+
image = self.feature_extractor(
|
166 |
+
images=image, return_tensors="pt").pixel_values
|
167 |
+
|
168 |
+
image = image.to(device=device, dtype=dtype)
|
169 |
+
image_embeds = self.image_encoder(image).image_embeds
|
170 |
+
|
171 |
+
image_embeds = self.noise_image_embeddings(
|
172 |
+
image_embeds=image_embeds,
|
173 |
+
noise_level=noise_level,
|
174 |
+
generator=generator,
|
175 |
+
)
|
176 |
+
|
177 |
+
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
178 |
+
image_embeds = image_embeds.unsqueeze(1)
|
179 |
+
bs_embed, seq_len, _ = image_embeds.shape
|
180 |
+
image_embeds = image_embeds.repeat(1, repeat_by, 1)
|
181 |
+
image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1)
|
182 |
+
image_embeds = image_embeds.squeeze(1)
|
183 |
+
|
184 |
+
if do_classifier_free_guidance:
|
185 |
+
negative_prompt_embeds = torch.zeros_like(image_embeds)
|
186 |
+
|
187 |
+
# For classifier free guidance, we need to do two forward passes.
|
188 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
189 |
+
# to avoid doing two forward passes
|
190 |
+
image_embeds = torch.cat([negative_prompt_embeds, image_embeds])
|
191 |
+
|
192 |
+
return image_embeds
|
193 |
+
|
194 |
+
@torch.no_grad()
|
195 |
+
def __call__(
|
196 |
+
self,
|
197 |
+
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
198 |
+
prompt: Union[str, List[str]] = None,
|
199 |
+
height: Optional[int] = None,
|
200 |
+
width: Optional[int] = None,
|
201 |
+
num_inference_steps: int = 40,
|
202 |
+
guidance_scale: float = 10,
|
203 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
204 |
+
num_images_per_prompt: Optional[int] = 1,
|
205 |
+
eta: float = 0.0,
|
206 |
+
generator: Optional[torch.Generator] = None,
|
207 |
+
latents: Optional[torch.FloatTensor] = None,
|
208 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
209 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
210 |
+
output_type: Optional[str] = "pil",
|
211 |
+
return_dict: bool = True,
|
212 |
+
callback: Optional[Callable[[
|
213 |
+
int, int, torch.FloatTensor], None]] = None,
|
214 |
+
callback_steps: int = 1,
|
215 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
216 |
+
noise_level: int = 0,
|
217 |
+
image_embeds=None,
|
218 |
+
timestemp=0,
|
219 |
+
):
|
220 |
+
# 0. Default height and width to unet
|
221 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
222 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
223 |
+
|
224 |
+
if prompt is None and prompt_embeds is None:
|
225 |
+
prompt = len(image) * [""] if isinstance(image, list) else ""
|
226 |
+
|
227 |
+
# 1. Check inputs. Raise error if not correct
|
228 |
+
self.check_inputs(
|
229 |
+
prompt=prompt,
|
230 |
+
image=None,
|
231 |
+
height=height,
|
232 |
+
width=width,
|
233 |
+
callback_steps=callback_steps,
|
234 |
+
noise_level=noise_level,
|
235 |
+
negative_prompt=negative_prompt,
|
236 |
+
prompt_embeds=prompt_embeds,
|
237 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
238 |
+
image_embeds=image_embeds,
|
239 |
+
)
|
240 |
+
|
241 |
+
# 2. Define call parameters
|
242 |
+
if prompt is not None and isinstance(prompt, str):
|
243 |
+
batch_size = 1
|
244 |
+
elif prompt is not None and isinstance(prompt, list):
|
245 |
+
batch_size = len(prompt)
|
246 |
+
else:
|
247 |
+
batch_size = prompt_embeds.shape[0]
|
248 |
+
|
249 |
+
batch_size = batch_size * num_images_per_prompt
|
250 |
+
|
251 |
+
device = self._execution_device
|
252 |
+
|
253 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
254 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
255 |
+
# corresponds to doing no classifier free guidance.
|
256 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
257 |
+
|
258 |
+
# 3. Encode input prompt
|
259 |
+
prompt_embeds = self._encode_prompt(
|
260 |
+
prompt=prompt,
|
261 |
+
device=device,
|
262 |
+
num_images_per_prompt=num_images_per_prompt,
|
263 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
264 |
+
negative_prompt=negative_prompt,
|
265 |
+
prompt_embeds=prompt_embeds,
|
266 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
267 |
+
)
|
268 |
+
|
269 |
+
# 4. Encoder input image
|
270 |
+
noise_level = torch.tensor([noise_level], device=device)
|
271 |
+
image_embeds = self._encode_image(
|
272 |
+
image=None,
|
273 |
+
device=device,
|
274 |
+
batch_size=batch_size,
|
275 |
+
num_images_per_prompt=num_images_per_prompt,
|
276 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
277 |
+
noise_level=noise_level,
|
278 |
+
generator=generator,
|
279 |
+
image_embeds=image_embeds,
|
280 |
+
)
|
281 |
+
|
282 |
+
# 5. Prepare timesteps
|
283 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
284 |
+
timesteps = self.scheduler.timesteps
|
285 |
+
latent_timestep = timesteps[timestemp:timestemp +
|
286 |
+
1].repeat(batch_size * num_images_per_prompt)
|
287 |
+
|
288 |
+
# 6. Prepare latent variables
|
289 |
+
image_processor = StableRemixImageProcessor(width, height)
|
290 |
+
image = image_processor.preprocess(image)
|
291 |
+
|
292 |
+
num_channels_latents = self.unet.in_channels
|
293 |
+
# def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
294 |
+
|
295 |
+
latents = self.prepare_latents(
|
296 |
+
image=image,
|
297 |
+
timestep=latent_timestep,
|
298 |
+
batch_size=batch_size,
|
299 |
+
dtype=prompt_embeds.dtype,
|
300 |
+
num_images_per_prompt=num_images_per_prompt,
|
301 |
+
device=device,
|
302 |
+
generator=generator,
|
303 |
+
noise=latents
|
304 |
+
)
|
305 |
+
|
306 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
307 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
308 |
+
|
309 |
+
# 8. Denoising loop
|
310 |
+
for i, t in enumerate(self.progress_bar(timesteps[timestemp:])):
|
311 |
+
latent_model_input = torch.cat(
|
312 |
+
[latents] * 2) if do_classifier_free_guidance else latents
|
313 |
+
latent_model_input = self.scheduler.scale_model_input(
|
314 |
+
latent_model_input, t)
|
315 |
+
|
316 |
+
# predict the noise residual
|
317 |
+
noise_pred = self.unet(
|
318 |
+
latent_model_input,
|
319 |
+
t,
|
320 |
+
encoder_hidden_states=prompt_embeds,
|
321 |
+
class_labels=image_embeds,
|
322 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
323 |
+
).sample
|
324 |
+
|
325 |
+
# perform guidance
|
326 |
+
if do_classifier_free_guidance:
|
327 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
328 |
+
noise_pred = noise_pred_uncond + guidance_scale * \
|
329 |
+
(noise_pred_text - noise_pred_uncond)
|
330 |
+
|
331 |
+
# compute the previous noisy sample x_t -> x_t-1
|
332 |
+
latents = self.scheduler.step(
|
333 |
+
noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
334 |
+
|
335 |
+
if callback is not None and i % callback_steps == 0:
|
336 |
+
callback(i, t, latents)
|
337 |
+
|
338 |
+
# 9. Post-processing
|
339 |
+
image = self.decode_latents(latents)
|
340 |
+
|
341 |
+
# Offload last model to CPU
|
342 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
343 |
+
self.final_offload_hook.offload()
|
344 |
+
|
345 |
+
# 10. Convert to PIL
|
346 |
+
if output_type == "pil":
|
347 |
+
image = self.numpy_to_pil(image)
|
348 |
+
|
349 |
+
if not return_dict:
|
350 |
+
return (image,)
|
351 |
+
|
352 |
+
return ImagePipelineOutput(images=image)
|
353 |
+
|
354 |
+
|
355 |
+
def run_remixing(pipe, content_img, style_img, alphas, **kwargs):
|
356 |
+
images = []
|
357 |
+
|
358 |
+
content_emb = encode_image(content_img, pipe)
|
359 |
+
style_emb = encode_image(style_img, pipe)
|
360 |
+
|
361 |
+
for alpha in alphas:
|
362 |
+
emb = slerp(alpha, content_emb, style_emb)
|
363 |
+
image = pipe(image=content_img, image_embeds=emb, **kwargs).images[0]
|
364 |
+
images.append(image)
|
365 |
+
|
366 |
+
return images
|
367 |
+
|
368 |
+
|
369 |
+
def parse_args():
|
370 |
+
parser = ArgumentParser()
|
371 |
+
|
372 |
+
parser.add_argument('content_img', type=Path, help='Path to content image')
|
373 |
+
parser.add_argument('style_img', type=Path, help='Path to style image')
|
374 |
+
parser.add_argument('--device', type=torch.device, default=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'),
|
375 |
+
help='Which device to use ("cpu", "cuda", "cuda:1", ...)')
|
376 |
+
parser.add_argument('save_dir', type=Path, nargs='?', default=Path('.'),
|
377 |
+
help='Path to dir where to save remixes')
|
378 |
+
|
379 |
+
return parser.parse_args()
|
380 |
+
|
381 |
+
|
382 |
+
def main():
|
383 |
+
args = parse_args()
|
384 |
+
print('Using device:', args.device)
|
385 |
+
|
386 |
+
pipe = StableRemix.from_pretrained(
|
387 |
+
"stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16"
|
388 |
+
)
|
389 |
+
pipe = pipe.to(args.device)
|
390 |
+
pipe.enable_xformers_memory_efficient_attention()
|
391 |
+
|
392 |
+
content_img = Image.open(args.content_img).convert('RGB')
|
393 |
+
style_img = Image.open(args.style_img).convert('RGB')
|
394 |
+
|
395 |
+
images = run_remixing(pipe, content_img, style_img, [0.6, 0.65, 0.7])
|
396 |
+
for idx, image in enumerate(images):
|
397 |
+
path = args.save_dir / f'remix_{idx}.png'
|
398 |
+
print('Saving remix to', path)
|
399 |
+
image.save(path)
|
400 |
+
|
401 |
+
|
402 |
+
if __name__ == '__main__':
|
403 |
+
main()
|