Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
•
572f947
1
Parent(s):
06c5f0c
support clip score and higher resolution at test time
Browse files- test_ddgan.py +45 -13
test_ddgan.py
CHANGED
@@ -12,7 +12,7 @@ import os
|
|
12 |
import json
|
13 |
import torchvision
|
14 |
from score_sde.models.ncsnpp_generator_adagn import NCSNpp
|
15 |
-
import
|
16 |
|
17 |
#%% Diffusion coefficients
|
18 |
def var_func_vp(t, beta_min, beta_max):
|
@@ -130,13 +130,13 @@ def sample_from_model(coefficients, generator, n_time, x_init, T, opt, cond=None
|
|
130 |
def sample_from_model_classifier_free_guidance(coefficients, generator, n_time, x_init, T, opt, text_encoder, cond=None, guidance_scale=0):
|
131 |
x = x_init
|
132 |
null = text_encoder([""] * len(x_init), return_only_pooled=False)
|
133 |
-
latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
|
134 |
with torch.no_grad():
|
135 |
for i in reversed(range(n_time)):
|
136 |
t = torch.full((x.size(0),), i, dtype=torch.int64).to(x.device)
|
137 |
t_time = t
|
138 |
|
139 |
-
|
140 |
|
141 |
x_0_uncond = generator(x, t_time, latent_z, cond=null)
|
142 |
|
@@ -184,10 +184,8 @@ def sample_from_model_classifier_free_guidance(coefficients, generator, n_time,
|
|
184 |
def sample_and_test(args):
|
185 |
torch.manual_seed(args.seed)
|
186 |
device = 'cuda:0'
|
187 |
-
text_encoder
|
188 |
args.cond_size = text_encoder.output_size
|
189 |
-
# cond = text_encoder([str(yi%10) for yi in range(args.batch_size)])
|
190 |
-
|
191 |
if args.dataset == 'cifar10':
|
192 |
real_img_dir = 'pytorch_fid/cifar10_train_stat.npy'
|
193 |
elif args.dataset == 'celeba_256':
|
@@ -201,7 +199,7 @@ def sample_and_test(args):
|
|
201 |
|
202 |
|
203 |
netG = NCSNpp(args).to(device)
|
204 |
-
|
205 |
|
206 |
if args.epoch_id == -1:
|
207 |
epochs = range(1000)
|
@@ -214,7 +212,7 @@ def sample_and_test(args):
|
|
214 |
if not os.path.exists(path):
|
215 |
continue
|
216 |
ckpt = torch.load(path, map_location=device)
|
217 |
-
dest = './saved_info/dd_gan/{}/{}/
|
218 |
|
219 |
if args.compute_fid and os.path.exists(dest):
|
220 |
continue
|
@@ -258,6 +256,15 @@ def sample_and_test(args):
|
|
258 |
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
259 |
inceptionv3 = InceptionV3([block_idx]).to(device)
|
260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
if not args.real_img_dir.endswith("npz"):
|
262 |
real_mu, real_sigma = compute_statistics_of_path(
|
263 |
args.real_img_dir, inceptionv3, args.batch_size, dims, device,
|
@@ -270,6 +277,9 @@ def sample_and_test(args):
|
|
270 |
real_sigma = stats['sigma']
|
271 |
|
272 |
fake_features = []
|
|
|
|
|
|
|
273 |
for b in range(0, len(texts), args.batch_size):
|
274 |
text = texts[b:b+args.batch_size]
|
275 |
with torch.no_grad():
|
@@ -277,6 +287,7 @@ def sample_and_test(args):
|
|
277 |
bs = len(text)
|
278 |
t0 = time.time()
|
279 |
x_t_1 = torch.randn(bs, args.num_channels,args.image_size, args.image_size).to(device)
|
|
|
280 |
if args.guidance_scale:
|
281 |
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
282 |
else:
|
@@ -295,6 +306,17 @@ def sample_and_test(args):
|
|
295 |
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
296 |
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
297 |
fake_features.append(pred)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
if i % 10 == 0:
|
299 |
print('generating batch ', i, time.time() - t0)
|
300 |
"""
|
@@ -311,14 +333,17 @@ def sample_and_test(args):
|
|
311 |
fake_mu = np.mean(fake_features, axis=0)
|
312 |
fake_sigma = np.cov(fake_features, rowvar=False)
|
313 |
fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
|
314 |
-
dest = './saved_info/dd_gan/{}/{}/
|
315 |
results = {
|
316 |
"fid": fid,
|
317 |
}
|
|
|
|
|
|
|
318 |
results.update(vars(args))
|
319 |
with open(dest, "w") as fd:
|
320 |
json.dump(results, fd)
|
321 |
-
print(
|
322 |
else:
|
323 |
if args.cond_text.endswith(".txt"):
|
324 |
texts = open(args.cond_text).readlines()
|
@@ -326,11 +351,13 @@ def sample_and_test(args):
|
|
326 |
else:
|
327 |
texts = [args.cond_text] * args.batch_size
|
328 |
cond = text_encoder(texts, return_only_pooled=False)
|
329 |
-
x_t_1 = torch.randn(len(texts), args.num_channels,args.image_size, args.image_size).to(device)
|
|
|
330 |
if args.guidance_scale:
|
331 |
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
332 |
else:
|
333 |
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
|
|
|
334 |
fake_sample = to_range_0_1(fake_sample)
|
335 |
torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
|
336 |
|
@@ -344,11 +371,16 @@ if __name__ == '__main__':
|
|
344 |
help='seed used for initialization')
|
345 |
parser.add_argument('--compute_fid', action='store_true', default=False,
|
346 |
help='whether or not compute FID')
|
|
|
|
|
|
|
|
|
347 |
parser.add_argument('--epoch_id', type=int,default=1000)
|
348 |
parser.add_argument('--guidance_scale', type=float,default=0)
|
349 |
parser.add_argument('--dynamic_thresholding_quantile', type=float,default=0)
|
350 |
parser.add_argument('--cond_text', type=str,default="0")
|
351 |
-
|
|
|
352 |
parser.add_argument('--cross_attention', action='store_true',default=False)
|
353 |
|
354 |
|
@@ -419,7 +451,7 @@ if __name__ == '__main__':
|
|
419 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
420 |
parser.add_argument('--masked_mean', action='store_true',default=False)
|
421 |
parser.add_argument('--nb_images_for_fid', type=int, default=0)
|
422 |
-
|
423 |
|
424 |
|
425 |
|
|
|
12 |
import json
|
13 |
import torchvision
|
14 |
from score_sde.models.ncsnpp_generator_adagn import NCSNpp
|
15 |
+
from encoder import build_encoder
|
16 |
|
17 |
#%% Diffusion coefficients
|
18 |
def var_func_vp(t, beta_min, beta_max):
|
|
|
130 |
def sample_from_model_classifier_free_guidance(coefficients, generator, n_time, x_init, T, opt, text_encoder, cond=None, guidance_scale=0):
|
131 |
x = x_init
|
132 |
null = text_encoder([""] * len(x_init), return_only_pooled=False)
|
133 |
+
#latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
|
134 |
with torch.no_grad():
|
135 |
for i in reversed(range(n_time)):
|
136 |
t = torch.full((x.size(0),), i, dtype=torch.int64).to(x.device)
|
137 |
t_time = t
|
138 |
|
139 |
+
latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
|
140 |
|
141 |
x_0_uncond = generator(x, t_time, latent_z, cond=null)
|
142 |
|
|
|
184 |
def sample_and_test(args):
|
185 |
torch.manual_seed(args.seed)
|
186 |
device = 'cuda:0'
|
187 |
+
text_encoder =build_encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
|
188 |
args.cond_size = text_encoder.output_size
|
|
|
|
|
189 |
if args.dataset == 'cifar10':
|
190 |
real_img_dir = 'pytorch_fid/cifar10_train_stat.npy'
|
191 |
elif args.dataset == 'celeba_256':
|
|
|
199 |
|
200 |
|
201 |
netG = NCSNpp(args).to(device)
|
202 |
+
netG.attn_resolutions = [r * args.scale_factor_w for r in netG.attn_resolutions]
|
203 |
|
204 |
if args.epoch_id == -1:
|
205 |
epochs = range(1000)
|
|
|
212 |
if not os.path.exists(path):
|
213 |
continue
|
214 |
ckpt = torch.load(path, map_location=device)
|
215 |
+
dest = './saved_info/dd_gan/{}/{}/eval_{}.json'.format(args.dataset, args.exp, args.epoch_id)
|
216 |
|
217 |
if args.compute_fid and os.path.exists(dest):
|
218 |
continue
|
|
|
256 |
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
257 |
inceptionv3 = InceptionV3([block_idx]).to(device)
|
258 |
|
259 |
+
if args.compute_clip_score:
|
260 |
+
import clip
|
261 |
+
CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
262 |
+
CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
|
263 |
+
clip_model, preprocess = clip.load(args.clip_model, device)
|
264 |
+
clip_mean = torch.Tensor(CLIP_MEAN).view(1,-1,1,1).to(device)
|
265 |
+
clip_std = torch.Tensor(CLIP_STD).view(1,-1,1,1).to(device)
|
266 |
+
|
267 |
+
|
268 |
if not args.real_img_dir.endswith("npz"):
|
269 |
real_mu, real_sigma = compute_statistics_of_path(
|
270 |
args.real_img_dir, inceptionv3, args.batch_size, dims, device,
|
|
|
277 |
real_sigma = stats['sigma']
|
278 |
|
279 |
fake_features = []
|
280 |
+
if args.compute_clip_score:
|
281 |
+
clip_scores = []
|
282 |
+
|
283 |
for b in range(0, len(texts), args.batch_size):
|
284 |
text = texts[b:b+args.batch_size]
|
285 |
with torch.no_grad():
|
|
|
287 |
bs = len(text)
|
288 |
t0 = time.time()
|
289 |
x_t_1 = torch.randn(bs, args.num_channels,args.image_size, args.image_size).to(device)
|
290 |
+
#print(x_t_1.shape)
|
291 |
if args.guidance_scale:
|
292 |
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
293 |
else:
|
|
|
306 |
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
307 |
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
308 |
fake_features.append(pred)
|
309 |
+
|
310 |
+
if args.compute_clip_score:
|
311 |
+
with torch.no_grad():
|
312 |
+
clip_ims = torch.nn.functional.interpolate(fake_sample, (224, 224), mode="bicubic")
|
313 |
+
clip_txt = clip.tokenize(text).to(device)
|
314 |
+
imf = clip_model.encode_image(clip_ims)
|
315 |
+
txtf = clip_model.encode_text(clip_txt)
|
316 |
+
imf = torch.nn.functional.normalize(imf, dim=1)
|
317 |
+
txtf = torch.nn.functional.normalize(txtf, dim=1)
|
318 |
+
clip_scores.append(((imf * txtf).sum(dim=1)).cpu())
|
319 |
+
break
|
320 |
if i % 10 == 0:
|
321 |
print('generating batch ', i, time.time() - t0)
|
322 |
"""
|
|
|
333 |
fake_mu = np.mean(fake_features, axis=0)
|
334 |
fake_sigma = np.cov(fake_features, rowvar=False)
|
335 |
fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
|
336 |
+
dest = './saved_info/dd_gan/{}/{}/eval_{}.json'.format(args.dataset, args.exp, args.epoch_id)
|
337 |
results = {
|
338 |
"fid": fid,
|
339 |
}
|
340 |
+
if args.compute_clip_score:
|
341 |
+
clip_score = torch.cat(clip_scores).mean().item()
|
342 |
+
results['clip_score'] = clip_score
|
343 |
results.update(vars(args))
|
344 |
with open(dest, "w") as fd:
|
345 |
json.dump(results, fd)
|
346 |
+
print(results)
|
347 |
else:
|
348 |
if args.cond_text.endswith(".txt"):
|
349 |
texts = open(args.cond_text).readlines()
|
|
|
351 |
else:
|
352 |
texts = [args.cond_text] * args.batch_size
|
353 |
cond = text_encoder(texts, return_only_pooled=False)
|
354 |
+
x_t_1 = torch.randn(len(texts), args.num_channels,args.image_size*args.scale_factor_h, args.image_size*args.scale_factor_w).to(device)
|
355 |
+
t0 = time.time()
|
356 |
if args.guidance_scale:
|
357 |
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
358 |
else:
|
359 |
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
|
360 |
+
print(time.time() - t0)
|
361 |
fake_sample = to_range_0_1(fake_sample)
|
362 |
torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
|
363 |
|
|
|
371 |
help='seed used for initialization')
|
372 |
parser.add_argument('--compute_fid', action='store_true', default=False,
|
373 |
help='whether or not compute FID')
|
374 |
+
parser.add_argument('--compute_clip_score', action='store_true', default=False,
|
375 |
+
help='whether or not compute CLIP score')
|
376 |
+
parser.add_argument('--clip_model', type=str,default="ViT-L/14")
|
377 |
+
|
378 |
parser.add_argument('--epoch_id', type=int,default=1000)
|
379 |
parser.add_argument('--guidance_scale', type=float,default=0)
|
380 |
parser.add_argument('--dynamic_thresholding_quantile', type=float,default=0)
|
381 |
parser.add_argument('--cond_text', type=str,default="0")
|
382 |
+
parser.add_argument('--scale_factor_h', type=int,default=1)
|
383 |
+
parser.add_argument('--scale_factor_w', type=int,default=1)
|
384 |
parser.add_argument('--cross_attention', action='store_true',default=False)
|
385 |
|
386 |
|
|
|
451 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
452 |
parser.add_argument('--masked_mean', action='store_true',default=False)
|
453 |
parser.add_argument('--nb_images_for_fid', type=int, default=0)
|
454 |
+
|
455 |
|
456 |
|
457 |
|