Spaces:
Runtime error
Runtime error
diff --git a/mapper/latent_mappers.py b/mapper/latent_mappers.py | |
index 56b9c55..f0dd005 100644 | |
--- a/mapper/latent_mappers.py | |
+++ b/mapper/latent_mappers.py | |
class ModulationModule(Module): | |
def forward(self, x, embedding, cut_flag): | |
x = self.fc(x) | |
- x = self.norm(x) | |
+ x = self.norm(x) | |
if cut_flag == 1: | |
return x | |
gamma = self.gamma_function(embedding.float()) | |
class SubHairMapper(Module): | |
def forward(self, x, embedding, cut_flag=0): | |
x = self.pixelnorm(x) | |
for modulation_module in self.modulation_module_list: | |
- x = modulation_module(x, embedding, cut_flag) | |
+ x = modulation_module(x, embedding, cut_flag) | |
return x | |
-class HairMapper(Module): | |
+class HairMapper(Module): | |
def __init__(self, opts): | |
super(HairMapper, self).__init__() | |
self.opts = opts | |
- self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda") | |
+ self.clip_model, self.preprocess = clip.load("ViT-B/32", device=opts.device) | |
self.transform = transforms.Compose([transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))]) | |
self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224)) | |
self.hairstyle_cut_flag = 0 | |
self.color_cut_flag = 0 | |
- if not opts.no_coarse_mapper: | |
+ if not opts.no_coarse_mapper: | |
self.course_mapping = SubHairMapper(opts, 4) | |
if not opts.no_medium_mapper: | |
self.medium_mapping = SubHairMapper(opts, 4) | |
class HairMapper(Module): | |
elif hairstyle_tensor.shape[1] != 1: | |
hairstyle_embedding = self.gen_image_embedding(hairstyle_tensor, self.clip_model, self.preprocess).unsqueeze(1).repeat(1, 18, 1).detach() | |
else: | |
- hairstyle_embedding = torch.ones(x.shape[0], 18, 512).cuda() | |
+ hairstyle_embedding = torch.ones(x.shape[0], 18, 512).to(self.opts.device) | |
if color_text_inputs.shape[1] != 1: | |
color_embedding = self.clip_model.encode_text(color_text_inputs).unsqueeze(1).repeat(1, 18, 1).detach() | |
elif color_tensor.shape[1] != 1: | |
color_embedding = self.gen_image_embedding(color_tensor, self.clip_model, self.preprocess).unsqueeze(1).repeat(1, 18, 1).detach() | |
else: | |
- color_embedding = torch.ones(x.shape[0], 18, 512).cuda() | |
+ color_embedding = torch.ones(x.shape[0], 18, 512).to(self.opts.device) | |
if (hairstyle_text_inputs.shape[1] == 1) and (hairstyle_tensor.shape[1] == 1): | |
class HairMapper(Module): | |
x_fine = torch.zeros_like(x_fine) | |
out = torch.cat([x_coarse, x_medium, x_fine], dim=1) | |
- return out | |
\ No newline at end of file | |
+ return out | |