linoyts HF staff commited on
Commit
6f3bf64
1 Parent(s): 3c27e97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -28
app.py CHANGED
@@ -66,7 +66,7 @@ from diffusers.utils import (
66
  )
67
  from diffusers.utils.torch_utils import randn_tensor
68
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
69
- from ledits.pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput
70
 
71
 
72
  if is_invisible_watermark_available():
@@ -428,10 +428,11 @@ class LEditsPPPipelineStableDiffusionXL(
428
  editing_prompt: Optional[str] = None,
429
  editing_prompt_embeds: Optional[torch.Tensor] = None,
430
  editing_pooled_prompt_embeds: Optional[torch.Tensor] = None,
431
- avg_diff = None,
432
- avg_diff_2 = None,
433
- correlation_weight_factor = 0.7,
434
  scale=2,
 
435
  ) -> object:
436
  r"""
437
  Encodes the prompt into text encoder hidden states.
@@ -551,9 +552,8 @@ class LEditsPPPipelineStableDiffusionXL(
551
  negative_pooled_prompt_embeds = negative_prompt_embeds[0]
552
  negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
553
 
554
- if avg_diff is not None and avg_diff_2 is not None:
555
- #scale=3
556
- print("SHALOM neg")
557
  normed_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True)
558
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
559
  if j == 0:
@@ -562,15 +562,26 @@ class LEditsPPPipelineStableDiffusionXL(
562
  standard_weights = torch.ones_like(weights)
563
 
564
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
565
- edit_concepts_embeds = negative_prompt_embeds + (weights * avg_diff[None, :].repeat(1,tokenizer.model_max_length, 1) * scale)
 
 
 
 
 
 
566
  else:
567
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
568
 
569
  standard_weights = torch.ones_like(weights)
570
 
571
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
572
- edit_concepts_embeds = negative_prompt_embeds + (weights * avg_diff_2[None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
 
573
 
 
 
 
 
574
 
575
  negative_prompt_embeds_list.append(negative_prompt_embeds)
576
  j+=1
@@ -858,8 +869,8 @@ class LEditsPPPipelineStableDiffusionXL(
858
 
859
  self.unet.set_attn_processor(attn_procs)
860
 
861
- @spaces.GPU
862
  @torch.no_grad()
 
863
  @replace_example_docstring(EXAMPLE_DOC_STRING)
864
  def __call__(
865
  self,
@@ -892,10 +903,12 @@ class LEditsPPPipelineStableDiffusionXL(
892
  clip_skip: Optional[int] = None,
893
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
894
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
895
- avg_diff = None,
896
- avg_diff_2 = None,
897
- correlation_weight_factor = 0.7,
898
  scale=2,
 
 
899
  init_latents: [torch.Tensor] = None,
900
  zs: [torch.Tensor] = None,
901
  **kwargs,
@@ -1102,9 +1115,10 @@ class LEditsPPPipelineStableDiffusionXL(
1102
  editing_prompt_embeds=editing_prompt_embeddings,
1103
  editing_pooled_prompt_embeds=editing_pooled_prompt_embeds,
1104
  avg_diff = avg_diff,
1105
- avg_diff_2 = avg_diff_2,
1106
  correlation_weight_factor = correlation_weight_factor,
1107
  scale=scale,
 
1108
  )
1109
 
1110
  # 4. Prepare timesteps
@@ -1475,7 +1489,6 @@ class LEditsPPPipelineStableDiffusionXL(
1475
 
1476
  @torch.no_grad()
1477
  # Modified from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.encode_image
1478
- @spaces.GPU
1479
  def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="default", crops_coords=None):
1480
  image = self.image_processor.preprocess(
1481
  image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
@@ -1504,8 +1517,8 @@ class LEditsPPPipelineStableDiffusionXL(
1504
  x0 = self.vae.config.scaling_factor * x0
1505
  return x0, resized
1506
 
1507
- @spaces.GPU
1508
  @torch.no_grad()
 
1509
  def invert(
1510
  self,
1511
  image: PipelineImageInput,
@@ -1669,20 +1682,17 @@ class LEditsPPPipelineStableDiffusionXL(
1669
  t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
1670
  xts = torch.zeros(size=variance_noise_shape, device=self.device, dtype=negative_prompt_embeds.dtype)
1671
 
1672
- print("pre loop 1")
1673
  for t in reversed(timesteps):
1674
  idx = num_inversion_steps - t_to_idx[int(t)] - 1
1675
  noise = randn_tensor(shape=x0.shape, generator=generator, device=self.device, dtype=x0.dtype)
1676
  xts[idx] = self.scheduler.add_noise(x0, noise, t.unsqueeze(0))
1677
  xts = torch.cat([x0.unsqueeze(0), xts], dim=0)
1678
- print("post loop 1")
1679
-
1680
  # noise maps
1681
  zs = torch.zeros(size=variance_noise_shape, device=self.device, dtype=negative_prompt_embeds.dtype)
1682
 
1683
  self.scheduler.set_timesteps(len(self.scheduler.timesteps))
1684
 
1685
- print("pre loop 2")
1686
  for t in self.progress_bar(timesteps):
1687
  idx = num_inversion_steps - t_to_idx[int(t)] - 1
1688
  # 1. predict noise residual
@@ -1714,21 +1724,18 @@ class LEditsPPPipelineStableDiffusionXL(
1714
 
1715
  # correction to avoid error accumulation
1716
  xts[idx] = xtm1_corrected
1717
- print("post loop 2")
1718
 
1719
- #self.init_latents = xts[-1]
1720
  zs = zs.flip(0)
1721
- print("post 3")
1722
  if num_zero_noise_steps > 0:
1723
  zs[-num_zero_noise_steps:] = torch.zeros_like(zs[-num_zero_noise_steps:])
1724
- print("post 4")
1725
- #self.zs = zs
1726
  #return LEditsPPInversionPipelineOutput(images=resized, vae_reconstruction_images=image_rec)
1727
  return xts[-1], zs
1728
 
1729
 
1730
  # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.rescale_noise_cfg
1731
- @spaces.GPU
1732
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
1733
  """
1734
  Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -1783,7 +1790,6 @@ def compute_noise_ddim(scheduler, prev_latents, latents, timestep, noise_pred, e
1783
 
1784
 
1785
  # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise_sde_dpm_pp_2nd
1786
- @spaces.GPU
1787
  def compute_noise_sde_dpm_pp_2nd(scheduler, prev_latents, latents, timestep, noise_pred, eta):
1788
  def first_order_update(model_output, sample): # timestep, prev_timestep, sample):
1789
  sigma_t, sigma_s = scheduler.sigmas[scheduler.step_index + 1], scheduler.sigmas[scheduler.step_index]
@@ -1869,7 +1875,6 @@ def compute_noise_sde_dpm_pp_2nd(scheduler, prev_latents, latents, timestep, noi
1869
 
1870
 
1871
  # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise
1872
- @spaces.GPU
1873
  def compute_noise(scheduler, *args):
1874
  if isinstance(scheduler, DDIMScheduler):
1875
  return compute_noise_ddim(scheduler, *args)
 
66
  )
67
  from diffusers.utils.torch_utils import randn_tensor
68
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
69
+ from .pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput
70
 
71
 
72
  if is_invisible_watermark_available():
 
428
  editing_prompt: Optional[str] = None,
429
  editing_prompt_embeds: Optional[torch.Tensor] = None,
430
  editing_pooled_prompt_embeds: Optional[torch.Tensor] = None,
431
+ avg_diff=None, # [0] -> text encoder 1,[1] ->text encoder 2
432
+ avg_diff_2nd=None, # text encoder 1,2
433
+ correlation_weight_factor=0.7,
434
  scale=2,
435
+ scale_2nd=2,
436
  ) -> object:
437
  r"""
438
  Encodes the prompt into text encoder hidden states.
 
552
  negative_pooled_prompt_embeds = negative_prompt_embeds[0]
553
  negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
554
 
555
+ if avg_diff is not None:
556
+ # scale=3
 
557
  normed_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True)
558
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
559
  if j == 0:
 
562
  standard_weights = torch.ones_like(weights)
563
 
564
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
565
+ edit_concepts_embeds = negative_prompt_embeds + (
566
+ weights * avg_diff[0][None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
567
+
568
+ if avg_diff_2nd is not None:
569
+ edit_concepts_embeds += (weights * avg_diff_2nd[0][None, :].repeat(1,
570
+ self.pipe.tokenizer.model_max_length,
571
+ 1) * scale_2nd)
572
  else:
573
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
574
 
575
  standard_weights = torch.ones_like(weights)
576
 
577
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
578
+ edit_concepts_embeds = negative_prompt_embeds + (
579
+ weights * avg_diff[1][None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
580
 
581
+ if avg_diff_2nd is not None:
582
+ edit_concepts_embeds += (weights * avg_diff_2nd[1][None, :].repeat(1,
583
+ self.pipe.tokenizer_2.model_max_length,
584
+ 1) * scale_2nd)
585
 
586
  negative_prompt_embeds_list.append(negative_prompt_embeds)
587
  j+=1
 
869
 
870
  self.unet.set_attn_processor(attn_procs)
871
 
 
872
  @torch.no_grad()
873
+ @spaces.GPU
874
  @replace_example_docstring(EXAMPLE_DOC_STRING)
875
  def __call__(
876
  self,
 
903
  clip_skip: Optional[int] = None,
904
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
905
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
906
+ avg_diff=None, # [0] -> text encoder 1,[1] ->text encoder 2
907
+ avg_diff_2nd=None, # text encoder 1,2
908
+ correlation_weight_factor=0.7,
909
  scale=2,
910
+ scale_2nd=2,
911
+ correlation_weight_factor = 0.7,
912
  init_latents: [torch.Tensor] = None,
913
  zs: [torch.Tensor] = None,
914
  **kwargs,
 
1115
  editing_prompt_embeds=editing_prompt_embeddings,
1116
  editing_pooled_prompt_embeds=editing_pooled_prompt_embeds,
1117
  avg_diff = avg_diff,
1118
+ avg_diff_2nd = avg_diff_2nd,
1119
  correlation_weight_factor = correlation_weight_factor,
1120
  scale=scale,
1121
+ scale_2nd=scale_2nd
1122
  )
1123
 
1124
  # 4. Prepare timesteps
 
1489
 
1490
  @torch.no_grad()
1491
  # Modified from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.encode_image
 
1492
  def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="default", crops_coords=None):
1493
  image = self.image_processor.preprocess(
1494
  image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
 
1517
  x0 = self.vae.config.scaling_factor * x0
1518
  return x0, resized
1519
 
 
1520
  @torch.no_grad()
1521
+ @spaces.GPU
1522
  def invert(
1523
  self,
1524
  image: PipelineImageInput,
 
1682
  t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
1683
  xts = torch.zeros(size=variance_noise_shape, device=self.device, dtype=negative_prompt_embeds.dtype)
1684
 
 
1685
  for t in reversed(timesteps):
1686
  idx = num_inversion_steps - t_to_idx[int(t)] - 1
1687
  noise = randn_tensor(shape=x0.shape, generator=generator, device=self.device, dtype=x0.dtype)
1688
  xts[idx] = self.scheduler.add_noise(x0, noise, t.unsqueeze(0))
1689
  xts = torch.cat([x0.unsqueeze(0), xts], dim=0)
1690
+
 
1691
  # noise maps
1692
  zs = torch.zeros(size=variance_noise_shape, device=self.device, dtype=negative_prompt_embeds.dtype)
1693
 
1694
  self.scheduler.set_timesteps(len(self.scheduler.timesteps))
1695
 
 
1696
  for t in self.progress_bar(timesteps):
1697
  idx = num_inversion_steps - t_to_idx[int(t)] - 1
1698
  # 1. predict noise residual
 
1724
 
1725
  # correction to avoid error accumulation
1726
  xts[idx] = xtm1_corrected
 
1727
 
1728
+ self.init_latents = xts[-1]
1729
  zs = zs.flip(0)
1730
+
1731
  if num_zero_noise_steps > 0:
1732
  zs[-num_zero_noise_steps:] = torch.zeros_like(zs[-num_zero_noise_steps:])
1733
+ self.zs = zs
 
1734
  #return LEditsPPInversionPipelineOutput(images=resized, vae_reconstruction_images=image_rec)
1735
  return xts[-1], zs
1736
 
1737
 
1738
  # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.rescale_noise_cfg
 
1739
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
1740
  """
1741
  Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
 
1790
 
1791
 
1792
  # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise_sde_dpm_pp_2nd
 
1793
  def compute_noise_sde_dpm_pp_2nd(scheduler, prev_latents, latents, timestep, noise_pred, eta):
1794
  def first_order_update(model_output, sample): # timestep, prev_timestep, sample):
1795
  sigma_t, sigma_s = scheduler.sigmas[scheduler.step_index + 1], scheduler.sigmas[scheduler.step_index]
 
1875
 
1876
 
1877
  # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise
 
1878
  def compute_noise(scheduler, *args):
1879
  if isinstance(scheduler, DDIMScheduler):
1880
  return compute_noise_ddim(scheduler, *args)