Spaces:
Runtime error
Runtime error
import math | |
import functools | |
import sys | |
import torch | |
import torch.nn as nn | |
from NTED.base_function import EncoderLayer, DecoderLayer, ToRGB | |
from NTED.edge_attention_layer import Edge_Attn | |
class Encoder(nn.Module): | |
def __init__( | |
self, | |
size, | |
input_dim, | |
channels, | |
num_labels=None, | |
match_kernels=None, | |
blur_kernel=[1, 3, 3, 1], | |
): | |
super().__init__() | |
self.first = EncoderLayer(input_dim, channels[size], 1) | |
self.convs = nn.ModuleList() | |
log_size = int(math.log(size, 2)) | |
self.log_size = log_size | |
in_channel = channels[size] | |
for i in range(log_size-1, 3, -1): | |
out_channel = channels[2 ** i] | |
num_label = num_labels[2 ** i] if num_labels is not None else None | |
match_kernel = match_kernels[2 ** i] if match_kernels is not None else None | |
use_extraction = num_label and match_kernel | |
conv = EncoderLayer( | |
in_channel, | |
out_channel, | |
kernel_size=3, | |
downsample=True, | |
blur_kernel=blur_kernel, | |
use_extraction=use_extraction, | |
num_label=num_label, | |
match_kernel=match_kernel | |
) | |
self.convs.append(conv) | |
in_channel = out_channel | |
def forward(self, input, recoder=None): | |
out = self.first(input) | |
for idx, layer in enumerate(self.convs): | |
out = layer(out, recoder) | |
return out | |
class Decoder(nn.Module): | |
def __init__( | |
self, | |
size, | |
channels, | |
num_labels, | |
match_kernels, | |
blur_kernel=[1, 3, 3, 1], | |
): | |
super().__init__() | |
self.convs = nn.ModuleList() | |
# input at resolution 16*16 | |
in_channel = channels[16] | |
self.log_size = int(math.log(size, 2)) | |
for i in range(4, self.log_size + 1): | |
out_channel = channels[2 ** i] | |
num_label, match_kernel = num_labels[2 ** i], match_kernels[2 ** i] | |
use_distribution = num_label and match_kernel | |
upsample = (i != 4) | |
base_layer = functools.partial( | |
DecoderLayer, | |
out_channel=out_channel, | |
kernel_size=3, | |
blur_kernel=blur_kernel, | |
use_distribution=use_distribution, | |
num_label=num_label, | |
match_kernel=match_kernel | |
) | |
up = nn.Module() | |
up.conv0 = base_layer(in_channel=in_channel, upsample=upsample) | |
up.conv1 = base_layer(in_channel=out_channel, upsample=False) | |
up.to_rgb = ToRGB(out_channel, upsample=upsample) | |
self.convs.append(up) | |
in_channel = out_channel | |
self.num_labels, self.match_kernels = num_labels, match_kernels | |
self.edge_attn_block = Edge_Attn(in_channels=3) | |
def forward(self, input, neural_textures, recoder): | |
counter = 0 | |
out, skip = input, None | |
for i, up in enumerate(self.convs): | |
if self.num_labels[2**(i+4)] and self.match_kernels[2**(i+4)]: | |
neural_texture_conv0 = neural_textures[counter] | |
neural_texture_conv1 = neural_textures[counter+1] | |
counter += 2 | |
else: | |
neural_texture_conv0, neural_texture_conv1 = None, None | |
out = up.conv0(out, neural_texture=neural_texture_conv0, recoder=recoder) | |
out = up.conv1(out, neural_texture=neural_texture_conv1, recoder=recoder) | |
skip = up.to_rgb(out, skip) | |
image = self.edge_attn_block(skip) | |
# image = skip | |
return image | |