Spaces:
Runtime error
Runtime error
import collections | |
from torch import nn | |
from NTED.base_module import Encoder, Decoder | |
from torch.cuda.amp import autocast as autocast | |
class Generator(nn.Module): | |
def __init__( | |
self, | |
size, | |
semantic_dim, | |
channels, | |
num_labels, | |
match_kernels, | |
blur_kernel=[1, 3, 3, 1], | |
): | |
super().__init__() | |
self.size = size | |
self.reference_encoder = Encoder( | |
size, 3, channels, num_labels, match_kernels, blur_kernel | |
) | |
self.skeleton_encoder = Encoder( | |
size, semantic_dim, channels, | |
) | |
self.target_image_renderer = Decoder( | |
size, channels, num_labels, match_kernels, blur_kernel | |
) | |
def _cal_temp(self, module): | |
return sum(p.numel() for p in module.parameters() if p.requires_grad) | |
def forward( | |
self, | |
source_image, | |
skeleton, | |
amp_flag=False, | |
): | |
if amp_flag: | |
with autocast(): | |
output_dict={} | |
recoder = collections.defaultdict(list) | |
skeleton_feature = self.skeleton_encoder(skeleton) | |
_ = self.reference_encoder(source_image, recoder) | |
neural_textures = recoder["neural_textures"] | |
output_dict['fake_image'] = self.target_image_renderer( | |
skeleton_feature, neural_textures, recoder | |
) | |
output_dict['info'] = recoder | |
return output_dict | |
else: | |
output_dict={} | |
recoder = collections.defaultdict(list) | |
skeleton_feature = self.skeleton_encoder(skeleton) | |
_ = self.reference_encoder(source_image, recoder) | |
neural_textures = recoder["neural_textures"] | |
output_dict['fake_image'] = self.target_image_renderer( | |
skeleton_feature, neural_textures, recoder | |
) | |
output_dict['info'] = recoder | |
return output_dict | |