tolgacangoz commited on
Commit
96d2b4d
1 Parent(s): 83e9f29

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. matryoshka.py +17 -2
matryoshka.py CHANGED
@@ -3782,8 +3782,6 @@ class MatryoshkaPipeline(
3782
  else:
3783
  raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
3784
 
3785
- unet = unet.to(self.device)
3786
-
3787
  if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
3788
  deprecation_message = (
3789
  f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
@@ -3840,10 +3838,27 @@ class MatryoshkaPipeline(
3840
  feature_extractor=feature_extractor,
3841
  image_encoder=image_encoder,
3842
  )
 
3843
  if hasattr(unet, "nest_ratio"):
3844
  scheduler.scales = unet.nest_ratio + [1]
3845
  self.image_processor = VaeImageProcessor(do_resize=False)
3846
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3847
  def encode_prompt(
3848
  self,
3849
  prompt,
 
3782
  else:
3783
  raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
3784
 
 
 
3785
  if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
3786
  deprecation_message = (
3787
  f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
 
3838
  feature_extractor=feature_extractor,
3839
  image_encoder=image_encoder,
3840
  )
3841
+ self.register_to_config(nesting_level=nesting_level)
3842
  if hasattr(unet, "nest_ratio"):
3843
  scheduler.scales = unet.nest_ratio + [1]
3844
  self.image_processor = VaeImageProcessor(do_resize=False)
3845
 
3846
+ def change_nesting_level(self, nesting_level: int):
3847
+ if nesting_level == 0:
3848
+ self.unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3849
+ subfolder="unet/nesting_level_0")
3850
+ self.config.nesting_level = 0
3851
+ elif nesting_level == 1:
3852
+ self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3853
+ subfolder="unet/nesting_level_1")
3854
+ self.config.nesting_level = 1
3855
+ elif nesting_level == 2:
3856
+ self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3857
+ subfolder="unet/nesting_level_2")
3858
+ self.config.nesting_level = 2
3859
+ else:
3860
+ raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
3861
+
3862
  def encode_prompt(
3863
  self,
3864
  prompt,