Spaces:
Runtime error
Runtime error
from typing import Optional | |
from torch.nn import Sigmoid, Sequential, Tanh | |
from tha3.nn.conv import create_conv3, create_conv3_from_block_args | |
from tha3.nn.nonlinearity_factory import ReLUFactory | |
from tha3.nn.normalization import InstanceNorm2dFactory | |
from tha3.nn.util import BlockArgs | |
class PoserArgs00: | |
def __init__(self, | |
image_size: int, | |
input_image_channels: int, | |
output_image_channels: int, | |
start_channels: int, | |
num_pose_params: int, | |
block_args: Optional[BlockArgs] = None): | |
self.num_pose_params = num_pose_params | |
self.start_channels = start_channels | |
self.output_image_channels = output_image_channels | |
self.input_image_channels = input_image_channels | |
self.image_size = image_size | |
if block_args is None: | |
self.block_args = BlockArgs( | |
normalization_layer_factory=InstanceNorm2dFactory(), | |
nonlinearity_factory=ReLUFactory(inplace=True)) | |
else: | |
self.block_args = block_args | |
def create_alpha_block(self): | |
from torch.nn import Sequential | |
return Sequential( | |
create_conv3( | |
in_channels=self.start_channels, | |
out_channels=1, | |
bias=True, | |
initialization_method=self.block_args.initialization_method, | |
use_spectral_norm=False), | |
Sigmoid()) | |
def create_all_channel_alpha_block(self): | |
from torch.nn import Sequential | |
return Sequential( | |
create_conv3( | |
in_channels=self.start_channels, | |
out_channels=self.output_image_channels, | |
bias=True, | |
initialization_method=self.block_args.initialization_method, | |
use_spectral_norm=False), | |
Sigmoid()) | |
def create_color_change_block(self): | |
return Sequential( | |
create_conv3_from_block_args( | |
in_channels=self.start_channels, | |
out_channels=self.output_image_channels, | |
bias=True, | |
block_args=self.block_args), | |
Tanh()) | |
def create_grid_change_block(self): | |
return create_conv3( | |
in_channels=self.start_channels, | |
out_channels=2, | |
bias=False, | |
initialization_method='zero', | |
use_spectral_norm=False) |