mhkj0 commited on
Commit
185cba8
1 Parent(s): 1021988
Files changed (1) hide show
  1. mk +352 -0
mk ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, glob
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from omegaconf import OmegaConf
6
+ from PIL import Image
7
+ from tqdm import tqdm, trange
8
+ from imwatermark import WatermarkEncoder
9
+ from itertools import islice
10
+ from einops import rearrange
11
+ from torchvision.utils import make_grid
12
+ import time
13
+ from pytorch_lightning import seed_everything
14
+ from torch import autocast
15
+ from contextlib import contextmanager, nullcontext
16
+
17
+ from ldm.util import instantiate_from_config
18
+ from ldm.models.diffusion.ddim import DDIMSampler
19
+ from ldm.models.diffusion.plms import PLMSSampler
20
+ from ldm.models.diffusion.dpm_solver import DPMSolverSampler
21
+
22
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
23
+ from transformers import AutoFeatureExtractor
24
+
25
+
26
+ # load safety model
27
+ safety_model_id = "CompVis/stable-diffusion-safety-checker"
28
+ safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
29
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
30
+
31
+
32
+ def chunk(it, size):
33
+ it = iter(it)
34
+ return iter(lambda: tuple(islice(it, size)), ())
35
+
36
+
37
+ def numpy_to_pil(images):
38
+ """
39
+ Convert a numpy image or a batch of images to a PIL image.
40
+ """
41
+ if images.ndim == 3:
42
+ images = images[None, ...]
43
+ images = (images * 255).round().astype("uint8")
44
+ pil_images = [Image.fromarray(image) for image in images]
45
+
46
+ return pil_images
47
+
48
+
49
+ def load_model_from_config(config, ckpt, verbose=False):
50
+ print(f"Loading model from {ckpt}")
51
+ pl_sd = torch.load(ckpt, map_location="cpu")
52
+ if "global_step" in pl_sd:
53
+ print(f"Global Step: {pl_sd['global_step']}")
54
+ sd = pl_sd["state_dict"]
55
+ model = instantiate_from_config(config.model)
56
+ m, u = model.load_state_dict(sd, strict=False)
57
+ if len(m) > 0 and verbose:
58
+ print("missing keys:")
59
+ print(m)
60
+ if len(u) > 0 and verbose:
61
+ print("unexpected keys:")
62
+ print(u)
63
+
64
+ model.cuda()
65
+ model.eval()
66
+ return model
67
+
68
+
69
+ def put_watermark(img, wm_encoder=None):
70
+ if wm_encoder is not None:
71
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
72
+ img = wm_encoder.encode(img, 'dwtDct')
73
+ img = Image.fromarray(img[:, :, ::-1])
74
+ return img
75
+
76
+
77
+ def load_replacement(x):
78
+ try:
79
+ hwc = x.shape
80
+ y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
81
+ y = (np.array(y)/255.0).astype(x.dtype)
82
+ assert y.shape == x.shape
83
+ return y
84
+ except Exception:
85
+ return x
86
+
87
+
88
+ def check_safety(x_image):
89
+ safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
90
+ x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
91
+ assert x_checked_image.shape[0] == len(has_nsfw_concept)
92
+ for i in range(len(has_nsfw_concept)):
93
+ if has_nsfw_concept[i]:
94
+ x_checked_image[i] = load_replacement(x_checked_image[i])
95
+ return x_checked_image, has_nsfw_concept
96
+
97
+
98
+ def main():
99
+ parser = argparse.ArgumentParser()
100
+
101
+ parser.add_argument(
102
+ "--prompt",
103
+ type=str,
104
+ nargs="?",
105
+ default="a painting of a virus monster playing guitar",
106
+ help="the prompt to render"
107
+ )
108
+ parser.add_argument(
109
+ "--outdir",
110
+ type=str,
111
+ nargs="?",
112
+ help="dir to write results to",
113
+ default="outputs/txt2img-samples"
114
+ )
115
+ parser.add_argument(
116
+ "--skip_grid",
117
+ action='store_true',
118
+ help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
119
+ )
120
+ parser.add_argument(
121
+ "--skip_save",
122
+ action='store_true',
123
+ help="do not save individual samples. For speed measurements.",
124
+ )
125
+ parser.add_argument(
126
+ "--ddim_steps",
127
+ type=int,
128
+ default=50,
129
+ help="number of ddim sampling steps",
130
+ )
131
+ parser.add_argument(
132
+ "--plms",
133
+ action='store_true',
134
+ help="use plms sampling",
135
+ )
136
+ parser.add_argument(
137
+ "--dpm_solver",
138
+ action='store_true',
139
+ help="use dpm_solver sampling",
140
+ )
141
+ parser.add_argument(
142
+ "--laion400m",
143
+ action='store_true',
144
+ help="uses the LAION400M model",
145
+ )
146
+ parser.add_argument(
147
+ "--fixed_code",
148
+ action='store_true',
149
+ help="if enabled, uses the same starting code across samples ",
150
+ )
151
+ parser.add_argument(
152
+ "--ddim_eta",
153
+ type=float,
154
+ default=0.0,
155
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
156
+ )
157
+ parser.add_argument(
158
+ "--n_iter",
159
+ type=int,
160
+ default=2,
161
+ help="sample this often",
162
+ )
163
+ parser.add_argument(
164
+ "--H",
165
+ type=int,
166
+ default=512,
167
+ help="image height, in pixel space",
168
+ )
169
+ parser.add_argument(
170
+ "--W",
171
+ type=int,
172
+ default=512,
173
+ help="image width, in pixel space",
174
+ )
175
+ parser.add_argument(
176
+ "--C",
177
+ type=int,
178
+ default=4,
179
+ help="latent channels",
180
+ )
181
+ parser.add_argument(
182
+ "--f",
183
+ type=int,
184
+ default=8,
185
+ help="downsampling factor",
186
+ )
187
+ parser.add_argument(
188
+ "--n_samples",
189
+ type=int,
190
+ default=3,
191
+ help="how many samples to produce for each given prompt. A.k.a. batch size",
192
+ )
193
+ parser.add_argument(
194
+ "--n_rows",
195
+ type=int,
196
+ default=0,
197
+ help="rows in the grid (default: n_samples)",
198
+ )
199
+ parser.add_argument(
200
+ "--scale",
201
+ type=float,
202
+ default=7.5,
203
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
204
+ )
205
+ parser.add_argument(
206
+ "--from-file",
207
+ type=str,
208
+ help="if specified, load prompts from this file",
209
+ )
210
+ parser.add_argument(
211
+ "--config",
212
+ type=str,
213
+ default="configs/stable-diffusion/v1-inference.yaml",
214
+ help="path to config which constructs model",
215
+ )
216
+ parser.add_argument(
217
+ "--ckpt",
218
+ type=str,
219
+ default="models/ldm/stable-diffusion-v1/model.ckpt",
220
+ help="path to checkpoint of model",
221
+ )
222
+ parser.add_argument(
223
+ "--seed",
224
+ type=int,
225
+ default=42,
226
+ help="the seed (for reproducible sampling)",
227
+ )
228
+ parser.add_argument(
229
+ "--precision",
230
+ type=str,
231
+ help="evaluate at this precision",
232
+ choices=["full", "autocast"],
233
+ default="autocast"
234
+ )
235
+ opt = parser.parse_args()
236
+
237
+ if opt.laion400m:
238
+ print("Falling back to LAION 400M model...")
239
+ opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
240
+ opt.ckpt = "models/ldm/text2img-large/model.ckpt"
241
+ opt.outdir = "outputs/txt2img-samples-laion400m"
242
+
243
+ seed_everything(opt.seed)
244
+
245
+ config = OmegaConf.load(f"{opt.config}")
246
+ model = load_model_from_config(config, f"{opt.ckpt}")
247
+
248
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
249
+ model = model.to(device)
250
+
251
+ if opt.dpm_solver:
252
+ sampler = DPMSolverSampler(model)
253
+ elif opt.plms:
254
+ sampler = PLMSSampler(model)
255
+ else:
256
+ sampler = DDIMSampler(model)
257
+
258
+ os.makedirs(opt.outdir, exist_ok=True)
259
+ outpath = opt.outdir
260
+
261
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
262
+ wm = "StableDiffusionV1"
263
+ wm_encoder = WatermarkEncoder()
264
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
265
+
266
+ batch_size = opt.n_samples
267
+ n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
268
+ if not opt.from_file:
269
+ prompt = opt.prompt
270
+ assert prompt is not None
271
+ data = [batch_size * [prompt]]
272
+
273
+ else:
274
+ print(f"reading prompts from {opt.from_file}")
275
+ with open(opt.from_file, "r") as f:
276
+ data = f.read().splitlines()
277
+ data = list(chunk(data, batch_size))
278
+
279
+ sample_path = os.path.join(outpath, "samples")
280
+ os.makedirs(sample_path, exist_ok=True)
281
+ base_count = len(os.listdir(sample_path))
282
+ grid_count = len(os.listdir(outpath)) - 1
283
+
284
+ start_code = None
285
+ if opt.fixed_code:
286
+ start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
287
+
288
+ precision_scope = autocast if opt.precision=="autocast" else nullcontext
289
+ with torch.no_grad():
290
+ with precision_scope("cuda"):
291
+ with model.ema_scope():
292
+ tic = time.time()
293
+ all_samples = list()
294
+ for n in trange(opt.n_iter, desc="Sampling"):
295
+ for prompts in tqdm(data, desc="data"):
296
+ uc = None
297
+ if opt.scale != 1.0:
298
+ uc = model.get_learned_conditioning(batch_size * [""])
299
+ if isinstance(prompts, tuple):
300
+ prompts = list(prompts)
301
+ c = model.get_learned_conditioning(prompts)
302
+ shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
303
+ samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
304
+ conditioning=c,
305
+ batch_size=opt.n_samples,
306
+ shape=shape,
307
+ verbose=False,
308
+ unconditional_guidance_scale=opt.scale,
309
+ unconditional_conditioning=uc,
310
+ eta=opt.ddim_eta,
311
+ x_T=start_code)
312
+
313
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
314
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
315
+ x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
316
+
317
+ x_checked_image = x_samples_ddim
318
+
319
+ x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
320
+
321
+ if not opt.skip_save:
322
+ for x_sample in x_checked_image_torch:
323
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
324
+ img = Image.fromarray(x_sample.astype(np.uint8))
325
+ img = put_watermark(img, wm_encoder)
326
+ img.save(os.path.join(sample_path, f"{base_count:05}.png"))
327
+ base_count += 1
328
+
329
+ if not opt.skip_grid:
330
+ all_samples.append(x_checked_image_torch)
331
+
332
+ if not opt.skip_grid:
333
+ # additionally, save as grid
334
+ grid = torch.stack(all_samples, 0)
335
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
336
+ grid = make_grid(grid, nrow=n_rows)
337
+
338
+ # to image
339
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
340
+ img = Image.fromarray(grid.astype(np.uint8))
341
+ img = put_watermark(img, wm_encoder)
342
+ img.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
343
+ grid_count += 1
344
+
345
+ toc = time.time()
346
+
347
+ print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
348
+ f" \nEnjoy.")
349
+
350
+
351
+ if __name__ == "__main__":
352
+ main()