JadenFK commited on
Commit
843b14b
1 Parent(s): 7a07ff9
Files changed (8) hide show
  1. LMSDiscreteScheduler.py +97 -0
  2. StableDiffuser.py +276 -0
  3. __init__.py +0 -0
  4. app.py +158 -65
  5. requirements.txt +5 -1
  6. test.py +18 -5
  7. train_esd.py +10 -8
  8. util.py +107 -0
LMSDiscreteScheduler.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Tuple, Union
3
+
4
+ import torch
5
+ from diffusers.schedulers.scheduling_lms_discrete import \
6
+ LMSDiscreteScheduler as _LMSDiscreteScheduler
7
+ from diffusers.schedulers.scheduling_lms_discrete import \
8
+ LMSDiscreteSchedulerOutput
9
+
10
+
11
+ class LMSDiscreteScheduler(_LMSDiscreteScheduler):
12
+
13
+ def step(
14
+ self,
15
+ model_output: torch.FloatTensor,
16
+ step_index: int,
17
+ sample: torch.FloatTensor,
18
+ order: int = 4,
19
+ return_dict: bool = True,
20
+ ) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
21
+ """
22
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
23
+ process from the learned model outputs (most often the predicted noise).
24
+
25
+ Args:
26
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
27
+ timestep (`float`): current timestep in the diffusion chain.
28
+ sample (`torch.FloatTensor`):
29
+ current instance of sample being created by diffusion process.
30
+ order: coefficient for multi-step inference.
31
+ return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class
32
+
33
+ Returns:
34
+ [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`:
35
+ [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
36
+ When returning a tuple, the first element is the sample tensor.
37
+
38
+ """
39
+ if not self.is_scale_input_called:
40
+ warnings.warn(
41
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
42
+ "See `StableDiffusionPipeline` for a usage example."
43
+ )
44
+
45
+ sigma = self.sigmas[step_index]
46
+
47
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
48
+ if self.config.prediction_type == "epsilon":
49
+ pred_original_sample = sample - sigma * model_output
50
+ elif self.config.prediction_type == "v_prediction":
51
+ # * c_out + input * c_skip
52
+ pred_original_sample = model_output * \
53
+ (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
54
+ else:
55
+ raise ValueError(
56
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
57
+ )
58
+
59
+ # 2. Convert to an ODE derivative
60
+ derivative = (sample - pred_original_sample) / sigma
61
+ self.derivatives.append(derivative)
62
+ if len(self.derivatives) > order:
63
+ self.derivatives.pop(0)
64
+
65
+ # 3. Compute linear multistep coefficients
66
+ order = min(step_index + 1, order)
67
+ lms_coeffs = [self.get_lms_coefficient(
68
+ order, step_index, curr_order) for curr_order in range(order)]
69
+
70
+ # 4. Compute previous sample based on the derivatives path
71
+ prev_sample = sample + sum(
72
+ coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
73
+ )
74
+
75
+ if not return_dict:
76
+ return (prev_sample,)
77
+
78
+ return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
79
+
80
+ def scale_model_input(
81
+ self,
82
+ sample: torch.FloatTensor,
83
+ iteration: int
84
+ ) -> torch.FloatTensor:
85
+ """
86
+ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
87
+
88
+ Args:
89
+ sample (`torch.FloatTensor`): input sample
90
+ timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
91
+
92
+ Returns:
93
+ `torch.FloatTensor`: scaled input sample
94
+ """
95
+ sample = sample / ((self.sigmas[iteration]**2 + 1) ** 0.5)
96
+ self.is_scale_input_called = True
97
+ return sample
StableDiffuser.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from baukit import TraceDict
5
+ from diffusers import AutoencoderKL, UNet2DConditionModel
6
+ from PIL import Image
7
+ from tqdm.auto import tqdm
8
+ from transformers import CLIPTextModel, CLIPTokenizer
9
+
10
+ import util
11
+ from LMSDiscreteScheduler import LMSDiscreteScheduler
12
+
13
+
14
+ def default_parser():
15
+
16
+ parser = argparse.ArgumentParser()
17
+
18
+ parser.add_argument('prompts', type=str, nargs='+')
19
+ parser.add_argument('outpath', type=str)
20
+
21
+ parser.add_argument('--images', type=str, nargs='+', default=None)
22
+ parser.add_argument('--nsteps', type=int, default=1000)
23
+ parser.add_argument('--nimgs', type=int, default=1)
24
+ parser.add_argument('--start_itr', type=int, default=0)
25
+ parser.add_argument('--return_steps', action='store_true', default=False)
26
+ parser.add_argument('--pred_x0', action='store_true', default=False)
27
+ parser.add_argument('--device', type=str, default='cuda:0')
28
+ parser.add_argument('--seed', type=int, default=42)
29
+
30
+ return parser
31
+
32
+
33
+ class StableDiffuser(torch.nn.Module):
34
+
35
+ def __init__(self,
36
+ seed=None
37
+ ):
38
+
39
+ super().__init__()
40
+
41
+ self._seed = seed
42
+
43
+ # Load the autoencoder model which will be used to decode the latents into image space.
44
+ self.vae = AutoencoderKL.from_pretrained(
45
+ "CompVis/stable-diffusion-v1-4", subfolder="vae")
46
+
47
+ # Load the tokenizer and text encoder to tokenize and encode the text.
48
+ self.tokenizer = CLIPTokenizer.from_pretrained(
49
+ "openai/clip-vit-large-patch14")
50
+ self.text_encoder = CLIPTextModel.from_pretrained(
51
+ "openai/clip-vit-large-patch14")
52
+
53
+ # The UNet model for generating the latents.
54
+ self.unet = UNet2DConditionModel.from_pretrained(
55
+ "CompVis/stable-diffusion-v1-4", subfolder="unet")
56
+
57
+ self.scheduler = LMSDiscreteScheduler(
58
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
59
+
60
+ self.generator = torch.Generator()
61
+
62
+ if self._seed is not None:
63
+
64
+ self.seed(seed)
65
+
66
+ self.eval()
67
+
68
+ def seed(self, seed):
69
+
70
+ self.generator = torch.manual_seed(seed)
71
+
72
+ def get_noise(self, batch_size, img_size):
73
+
74
+ param = list(self.parameters())[0]
75
+
76
+ return torch.randn(
77
+ (batch_size, self.unet.in_channels, img_size // 8, img_size // 8),
78
+ generator=self.generator).type(param.dtype).to(param.device)
79
+
80
+ def add_noise(self, latents, noise, step):
81
+
82
+ return self.scheduler.add_noise(latents, noise, torch.tensor([self.scheduler.timesteps[step]]))
83
+
84
+ def text_tokenize(self, prompts):
85
+
86
+ return self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
87
+
88
+ def text_detokenize(self, tokens):
89
+
90
+ return [self.tokenizer.decode(token) for token in tokens if token != self.tokenizer.vocab_size - 1]
91
+
92
+ def text_encode(self, tokens):
93
+
94
+ return self.text_encoder(tokens.input_ids.to(self.unet.device))[0]
95
+
96
+ def decode(self, latents):
97
+
98
+ return self.vae.decode(1 / 0.18215 * latents).sample
99
+
100
+ def encode(self, tensors):
101
+
102
+ return self.vae.encode(tensors).latent_dist.mode() * 0.18215
103
+
104
+ def to_image(self, image):
105
+
106
+ image = (image / 2 + 0.5).clamp(0, 1)
107
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
108
+ images = (image * 255).round().astype("uint8")
109
+ pil_images = [Image.fromarray(image) for image in images]
110
+
111
+ return pil_images
112
+
113
+ def set_scheduler_timesteps(self, n_steps):
114
+ self.scheduler.set_timesteps(n_steps, device=self.unet.device)
115
+
116
+ def get_initial_latents(self, n_imgs, img_size, n_prompts):
117
+
118
+ noise = self.get_noise(n_imgs, img_size).repeat(n_prompts, 1, 1, 1)
119
+
120
+ latents = noise * self.scheduler.init_noise_sigma
121
+
122
+ return latents
123
+
124
+ def get_text_embeddings(self, prompts, n_imgs):
125
+
126
+ text_tokens = self.text_tokenize(prompts)
127
+
128
+ text_embeddings = self.text_encode(text_tokens)
129
+
130
+ unconditional_tokens = self.text_tokenize([""] * len(prompts))
131
+
132
+ unconditional_embeddings = self.text_encode(unconditional_tokens)
133
+
134
+ text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]).repeat_interleave(n_imgs, dim=0)
135
+
136
+ return text_embeddings
137
+
138
+ def predict_noise(self,
139
+ iteration,
140
+ latents,
141
+ text_embeddings,
142
+ guidance_scale=7.5
143
+ ):
144
+
145
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
146
+ latents = torch.cat([latents] * 2)
147
+ latents = self.scheduler.scale_model_input(
148
+ latents, iteration)
149
+
150
+ # predict the noise residual
151
+ noise_prediction = self.unet(
152
+ latents, self.scheduler.timesteps[iteration], encoder_hidden_states=text_embeddings).sample
153
+
154
+ # perform guidance
155
+ noise_prediction_uncond, noise_prediction_text = noise_prediction.chunk(2)
156
+ noise_prediction = noise_prediction_uncond + guidance_scale * \
157
+ (noise_prediction_text - noise_prediction_uncond)
158
+
159
+ return noise_prediction
160
+
161
+ @torch.no_grad()
162
+ def diffusion(self,
163
+ latents,
164
+ text_embeddings,
165
+ end_iteration=1000,
166
+ start_iteration=0,
167
+ return_steps=False,
168
+ pred_x0=False,
169
+ trace_args=None,
170
+ show_progress=True,
171
+ **kwargs):
172
+
173
+ latents_steps = []
174
+ trace_steps = []
175
+
176
+ trace = None
177
+
178
+ for iteration in tqdm(range(start_iteration, end_iteration), disable=not show_progress):
179
+
180
+ if trace_args:
181
+
182
+ trace = TraceDict(self, **trace_args)
183
+
184
+ noise_pred = self.predict_noise(
185
+ iteration,
186
+ latents,
187
+ text_embeddings,
188
+ **kwargs)
189
+
190
+ # compute the previous noisy sample x_t -> x_t-1
191
+ output = self.scheduler.step(noise_pred, iteration, latents)
192
+
193
+ if trace_args:
194
+
195
+ trace.close()
196
+
197
+ trace_steps.append(trace)
198
+
199
+ latents = output.prev_sample
200
+
201
+ if return_steps or iteration == end_iteration - 1:
202
+
203
+ output = output.pred_original_sample if pred_x0 else latents
204
+
205
+ if return_steps:
206
+ latents_steps.append(output.cpu())
207
+ else:
208
+ latents_steps.append(output)
209
+
210
+ return latents_steps, trace_steps
211
+
212
+ @torch.no_grad()
213
+ def __call__(self,
214
+ prompts,
215
+ img_size=512,
216
+ n_steps=50,
217
+ n_imgs=1,
218
+ end_iteration=None,
219
+ reseed=False,
220
+ **kwargs
221
+ ):
222
+
223
+ assert 0 <= n_steps <= 1000
224
+
225
+ if not isinstance(prompts, list):
226
+
227
+ prompts = [prompts]
228
+
229
+ self.set_scheduler_timesteps(n_steps)
230
+
231
+ if reseed:
232
+
233
+ self.seed(self._seed)
234
+
235
+ latents = self.get_initial_latents(n_imgs, img_size, len(prompts))
236
+
237
+ text_embeddings = self.get_text_embeddings(prompts,n_imgs=n_imgs)
238
+
239
+ end_iteration = end_iteration or n_steps
240
+
241
+ latents_steps, trace_steps = self.diffusion(
242
+ latents,
243
+ text_embeddings,
244
+ end_iteration=end_iteration,
245
+ **kwargs
246
+ )
247
+
248
+ latents_steps = [self.decode(latents.to(self.unet.device)) for latents in latents_steps]
249
+ images_steps = [self.to_image(latents) for latents in latents_steps]
250
+
251
+ images_steps = list(zip(*images_steps))
252
+
253
+ if trace_steps:
254
+
255
+ return images_steps, trace_steps
256
+
257
+ return images_steps
258
+
259
+
260
+ if __name__ == '__main__':
261
+
262
+ parser = default_parser()
263
+
264
+ args = parser.parse_args()
265
+
266
+ diffuser = StableDiffuser(seed=args.seed).to(torch.device(args.device)).half()
267
+
268
+ images = diffuser(args.prompts,
269
+ n_steps=args.nsteps,
270
+ n_imgs=args.nimgs,
271
+ start_iteration=args.start_itr,
272
+ return_steps=args.return_steps,
273
+ pred_x0=args.pred_x0
274
+ )
275
+
276
+ util.image_grid(images, args.outpath)
__init__.py ADDED
File without changes
app.py CHANGED
@@ -2,73 +2,166 @@ import sys
2
  sys.path.insert(0,'stable_diffusion')
3
  import gradio as gr
4
  from train_esd import train_esd
5
-
 
 
 
6
 
7
  ckpt_path = "stable-diffusion/models/ldm/sd-v1-4-full-ema.ckpt"
8
  config_path = "stable-diffusion/configs/stable-diffusion/v1-inference.yaml"
9
  diffusers_config_path = "stable-diffusion/config.json"
10
 
11
- def train(prompt, train_method, neg_guidance, iterations, lr):
12
-
13
- train_esd(prompt,
14
- train_method,
15
- 3,
16
- neg_guidance,
17
- iterations,
18
- lr,
19
- config_path,
20
- ckpt_path,
21
- diffusers_config_path,
22
- ['cuda']
23
- )
24
-
25
-
26
- with gr.Blocks() as demo:
27
-
28
- prompt_input = gr.Text(
29
- placeholder="Enter prompt...",
30
- label="Prompt",
31
- info="Prompt corresponding to concept to erase"
32
- )
33
- train_method_input = gr.Dropdown(
34
- choices=['noxattn', 'selfattn', 'xattn', 'full'],
35
- value='xattn',
36
- label='Train Method',
37
- info='Method of training'
38
- )
39
-
40
- neg_guidance_input = gr.Number(
41
- value=1,
42
- label="Negative Guidance",
43
- info='Guidance of negative training used to train'
44
- )
45
-
46
- iterations_input = gr.Number(
47
- value=1000,
48
- precision=0,
49
- label="Iterations",
50
- info='iterations used to train'
51
- )
52
-
53
- lr_input = gr.Number(
54
- value=1e-5,
55
- label="Iterations",
56
- info='Learning rate used to train'
57
- )
58
-
59
- train_button = gr.Button(
60
- value="Train",
61
- )
62
- train_button.click(train, inputs = [
63
- prompt_input,
64
- train_method_input,
65
- neg_guidance_input,
66
- iterations_input,
67
- lr_input
68
- ]
69
- )
70
-
71
-
72
-
73
-
74
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  sys.path.insert(0,'stable_diffusion')
3
  import gradio as gr
4
  from train_esd import train_esd
5
+ from convertModels import convert_ldm_unet_checkpoint, create_unet_diffusers_config
6
+ from omegaconf import OmegaConf
7
+ from StableDiffuser import StableDiffuser
8
+ from diffusers import UNet2DConditionModel
9
 
10
  ckpt_path = "stable-diffusion/models/ldm/sd-v1-4-full-ema.ckpt"
11
  config_path = "stable-diffusion/configs/stable-diffusion/v1-inference.yaml"
12
  diffusers_config_path = "stable-diffusion/config.json"
13
 
14
+
15
+ class Demo:
16
+
17
+ def __init__(self) -> None:
18
+ demo = self.layout()
19
+ demo.launch()
20
+
21
+
22
+ def layout(self):
23
+
24
+ with gr.Blocks() as demo:
25
+
26
+ with gr.Row():
27
+ with gr.Column() as training_column:
28
+ self.prompt_input = gr.Text(
29
+ placeholder="Enter prompt...",
30
+ label="Prompt",
31
+ info="Prompt corresponding to concept to erase"
32
+ )
33
+ self.train_method_input = gr.Dropdown(
34
+ choices=['noxattn', 'selfattn', 'xattn', 'full'],
35
+ value='xattn',
36
+ label='Train Method',
37
+ info='Method of training'
38
+ )
39
+
40
+ self.neg_guidance_input = gr.Number(
41
+ value=1,
42
+ label="Negative Guidance",
43
+ info='Guidance of negative training used to train'
44
+ )
45
+
46
+ self.iterations_input = gr.Number(
47
+ value=1000,
48
+ precision=0,
49
+ label="Iterations",
50
+ info='iterations used to train'
51
+ )
52
+
53
+ self.lr_input = gr.Number(
54
+ value=1e-5,
55
+ label="Learning Rate",
56
+ info='Learning rate used to train'
57
+ )
58
+
59
+ self.train_button = gr.Button(
60
+ value="Train",
61
+ )
62
+ self.train_button.click(self.train, inputs = [
63
+ self.prompt_input,
64
+ self.train_method_input,
65
+ self.neg_guidance_input,
66
+ self.iterations_input,
67
+ self.lr_input
68
+ ]
69
+ )
70
+ with gr.Column() as inference_column:
71
+
72
+ with gr.Row():
73
+
74
+ self.prompt_input_infr = gr.Text(
75
+ placeholder="Enter prompt...",
76
+ label="Prompt",
77
+ info="Prompt corresponding to concept to erase"
78
+ )
79
+
80
+ with gr.Row():
81
+
82
+ self.image_new = gr.Image(
83
+ label="New Image",
84
+ interactive=False
85
+ )
86
+ self.image_orig = gr.Image(
87
+ label="Orig Image",
88
+ interactive=False
89
+ )
90
+
91
+ with gr.Row():
92
+
93
+ self.infr_button = gr.Button(
94
+ value="Generate",
95
+ )
96
+ self.infr_button.click(self.inference, inputs = [
97
+ self.prompt_input_infr,
98
+ ],
99
+ outputs=[
100
+ self.image_new,
101
+ self.image_orig
102
+ ]
103
+ )
104
+ return demo
105
+
106
+
107
+ def train(self, prompt, train_method, neg_guidance, iterations, lr):
108
+
109
+ model_orig, model_edited = train_esd(prompt,
110
+ train_method,
111
+ 3,
112
+ neg_guidance,
113
+ iterations,
114
+ lr,
115
+ config_path,
116
+ ckpt_path,
117
+ diffusers_config_path,
118
+ ['cuda', 'cuda'],
119
+ gr.Progress()
120
+ )
121
+
122
+ original_config = OmegaConf.load(config_path)
123
+ original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = 4
124
+ unet_config = create_unet_diffusers_config(original_config, image_size=512)
125
+ model_edited_sd = convert_ldm_unet_checkpoint(model_edited.state_dict(), unet_config)
126
+ model_orig_sd = convert_ldm_unet_checkpoint(model_orig.state_dict(), unet_config)
127
+
128
+ self.init_inference(model_edited_sd, model_orig_sd, unet_config)
129
+
130
+ def init_inference(self, model_edited_sd, model_orig_sd, unet_config):
131
+
132
+ self.model_edited_sd = model_edited_sd
133
+ self.model_orig_sd = model_orig_sd
134
+
135
+ self.diffuser = StableDiffuser(42)
136
+
137
+ self.diffuser.unet = UNet2DConditionModel(**unet_config)
138
+ self.diffuser.to('cuda')
139
+
140
+
141
+ def inference(self, prompt):
142
+
143
+ self.diffuser.unet.load_state_dict(self.model_orig_sd)
144
+
145
+ images = self.diffuser(
146
+ prompt,
147
+ n_steps=50,
148
+ reseed=True
149
+ )
150
+
151
+ orig_image = images[0][0]
152
+
153
+ self.diffuser.unet.load_state_dict(self.model_edited_sd)
154
+
155
+ images = self.diffuser(
156
+ prompt,
157
+ n_steps=50,
158
+ reseed=True
159
+ )
160
+
161
+ edited_image = images[0][0]
162
+
163
+ return edited_image, orig_image
164
+
165
+
166
+
167
+
requirements.txt CHANGED
@@ -4,4 +4,8 @@ torchvision
4
  einops
5
  diffusers
6
  transformers
7
- pytorch_lightning
 
 
 
 
 
4
  einops
5
  diffusers
6
  transformers
7
+ pytorch_lightning==1.6.5
8
+ taming-transformers
9
+ kornia
10
+ git+https://github.com/openai/CLIP.git@main#egg=clip
11
+ git+https://github.com/davidbau/baukit.git
test.py CHANGED
@@ -1,19 +1,32 @@
1
  import sys
2
  sys.path.insert(0,'stable_diffusion')
3
  from train_esd import train_esd
4
-
5
  ckpt_path = "stable_diffusion/models/ldm/sd-v1-4-full-ema.ckpt"
6
  config_path = "stable_diffusion/configs/stable-diffusion/v1-inference.yaml"
7
  diffusers_config_path = "stable_diffusion/config.json"
8
 
9
- train_esd("England",
10
  'xattn',
11
  3,
12
  1,
13
- 1000,
14
  .003,
15
  config_path,
16
  ckpt_path,
17
  diffusers_config_path,
18
- ['cuda', 'cuda']
19
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import sys
2
  sys.path.insert(0,'stable_diffusion')
3
  from train_esd import train_esd
4
+ import torch
5
  ckpt_path = "stable_diffusion/models/ldm/sd-v1-4-full-ema.ckpt"
6
  config_path = "stable_diffusion/configs/stable-diffusion/v1-inference.yaml"
7
  diffusers_config_path = "stable_diffusion/config.json"
8
 
9
+ orig, newm = train_esd("England",
10
  'xattn',
11
  3,
12
  1,
13
+ 2,
14
  .003,
15
  config_path,
16
  ckpt_path,
17
  diffusers_config_path,
18
+ ['cuda', 'cuda'],
19
+ None
20
+ )
21
+
22
+
23
+ from convertModels import convert_ldm_unet_checkpoint, create_unet_diffusers_config
24
+ from diffusers import UNet2DConditionModel, AutoencoderKL, LMSDiscreteScheduler
25
+ from omegaconf import OmegaConf
26
+ from transformers import CLIPTextModel, CLIPTokenizer
27
+ original_config = OmegaConf.load(config_path)
28
+ original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = 4
29
+ unet_config = create_unet_diffusers_config(original_config, image_size=512)
30
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(newm.state_dict(), unet_config)
31
+ unet = UNet2DConditionModel(**unet_config)
32
+ unet.load_state_dict(converted_unet_checkpoint)
train_esd.py CHANGED
@@ -102,7 +102,7 @@ def get_models(config_path, ckpt_path, devices):
102
 
103
  return model_orig, sampler_orig, model, sampler
104
 
105
- def train_esd(prompt, train_method, start_guidance, negative_guidance, iterations, lr, config_path, ckpt_path, diffusers_config_path, devices, seperator=None, image_size=512, ddim_steps=50):
106
  '''
107
  Function to train diffusion models to erase concepts from model weights
108
 
@@ -251,17 +251,19 @@ def train_esd(prompt, train_method, start_guidance, negative_guidance, iteration
251
  pbar.set_postfix({"loss": loss.item()})
252
  history.append(loss.item())
253
  opt.step()
254
- # save checkpoint and loss curve
255
- if (i+1) % 500 == 0 and i+1 != iterations and i+1>= 500:
256
- save_model(model, name, i-1, save_compvis=True, save_diffusers=False)
257
 
258
- if i % 100 == 0:
259
- save_history(losses, name, word_print)
260
 
261
  model.eval()
262
 
263
- save_model(model, name, None, save_compvis=True, save_diffusers=True, compvis_config_file=config_path, diffusers_config_file=diffusers_config_path)
264
- save_history(losses, name, word_print)
 
 
265
 
266
  def save_model(model, name, num, compvis_config_file=None, diffusers_config_file=None, device='cpu', save_compvis=True, save_diffusers=True):
267
  # SAVE MODEL
 
102
 
103
  return model_orig, sampler_orig, model, sampler
104
 
105
+ def train_esd(prompt, train_method, start_guidance, negative_guidance, iterations, lr, config_path, ckpt_path, diffusers_config_path, devices, progress_bar, seperator=None, image_size=512, ddim_steps=50):
106
  '''
107
  Function to train diffusion models to erase concepts from model weights
108
 
 
251
  pbar.set_postfix({"loss": loss.item()})
252
  history.append(loss.item())
253
  opt.step()
254
+ # # save checkpoint and loss curve
255
+ # if (i+1) % 500 == 0 and i+1 != iterations and i+1>= 500:
256
+ # save_model(model, name, i-1, save_compvis=True, save_diffusers=False)
257
 
258
+ # if i % 100 == 0:
259
+ # save_history(losses, name, word_print)
260
 
261
  model.eval()
262
 
263
+ # save_model(model, name, None, save_compvis=True, save_diffusers=True, compvis_config_file=config_path, diffusers_config_file=diffusers_config_path)
264
+ # save_history(losses, name, word_print)
265
+
266
+ return model_orig, model
267
 
268
  def save_model(model, name, num, compvis_config_file=None, diffusers_config_file=None, device='cpu', save_compvis=True, save_diffusers=True):
269
  # SAVE MODEL
util.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from matplotlib import pyplot as plt
3
+ import textwrap
4
+
5
+
6
+ def to_gif(images, path):
7
+
8
+ images[0].save(path, save_all=True,
9
+ append_images=images[1:], loop=0, duration=len(images) * 20)
10
+
11
+
12
+ def figure_to_image(figure):
13
+
14
+ figure.set_dpi(300)
15
+
16
+ figure.canvas.draw()
17
+
18
+ return Image.frombytes('RGB', figure.canvas.get_width_height(), figure.canvas.tostring_rgb())
19
+
20
+
21
+ def image_grid(images, outpath=None, column_titles=None, row_titles=None):
22
+
23
+ n_rows = len(images)
24
+ n_cols = len(images[0])
25
+
26
+ fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols,
27
+ figsize=(n_cols, n_rows), squeeze=False)
28
+
29
+ for row, _images in enumerate(images):
30
+
31
+ for column, image in enumerate(_images):
32
+ ax = axs[row][column]
33
+ ax.imshow(image)
34
+ if column_titles and row == 0:
35
+ ax.set_title(textwrap.fill(
36
+ column_titles[column], width=12), fontsize='x-small')
37
+ if row_titles and column == 0:
38
+ ax.set_ylabel(row_titles[row], rotation=0, fontsize='x-small', labelpad=1.6 * len(row_titles[row]))
39
+ ax.set_xticks([])
40
+ ax.set_yticks([])
41
+
42
+ plt.subplots_adjust(wspace=0, hspace=0)
43
+
44
+ if outpath is not None:
45
+ plt.savefig(outpath, bbox_inches='tight', dpi=300)
46
+ plt.close()
47
+ else:
48
+ plt.tight_layout(pad=0)
49
+ image = figure_to_image(plt.gcf())
50
+ plt.close()
51
+ return image
52
+
53
+
54
+
55
+
56
+
57
+
58
+
59
+ def get_module(module, module_name):
60
+
61
+ if isinstance(module_name, str):
62
+ module_name = module_name.split('.')
63
+
64
+ if len(module_name) == 0:
65
+ return module
66
+ else:
67
+ module = getattr(module, module_name[0])
68
+ return get_module(module, module_name[1:])
69
+
70
+
71
+ def set_module(module, module_name, new_module):
72
+
73
+ if isinstance(module_name, str):
74
+ module_name = module_name.split('.')
75
+
76
+ if len(module_name) == 1:
77
+ return setattr(module, module_name[0], new_module)
78
+ else:
79
+ module = getattr(module, module_name[0])
80
+ return set_module(module, module_name[1:], new_module)
81
+
82
+
83
+ def freeze(module):
84
+
85
+ for parameter in module.parameters():
86
+
87
+ parameter.requires_grad = False
88
+
89
+
90
+ def unfreeze(module):
91
+
92
+ for parameter in module.parameters():
93
+
94
+ parameter.requires_grad = True
95
+
96
+
97
+ def get_concat_h(im1, im2):
98
+ dst = Image.new('RGB', (im1.width + im2.width, im1.height))
99
+ dst.paste(im1, (0, 0))
100
+ dst.paste(im2, (im1.width, 0))
101
+ return dst
102
+
103
+ def get_concat_v(im1, im2):
104
+ dst = Image.new('RGB', (im1.width, im1.height + im2.height))
105
+ dst.paste(im1, (0, 0))
106
+ dst.paste(im2, (0, im1.height))
107
+ return dst