yyk19 commited on
Commit
73bb868
1 Parent(s): 73a839e

modify the rendertext_tool

Browse files
Files changed (3) hide show
  1. app.py +3 -3
  2. cldm/cldm.py +47 -342
  3. scripts/rendertext_tool.py +33 -3
app.py CHANGED
@@ -93,7 +93,7 @@ def load_ckpt(model_ckpt = "LAION-Glyph-10M-Epoch-5"):
93
  elif model_ckpt == "TextCaps-5K-Epoch-40":
94
  model = load_model_ckpt(model, "textcaps5K_epoch_40_model_wo_ema.ckpt")
95
 
96
- render_tool = Render_Text(model)
97
  output_str = f"already change the model checkpoint to {model_ckpt}"
98
  print(output_str)
99
  if torch.cuda.is_available():
@@ -104,14 +104,14 @@ def load_ckpt(model_ckpt = "LAION-Glyph-10M-Epoch-5"):
104
  allow_run_generation = False
105
  return output_str, None, allow_run_generation
106
 
107
-
108
  cfg = OmegaConf.load("config.yaml")
109
  model = load_model_from_config(cfg, "laion10M_epoch_6_model_wo_ema.ckpt", verbose=True)
110
  # model = load_model_from_config(cfg, "model_wo_ema.ckpt", verbose=True)
111
  # model = load_model_from_config(cfg, "model_states.pt", verbose=True)
112
  # model = load_model_from_config(cfg, "model.ckpt", verbose=True)
113
  # ddim_sampler = DDIMSampler(model)
114
- render_tool = Render_Text(model)
115
 
116
 
117
  description = """
 
93
  elif model_ckpt == "TextCaps-5K-Epoch-40":
94
  model = load_model_ckpt(model, "textcaps5K_epoch_40_model_wo_ema.ckpt")
95
 
96
+ render_tool = Render_Text(model, save_memory = SAVE_MEMORY)
97
  output_str = f"already change the model checkpoint to {model_ckpt}"
98
  print(output_str)
99
  if torch.cuda.is_available():
 
104
  allow_run_generation = False
105
  return output_str, None, allow_run_generation
106
 
107
+ SAVE_MEMORY = True
108
  cfg = OmegaConf.load("config.yaml")
109
  model = load_model_from_config(cfg, "laion10M_epoch_6_model_wo_ema.ckpt", verbose=True)
110
  # model = load_model_from_config(cfg, "model_wo_ema.ckpt", verbose=True)
111
  # model = load_model_from_config(cfg, "model_states.pt", verbose=True)
112
  # model = load_model_from_config(cfg, "model.ckpt", verbose=True)
113
  # ddim_sampler = DDIMSampler(model)
114
+ render_tool = Render_Text(model, save_memory = SAVE_MEMORY)
115
 
116
 
117
  description = """
cldm/cldm.py CHANGED
@@ -28,7 +28,7 @@ def disabled_train(self, mode=True):
28
  return self
29
 
30
  class ControlledUnetModel(UNetModel):
31
- def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, context_glyph= None, **kwargs):
32
  hs = []
33
  with torch.no_grad():
34
  t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
@@ -47,7 +47,7 @@ class ControlledUnetModel(UNetModel):
47
  h = torch.cat([h, hs.pop()], dim=1)
48
  else:
49
  h = torch.cat([h, hs.pop() + control.pop()], dim=1)
50
- h = module(h, emb, context if context_glyph is None else context_glyph)
51
 
52
  h = h.type(x.dtype)
53
  return self.out(h)
@@ -317,16 +317,12 @@ class ControlLDM(LatentDiffusion):
317
 
318
  def __init__(self,
319
  control_stage_config,
320
- control_key, only_mid_control,
321
- sd_locked = True, concat_textemb = False,
322
- trans_textemb=False, trans_textemb_config = None,
323
  learnable_conscale = False, guess_mode=False,
324
- sep_lr = False, decoder_lr = 1.0**-4,
325
- add_glyph_control = False, glyph_control_config = None, glycon_wd = 0.2, glycon_lr = 1.0**-4, glycon_sched = "lambda",
326
- glyph_control_key = "centered_hint", sep_cond_txt = False, exchange_cond_txt = False,
327
- max_step = None, multiple_optimizers = False, deepspeed = False, trans_glyph_lr = 1.0**-5,
328
  *args, **kwargs
329
- ): #sep_cap_for_2b = False
330
  use_ema = kwargs.pop("use_ema", False)
331
  ckpt_path = kwargs.pop("ckpt_path", None)
332
  reset_ema = kwargs.pop("reset_ema", False)
@@ -336,90 +332,53 @@ class ControlLDM(LatentDiffusion):
336
  ignore_keys = kwargs.pop("ignore_keys", [])
337
 
338
  super().__init__(*args, use_ema=False, **kwargs)
 
 
339
  self.control_model = instantiate_from_config(control_stage_config)
340
  self.control_key = control_key
341
  self.only_mid_control = only_mid_control
 
342
  self.learnable_conscale = learnable_conscale
343
  conscale_init = [1.0] * 13 if not guess_mode else [(0.825 ** float(12 - i)) for i in range(13)]
344
  if learnable_conscale:
345
  # self.control_scales = nn.Parameter(torch.ones(13), requires_grad=True)
346
  self.control_scales = nn.Parameter(torch.Tensor(conscale_init), requires_grad=True)
347
- else: # TODO: register the buffer
348
  self.control_scales = conscale_init #[1.0] * 13
 
 
 
349
  self.sd_locked = sd_locked
350
- self.concat_textemb = concat_textemb
351
- # update
352
- self.trans_textemb = False
353
- if trans_textemb and trans_textemb_config is not None:
354
- self.trans_textemb = True
355
- self.instantiate_trans_textemb_model(trans_textemb_config)
356
- # self.sep_cap_for_2b = sep_cap_for_2b
357
-
358
  self.sep_lr = sep_lr
359
  self.decoder_lr = decoder_lr
 
 
360
  self.sep_cond_txt = sep_cond_txt
 
361
  self.exchange_cond_txt = exchange_cond_txt
362
- # update (4.18)
363
- self.multiple_optimizers = multiple_optimizers
364
- self.add_glyph_control = False
365
- self.glyph_control_key = glyph_control_key
366
- self.freeze_glyph_image_encoder = True
367
- self.glyph_image_encoder_type = "CLIP"
368
- self.max_step = max_step
369
- self.trans_glyph_embed = False
370
- self.trans_glyph_lr = trans_glyph_lr
371
- if deepspeed:
372
- try:
373
- from deepspeed.ops.adam import FusedAdam, DeepSpeedCPUAdam
374
- self.optimizer = DeepSpeedCPUAdam #FusedAdam
375
- except:
376
- print("could not import FuseAdam from deepspeed")
377
- self.optimizer = torch.optim.AdamW
378
- else:
379
- self.optimizer = torch.optim.AdamW
380
 
381
- if add_glyph_control and glyph_control_config is not None:
382
- self.add_glyph_control = True
383
- self.glycon_wd = glycon_wd
384
- self.glycon_lr = glycon_lr
385
- self.glycon_sched = glycon_sched
386
- self.instantiate_glyph_control_model(glyph_control_config)
387
- if self.glyph_control_model.trans_glyph_emb_model is not None:
388
- self.trans_glyph_embed = True
389
-
390
  self.use_ema = use_ema
391
- if self.use_ema: #TODO: trainable glyph Image encoder
392
- # assert self.sd_locked == True
393
  self.model_ema = LitEma(self.control_model, init_num_updates= 0)
394
  print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
395
- if not self.sd_locked: # Update
396
  self.model_diffoutblock_ema = LitEma(self.model.diffusion_model.output_blocks, init_num_updates= 0)
397
  print(f"Keeping diffoutblock EMAs of {len(list(self.model_diffoutblock_ema.buffers()))}.")
398
  self.model_diffout_ema = LitEma(self.model.diffusion_model.out, init_num_updates= 0)
399
  print(f"Keeping diffout EMAs of {len(list(self.model_diffout_ema.buffers()))}.")
400
- if not self.freeze_glyph_image_encoder:
401
- self.model_glyphcon_ema = LitEma(self.glyph_control_model.image_encoder, init_num_updates=0)
402
- print(f"Keeping glyphcon EMAs of {len(list(self.model_glyphcon_ema.buffers()))}.")
403
- if self.trans_glyph_embed:
404
- self.model_transglyph_ema = LitEma(self.glyph_control_model.trans_glyph_emb_model, init_num_updates=0)
405
- print(f"Keeping glyphcon EMAs of {len(list(self.model_transglyph_ema.buffers()))}.")
406
 
 
407
  if ckpt_path is not None:
408
  ema_num_updates = self.init_from_ckpt(ckpt_path, ignore_keys, only_model=only_model)
409
  self.restarted_from_ckpt = True
410
- # if reset_ema:
411
- # assert self.use_ema
412
  if self.use_ema and reset_ema:
413
  print(
414
  f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
415
  self.model_ema = LitEma(self.control_model, init_num_updates= ema_num_updates if keep_num_ema_updates else 0)
416
- if not self.sd_locked: # Update
417
  self.model_diffoutblock_ema = LitEma(self.model.diffusion_model.output_blocks, init_num_updates= ema_num_updates if keep_num_ema_updates else 0)
418
  self.model_diffout_ema = LitEma(self.model.diffusion_model.out, init_num_updates= ema_num_updates if keep_num_ema_updates else 0)
419
- if not self.freeze_glyph_image_encoder:
420
- self.model_glyphcon_ema = LitEma(self.glyph_control_model.image_encoder, init_num_updates= ema_num_updates if keep_num_ema_updates else 0)
421
- if self.trans_glyph_embed:
422
- self.model_transglyph_ema = LitEma(self.glyph_control_model.trans_glyph_emb_model, init_num_updates= ema_num_updates if keep_num_ema_updates else 0)
423
 
424
  if reset_num_ema_updates:
425
  print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
@@ -428,13 +387,6 @@ class ControlLDM(LatentDiffusion):
428
  if not self.sd_locked: # Update
429
  self.model_diffoutblock_ema.reset_num_updates()
430
  self.model_diffout_ema.reset_num_updates()
431
- if not self.freeze_glyph_image_encoder:
432
- self.model_glyphcon_ema.reset_num_updates()
433
- if self.trans_glyph_embed:
434
- self.model_transglyph_ema.reset_num_updates()
435
-
436
-
437
- # self.freeze_unet()
438
 
439
  @contextmanager
440
  def ema_scope(self, context=None):
@@ -446,12 +398,6 @@ class ControlLDM(LatentDiffusion):
446
  self.model_diffoutblock_ema.copy_to(self.model.diffusion_model.output_blocks)
447
  self.model_diffout_ema.store(self.model.diffusion_model.out.parameters())
448
  self.model_diffout_ema.copy_to(self.model.diffusion_model.out)
449
- if not self.freeze_glyph_image_encoder:
450
- self.model_glyphcon_ema.store(self.glyph_control_model.image_encoder.parameters())
451
- self.model_glyphcon_ema.copy_to(self.glyph_control_model.image_encoder)
452
- if self.trans_glyph_embed:
453
- self.model_transglyph_ema.store(self.glyph_control_model.trans_glyph_emb_model.parameters())
454
- self.model_transglyph_ema.copy_to(self.glyph_control_model.trans_glyph_emb_model)
455
 
456
  if context is not None:
457
  print(f"{context}: Switched ControlNet to EMA weights")
@@ -463,10 +409,6 @@ class ControlLDM(LatentDiffusion):
463
  if not self.sd_locked: # Update
464
  self.model_diffoutblock_ema.restore(self.model.diffusion_model.output_blocks.parameters())
465
  self.model_diffout_ema.restore(self.model.diffusion_model.out.parameters())
466
- if not self.freeze_glyph_image_encoder:
467
- self.model_glyphcon_ema.restore(self.glyph_control_model.image_encoder.parameters())
468
- if self.trans_glyph_embed:
469
- self.model_transglyph_ema.restore(self.glyph_control_model.trans_glyph_emb_model.parameters())
470
  if context is not None:
471
  print(f"{context}: Restored training weights of ControlNet")
472
 
@@ -493,14 +435,8 @@ class ControlLDM(LatentDiffusion):
493
  if k.startswith(ik):
494
  print("Deleting key {} from state_dict.".format(k))
495
  del sd[k]
496
- # missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
497
- # sd, strict=False)
498
- if not only_model:
499
- missing, unexpected = self.load_state_dict(sd, strict=False)
500
- elif path.endswith(".bin"):
501
- missing, unexpected = self.model.diffusion_model.load_state_dict(sd, strict=False)
502
- elif path.endswith(".ckpt"):
503
- missing, unexpected = self.model.load_state_dict(sd, strict=False)
504
 
505
  print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
506
  if len(missing) > 0:
@@ -513,28 +449,6 @@ class ControlLDM(LatentDiffusion):
513
  else:
514
  return 0
515
 
516
- def instantiate_trans_textemb_model(self, config):
517
- model = instantiate_from_config(config)
518
- params = []
519
- for i in range(model.emb_num):
520
- if model.trans_trainable[i]:
521
- params += list(model.trans_list[i].parameters())
522
- else:
523
- for param in model.trans_list[i].parameters():
524
- param.requires_grad = False
525
- self.trans_textemb_model = model
526
- self.trans_textemb_params = params
527
-
528
- # add
529
- def instantiate_glyph_control_model(self, config):
530
- model = instantiate_from_config(config)
531
- # params = []
532
- self.freeze_glyph_image_encoder = model.freeze_image_encoder #image_encoder.freeze_model
533
- self.glyph_control_model = model
534
- self.glyph_image_encoder_type = model.image_encoder_type
535
-
536
-
537
-
538
  @torch.no_grad()
539
  def get_input(self, batch, k, bs=None, *args, **kwargs):
540
  x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
@@ -544,80 +458,42 @@ class ControlLDM(LatentDiffusion):
544
  control = control.to(self.device)
545
  control = einops.rearrange(control, 'b h w c -> b c h w')
546
  control = control.to(memory_format=torch.contiguous_format).float()
547
-
548
- if self.add_glyph_control:
549
- assert self.glyph_control_key in batch.keys()
550
- glyph_control = batch[self.glyph_control_key]
551
- if bs is not None:
552
- glyph_control = glyph_control[:bs]
553
- glycon_samples = []
554
- for glycon_sample in glyph_control:
555
- glycon_sample = glycon_sample.to(self.device)
556
- glycon_sample = einops.rearrange(glycon_sample, 'b h w c -> b c h w')
557
- glycon_sample = glycon_sample.to(memory_format=torch.contiguous_format).float()
558
- glycon_samples.append(glycon_sample)
559
- # return x, dict(c_crossattn=[c], c_concat=[control])
560
- return x, dict(c_crossattn=[c] if not isinstance(c, list) else c, c_concat=[control], c_glyph=glycon_samples)
561
  return x, dict(c_crossattn=[c] if not isinstance(c, list) else c, c_concat=[control])
562
 
563
  def apply_model(self, x_noisy, t, cond, *args, **kwargs):
564
  assert isinstance(cond, dict)
565
  diffusion_model = self.model.diffusion_model
566
-
567
- #update
568
- embdim_list = []
569
- for c in cond["c_crossattn"]:
570
- embdim_list.append(c.shape[-1])
571
- embdim_list = np.array(embdim_list)
572
- if np.sum(embdim_list != diffusion_model.context_dim):
573
- assert self.trans_textemb
574
-
575
- if self.trans_textemb:
576
- assert self.trans_textemb_model
577
- cond_txt_list = self.trans_textemb_model(cond["c_crossattn"])
578
- # if len(cond_txt_list) == 2:
579
- # print("cond_txt_2 max: {}".format(torch.max(torch.abs(cond_txt_list[1]))))
580
- else:
581
- cond_txt_list = cond["c_crossattn"]
582
-
583
 
584
  assert len(cond_txt_list) > 0
585
- if self.sep_cond_txt:
586
- cond_txt = cond_txt_list[0]
587
- cond_txt_2 = None if len(cond_txt_list) == 1 else cond_txt_list[1]
 
 
588
  else:
589
- if len(cond_txt_list) > 1:
590
- cond_txt = cond_txt_list[0] # input text embedding of the pretrained SD
591
- if not self.concat_textemb:
592
- # currently len(cond_txt_list) <= 2
593
- cond_txt_2 = torch.cat(cond_txt_list[1:], 1) # input text embedding of the ControlNet branch
 
 
594
  else:
595
  cond_txt_2 = torch.cat(cond_txt_list, 1)
596
- if self.exchange_cond_txt:
597
- txt_buffer = cond_txt
598
- cond_txt = cond_txt_2
599
- cond_txt_2 = txt_buffer
600
- print("len cond_txt_list: {} | cond_txt_1 shape: {} | cond_txt_2 shape: {}".format(len(cond_txt_list), cond_txt.shape, cond_txt_2.shape))
601
- else:
602
- cond_txt = torch.cat(cond_txt_list, 1)
603
- cond_txt_2 = None
604
-
605
- context_glyph = None
606
- if self.add_glyph_control:
607
- assert "c_glyph" in cond.keys()
608
- if cond["c_glyph"] is not None:
609
- context_glyph = self.glyph_control_model(cond["c_glyph"], text_embed = cond_txt_list[-1] if len(cond_txt_list) == 3 else cond_txt)
610
- else:
611
- context_glyph = cond_txt_list[-1] if len(cond_txt_list) == 3 else cond_txt
612
- # if cond_txt_2 is None:
613
- # print("cond_txt_2 is None")
614
 
615
  if cond['c_concat'] is None:
616
- eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control, context_glyph = context_glyph)
617
  else:
618
  control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt if cond_txt_2 is None else cond_txt_2)
619
  control = [c * scale for c, scale in zip(control, self.control_scales)]
620
- eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control, context_glyph=context_glyph)
621
 
622
  return eps
623
 
@@ -625,96 +501,16 @@ class ControlLDM(LatentDiffusion):
625
  def get_unconditional_conditioning(self, N):
626
  return self.get_learned_conditioning([""] * N)
627
 
628
- # Maybe not useful: modify the codes to fit the separate input captions
629
- # @torch.no_grad()
630
- # def get_unconditional_conditioning(self, N):
631
- # return self.get_learned_conditioning([""] * N) if not self.sep_cap_for_2b else self.get_learned_conditioning([[""] * N, [""] * N])
632
- # TODO: adapt to new model
633
- @torch.no_grad()
634
- def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
635
- quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
636
- plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
637
- use_ema_scope=True,
638
- **kwargs):
639
- use_ddim = ddim_steps is not None
640
-
641
- log = dict()
642
- z, c = self.get_input(batch, self.first_stage_key, bs=N)
643
- c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
644
- N = min(z.shape[0], N)
645
- n_row = min(z.shape[0], n_row)
646
- log["reconstruction"] = self.decode_first_stage(z)
647
- log["control"] = c_cat * 2.0 - 1.0
648
- log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
649
-
650
- if plot_diffusion_rows:
651
- # get diffusion row
652
- diffusion_row = list()
653
- z_start = z[:n_row]
654
- for t in range(self.num_timesteps):
655
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
656
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
657
- t = t.to(self.device).long()
658
- noise = torch.randn_like(z_start)
659
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
660
- diffusion_row.append(self.decode_first_stage(z_noisy))
661
-
662
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
663
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
664
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
665
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
666
- log["diffusion_row"] = diffusion_grid
667
-
668
- if sample:
669
- # get denoise row
670
- samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
671
- batch_size=N, ddim=use_ddim,
672
- ddim_steps=ddim_steps, eta=ddim_eta)
673
- x_samples = self.decode_first_stage(samples)
674
- log["samples"] = x_samples
675
- if plot_denoise_rows:
676
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
677
- log["denoise_row"] = denoise_grid
678
-
679
- if unconditional_guidance_scale > 1.0:
680
- uc_cross = self.get_unconditional_conditioning(N)
681
- uc_cat = c_cat # torch.zeros_like(c_cat)
682
- uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
683
- samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
684
- batch_size=N, ddim=use_ddim,
685
- ddim_steps=ddim_steps, eta=ddim_eta,
686
- unconditional_guidance_scale=unconditional_guidance_scale,
687
- unconditional_conditioning=uc_full,
688
- )
689
- x_samples_cfg = self.decode_first_stage(samples_cfg)
690
- log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
691
-
692
- return log
693
- # TODO: adapt to new model
694
- @torch.no_grad()
695
- def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
696
- ddim_sampler = DDIMSampler(self)
697
- b, c, h, w = cond["c_concat"][0].shape
698
- shape = (self.channels, h // 8, w // 8)
699
- samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
700
- return samples, intermediates
701
- # add
702
  def training_step(self, batch, batch_idx, optimizer_idx=0):
703
  loss = super().training_step(batch, batch_idx, optimizer_idx)
704
  if self.use_scheduler and not self.sd_locked and self.sep_lr:
705
  decoder_lr = self.optimizers().param_groups[1]["lr"]
706
  self.log('decoder_lr_abs', decoder_lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
707
- if self.trans_glyph_embed and self.freeze_glyph_image_encoder:
708
- trans_glyph_embed_lr = self.optimizers().param_groups[2]["lr"]
709
- self.log('trans_glyph_embed_lr_abs', trans_glyph_embed_lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
710
  return loss
711
 
712
  def configure_optimizers(self):
713
  lr = self.learning_rate
714
- params = list(self.control_model.parameters())
715
- if self.trans_textemb:
716
- params += self.trans_textemb_params #list(self.trans_textemb_model.parameters())
717
-
718
  if self.learnable_conscale:
719
  params += [self.control_scales]
720
 
@@ -731,34 +527,9 @@ class ControlLDM(LatentDiffusion):
731
  if decoder_params is not None:
732
  params_wlr.append({"params": decoder_params, "lr": self.decoder_lr})
733
 
734
- if not self.freeze_glyph_image_encoder:
735
- if self.glyph_image_encoder_type == "CLIP":
736
- # assert self.sep_lr
737
- # follow the training codes in the OpenClip repo
738
- # https://github.com/mlfoundations/open_clip/blob/main/src/training/main.py#L303
739
- exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
740
- include = lambda n, p: not exclude(n, p)
741
-
742
- # named_parameters = list(model.image_encoder.named_parameters())
743
- named_parameters = list(self.glyph_control_model.image_encoder.named_parameters())
744
- gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]
745
- rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]
746
- self.glyph_control_params_wlr = [
747
- {"params": gain_or_bias_params, "weight_decay": 0., "lr": self.glycon_lr},
748
- {"params": rest_params, "weight_decay": self.glycon_wd, "lr": self.glycon_lr},
749
- ]
750
- if not self.freeze_glyph_image_encoder and not self.multiple_optimizers:
751
- params_wlr.extend(self.glyph_control_params_wlr)
752
-
753
- if self.trans_glyph_embed:
754
- trans_glyph_params = list(self.glyph_control_model.trans_glyph_emb_model.parameters())
755
- params_wlr.append({"params": trans_glyph_params, "lr": self.trans_glyph_lr})
756
  # opt = torch.optim.AdamW(params_wlr)
757
  opt = self.optimizer(params_wlr)
758
  opts = [opt]
759
- if not self.freeze_glyph_image_encoder and self.multiple_optimizers:
760
- glyph_control_opt = self.optimizer(self.glyph_control_params_wlr) #torch.optim.AdamW(self.glyph_control_params_wlr)
761
- opts.append(glyph_control_opt)
762
 
763
  # updated
764
  schedulers = []
@@ -776,33 +547,8 @@ class ControlLDM(LatentDiffusion):
776
  'frequency': 1
777
  }]
778
 
779
- if not self.freeze_glyph_image_encoder and self.multiple_optimizers:
780
- if self.glycon_sched == "cosine" and self.max_step is not None:
781
- glyph_scheduler = CosineAnnealingLR(glyph_control_opt, T_max=self.max_step) #: max_step
782
- elif self.glycon_sched == "onecycle" and self.max_step is not None:
783
- glyph_scheduler = OneCycleLR(
784
- glyph_control_opt,
785
- max_lr=self.glycon_lr,
786
- total_steps=self.max_step, #: max_step
787
- pct_start=0.0001,
788
- anneal_strategy="cos" #'linear'
789
- )
790
- # elif self.glycon_sched == "lambda":
791
- else:
792
- glyph_scheduler = LambdaLR(
793
- glyph_control_opt,
794
- lr_lambda = [scheduler_func.schedule] * len(self.glyph_control_params_wlr)
795
- )
796
- schedulers.append(
797
- {
798
- "scheduler": glyph_scheduler,
799
- "interval": 'step',
800
- 'frequency': 1
801
- }
802
- )
803
  return opts, schedulers
804
-
805
- # TODO: adapt to new model
806
  def low_vram_shift(self, is_diffusing):
807
  if is_diffusing:
808
  self.model = self.model.cuda()
@@ -822,10 +568,6 @@ class ControlLDM(LatentDiffusion):
822
  if not self.sd_locked: # Update
823
  self.model_diffoutblock_ema(self.model.diffusion_model.output_blocks)
824
  self.model_diffout_ema(self.model.diffusion_model.out)
825
- if not self.freeze_glyph_image_encoder:
826
- self.model_glyphcon_ema(self.glyph_control_model.image_encoder)
827
- if self.trans_glyph_embed:
828
- self.model_transglyph_ema(self.glyph_control_model.trans_glyph_emb_model)
829
  if self.log_all_grad_norm:
830
  zeroconvs = list(self.control_model.input_hint_block.named_parameters())[-2:]
831
  zeroconvs.extend(
@@ -867,43 +609,6 @@ class ControlLDM(LatentDiffusion):
867
  prog_bar=False, logger=True, on_step=True, on_epoch=False
868
  )
869
 
870
- if self.trans_textemb:
871
- for name, p in self.trans_textemb_model.named_parameters():
872
- if p.requires_grad and p.grad is not None:
873
- self.log(
874
- "trans_textemb_gradient_norm/{}".format(name),
875
- p.grad.cpu().detach().norm().item(),
876
- prog_bar=False, logger=True, on_step=True, on_epoch=False
877
- )
878
- self.log(
879
- "trans_textemb_params/{}_norm".format(name),
880
- p.cpu().detach().norm().item(),
881
- prog_bar=False, logger=True, on_step=True, on_epoch=False
882
- )
883
- self.log(
884
- "trans_textemb_params/{}_abs_max".format(name),
885
- torch.max(torch.abs(p.cpu().detach())).item(),
886
- prog_bar=False, logger=True, on_step=True, on_epoch=False
887
- )
888
- if self.trans_glyph_embed:
889
- for name, p in self.glyph_control_model.trans_glyph_emb_model.named_parameters():
890
- if p.requires_grad and p.grad is not None:
891
- self.log(
892
- "trans_glyph_embed_gradient_norm/{}".format(name),
893
- p.grad.cpu().detach().norm().item(),
894
- prog_bar=False, logger=True, on_step=True, on_epoch=False
895
- )
896
- self.log(
897
- "trans_glyph_embed_params/{}_norm".format(name),
898
- p.cpu().detach().norm().item(),
899
- prog_bar=False, logger=True, on_step=True, on_epoch=False
900
- )
901
- self.log(
902
- "trans_glyph_embed_params/{}_abs_max".format(name),
903
- torch.max(torch.abs(p.cpu().detach())).item(),
904
- prog_bar=False, logger=True, on_step=True, on_epoch=False
905
- )
906
-
907
  if self.learnable_conscale:
908
  for i in range(len(self.control_scales)):
909
  self.log(
@@ -912,4 +617,4 @@ class ControlLDM(LatentDiffusion):
912
  prog_bar=False, logger=True, on_step=True, on_epoch=False
913
  )
914
  del gradnorm_list
915
- del zeroconvs
 
28
  return self
29
 
30
  class ControlledUnetModel(UNetModel):
31
+ def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
32
  hs = []
33
  with torch.no_grad():
34
  t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
 
47
  h = torch.cat([h, hs.pop()], dim=1)
48
  else:
49
  h = torch.cat([h, hs.pop() + control.pop()], dim=1)
50
+ h = module(h, emb, context)
51
 
52
  h = h.type(x.dtype)
53
  return self.out(h)
 
317
 
318
  def __init__(self,
319
  control_stage_config,
320
+ control_key, only_mid_control,
 
 
321
  learnable_conscale = False, guess_mode=False,
322
+ sd_locked = True, sep_lr = False, decoder_lr = 1.0**-4,
323
+ sep_cond_txt = True, exchange_cond_txt = False, concat_all_textemb = False,
 
 
324
  *args, **kwargs
325
+ ):
326
  use_ema = kwargs.pop("use_ema", False)
327
  ckpt_path = kwargs.pop("ckpt_path", None)
328
  reset_ema = kwargs.pop("reset_ema", False)
 
332
  ignore_keys = kwargs.pop("ignore_keys", [])
333
 
334
  super().__init__(*args, use_ema=False, **kwargs)
335
+
336
+ # Glyph ControlNet
337
  self.control_model = instantiate_from_config(control_stage_config)
338
  self.control_key = control_key
339
  self.only_mid_control = only_mid_control
340
+
341
  self.learnable_conscale = learnable_conscale
342
  conscale_init = [1.0] * 13 if not guess_mode else [(0.825 ** float(12 - i)) for i in range(13)]
343
  if learnable_conscale:
344
  # self.control_scales = nn.Parameter(torch.ones(13), requires_grad=True)
345
  self.control_scales = nn.Parameter(torch.Tensor(conscale_init), requires_grad=True)
346
+ else:
347
  self.control_scales = conscale_init #[1.0] * 13
348
+
349
+ self.optimizer = torch.optim.AdamW
350
+ # whether to unlock (fine-tune) the decoder parts of SD U-Net
351
  self.sd_locked = sd_locked
 
 
 
 
 
 
 
 
352
  self.sep_lr = sep_lr
353
  self.decoder_lr = decoder_lr
354
+
355
+ # specify the input text embedding of two branches (SD branch and Glyph ControlNet branch)
356
  self.sep_cond_txt = sep_cond_txt
357
+ self.concat_all_textemb = concat_all_textemb
358
  self.exchange_cond_txt = exchange_cond_txt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
+ # ema
 
 
 
 
 
 
 
 
361
  self.use_ema = use_ema
362
+ if self.use_ema:
 
363
  self.model_ema = LitEma(self.control_model, init_num_updates= 0)
364
  print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
365
+ if not self.sd_locked:
366
  self.model_diffoutblock_ema = LitEma(self.model.diffusion_model.output_blocks, init_num_updates= 0)
367
  print(f"Keeping diffoutblock EMAs of {len(list(self.model_diffoutblock_ema.buffers()))}.")
368
  self.model_diffout_ema = LitEma(self.model.diffusion_model.out, init_num_updates= 0)
369
  print(f"Keeping diffout EMAs of {len(list(self.model_diffout_ema.buffers()))}.")
 
 
 
 
 
 
370
 
371
+ # initialize the model from the checkpoint
372
  if ckpt_path is not None:
373
  ema_num_updates = self.init_from_ckpt(ckpt_path, ignore_keys, only_model=only_model)
374
  self.restarted_from_ckpt = True
 
 
375
  if self.use_ema and reset_ema:
376
  print(
377
  f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
378
  self.model_ema = LitEma(self.control_model, init_num_updates= ema_num_updates if keep_num_ema_updates else 0)
379
+ if not self.sd_locked:
380
  self.model_diffoutblock_ema = LitEma(self.model.diffusion_model.output_blocks, init_num_updates= ema_num_updates if keep_num_ema_updates else 0)
381
  self.model_diffout_ema = LitEma(self.model.diffusion_model.out, init_num_updates= ema_num_updates if keep_num_ema_updates else 0)
 
 
 
 
382
 
383
  if reset_num_ema_updates:
384
  print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
 
387
  if not self.sd_locked: # Update
388
  self.model_diffoutblock_ema.reset_num_updates()
389
  self.model_diffout_ema.reset_num_updates()
 
 
 
 
 
 
 
390
 
391
  @contextmanager
392
  def ema_scope(self, context=None):
 
398
  self.model_diffoutblock_ema.copy_to(self.model.diffusion_model.output_blocks)
399
  self.model_diffout_ema.store(self.model.diffusion_model.out.parameters())
400
  self.model_diffout_ema.copy_to(self.model.diffusion_model.out)
 
 
 
 
 
 
401
 
402
  if context is not None:
403
  print(f"{context}: Switched ControlNet to EMA weights")
 
409
  if not self.sd_locked: # Update
410
  self.model_diffoutblock_ema.restore(self.model.diffusion_model.output_blocks.parameters())
411
  self.model_diffout_ema.restore(self.model.diffusion_model.out.parameters())
 
 
 
 
412
  if context is not None:
413
  print(f"{context}: Restored training weights of ControlNet")
414
 
 
435
  if k.startswith(ik):
436
  print("Deleting key {} from state_dict.".format(k))
437
  del sd[k]
438
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
439
+ sd, strict=False)
 
 
 
 
 
 
440
 
441
  print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
442
  if len(missing) > 0:
 
449
  else:
450
  return 0
451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  @torch.no_grad()
453
  def get_input(self, batch, k, bs=None, *args, **kwargs):
454
  x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
 
458
  control = control.to(self.device)
459
  control = einops.rearrange(control, 'b h w c -> b c h w')
460
  control = control.to(memory_format=torch.contiguous_format).float()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  return x, dict(c_crossattn=[c] if not isinstance(c, list) else c, c_concat=[control])
462
 
463
  def apply_model(self, x_noisy, t, cond, *args, **kwargs):
464
  assert isinstance(cond, dict)
465
  diffusion_model = self.model.diffusion_model
466
+ cond_txt_list = cond["c_crossattn"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
 
468
  assert len(cond_txt_list) > 0
469
+ # cond_txt: input text embedding of the pretrained SD branch
470
+ # cond_txt_2: input text embedding of the Glyph ControlNet branch
471
+ cond_txt = cond_txt_list[0]
472
+ if len(cond_txt_list) == 1:
473
+ cond_txt_2 = None
474
  else:
475
+ if self.sep_cond_txt:
476
+ # use each embedding for each branch separately
477
+ cond_txt_2 = cond_txt_list[1]
478
+ else:
479
+ # concat the embedding for Glyph ControlNet branch
480
+ if not self.concat_all_textemb:
481
+ cond_txt_2 = torch.cat(cond_txt_list[1:], 1)
482
  else:
483
  cond_txt_2 = torch.cat(cond_txt_list, 1)
484
+
485
+ if self.exchange_cond_txt:
486
+ # exchange the input text embedding of two branches
487
+ txt_buffer = cond_txt
488
+ cond_txt = cond_txt_2
489
+ cond_txt_2 = txt_buffer
 
 
 
 
 
 
 
 
 
 
 
 
490
 
491
  if cond['c_concat'] is None:
492
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
493
  else:
494
  control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt if cond_txt_2 is None else cond_txt_2)
495
  control = [c * scale for c, scale in zip(control, self.control_scales)]
496
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
497
 
498
  return eps
499
 
 
501
  def get_unconditional_conditioning(self, N):
502
  return self.get_learned_conditioning([""] * N)
503
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  def training_step(self, batch, batch_idx, optimizer_idx=0):
505
  loss = super().training_step(batch, batch_idx, optimizer_idx)
506
  if self.use_scheduler and not self.sd_locked and self.sep_lr:
507
  decoder_lr = self.optimizers().param_groups[1]["lr"]
508
  self.log('decoder_lr_abs', decoder_lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
 
 
 
509
  return loss
510
 
511
  def configure_optimizers(self):
512
  lr = self.learning_rate
513
+ params = list(self.control_model.parameters())
 
 
 
514
  if self.learnable_conscale:
515
  params += [self.control_scales]
516
 
 
527
  if decoder_params is not None:
528
  params_wlr.append({"params": decoder_params, "lr": self.decoder_lr})
529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  # opt = torch.optim.AdamW(params_wlr)
531
  opt = self.optimizer(params_wlr)
532
  opts = [opt]
 
 
 
533
 
534
  # updated
535
  schedulers = []
 
547
  'frequency': 1
548
  }]
549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
  return opts, schedulers
551
+
 
552
  def low_vram_shift(self, is_diffusing):
553
  if is_diffusing:
554
  self.model = self.model.cuda()
 
568
  if not self.sd_locked: # Update
569
  self.model_diffoutblock_ema(self.model.diffusion_model.output_blocks)
570
  self.model_diffout_ema(self.model.diffusion_model.out)
 
 
 
 
571
  if self.log_all_grad_norm:
572
  zeroconvs = list(self.control_model.input_hint_block.named_parameters())[-2:]
573
  zeroconvs.extend(
 
609
  prog_bar=False, logger=True, on_step=True, on_epoch=False
610
  )
611
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612
  if self.learnable_conscale:
613
  for i in range(len(self.control_scales)):
614
  self.log(
 
617
  prog_bar=False, logger=True, on_step=True, on_epoch=False
618
  )
619
  del gradnorm_list
620
+ del zeroconvs
scripts/rendertext_tool.py CHANGED
@@ -75,12 +75,14 @@ class Render_Text:
75
  def __init__(self,
76
  model,
77
  precision_scope=nullcontext,
78
- transform=ToTensor()
 
79
  ):
80
  self.model = model
81
  self.precision_scope = precision_scope
82
  self.transform = transform
83
  self.ddim_sampler = DDIMSampler(model)
 
84
 
85
  def process_multi(self,
86
  rendered_txt_values, shared_prompt,
@@ -138,11 +140,36 @@ class Render_Text:
138
  shared_seed = random.randint(0, 65535)
139
  seed_everything(shared_seed)
140
 
 
 
 
 
141
  print("control is None: {}".format(control is None))
142
- print("prompt for the SD branch:", str(shared_prompt), "[t]")
143
- cond_c_cross = self.model.get_learned_conditioning([shared_prompt + ', ' + shared_a_prompt] * shared_num_samples)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  un_cond_cross = self.model.get_learned_conditioning([shared_n_prompt] * shared_num_samples)
145
 
 
 
 
 
146
  cond = {"c_concat": control, "c_crossattn": [cond_c_cross] if not isinstance(cond_c_cross, list) else cond_c_cross}
147
  un_cond = {"c_concat": None if shared_guess_mode else control, "c_crossattn": [un_cond_cross] if not isinstance(un_cond_cross, list) else un_cond_cross}
148
  shape = (4, H // 8, W // 8)
@@ -155,6 +182,9 @@ class Render_Text:
155
  shape, cond, verbose=False, eta=shared_eta,
156
  unconditional_guidance_scale=shared_scale,
157
  unconditional_conditioning=un_cond)
 
 
 
158
 
159
  x_samples = self.model.decode_first_stage(samples)
160
  x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
 
75
  def __init__(self,
76
  model,
77
  precision_scope=nullcontext,
78
+ transform=ToTensor(),
79
+ save_memory = False,
80
  ):
81
  self.model = model
82
  self.precision_scope = precision_scope
83
  self.transform = transform
84
  self.ddim_sampler = DDIMSampler(model)
85
+ self.save_memory = save_memory
86
 
87
  def process_multi(self,
88
  rendered_txt_values, shared_prompt,
 
140
  shared_seed = random.randint(0, 65535)
141
  seed_everything(shared_seed)
142
 
143
+ if torch.cuda.is_available() and self.save_memory:
144
+ print("low_vram_shift: is_diffusing", False)
145
+ self.model.low_vram_shift(is_diffusing=False)
146
+
147
  print("control is None: {}".format(control is None))
148
+ if shared_prompt.endswith("."):
149
+ if shared_a_prompt == "":
150
+ c_prompt = shared_prompt
151
+ else:
152
+ c_prompt = shared_prompt + " " + shared_a_prompt
153
+ elif shared_prompt.endswith(","):
154
+ if shared_a_prompt == "":
155
+ c_prompt = shared_prompt[:-1] + "."
156
+ else:
157
+ c_prompt = shared_prompt + " " + shared_a_prompt
158
+ else:
159
+ if shared_a_prompt == "":
160
+ c_prompt = shared_prompt + "."
161
+ else:
162
+ c_prompt = shared_prompt + ", " + shared_a_prompt
163
+
164
+ # cond_c_cross = self.model.get_learned_conditioning([shared_prompt + ', ' + shared_a_prompt] * shared_num_samples)
165
+ cond_c_cross = self.model.get_learned_conditioning([c_prompt] * shared_num_samples)
166
+ print("prompt:", c_prompt)
167
  un_cond_cross = self.model.get_learned_conditioning([shared_n_prompt] * shared_num_samples)
168
 
169
+ if torch.cuda.is_available() and self.save_memory:
170
+ print("low_vram_shift: is_diffusing", True)
171
+ self.model.low_vram_shift(is_diffusing=True)
172
+
173
  cond = {"c_concat": control, "c_crossattn": [cond_c_cross] if not isinstance(cond_c_cross, list) else cond_c_cross}
174
  un_cond = {"c_concat": None if shared_guess_mode else control, "c_crossattn": [un_cond_cross] if not isinstance(un_cond_cross, list) else un_cond_cross}
175
  shape = (4, H // 8, W // 8)
 
182
  shape, cond, verbose=False, eta=shared_eta,
183
  unconditional_guidance_scale=shared_scale,
184
  unconditional_conditioning=un_cond)
185
+ if torch.cuda.is_available() and self.save_memory:
186
+ print("low_vram_shift: is_diffusing", False)
187
+ self.model.low_vram_shift(is_diffusing=False)
188
 
189
  x_samples = self.model.decode_first_stage(samples)
190
  x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)