Spaces:
Running
on
Zero
Running
on
Zero
parokshsaxena
commited on
Commit
β’
034254b
1
Parent(s):
fd52149
adding enhanced garmet net
Browse files- src/enhanced_garment_net.py +0 -3
- src/tryon_pipeline.py +4 -4
src/enhanced_garment_net.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
"""
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
import torch.nn.functional as F
|
@@ -122,5 +121,3 @@ class EnhancedGarmentNetWithTimestep(nn.Module):
|
|
122 |
combined_features.append(combined)
|
123 |
|
124 |
return garment_out, combined_features
|
125 |
-
|
126 |
-
"""
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
|
|
121 |
combined_features.append(combined)
|
122 |
|
123 |
return garment_out, combined_features
|
|
|
|
src/tryon_pipeline.py
CHANGED
@@ -57,7 +57,7 @@ from diffusers.utils.torch_utils import randn_tensor
|
|
57 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
58 |
|
59 |
# Commenting out for now
|
60 |
-
|
61 |
|
62 |
|
63 |
|
@@ -401,7 +401,7 @@ class StableDiffusionXLInpaintPipeline(
|
|
401 |
force_zeros_for_empty_prompt: bool = True,
|
402 |
):
|
403 |
super().__init__()
|
404 |
-
|
405 |
|
406 |
self.register_modules(
|
407 |
vae=vae,
|
@@ -1786,8 +1786,8 @@ class StableDiffusionXLInpaintPipeline(
|
|
1786 |
added_cond_kwargs["image_embeds"] = image_embeds
|
1787 |
print("Calling unet encoder for garment feature extraction")
|
1788 |
# down,reference_features = self.UNet_Encoder(cloth,t, text_embeds_cloth,added_cond_kwargs= {"text_embeds": pooled_prompt_embeds_c, "time_ids": add_time_ids},return_dict=False)
|
1789 |
-
down,reference_features = self.unet_encoder(cloth,t, text_embeds_cloth,return_dict=False)
|
1790 |
-
|
1791 |
print(type(reference_features))
|
1792 |
print(reference_features)
|
1793 |
reference_features = list(reference_features)
|
|
|
57 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
58 |
|
59 |
# Commenting out for now
|
60 |
+
from src.enhanced_garment_net import EnhancedGarmentNetWithTimestep
|
61 |
|
62 |
|
63 |
|
|
|
401 |
force_zeros_for_empty_prompt: bool = True,
|
402 |
):
|
403 |
super().__init__()
|
404 |
+
self.garment_net = EnhancedGarmentNetWithTimestep()
|
405 |
|
406 |
self.register_modules(
|
407 |
vae=vae,
|
|
|
1786 |
added_cond_kwargs["image_embeds"] = image_embeds
|
1787 |
print("Calling unet encoder for garment feature extraction")
|
1788 |
# down,reference_features = self.UNet_Encoder(cloth,t, text_embeds_cloth,added_cond_kwargs= {"text_embeds": pooled_prompt_embeds_c, "time_ids": add_time_ids},return_dict=False)
|
1789 |
+
#down,reference_features = self.unet_encoder(cloth,t, text_embeds_cloth,return_dict=False)
|
1790 |
+
garment_out, reference_features = self.garment_net(cloth, t, text_embeds_cloth)
|
1791 |
print(type(reference_features))
|
1792 |
print(reference_features)
|
1793 |
reference_features = list(reference_features)
|