Spaces:
Runtime error
Runtime error
Inference
Browse files- LMSDiscreteScheduler.py +97 -0
- StableDiffuser.py +276 -0
- __init__.py +0 -0
- app.py +158 -65
- requirements.txt +5 -1
- test.py +18 -5
- train_esd.py +10 -8
- util.py +107 -0
LMSDiscreteScheduler.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from diffusers.schedulers.scheduling_lms_discrete import \
|
6 |
+
LMSDiscreteScheduler as _LMSDiscreteScheduler
|
7 |
+
from diffusers.schedulers.scheduling_lms_discrete import \
|
8 |
+
LMSDiscreteSchedulerOutput
|
9 |
+
|
10 |
+
|
11 |
+
class LMSDiscreteScheduler(_LMSDiscreteScheduler):
|
12 |
+
|
13 |
+
def step(
|
14 |
+
self,
|
15 |
+
model_output: torch.FloatTensor,
|
16 |
+
step_index: int,
|
17 |
+
sample: torch.FloatTensor,
|
18 |
+
order: int = 4,
|
19 |
+
return_dict: bool = True,
|
20 |
+
) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
|
21 |
+
"""
|
22 |
+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
23 |
+
process from the learned model outputs (most often the predicted noise).
|
24 |
+
|
25 |
+
Args:
|
26 |
+
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
27 |
+
timestep (`float`): current timestep in the diffusion chain.
|
28 |
+
sample (`torch.FloatTensor`):
|
29 |
+
current instance of sample being created by diffusion process.
|
30 |
+
order: coefficient for multi-step inference.
|
31 |
+
return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
[`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`:
|
35 |
+
[`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
|
36 |
+
When returning a tuple, the first element is the sample tensor.
|
37 |
+
|
38 |
+
"""
|
39 |
+
if not self.is_scale_input_called:
|
40 |
+
warnings.warn(
|
41 |
+
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
42 |
+
"See `StableDiffusionPipeline` for a usage example."
|
43 |
+
)
|
44 |
+
|
45 |
+
sigma = self.sigmas[step_index]
|
46 |
+
|
47 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
48 |
+
if self.config.prediction_type == "epsilon":
|
49 |
+
pred_original_sample = sample - sigma * model_output
|
50 |
+
elif self.config.prediction_type == "v_prediction":
|
51 |
+
# * c_out + input * c_skip
|
52 |
+
pred_original_sample = model_output * \
|
53 |
+
(-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
54 |
+
else:
|
55 |
+
raise ValueError(
|
56 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
57 |
+
)
|
58 |
+
|
59 |
+
# 2. Convert to an ODE derivative
|
60 |
+
derivative = (sample - pred_original_sample) / sigma
|
61 |
+
self.derivatives.append(derivative)
|
62 |
+
if len(self.derivatives) > order:
|
63 |
+
self.derivatives.pop(0)
|
64 |
+
|
65 |
+
# 3. Compute linear multistep coefficients
|
66 |
+
order = min(step_index + 1, order)
|
67 |
+
lms_coeffs = [self.get_lms_coefficient(
|
68 |
+
order, step_index, curr_order) for curr_order in range(order)]
|
69 |
+
|
70 |
+
# 4. Compute previous sample based on the derivatives path
|
71 |
+
prev_sample = sample + sum(
|
72 |
+
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
|
73 |
+
)
|
74 |
+
|
75 |
+
if not return_dict:
|
76 |
+
return (prev_sample,)
|
77 |
+
|
78 |
+
return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
79 |
+
|
80 |
+
def scale_model_input(
|
81 |
+
self,
|
82 |
+
sample: torch.FloatTensor,
|
83 |
+
iteration: int
|
84 |
+
) -> torch.FloatTensor:
|
85 |
+
"""
|
86 |
+
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
sample (`torch.FloatTensor`): input sample
|
90 |
+
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
`torch.FloatTensor`: scaled input sample
|
94 |
+
"""
|
95 |
+
sample = sample / ((self.sigmas[iteration]**2 + 1) ** 0.5)
|
96 |
+
self.is_scale_input_called = True
|
97 |
+
return sample
|
StableDiffuser.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from baukit import TraceDict
|
5 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel
|
6 |
+
from PIL import Image
|
7 |
+
from tqdm.auto import tqdm
|
8 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
9 |
+
|
10 |
+
import util
|
11 |
+
from LMSDiscreteScheduler import LMSDiscreteScheduler
|
12 |
+
|
13 |
+
|
14 |
+
def default_parser():
|
15 |
+
|
16 |
+
parser = argparse.ArgumentParser()
|
17 |
+
|
18 |
+
parser.add_argument('prompts', type=str, nargs='+')
|
19 |
+
parser.add_argument('outpath', type=str)
|
20 |
+
|
21 |
+
parser.add_argument('--images', type=str, nargs='+', default=None)
|
22 |
+
parser.add_argument('--nsteps', type=int, default=1000)
|
23 |
+
parser.add_argument('--nimgs', type=int, default=1)
|
24 |
+
parser.add_argument('--start_itr', type=int, default=0)
|
25 |
+
parser.add_argument('--return_steps', action='store_true', default=False)
|
26 |
+
parser.add_argument('--pred_x0', action='store_true', default=False)
|
27 |
+
parser.add_argument('--device', type=str, default='cuda:0')
|
28 |
+
parser.add_argument('--seed', type=int, default=42)
|
29 |
+
|
30 |
+
return parser
|
31 |
+
|
32 |
+
|
33 |
+
class StableDiffuser(torch.nn.Module):
|
34 |
+
|
35 |
+
def __init__(self,
|
36 |
+
seed=None
|
37 |
+
):
|
38 |
+
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self._seed = seed
|
42 |
+
|
43 |
+
# Load the autoencoder model which will be used to decode the latents into image space.
|
44 |
+
self.vae = AutoencoderKL.from_pretrained(
|
45 |
+
"CompVis/stable-diffusion-v1-4", subfolder="vae")
|
46 |
+
|
47 |
+
# Load the tokenizer and text encoder to tokenize and encode the text.
|
48 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(
|
49 |
+
"openai/clip-vit-large-patch14")
|
50 |
+
self.text_encoder = CLIPTextModel.from_pretrained(
|
51 |
+
"openai/clip-vit-large-patch14")
|
52 |
+
|
53 |
+
# The UNet model for generating the latents.
|
54 |
+
self.unet = UNet2DConditionModel.from_pretrained(
|
55 |
+
"CompVis/stable-diffusion-v1-4", subfolder="unet")
|
56 |
+
|
57 |
+
self.scheduler = LMSDiscreteScheduler(
|
58 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
59 |
+
|
60 |
+
self.generator = torch.Generator()
|
61 |
+
|
62 |
+
if self._seed is not None:
|
63 |
+
|
64 |
+
self.seed(seed)
|
65 |
+
|
66 |
+
self.eval()
|
67 |
+
|
68 |
+
def seed(self, seed):
|
69 |
+
|
70 |
+
self.generator = torch.manual_seed(seed)
|
71 |
+
|
72 |
+
def get_noise(self, batch_size, img_size):
|
73 |
+
|
74 |
+
param = list(self.parameters())[0]
|
75 |
+
|
76 |
+
return torch.randn(
|
77 |
+
(batch_size, self.unet.in_channels, img_size // 8, img_size // 8),
|
78 |
+
generator=self.generator).type(param.dtype).to(param.device)
|
79 |
+
|
80 |
+
def add_noise(self, latents, noise, step):
|
81 |
+
|
82 |
+
return self.scheduler.add_noise(latents, noise, torch.tensor([self.scheduler.timesteps[step]]))
|
83 |
+
|
84 |
+
def text_tokenize(self, prompts):
|
85 |
+
|
86 |
+
return self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
87 |
+
|
88 |
+
def text_detokenize(self, tokens):
|
89 |
+
|
90 |
+
return [self.tokenizer.decode(token) for token in tokens if token != self.tokenizer.vocab_size - 1]
|
91 |
+
|
92 |
+
def text_encode(self, tokens):
|
93 |
+
|
94 |
+
return self.text_encoder(tokens.input_ids.to(self.unet.device))[0]
|
95 |
+
|
96 |
+
def decode(self, latents):
|
97 |
+
|
98 |
+
return self.vae.decode(1 / 0.18215 * latents).sample
|
99 |
+
|
100 |
+
def encode(self, tensors):
|
101 |
+
|
102 |
+
return self.vae.encode(tensors).latent_dist.mode() * 0.18215
|
103 |
+
|
104 |
+
def to_image(self, image):
|
105 |
+
|
106 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
107 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
108 |
+
images = (image * 255).round().astype("uint8")
|
109 |
+
pil_images = [Image.fromarray(image) for image in images]
|
110 |
+
|
111 |
+
return pil_images
|
112 |
+
|
113 |
+
def set_scheduler_timesteps(self, n_steps):
|
114 |
+
self.scheduler.set_timesteps(n_steps, device=self.unet.device)
|
115 |
+
|
116 |
+
def get_initial_latents(self, n_imgs, img_size, n_prompts):
|
117 |
+
|
118 |
+
noise = self.get_noise(n_imgs, img_size).repeat(n_prompts, 1, 1, 1)
|
119 |
+
|
120 |
+
latents = noise * self.scheduler.init_noise_sigma
|
121 |
+
|
122 |
+
return latents
|
123 |
+
|
124 |
+
def get_text_embeddings(self, prompts, n_imgs):
|
125 |
+
|
126 |
+
text_tokens = self.text_tokenize(prompts)
|
127 |
+
|
128 |
+
text_embeddings = self.text_encode(text_tokens)
|
129 |
+
|
130 |
+
unconditional_tokens = self.text_tokenize([""] * len(prompts))
|
131 |
+
|
132 |
+
unconditional_embeddings = self.text_encode(unconditional_tokens)
|
133 |
+
|
134 |
+
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]).repeat_interleave(n_imgs, dim=0)
|
135 |
+
|
136 |
+
return text_embeddings
|
137 |
+
|
138 |
+
def predict_noise(self,
|
139 |
+
iteration,
|
140 |
+
latents,
|
141 |
+
text_embeddings,
|
142 |
+
guidance_scale=7.5
|
143 |
+
):
|
144 |
+
|
145 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
146 |
+
latents = torch.cat([latents] * 2)
|
147 |
+
latents = self.scheduler.scale_model_input(
|
148 |
+
latents, iteration)
|
149 |
+
|
150 |
+
# predict the noise residual
|
151 |
+
noise_prediction = self.unet(
|
152 |
+
latents, self.scheduler.timesteps[iteration], encoder_hidden_states=text_embeddings).sample
|
153 |
+
|
154 |
+
# perform guidance
|
155 |
+
noise_prediction_uncond, noise_prediction_text = noise_prediction.chunk(2)
|
156 |
+
noise_prediction = noise_prediction_uncond + guidance_scale * \
|
157 |
+
(noise_prediction_text - noise_prediction_uncond)
|
158 |
+
|
159 |
+
return noise_prediction
|
160 |
+
|
161 |
+
@torch.no_grad()
|
162 |
+
def diffusion(self,
|
163 |
+
latents,
|
164 |
+
text_embeddings,
|
165 |
+
end_iteration=1000,
|
166 |
+
start_iteration=0,
|
167 |
+
return_steps=False,
|
168 |
+
pred_x0=False,
|
169 |
+
trace_args=None,
|
170 |
+
show_progress=True,
|
171 |
+
**kwargs):
|
172 |
+
|
173 |
+
latents_steps = []
|
174 |
+
trace_steps = []
|
175 |
+
|
176 |
+
trace = None
|
177 |
+
|
178 |
+
for iteration in tqdm(range(start_iteration, end_iteration), disable=not show_progress):
|
179 |
+
|
180 |
+
if trace_args:
|
181 |
+
|
182 |
+
trace = TraceDict(self, **trace_args)
|
183 |
+
|
184 |
+
noise_pred = self.predict_noise(
|
185 |
+
iteration,
|
186 |
+
latents,
|
187 |
+
text_embeddings,
|
188 |
+
**kwargs)
|
189 |
+
|
190 |
+
# compute the previous noisy sample x_t -> x_t-1
|
191 |
+
output = self.scheduler.step(noise_pred, iteration, latents)
|
192 |
+
|
193 |
+
if trace_args:
|
194 |
+
|
195 |
+
trace.close()
|
196 |
+
|
197 |
+
trace_steps.append(trace)
|
198 |
+
|
199 |
+
latents = output.prev_sample
|
200 |
+
|
201 |
+
if return_steps or iteration == end_iteration - 1:
|
202 |
+
|
203 |
+
output = output.pred_original_sample if pred_x0 else latents
|
204 |
+
|
205 |
+
if return_steps:
|
206 |
+
latents_steps.append(output.cpu())
|
207 |
+
else:
|
208 |
+
latents_steps.append(output)
|
209 |
+
|
210 |
+
return latents_steps, trace_steps
|
211 |
+
|
212 |
+
@torch.no_grad()
|
213 |
+
def __call__(self,
|
214 |
+
prompts,
|
215 |
+
img_size=512,
|
216 |
+
n_steps=50,
|
217 |
+
n_imgs=1,
|
218 |
+
end_iteration=None,
|
219 |
+
reseed=False,
|
220 |
+
**kwargs
|
221 |
+
):
|
222 |
+
|
223 |
+
assert 0 <= n_steps <= 1000
|
224 |
+
|
225 |
+
if not isinstance(prompts, list):
|
226 |
+
|
227 |
+
prompts = [prompts]
|
228 |
+
|
229 |
+
self.set_scheduler_timesteps(n_steps)
|
230 |
+
|
231 |
+
if reseed:
|
232 |
+
|
233 |
+
self.seed(self._seed)
|
234 |
+
|
235 |
+
latents = self.get_initial_latents(n_imgs, img_size, len(prompts))
|
236 |
+
|
237 |
+
text_embeddings = self.get_text_embeddings(prompts,n_imgs=n_imgs)
|
238 |
+
|
239 |
+
end_iteration = end_iteration or n_steps
|
240 |
+
|
241 |
+
latents_steps, trace_steps = self.diffusion(
|
242 |
+
latents,
|
243 |
+
text_embeddings,
|
244 |
+
end_iteration=end_iteration,
|
245 |
+
**kwargs
|
246 |
+
)
|
247 |
+
|
248 |
+
latents_steps = [self.decode(latents.to(self.unet.device)) for latents in latents_steps]
|
249 |
+
images_steps = [self.to_image(latents) for latents in latents_steps]
|
250 |
+
|
251 |
+
images_steps = list(zip(*images_steps))
|
252 |
+
|
253 |
+
if trace_steps:
|
254 |
+
|
255 |
+
return images_steps, trace_steps
|
256 |
+
|
257 |
+
return images_steps
|
258 |
+
|
259 |
+
|
260 |
+
if __name__ == '__main__':
|
261 |
+
|
262 |
+
parser = default_parser()
|
263 |
+
|
264 |
+
args = parser.parse_args()
|
265 |
+
|
266 |
+
diffuser = StableDiffuser(seed=args.seed).to(torch.device(args.device)).half()
|
267 |
+
|
268 |
+
images = diffuser(args.prompts,
|
269 |
+
n_steps=args.nsteps,
|
270 |
+
n_imgs=args.nimgs,
|
271 |
+
start_iteration=args.start_itr,
|
272 |
+
return_steps=args.return_steps,
|
273 |
+
pred_x0=args.pred_x0
|
274 |
+
)
|
275 |
+
|
276 |
+
util.image_grid(images, args.outpath)
|
__init__.py
ADDED
File without changes
|
app.py
CHANGED
@@ -2,73 +2,166 @@ import sys
|
|
2 |
sys.path.insert(0,'stable_diffusion')
|
3 |
import gradio as gr
|
4 |
from train_esd import train_esd
|
5 |
-
|
|
|
|
|
|
|
6 |
|
7 |
ckpt_path = "stable-diffusion/models/ldm/sd-v1-4-full-ema.ckpt"
|
8 |
config_path = "stable-diffusion/configs/stable-diffusion/v1-inference.yaml"
|
9 |
diffusers_config_path = "stable-diffusion/config.json"
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
sys.path.insert(0,'stable_diffusion')
|
3 |
import gradio as gr
|
4 |
from train_esd import train_esd
|
5 |
+
from convertModels import convert_ldm_unet_checkpoint, create_unet_diffusers_config
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
from StableDiffuser import StableDiffuser
|
8 |
+
from diffusers import UNet2DConditionModel
|
9 |
|
10 |
ckpt_path = "stable-diffusion/models/ldm/sd-v1-4-full-ema.ckpt"
|
11 |
config_path = "stable-diffusion/configs/stable-diffusion/v1-inference.yaml"
|
12 |
diffusers_config_path = "stable-diffusion/config.json"
|
13 |
|
14 |
+
|
15 |
+
class Demo:
|
16 |
+
|
17 |
+
def __init__(self) -> None:
|
18 |
+
demo = self.layout()
|
19 |
+
demo.launch()
|
20 |
+
|
21 |
+
|
22 |
+
def layout(self):
|
23 |
+
|
24 |
+
with gr.Blocks() as demo:
|
25 |
+
|
26 |
+
with gr.Row():
|
27 |
+
with gr.Column() as training_column:
|
28 |
+
self.prompt_input = gr.Text(
|
29 |
+
placeholder="Enter prompt...",
|
30 |
+
label="Prompt",
|
31 |
+
info="Prompt corresponding to concept to erase"
|
32 |
+
)
|
33 |
+
self.train_method_input = gr.Dropdown(
|
34 |
+
choices=['noxattn', 'selfattn', 'xattn', 'full'],
|
35 |
+
value='xattn',
|
36 |
+
label='Train Method',
|
37 |
+
info='Method of training'
|
38 |
+
)
|
39 |
+
|
40 |
+
self.neg_guidance_input = gr.Number(
|
41 |
+
value=1,
|
42 |
+
label="Negative Guidance",
|
43 |
+
info='Guidance of negative training used to train'
|
44 |
+
)
|
45 |
+
|
46 |
+
self.iterations_input = gr.Number(
|
47 |
+
value=1000,
|
48 |
+
precision=0,
|
49 |
+
label="Iterations",
|
50 |
+
info='iterations used to train'
|
51 |
+
)
|
52 |
+
|
53 |
+
self.lr_input = gr.Number(
|
54 |
+
value=1e-5,
|
55 |
+
label="Learning Rate",
|
56 |
+
info='Learning rate used to train'
|
57 |
+
)
|
58 |
+
|
59 |
+
self.train_button = gr.Button(
|
60 |
+
value="Train",
|
61 |
+
)
|
62 |
+
self.train_button.click(self.train, inputs = [
|
63 |
+
self.prompt_input,
|
64 |
+
self.train_method_input,
|
65 |
+
self.neg_guidance_input,
|
66 |
+
self.iterations_input,
|
67 |
+
self.lr_input
|
68 |
+
]
|
69 |
+
)
|
70 |
+
with gr.Column() as inference_column:
|
71 |
+
|
72 |
+
with gr.Row():
|
73 |
+
|
74 |
+
self.prompt_input_infr = gr.Text(
|
75 |
+
placeholder="Enter prompt...",
|
76 |
+
label="Prompt",
|
77 |
+
info="Prompt corresponding to concept to erase"
|
78 |
+
)
|
79 |
+
|
80 |
+
with gr.Row():
|
81 |
+
|
82 |
+
self.image_new = gr.Image(
|
83 |
+
label="New Image",
|
84 |
+
interactive=False
|
85 |
+
)
|
86 |
+
self.image_orig = gr.Image(
|
87 |
+
label="Orig Image",
|
88 |
+
interactive=False
|
89 |
+
)
|
90 |
+
|
91 |
+
with gr.Row():
|
92 |
+
|
93 |
+
self.infr_button = gr.Button(
|
94 |
+
value="Generate",
|
95 |
+
)
|
96 |
+
self.infr_button.click(self.inference, inputs = [
|
97 |
+
self.prompt_input_infr,
|
98 |
+
],
|
99 |
+
outputs=[
|
100 |
+
self.image_new,
|
101 |
+
self.image_orig
|
102 |
+
]
|
103 |
+
)
|
104 |
+
return demo
|
105 |
+
|
106 |
+
|
107 |
+
def train(self, prompt, train_method, neg_guidance, iterations, lr):
|
108 |
+
|
109 |
+
model_orig, model_edited = train_esd(prompt,
|
110 |
+
train_method,
|
111 |
+
3,
|
112 |
+
neg_guidance,
|
113 |
+
iterations,
|
114 |
+
lr,
|
115 |
+
config_path,
|
116 |
+
ckpt_path,
|
117 |
+
diffusers_config_path,
|
118 |
+
['cuda', 'cuda'],
|
119 |
+
gr.Progress()
|
120 |
+
)
|
121 |
+
|
122 |
+
original_config = OmegaConf.load(config_path)
|
123 |
+
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = 4
|
124 |
+
unet_config = create_unet_diffusers_config(original_config, image_size=512)
|
125 |
+
model_edited_sd = convert_ldm_unet_checkpoint(model_edited.state_dict(), unet_config)
|
126 |
+
model_orig_sd = convert_ldm_unet_checkpoint(model_orig.state_dict(), unet_config)
|
127 |
+
|
128 |
+
self.init_inference(model_edited_sd, model_orig_sd, unet_config)
|
129 |
+
|
130 |
+
def init_inference(self, model_edited_sd, model_orig_sd, unet_config):
|
131 |
+
|
132 |
+
self.model_edited_sd = model_edited_sd
|
133 |
+
self.model_orig_sd = model_orig_sd
|
134 |
+
|
135 |
+
self.diffuser = StableDiffuser(42)
|
136 |
+
|
137 |
+
self.diffuser.unet = UNet2DConditionModel(**unet_config)
|
138 |
+
self.diffuser.to('cuda')
|
139 |
+
|
140 |
+
|
141 |
+
def inference(self, prompt):
|
142 |
+
|
143 |
+
self.diffuser.unet.load_state_dict(self.model_orig_sd)
|
144 |
+
|
145 |
+
images = self.diffuser(
|
146 |
+
prompt,
|
147 |
+
n_steps=50,
|
148 |
+
reseed=True
|
149 |
+
)
|
150 |
+
|
151 |
+
orig_image = images[0][0]
|
152 |
+
|
153 |
+
self.diffuser.unet.load_state_dict(self.model_edited_sd)
|
154 |
+
|
155 |
+
images = self.diffuser(
|
156 |
+
prompt,
|
157 |
+
n_steps=50,
|
158 |
+
reseed=True
|
159 |
+
)
|
160 |
+
|
161 |
+
edited_image = images[0][0]
|
162 |
+
|
163 |
+
return edited_image, orig_image
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
|
requirements.txt
CHANGED
@@ -4,4 +4,8 @@ torchvision
|
|
4 |
einops
|
5 |
diffusers
|
6 |
transformers
|
7 |
-
pytorch_lightning
|
|
|
|
|
|
|
|
|
|
4 |
einops
|
5 |
diffusers
|
6 |
transformers
|
7 |
+
pytorch_lightning==1.6.5
|
8 |
+
taming-transformers
|
9 |
+
kornia
|
10 |
+
git+https://github.com/openai/CLIP.git@main#egg=clip
|
11 |
+
git+https://github.com/davidbau/baukit.git
|
test.py
CHANGED
@@ -1,19 +1,32 @@
|
|
1 |
import sys
|
2 |
sys.path.insert(0,'stable_diffusion')
|
3 |
from train_esd import train_esd
|
4 |
-
|
5 |
ckpt_path = "stable_diffusion/models/ldm/sd-v1-4-full-ema.ckpt"
|
6 |
config_path = "stable_diffusion/configs/stable-diffusion/v1-inference.yaml"
|
7 |
diffusers_config_path = "stable_diffusion/config.json"
|
8 |
|
9 |
-
train_esd("England",
|
10 |
'xattn',
|
11 |
3,
|
12 |
1,
|
13 |
-
|
14 |
.003,
|
15 |
config_path,
|
16 |
ckpt_path,
|
17 |
diffusers_config_path,
|
18 |
-
['cuda', 'cuda']
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import sys
|
2 |
sys.path.insert(0,'stable_diffusion')
|
3 |
from train_esd import train_esd
|
4 |
+
import torch
|
5 |
ckpt_path = "stable_diffusion/models/ldm/sd-v1-4-full-ema.ckpt"
|
6 |
config_path = "stable_diffusion/configs/stable-diffusion/v1-inference.yaml"
|
7 |
diffusers_config_path = "stable_diffusion/config.json"
|
8 |
|
9 |
+
orig, newm = train_esd("England",
|
10 |
'xattn',
|
11 |
3,
|
12 |
1,
|
13 |
+
2,
|
14 |
.003,
|
15 |
config_path,
|
16 |
ckpt_path,
|
17 |
diffusers_config_path,
|
18 |
+
['cuda', 'cuda'],
|
19 |
+
None
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
from convertModels import convert_ldm_unet_checkpoint, create_unet_diffusers_config
|
24 |
+
from diffusers import UNet2DConditionModel, AutoencoderKL, LMSDiscreteScheduler
|
25 |
+
from omegaconf import OmegaConf
|
26 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
27 |
+
original_config = OmegaConf.load(config_path)
|
28 |
+
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = 4
|
29 |
+
unet_config = create_unet_diffusers_config(original_config, image_size=512)
|
30 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(newm.state_dict(), unet_config)
|
31 |
+
unet = UNet2DConditionModel(**unet_config)
|
32 |
+
unet.load_state_dict(converted_unet_checkpoint)
|
train_esd.py
CHANGED
@@ -102,7 +102,7 @@ def get_models(config_path, ckpt_path, devices):
|
|
102 |
|
103 |
return model_orig, sampler_orig, model, sampler
|
104 |
|
105 |
-
def train_esd(prompt, train_method, start_guidance, negative_guidance, iterations, lr, config_path, ckpt_path, diffusers_config_path, devices, seperator=None, image_size=512, ddim_steps=50):
|
106 |
'''
|
107 |
Function to train diffusion models to erase concepts from model weights
|
108 |
|
@@ -251,17 +251,19 @@ def train_esd(prompt, train_method, start_guidance, negative_guidance, iteration
|
|
251 |
pbar.set_postfix({"loss": loss.item()})
|
252 |
history.append(loss.item())
|
253 |
opt.step()
|
254 |
-
# save checkpoint and loss curve
|
255 |
-
if (i+1) % 500 == 0 and i+1 != iterations and i+1>= 500:
|
256 |
-
|
257 |
|
258 |
-
if i % 100 == 0:
|
259 |
-
|
260 |
|
261 |
model.eval()
|
262 |
|
263 |
-
save_model(model, name, None, save_compvis=True, save_diffusers=True, compvis_config_file=config_path, diffusers_config_file=diffusers_config_path)
|
264 |
-
save_history(losses, name, word_print)
|
|
|
|
|
265 |
|
266 |
def save_model(model, name, num, compvis_config_file=None, diffusers_config_file=None, device='cpu', save_compvis=True, save_diffusers=True):
|
267 |
# SAVE MODEL
|
|
|
102 |
|
103 |
return model_orig, sampler_orig, model, sampler
|
104 |
|
105 |
+
def train_esd(prompt, train_method, start_guidance, negative_guidance, iterations, lr, config_path, ckpt_path, diffusers_config_path, devices, progress_bar, seperator=None, image_size=512, ddim_steps=50):
|
106 |
'''
|
107 |
Function to train diffusion models to erase concepts from model weights
|
108 |
|
|
|
251 |
pbar.set_postfix({"loss": loss.item()})
|
252 |
history.append(loss.item())
|
253 |
opt.step()
|
254 |
+
# # save checkpoint and loss curve
|
255 |
+
# if (i+1) % 500 == 0 and i+1 != iterations and i+1>= 500:
|
256 |
+
# save_model(model, name, i-1, save_compvis=True, save_diffusers=False)
|
257 |
|
258 |
+
# if i % 100 == 0:
|
259 |
+
# save_history(losses, name, word_print)
|
260 |
|
261 |
model.eval()
|
262 |
|
263 |
+
# save_model(model, name, None, save_compvis=True, save_diffusers=True, compvis_config_file=config_path, diffusers_config_file=diffusers_config_path)
|
264 |
+
# save_history(losses, name, word_print)
|
265 |
+
|
266 |
+
return model_orig, model
|
267 |
|
268 |
def save_model(model, name, num, compvis_config_file=None, diffusers_config_file=None, device='cpu', save_compvis=True, save_diffusers=True):
|
269 |
# SAVE MODEL
|
util.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from matplotlib import pyplot as plt
|
3 |
+
import textwrap
|
4 |
+
|
5 |
+
|
6 |
+
def to_gif(images, path):
|
7 |
+
|
8 |
+
images[0].save(path, save_all=True,
|
9 |
+
append_images=images[1:], loop=0, duration=len(images) * 20)
|
10 |
+
|
11 |
+
|
12 |
+
def figure_to_image(figure):
|
13 |
+
|
14 |
+
figure.set_dpi(300)
|
15 |
+
|
16 |
+
figure.canvas.draw()
|
17 |
+
|
18 |
+
return Image.frombytes('RGB', figure.canvas.get_width_height(), figure.canvas.tostring_rgb())
|
19 |
+
|
20 |
+
|
21 |
+
def image_grid(images, outpath=None, column_titles=None, row_titles=None):
|
22 |
+
|
23 |
+
n_rows = len(images)
|
24 |
+
n_cols = len(images[0])
|
25 |
+
|
26 |
+
fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols,
|
27 |
+
figsize=(n_cols, n_rows), squeeze=False)
|
28 |
+
|
29 |
+
for row, _images in enumerate(images):
|
30 |
+
|
31 |
+
for column, image in enumerate(_images):
|
32 |
+
ax = axs[row][column]
|
33 |
+
ax.imshow(image)
|
34 |
+
if column_titles and row == 0:
|
35 |
+
ax.set_title(textwrap.fill(
|
36 |
+
column_titles[column], width=12), fontsize='x-small')
|
37 |
+
if row_titles and column == 0:
|
38 |
+
ax.set_ylabel(row_titles[row], rotation=0, fontsize='x-small', labelpad=1.6 * len(row_titles[row]))
|
39 |
+
ax.set_xticks([])
|
40 |
+
ax.set_yticks([])
|
41 |
+
|
42 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
43 |
+
|
44 |
+
if outpath is not None:
|
45 |
+
plt.savefig(outpath, bbox_inches='tight', dpi=300)
|
46 |
+
plt.close()
|
47 |
+
else:
|
48 |
+
plt.tight_layout(pad=0)
|
49 |
+
image = figure_to_image(plt.gcf())
|
50 |
+
plt.close()
|
51 |
+
return image
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
def get_module(module, module_name):
|
60 |
+
|
61 |
+
if isinstance(module_name, str):
|
62 |
+
module_name = module_name.split('.')
|
63 |
+
|
64 |
+
if len(module_name) == 0:
|
65 |
+
return module
|
66 |
+
else:
|
67 |
+
module = getattr(module, module_name[0])
|
68 |
+
return get_module(module, module_name[1:])
|
69 |
+
|
70 |
+
|
71 |
+
def set_module(module, module_name, new_module):
|
72 |
+
|
73 |
+
if isinstance(module_name, str):
|
74 |
+
module_name = module_name.split('.')
|
75 |
+
|
76 |
+
if len(module_name) == 1:
|
77 |
+
return setattr(module, module_name[0], new_module)
|
78 |
+
else:
|
79 |
+
module = getattr(module, module_name[0])
|
80 |
+
return set_module(module, module_name[1:], new_module)
|
81 |
+
|
82 |
+
|
83 |
+
def freeze(module):
|
84 |
+
|
85 |
+
for parameter in module.parameters():
|
86 |
+
|
87 |
+
parameter.requires_grad = False
|
88 |
+
|
89 |
+
|
90 |
+
def unfreeze(module):
|
91 |
+
|
92 |
+
for parameter in module.parameters():
|
93 |
+
|
94 |
+
parameter.requires_grad = True
|
95 |
+
|
96 |
+
|
97 |
+
def get_concat_h(im1, im2):
|
98 |
+
dst = Image.new('RGB', (im1.width + im2.width, im1.height))
|
99 |
+
dst.paste(im1, (0, 0))
|
100 |
+
dst.paste(im2, (im1.width, 0))
|
101 |
+
return dst
|
102 |
+
|
103 |
+
def get_concat_v(im1, im2):
|
104 |
+
dst = Image.new('RGB', (im1.width, im1.height + im2.height))
|
105 |
+
dst.paste(im1, (0, 0))
|
106 |
+
dst.paste(im2, (0, im1.height))
|
107 |
+
return dst
|