init space
Browse files- .gitignore +5 -0
- README.md +2 -1
- ditail/.DS_Store +0 -0
- ditail/__init__.py +1 -0
- ditail/lora/.DS_Store +0 -0
- ditail/lora/animeoutline.jpeg +0 -0
- ditail/lora/animeoutline.safetensors +3 -0
- ditail/lora/film.jpeg +0 -0
- ditail/lora/film.safetensors +3 -0
- ditail/lora/flat.jpeg +0 -0
- ditail/lora/flat.safetensors +3 -0
- ditail/lora/impressionism.jpeg +0 -0
- ditail/lora/impressionism.safetensors +3 -0
- ditail/lora/minecraft.jpeg +0 -0
- ditail/lora/minecraft.safetensors +3 -0
- ditail/lora/none.jpeg +0 -0
- ditail/lora/pop.jpeg +0 -0
- ditail/lora/pop.safetensors +3 -0
- ditail/lora/shinkai_makoto.jpeg +0 -0
- ditail/lora/shinkai_makoto.safetensors +3 -0
- ditail/lora/snow.jpeg +0 -0
- ditail/lora/snow.safetensors +3 -0
- ditail/src/ditail_demo.py +233 -0
- ditail/src/ditail_utils.py +121 -0
- requirements.txt +10 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
output_demo/
|
3 |
+
cache/
|
4 |
+
gradio_cached_examples
|
5 |
+
secrets.sh
|
README.md
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
---
|
2 |
title: Diffusion Cocktail
|
3 |
-
emoji:
|
4 |
colorFrom: indigo
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.9.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: Diffusion Cocktail
|
3 |
+
emoji: 🍸
|
4 |
colorFrom: indigo
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.9.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
python: 3.8
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
ditail/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
ditail/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .src.ditail_demo import DitailDemo, seed_everything
|
ditail/lora/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
ditail/lora/animeoutline.jpeg
ADDED
ditail/lora/animeoutline.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2fcd88e6aa07b2db73befca54b3956131e1decc8e6e719508ce32c28768f9b91
|
3 |
+
size 18986312
|
ditail/lora/film.jpeg
ADDED
ditail/lora/film.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:11c5b684b502915273e40a8d0a50473c9ab3ec98e6fa5baed307b672da5fcf08
|
3 |
+
size 37871065
|
ditail/lora/flat.jpeg
ADDED
ditail/lora/flat.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:315ffc392c322a3768c0a6837e333e30447581ba3687b6379af998d90e1ce21d
|
3 |
+
size 151114856
|
ditail/lora/impressionism.jpeg
ADDED
ditail/lora/impressionism.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:28af0aafb8344fbeab124bc6ec9addbe4f014de37ac3fb8174effbd83de3777c
|
3 |
+
size 151110218
|
ditail/lora/minecraft.jpeg
ADDED
ditail/lora/minecraft.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2e265490127bee6eea5a26d9f6caa75899a616e61890cb387cc571af5db666f9
|
3 |
+
size 37870517
|
ditail/lora/none.jpeg
ADDED
ditail/lora/pop.jpeg
ADDED
ditail/lora/pop.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7cc7ae5fb9f74efd4bd366e57fad4b48f45479fd67bb7dd944b104ee6819e84b
|
3 |
+
size 151115176
|
ditail/lora/shinkai_makoto.jpeg
ADDED
ditail/lora/shinkai_makoto.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef6ba90e343502f6a8bf6da0d9f8f4e2571d0248d11d14aa577b7ddc490bbd48
|
3 |
+
size 151108831
|
ditail/lora/snow.jpeg
ADDED
ditail/lora/snow.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e25bdd1155498b3bf7b02d0243d53e166eafe543b36dc77588dbfc7f03fd555a
|
3 |
+
size 75612254
|
ditail/src/ditail_demo.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import argparse
|
4 |
+
import warnings
|
5 |
+
from PIL import Image
|
6 |
+
from tqdm import tqdm
|
7 |
+
from datetime import datetime
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torchvision.transforms as T
|
12 |
+
from transformers import logging
|
13 |
+
from diffusers import DDIMScheduler, StableDiffusionPipeline
|
14 |
+
|
15 |
+
from .ditail_utils import *
|
16 |
+
|
17 |
+
# suppress warnings
|
18 |
+
logging.set_verbosity_error()
|
19 |
+
warnings.filterwarnings("ignore", message=".*LoRA backend.*")
|
20 |
+
|
21 |
+
class DitailDemo(nn.Module):
|
22 |
+
def __init__(self, args):
|
23 |
+
super().__init__()
|
24 |
+
self.args = args
|
25 |
+
if isinstance(self.args, dict):
|
26 |
+
for k, v in args.items():
|
27 |
+
setattr(self, k, v)
|
28 |
+
else:
|
29 |
+
for k, v in vars(args).items():
|
30 |
+
setattr(self, k, v)
|
31 |
+
|
32 |
+
def load_inv_model(self):
|
33 |
+
self.scheduler = DDIMScheduler.from_pretrained(self.inv_model, subfolder='scheduler')
|
34 |
+
self.scheduler.set_timesteps(self.inv_steps, device=self.device)
|
35 |
+
print(f'[INFO] Loading inversion model: {self.inv_model}')
|
36 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
37 |
+
self.inv_model, torch_dtype=torch.float16
|
38 |
+
).to(self.device)
|
39 |
+
pipe.enable_xformers_memory_efficient_attention()
|
40 |
+
self.text_encoder = pipe.text_encoder
|
41 |
+
self.tokenizer = pipe.tokenizer
|
42 |
+
self.unet = pipe.unet
|
43 |
+
self.vae = pipe.vae
|
44 |
+
self.tokenizer_kwargs = dict(
|
45 |
+
truncation=True,
|
46 |
+
return_tensors='pt',
|
47 |
+
padding='max_length',
|
48 |
+
max_length=self.tokenizer.model_max_length
|
49 |
+
)
|
50 |
+
|
51 |
+
def load_spl_model(self):
|
52 |
+
self.scheduler = DDIMScheduler.from_pretrained(self.spl_model, subfolder='scheduler')
|
53 |
+
self.scheduler.set_timesteps(self.spl_steps, device=self.device)
|
54 |
+
print(f'[INFO] Loading sampling model: {self.spl_model}')
|
55 |
+
if (self.lora != 'none') or (self.inv_model != self.spl_model):
|
56 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
57 |
+
self.spl_model, torch_dtype=torch.float16
|
58 |
+
).to(self.device)
|
59 |
+
if self.lora != 'none':
|
60 |
+
pipe.unfuse_lora()
|
61 |
+
pipe.unload_lora_weights()
|
62 |
+
pipe.load_lora_weights(self.lora_dir, weight_name=f'{self.lora}.safetensors')
|
63 |
+
pipe.fuse_lora(lora_scale=self.lora_scale)
|
64 |
+
pipe.enable_xformers_memory_efficient_attention()
|
65 |
+
self.text_encoder = pipe.text_encoder
|
66 |
+
self.tokenizer = pipe.tokenizer
|
67 |
+
self.unet = pipe.unet
|
68 |
+
self.vae = pipe.vae
|
69 |
+
self.tokenizer_kwargs = dict(
|
70 |
+
truncation=True,
|
71 |
+
return_tensors='pt',
|
72 |
+
padding='max_length',
|
73 |
+
max_length=self.tokenizer.model_max_length
|
74 |
+
)
|
75 |
+
|
76 |
+
@torch.no_grad()
|
77 |
+
def encode_image(self, image_pil):
|
78 |
+
# image_pil = T.Resize(512)(img.convert('RGB'))
|
79 |
+
image_pil = T.Resize(512)(image_pil)
|
80 |
+
image = T.ToTensor()(image_pil).unsqueeze(0).to(self.device)
|
81 |
+
with torch.autocast(device_type=self.device, dtype=torch.float32):
|
82 |
+
image = 2 * image - 1
|
83 |
+
posterior = self.vae.encode(image).latent_dist
|
84 |
+
latent = posterior.mean * 0.18215
|
85 |
+
return latent
|
86 |
+
|
87 |
+
@torch.no_grad()
|
88 |
+
def invert_image(self, cond, latent):
|
89 |
+
self.latents = {}
|
90 |
+
timesteps = reversed(self.scheduler.timesteps)
|
91 |
+
with torch.autocast(device_type=self.device, dtype=torch.float32):
|
92 |
+
for i, t in enumerate(tqdm(timesteps)):
|
93 |
+
cond_batch = cond.repeat(latent.shape[0], 1, 1)
|
94 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[t]
|
95 |
+
alpha_prod_t_prev = (
|
96 |
+
self.scheduler.alphas_cumprod[timesteps[i-1]]
|
97 |
+
if i > 0 else self.scheduler.final_alpha_cumprod
|
98 |
+
)
|
99 |
+
mu = alpha_prod_t ** 0.5
|
100 |
+
mu_prev = alpha_prod_t_prev ** 0.5
|
101 |
+
sigma = (1 - alpha_prod_t) ** 0.5
|
102 |
+
sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
|
103 |
+
eps = self.unet(latent, t, encoder_hidden_states=cond_batch).sample
|
104 |
+
pred_x0 = (latent - sigma_prev * eps) / mu_prev
|
105 |
+
latent = mu * pred_x0 + sigma * eps
|
106 |
+
self.latents[t.item()] = latent
|
107 |
+
self.noisy_latent = latent
|
108 |
+
|
109 |
+
@torch.no_grad()
|
110 |
+
def extract_latents(self):
|
111 |
+
# get the embeddings for pos & neg prompts
|
112 |
+
# self.pos_prompt = ' ,'.join(LORA_TRIGGER_WORD.get(self.lora, [''])+[self.pos_prompt])
|
113 |
+
# print('the prompt after adding trigger word:', self.pos_prompt)
|
114 |
+
text_pos = self.tokenizer(self.pos_prompt, **self.tokenizer_kwargs)
|
115 |
+
text_neg = self.tokenizer(self.neg_prompt, **self.tokenizer_kwargs)
|
116 |
+
self.emb_pos = self.text_encoder(text_pos.input_ids.to(self.device))[0]
|
117 |
+
self.emb_neg = self.text_encoder(text_neg.input_ids.to(self.device))[0]
|
118 |
+
# apply condition scaling
|
119 |
+
cond = self.alpha * self.emb_pos - self.beta * self.emb_neg
|
120 |
+
# encode source image & apply DDIM inversion
|
121 |
+
self.invert_image(cond, self.encode_image(self.img))
|
122 |
+
|
123 |
+
@torch.no_grad()
|
124 |
+
def latent_to_image(self, latent, save_path=None):
|
125 |
+
with torch.autocast(device_type=self.device, dtype=torch.float32):
|
126 |
+
latent = 1 / 0.18215 * latent
|
127 |
+
image = self.vae.decode(latent).sample[0]
|
128 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
129 |
+
# T.ToPILImage()(image).save(save_path)
|
130 |
+
return T.ToPILImage()(image)
|
131 |
+
|
132 |
+
def init_injection(self, attn_ratio=0.5, conv_ratio=0.8):
|
133 |
+
attn_thresh = int(attn_ratio * self.spl_steps)
|
134 |
+
conv_thresh = int(conv_ratio * self.spl_steps)
|
135 |
+
self.attn_inj_timesteps = self.scheduler.timesteps[:attn_thresh]
|
136 |
+
self.conv_inj_timesteps = self.scheduler.timesteps[:conv_thresh]
|
137 |
+
register_attn_inj(self, self.attn_inj_timesteps)
|
138 |
+
register_conv_inj(self, self.conv_inj_timesteps)
|
139 |
+
|
140 |
+
@torch.no_grad()
|
141 |
+
def sampling_loop(self):
|
142 |
+
# init text embeddings
|
143 |
+
text_ept = self.tokenizer('', **self.tokenizer_kwargs)
|
144 |
+
self.emb_ept = self.text_encoder(text_ept.input_ids.to(self.device))[0]
|
145 |
+
self.emb_spl = torch.cat([self.emb_ept, self.emb_pos, self.emb_neg], dim=0)
|
146 |
+
with torch.autocast(device_type=self.device, dtype=torch.float16):
|
147 |
+
# use noisy latent as starting point
|
148 |
+
x = self.latents[self.scheduler.timesteps[0].item()]
|
149 |
+
# sampling loop
|
150 |
+
for t in tqdm(self.scheduler.timesteps):
|
151 |
+
# concat latents & register timestep
|
152 |
+
src_latent = self.latents[t.item()]
|
153 |
+
latents = torch.cat([src_latent, x, x])
|
154 |
+
register_time(self, t.item())
|
155 |
+
# apply U-Net for denoising
|
156 |
+
noise_pred = self.unet(latents, t, encoder_hidden_states=self.emb_spl).sample
|
157 |
+
# classifier-free guidance
|
158 |
+
_, noise_pred_pos, noise_pred_neg = noise_pred.chunk(3)
|
159 |
+
noise_pred = noise_pred_neg + self.omega * (noise_pred_pos - noise_pred_neg)
|
160 |
+
# denoise step
|
161 |
+
x = self.scheduler.step(noise_pred, t, x).prev_sample
|
162 |
+
# save output latent
|
163 |
+
self.output_latent = x
|
164 |
+
|
165 |
+
def run_ditail(self):
|
166 |
+
# init output dir & dump config
|
167 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
168 |
+
# self.save_dir = get_save_dir(self.output_dir)
|
169 |
+
# os.makedirs(self.save_dir, exist_ok=True)
|
170 |
+
# with open(os.path.join(self.output_dir, 'config.yaml'), 'w') as f:
|
171 |
+
# if isinstance(self.args, dict):
|
172 |
+
# f.write(yaml.dump(self.args))
|
173 |
+
# else:
|
174 |
+
# f.write(yaml.dump(vars(self.args)))
|
175 |
+
# step 1: inversion stage
|
176 |
+
self.load_inv_model()
|
177 |
+
self.extract_latents()
|
178 |
+
# self.latent_to_image(
|
179 |
+
# latent=self.noisy_latent,
|
180 |
+
# save_path=os.path.join(self.save_dir, 'noise.png')
|
181 |
+
# )
|
182 |
+
# step 2: sampling stage
|
183 |
+
self.load_spl_model()
|
184 |
+
if not self.no_injection:
|
185 |
+
self.init_injection()
|
186 |
+
self.sampling_loop()
|
187 |
+
return self.latent_to_image(
|
188 |
+
latent=self.output_latent,
|
189 |
+
# save_path=os.path.join(self.save_dir, 'output.png')
|
190 |
+
)
|
191 |
+
|
192 |
+
def main(args):
|
193 |
+
seed_everything(args.seed)
|
194 |
+
ditail = DitailDemo(args)
|
195 |
+
ditail.run_ditail()
|
196 |
+
|
197 |
+
if __name__ == "__main__":
|
198 |
+
parser = argparse.ArgumentParser()
|
199 |
+
parser.add_argument('--seed', type=int, default=42)
|
200 |
+
parser.add_argument('--device', type=str, default='cuda')
|
201 |
+
parser.add_argument('--output_dir', type=str, default='./output_demo')
|
202 |
+
parser.add_argument('--inv_model', type=str, default='runwayml/stable-diffusion-v1-5',
|
203 |
+
help='Pre-trained inversion model name or path (step 1)')
|
204 |
+
parser.add_argument('--spl_model', type=str, default='runwayml/stable-diffusion-v1-5',
|
205 |
+
help='Pre-trained sampling model name or path (step 2)')
|
206 |
+
parser.add_argument('--inv_steps', type=int, default=50,
|
207 |
+
help='Number of inversion steps (step 1)')
|
208 |
+
parser.add_argument('--spl_steps', type=int, default=50,
|
209 |
+
help='Number of sampling steps (step 2)')
|
210 |
+
# parser.add_argument('--img_path', type=str, required=True,
|
211 |
+
# help='Path to the source image')
|
212 |
+
parser.add_argument('--pos_prompt', type=str, required=True,
|
213 |
+
help='Positive prompt for inversion')
|
214 |
+
parser.add_argument('--neg_prompt', type=str, default='worst quality, blurry, low res, NSFW',
|
215 |
+
help='Negative prompt for inversion')
|
216 |
+
parser.add_argument('--alpha', type=float, default=2.0,
|
217 |
+
help='Positive prompt scaling factor')
|
218 |
+
parser.add_argument('--beta', type=float, default=1.0,
|
219 |
+
help='Negative prompt scaling factor')
|
220 |
+
parser.add_argument('--omega', type=float, default=15,
|
221 |
+
help='Classifier-free guidance factor')
|
222 |
+
parser.add_argument('--mask', type=str, default='none',
|
223 |
+
help='Optional mask for regional injection')
|
224 |
+
parser.add_argument('--lora', type=str, default='none',
|
225 |
+
help='Optional LoRA for the sampling stage')
|
226 |
+
parser.add_argument('--lora_dir', type=str, default='./lora',
|
227 |
+
help='Optional LoRA storing directory')
|
228 |
+
parser.add_argument('--lora_scale', type=float, default=0.7,
|
229 |
+
help='Optional LoRA scaling weight')
|
230 |
+
parser.add_argument('--no_injection', action="store_true",
|
231 |
+
help='Do not use PnP injection')
|
232 |
+
args = parser.parse_args()
|
233 |
+
main(args)
|
ditail/src/ditail_utils.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# credits: https://github.com/MichalGeyer/pnp-diffusers/blob/main/pnp_utils.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
def seed_everything(seed):
|
9 |
+
random.seed(seed)
|
10 |
+
np.random.seed(seed)
|
11 |
+
torch.manual_seed(seed)
|
12 |
+
torch.cuda.manual_seed(seed)
|
13 |
+
|
14 |
+
# def get_save_dir(output_dir, img_path):
|
15 |
+
# os.makedirs(output_dir, exist_ok=True)
|
16 |
+
# file = os.path.basename(img_path)
|
17 |
+
# indices = [d for d in os.listdir(output_dir) if d.startswith(file)]
|
18 |
+
# return os.path.join(output_dir, f'{file}_{len(indices)}')
|
19 |
+
|
20 |
+
def register_time(model, t):
|
21 |
+
conv_module = model.unet.up_blocks[1].resnets[1]
|
22 |
+
setattr(conv_module, 't', t)
|
23 |
+
down_res_dict = {0: [0, 1], 1: [0, 1], 2: [0, 1]}
|
24 |
+
up_res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
|
25 |
+
for res in up_res_dict:
|
26 |
+
for block in up_res_dict[res]:
|
27 |
+
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
|
28 |
+
setattr(module, 't', t)
|
29 |
+
for res in down_res_dict:
|
30 |
+
for block in down_res_dict[res]:
|
31 |
+
module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn1
|
32 |
+
setattr(module, 't', t)
|
33 |
+
module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn1
|
34 |
+
setattr(module, 't', t)
|
35 |
+
|
36 |
+
def register_attn_inj(model, injection_schedule):
|
37 |
+
def sa_forward(self):
|
38 |
+
to_out = self.to_out
|
39 |
+
if type(to_out) is torch.nn.modules.container.ModuleList:
|
40 |
+
to_out = self.to_out[0]
|
41 |
+
else:
|
42 |
+
to_out = self.to_out
|
43 |
+
def forward(x, encoder_hidden_states=None, attention_mask=None):
|
44 |
+
batch_size, sequence_length, dim = x.shape
|
45 |
+
h = self.heads
|
46 |
+
is_cross = encoder_hidden_states is not None
|
47 |
+
encoder_hidden_states = encoder_hidden_states if is_cross else x
|
48 |
+
q = self.to_q(x)
|
49 |
+
k = self.to_k(encoder_hidden_states)
|
50 |
+
v = self.to_v(encoder_hidden_states)
|
51 |
+
if not is_cross and self.injection_schedule is not None and (
|
52 |
+
self.t in self.injection_schedule or self.t == 1000):
|
53 |
+
source_batch_size = int(q.shape[0] // 3)
|
54 |
+
# inject pos chunk
|
55 |
+
q[source_batch_size:2 * source_batch_size] = q[:source_batch_size]
|
56 |
+
k[source_batch_size:2 * source_batch_size] = k[:source_batch_size]
|
57 |
+
# inject neg chunk
|
58 |
+
q[2 * source_batch_size:] = q[:source_batch_size]
|
59 |
+
k[2 * source_batch_size:] = k[:source_batch_size]
|
60 |
+
q = self.head_to_batch_dim(q)
|
61 |
+
k = self.head_to_batch_dim(k)
|
62 |
+
v = self.head_to_batch_dim(v)
|
63 |
+
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
64 |
+
if attention_mask is not None:
|
65 |
+
attention_mask = attention_mask.reshape(batch_size, -1)
|
66 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
67 |
+
attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
|
68 |
+
sim.masked_fill_(~attention_mask, max_neg_value)
|
69 |
+
attn = sim.softmax(dim=-1)
|
70 |
+
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
71 |
+
out = self.batch_to_head_dim(out)
|
72 |
+
return to_out(out)
|
73 |
+
return forward
|
74 |
+
res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
|
75 |
+
for res in res_dict:
|
76 |
+
for block in res_dict[res]:
|
77 |
+
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
|
78 |
+
module.forward = sa_forward(module)
|
79 |
+
setattr(module, 'injection_schedule', injection_schedule)
|
80 |
+
|
81 |
+
def register_conv_inj(model, injection_schedule):
|
82 |
+
def conv_forward(self):
|
83 |
+
def forward(input_tensor, temb, scale):
|
84 |
+
hidden_states = input_tensor
|
85 |
+
hidden_states = self.norm1(hidden_states)
|
86 |
+
hidden_states = self.nonlinearity(hidden_states)
|
87 |
+
if self.upsample is not None:
|
88 |
+
if hidden_states.shape[0] >= 64:
|
89 |
+
input_tensor = input_tensor.contiguous()
|
90 |
+
hidden_states = hidden_states.contiguous()
|
91 |
+
input_tensor = self.upsample(input_tensor, scale=scale)
|
92 |
+
hidden_states = self.upsample(hidden_states, scale=scale)
|
93 |
+
elif self.downsample is not None:
|
94 |
+
input_tensor = self.downsample(input_tensor, scale=scale)
|
95 |
+
hidden_states = self.downsample(hidden_states, scale=scale)
|
96 |
+
hidden_states = self.conv1(hidden_states, scale)
|
97 |
+
if temb is not None:
|
98 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
99 |
+
if temb is not None and self.time_embedding_norm == "default":
|
100 |
+
hidden_states = hidden_states + temb
|
101 |
+
hidden_states = self.norm2(hidden_states)
|
102 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
103 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
104 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
105 |
+
hidden_states = self.nonlinearity(hidden_states)
|
106 |
+
hidden_states = self.dropout(hidden_states)
|
107 |
+
hidden_states = self.conv2(hidden_states, scale)
|
108 |
+
if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
|
109 |
+
source_batch_size = int(hidden_states.shape[0] // 3)
|
110 |
+
# inject pos chunk
|
111 |
+
hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size]
|
112 |
+
# inject neg chunk
|
113 |
+
hidden_states[2 * source_batch_size:] = hidden_states[:source_batch_size]
|
114 |
+
if self.conv_shortcut is not None:
|
115 |
+
input_tensor = self.conv_shortcut(input_tensor, scale)
|
116 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
117 |
+
return output_tensor
|
118 |
+
return forward
|
119 |
+
conv_module = model.unet.up_blocks[1].resnets[1]
|
120 |
+
conv_module.forward = conv_forward(conv_module)
|
121 |
+
setattr(conv_module, 'injection_schedule', injection_schedule)
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
accelerate
|
3 |
+
--extra-index-url https://download.pytorch.org/whl/cu121
|
4 |
+
torch==2.1.0
|
5 |
+
torchvision
|
6 |
+
transformers==4.35.2
|
7 |
+
diffusers==0.24.0
|
8 |
+
xformers
|
9 |
+
open_clip_torch
|
10 |
+
clip-interrogator
|