parokshsaxena commited on
Commit
034254b
β€’
1 Parent(s): fd52149

adding enhanced garmet net

Browse files
src/enhanced_garment_net.py CHANGED
@@ -1,4 +1,3 @@
1
- """
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
@@ -122,5 +121,3 @@ class EnhancedGarmentNetWithTimestep(nn.Module):
122
  combined_features.append(combined)
123
 
124
  return garment_out, combined_features
125
-
126
- """
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
121
  combined_features.append(combined)
122
 
123
  return garment_out, combined_features
 
 
src/tryon_pipeline.py CHANGED
@@ -57,7 +57,7 @@ from diffusers.utils.torch_utils import randn_tensor
57
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
58
 
59
  # Commenting out for now
60
- #from src.enhanced_garment_net import EnhancedGarmentNetWithTimestep
61
 
62
 
63
 
@@ -401,7 +401,7 @@ class StableDiffusionXLInpaintPipeline(
401
  force_zeros_for_empty_prompt: bool = True,
402
  ):
403
  super().__init__()
404
- # self.garment_net = EnhancedGarmentNetWithTimestep()
405
 
406
  self.register_modules(
407
  vae=vae,
@@ -1786,8 +1786,8 @@ class StableDiffusionXLInpaintPipeline(
1786
  added_cond_kwargs["image_embeds"] = image_embeds
1787
  print("Calling unet encoder for garment feature extraction")
1788
  # down,reference_features = self.UNet_Encoder(cloth,t, text_embeds_cloth,added_cond_kwargs= {"text_embeds": pooled_prompt_embeds_c, "time_ids": add_time_ids},return_dict=False)
1789
- down,reference_features = self.unet_encoder(cloth,t, text_embeds_cloth,return_dict=False)
1790
- #garment_out, reference_features = self.garment_net(cloth, t, text_embeds_cloth)
1791
  print(type(reference_features))
1792
  print(reference_features)
1793
  reference_features = list(reference_features)
 
57
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
58
 
59
  # Commenting out for now
60
+ from src.enhanced_garment_net import EnhancedGarmentNetWithTimestep
61
 
62
 
63
 
 
401
  force_zeros_for_empty_prompt: bool = True,
402
  ):
403
  super().__init__()
404
+ self.garment_net = EnhancedGarmentNetWithTimestep()
405
 
406
  self.register_modules(
407
  vae=vae,
 
1786
  added_cond_kwargs["image_embeds"] = image_embeds
1787
  print("Calling unet encoder for garment feature extraction")
1788
  # down,reference_features = self.UNet_Encoder(cloth,t, text_embeds_cloth,added_cond_kwargs= {"text_embeds": pooled_prompt_embeds_c, "time_ids": add_time_ids},return_dict=False)
1789
+ #down,reference_features = self.unet_encoder(cloth,t, text_embeds_cloth,return_dict=False)
1790
+ garment_out, reference_features = self.garment_net(cloth, t, text_embeds_cloth)
1791
  print(type(reference_features))
1792
  print(reference_features)
1793
  reference_features = list(reference_features)