import os import sys sys.path.append(os.getcwd()) import torch import torch.nn as nn import numpy as np # TODO: be aware of the actual netork structures def get_log(x): log = 0 while x > 1: if x % 2 == 0: x = x // 2 log += 1 else: raise ValueError('x is not a power of 2') return log class ConvNormRelu(nn.Module): ''' (B,C_in,H,W) -> (B, C_out, H, W) there exist some kernel size that makes the result is not H/s #TODO: there might some problems with residual ''' def __init__(self, in_channels, out_channels, type='1d', leaky=False, downsample=False, kernel_size=None, stride=None, padding=None, p=0, groups=1, residual=False, norm='bn'): ''' conv-bn-relu ''' super(ConvNormRelu, self).__init__() self.residual = residual self.norm_type = norm # kernel_size = k # stride = s if kernel_size is None and stride is None: if not downsample: kernel_size = 3 stride = 1 else: kernel_size = 4 stride = 2 if padding is None: if isinstance(kernel_size, int) and isinstance(stride, tuple): padding = tuple(int((kernel_size - st) / 2) for st in stride) elif isinstance(kernel_size, tuple) and isinstance(stride, int): padding = tuple(int((ks - stride) / 2) for ks in kernel_size) elif isinstance(kernel_size, tuple) and isinstance(stride, tuple): padding = tuple(int((ks - st) / 2) for ks, st in zip(kernel_size, stride)) else: padding = int((kernel_size - stride) / 2) if self.residual: if downsample: if type == '1d': self.residual_layer = nn.Sequential( nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding ) ) elif type == '2d': self.residual_layer = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding ) ) else: if in_channels == out_channels: self.residual_layer = nn.Identity() else: if type == '1d': self.residual_layer = nn.Sequential( nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding ) ) elif type == '2d': self.residual_layer = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding ) ) in_channels = in_channels * groups out_channels = out_channels * groups if type == '1d': self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) self.norm = nn.BatchNorm1d(out_channels) self.dropout = nn.Dropout(p=p) elif type == '2d': self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) self.norm = nn.BatchNorm2d(out_channels) self.dropout = nn.Dropout2d(p=p) if norm == 'gn': self.norm = nn.GroupNorm(2, out_channels) elif norm == 'ln': self.norm = nn.LayerNorm(out_channels) if leaky: self.relu = nn.LeakyReLU(negative_slope=0.2) else: self.relu = nn.ReLU() def forward(self, x, **kwargs): if self.norm_type == 'ln': out = self.dropout(self.conv(x)) out = self.norm(out.transpose(1,2)).transpose(1,2) else: out = self.norm(self.dropout(self.conv(x))) if self.residual: residual = self.residual_layer(x) out += residual return self.relu(out) class UNet1D(nn.Module): def __init__(self, input_channels, output_channels, max_depth=5, kernel_size=None, stride=None, p=0, groups=1): super(UNet1D, self).__init__() self.pre_downsampling_conv = nn.ModuleList([]) self.conv1 = nn.ModuleList([]) self.conv2 = nn.ModuleList([]) self.upconv = nn.Upsample(scale_factor=2, mode='nearest') self.max_depth = max_depth self.groups = groups self.pre_downsampling_conv.append(ConvNormRelu(input_channels, output_channels, type='1d', leaky=True, downsample=False, kernel_size=kernel_size, stride=stride, p=p, groups=groups)) self.pre_downsampling_conv.append(ConvNormRelu(output_channels, output_channels, type='1d', leaky=True, downsample=False, kernel_size=kernel_size, stride=stride, p=p, groups=groups)) for i in range(self.max_depth): self.conv1.append(ConvNormRelu(output_channels, output_channels, type='1d', leaky=True, downsample=True, kernel_size=kernel_size, stride=stride, p=p, groups=groups)) for i in range(self.max_depth): self.conv2.append(ConvNormRelu(output_channels, output_channels, type='1d', leaky=True, downsample=False, kernel_size=kernel_size, stride=stride, p=p, groups=groups)) def forward(self, x): input_size = x.shape[-1] assert get_log( input_size) >= self.max_depth, 'num_frames must be a power of 2 and its power must be greater than max_depth' x = nn.Sequential(*self.pre_downsampling_conv)(x) residuals = [] residuals.append(x) for i, conv1 in enumerate(self.conv1): x = conv1(x) if i < self.max_depth - 1: residuals.append(x) for i, conv2 in enumerate(self.conv2): x = self.upconv(x) + residuals[self.max_depth - i - 1] x = conv2(x) return x class UNet2D(nn.Module): def __init__(self): super(UNet2D, self).__init__() raise NotImplementedError('2D Unet is wierd') class AudioPoseEncoder1D(nn.Module): ''' (B, C, T) -> (B, C*2, T) -> ... -> (B, C_out, T) ''' def __init__(self, C_in, C_out, kernel_size=None, stride=None, min_layer_nums=None ): super(AudioPoseEncoder1D, self).__init__() self.C_in = C_in self.C_out = C_out conv_layers = nn.ModuleList([]) cur_C = C_in num_layers = 0 while cur_C < self.C_out: conv_layers.append(ConvNormRelu( in_channels=cur_C, out_channels=cur_C * 2, kernel_size=kernel_size, stride=stride )) cur_C *= 2 num_layers += 1 if (cur_C != C_out) or (min_layer_nums is not None and num_layers < min_layer_nums): while (cur_C != C_out) or num_layers < min_layer_nums: conv_layers.append(ConvNormRelu( in_channels=cur_C, out_channels=C_out, kernel_size=kernel_size, stride=stride )) num_layers += 1 cur_C = C_out self.conv_layers = nn.Sequential(*conv_layers) def forward(self, x): ''' x: (B, C, T) ''' x = self.conv_layers(x) return x class AudioPoseEncoder2D(nn.Module): ''' (B, C, T) -> (B, 1, C, T) -> ... -> (B, C_out, T) ''' def __init__(self): raise NotImplementedError class AudioPoseEncoderRNN(nn.Module): ''' (B, C, T)->(B, T, C)->(B, T, C_out)->(B, C_out, T) ''' def __init__(self, C_in, hidden_size, num_layers, rnn_cell='gru', bidirectional=False ): super(AudioPoseEncoderRNN, self).__init__() if rnn_cell == 'gru': self.cell = nn.GRU(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=bidirectional) elif rnn_cell == 'lstm': self.cell = nn.LSTM(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=bidirectional) else: raise ValueError('invalid rnn cell:%s' % (rnn_cell)) def forward(self, x, state=None): x = x.permute(0, 2, 1) x, state = self.cell(x, state) x = x.permute(0, 2, 1) return x class AudioPoseEncoderGraph(nn.Module): ''' (B, C, T)->(B, 2, V, T)->(B, 2, T, V)->(B, D, T, V) ''' def __init__(self, layers_config, # 理应是(C_in, C_out, kernel_size)的list A, # adjacent matrix (num_parts, V, V) residual, local_bn=False, share_weights=False ) -> None: super().__init__() self.A = A self.num_joints = A.shape[1] self.num_parts = A.shape[0] self.C_in = layers_config[0][0] self.C_out = layers_config[-1][1] self.conv_layers = nn.ModuleList([ GraphConvNormRelu( C_in=c_in, C_out=c_out, A=self.A, residual=residual, local_bn=local_bn, kernel_size=k, share_weights=share_weights ) for (c_in, c_out, k) in layers_config ]) self.conv_layers = nn.Sequential(*self.conv_layers) def forward(self, x): ''' x: (B, C, T), C should be num_joints*D output: (B, D, T, V) ''' B, C, T = x.shape x = x.view(B, self.num_joints, self.C_in, T) # (B, V, D, T),D:每个joint的特征维度,注意这里V在前面 x = x.permute(0, 2, 3, 1) # (B, D, T, V) assert x.shape[1] == self.C_in x_conved = self.conv_layers(x) # x_conved = x_conved.permute(0, 3, 1, 2).contiguous().view(B, self.C_out*self.num_joints, T)#(B, V*C_out, T) return x_conved class SeqEncoder2D(nn.Module): ''' seq_encoder, encoding a seq to a vector (B, C, T)->(B, 2, V, T)->(B, 2, T, V) -> (B, 32, )->...->(B, C_out) ''' def __init__(self, C_in, # should be 2 T_in, C_out, num_joints, min_layer_num=None, residual=False ): super(SeqEncoder2D, self).__init__() self.C_in = C_in self.C_out = C_out self.T_in = T_in self.num_joints = num_joints conv_layers = nn.ModuleList([]) conv_layers.append(ConvNormRelu( in_channels=C_in, out_channels=32, type='2d', residual=residual )) cur_C = 32 cur_H = T_in cur_W = num_joints num_layers = 1 while (cur_C < C_out) or (cur_H > 1) or (cur_W > 1): ks = [3, 3] st = [1, 1] if cur_H > 1: if cur_H > 4: ks[0] = 4 st[0] = 2 else: ks[0] = cur_H st[0] = cur_H if cur_W > 1: if cur_W > 4: ks[1] = 4 st[1] = 2 else: ks[1] = cur_W st[1] = cur_W conv_layers.append(ConvNormRelu( in_channels=cur_C, out_channels=min(C_out, cur_C * 2), type='2d', kernel_size=tuple(ks), stride=tuple(st), residual=residual )) cur_C = min(cur_C * 2, C_out) if cur_H > 1: if cur_H > 4: cur_H //= 2 else: cur_H = 1 if cur_W > 1: if cur_W > 4: cur_W //= 2 else: cur_W = 1 num_layers += 1 if min_layer_num is not None and (num_layers < min_layer_num): while num_layers < min_layer_num: conv_layers.append(ConvNormRelu( in_channels=C_out, out_channels=C_out, type='2d', kernel_size=1, stride=1, residual=residual )) num_layers += 1 self.conv_layers = nn.Sequential(*conv_layers) self.num_layers = num_layers def forward(self, x): B, C, T = x.shape x = x.view(B, self.num_joints, self.C_in, T) # (B, V, D, T) V in front x = x.permute(0, 2, 3, 1) # (B, D, T, V) assert x.shape[1] == self.C_in and x.shape[-1] == self.num_joints x = self.conv_layers(x) return x.squeeze() class SeqEncoder1D(nn.Module): ''' (B, C, T)->(B, D) ''' def __init__(self, C_in, C_out, T_in, min_layer_nums=None ): super(SeqEncoder1D, self).__init__() conv_layers = nn.ModuleList([]) cur_C = C_in cur_T = T_in self.num_layers = 0 while (cur_C < C_out) or (cur_T > 1): ks = 3 st = 1 if cur_T > 1: if cur_T > 4: ks = 4 st = 2 else: ks = cur_T st = cur_T conv_layers.append(ConvNormRelu( in_channels=cur_C, out_channels=min(C_out, cur_C * 2), type='1d', kernel_size=ks, stride=st )) cur_C = min(cur_C * 2, C_out) if cur_T > 1: if cur_T > 4: cur_T = cur_T // 2 else: cur_T = 1 self.num_layers += 1 if min_layer_nums is not None and (self.num_layers < min_layer_nums): while self.num_layers < min_layer_nums: conv_layers.append(ConvNormRelu( in_channels=C_out, out_channels=C_out, type='1d', kernel_size=1, stride=1 )) self.num_layers += 1 self.conv_layers = nn.Sequential(*conv_layers) def forward(self, x): x = self.conv_layers(x) return x.squeeze() class SeqEncoderRNN(nn.Module): ''' (B, C, T) -> (B, T, C) -> (B, D) LSTM/GRU-FC ''' def __init__(self, hidden_size, in_size, num_rnn_layers, rnn_cell='gru', bidirectional=False ): super(SeqEncoderRNN, self).__init__() self.hidden_size = hidden_size self.in_size = in_size self.num_rnn_layers = num_rnn_layers self.bidirectional = bidirectional if rnn_cell == 'gru': self.cell = nn.GRU(input_size=self.in_size, hidden_size=self.hidden_size, num_layers=self.num_rnn_layers, batch_first=True, bidirectional=bidirectional) elif rnn_cell == 'lstm': self.cell = nn.LSTM(input_size=self.in_size, hidden_size=self.hidden_size, num_layers=self.num_rnn_layers, batch_first=True, bidirectional=bidirectional) def forward(self, x, state=None): x = x.permute(0, 2, 1) B, T, C = x.shape x, _ = self.cell(x, state) if self.bidirectional: out = torch.cat([x[:, -1, :self.hidden_size], x[:, 0, self.hidden_size:]], dim=-1) else: out = x[:, -1, :] assert out.shape[0] == B return out class SeqEncoderGraph(nn.Module): ''' ''' def __init__(self, embedding_size, layer_configs, residual, local_bn, A, T, share_weights=False ) -> None: super().__init__() self.C_in = layer_configs[0][0] self.C_out = embedding_size self.num_joints = A.shape[1] self.graph_encoder = AudioPoseEncoderGraph( layers_config=layer_configs, A=A, residual=residual, local_bn=local_bn, share_weights=share_weights ) cur_C = layer_configs[-1][1] self.spatial_pool = ConvNormRelu( in_channels=cur_C, out_channels=cur_C, type='2d', kernel_size=(1, self.num_joints), stride=(1, 1), padding=(0, 0) ) temporal_pool = nn.ModuleList([]) cur_H = T num_layers = 0 self.temporal_conv_info = [] while cur_C < self.C_out or cur_H > 1: self.temporal_conv_info.append(cur_C) ks = [3, 1] st = [1, 1] if cur_H > 1: if cur_H > 4: ks[0] = 4 st[0] = 2 else: ks[0] = cur_H st[0] = cur_H temporal_pool.append(ConvNormRelu( in_channels=cur_C, out_channels=min(self.C_out, cur_C * 2), type='2d', kernel_size=tuple(ks), stride=tuple(st) )) cur_C = min(cur_C * 2, self.C_out) if cur_H > 1: if cur_H > 4: cur_H //= 2 else: cur_H = 1 num_layers += 1 self.temporal_pool = nn.Sequential(*temporal_pool) print("graph seq encoder info: temporal pool:", self.temporal_conv_info) self.num_layers = num_layers # need fc? def forward(self, x): ''' x: (B, C, T) ''' B, C, T = x.shape x = self.graph_encoder(x) x = self.spatial_pool(x) x = self.temporal_pool(x) x = x.view(B, self.C_out) return x class SeqDecoder2D(nn.Module): ''' (B, D)->(B, D, 1, 1)->(B, C_out, C, T)->(B, C_out, T) ''' def __init__(self): super(SeqDecoder2D, self).__init__() raise NotImplementedError class SeqDecoder1D(nn.Module): ''' (B, D)->(B, D, 1)->...->(B, C_out, T) ''' def __init__(self, D_in, C_out, T_out, min_layer_num=None ): super(SeqDecoder1D, self).__init__() self.T_out = T_out self.min_layer_num = min_layer_num cur_t = 1 self.pre_conv = ConvNormRelu( in_channels=D_in, out_channels=C_out, type='1d' ) self.num_layers = 1 self.upconv = nn.Upsample(scale_factor=2, mode='nearest') self.conv_layers = nn.ModuleList([]) cur_t *= 2 while cur_t <= T_out: self.conv_layers.append(ConvNormRelu( in_channels=C_out, out_channels=C_out, type='1d' )) cur_t *= 2 self.num_layers += 1 post_conv = nn.ModuleList([ConvNormRelu( in_channels=C_out, out_channels=C_out, type='1d' )]) self.num_layers += 1 if min_layer_num is not None and self.num_layers < min_layer_num: while self.num_layers < min_layer_num: post_conv.append(ConvNormRelu( in_channels=C_out, out_channels=C_out, type='1d' )) self.num_layers += 1 self.post_conv = nn.Sequential(*post_conv) def forward(self, x): x = x.unsqueeze(-1) x = self.pre_conv(x) for conv in self.conv_layers: x = self.upconv(x) x = conv(x) x = torch.nn.functional.interpolate(x, size=self.T_out, mode='nearest') x = self.post_conv(x) return x class SeqDecoderRNN(nn.Module): ''' (B, D)->(B, C_out, T) ''' def __init__(self, hidden_size, C_out, T_out, num_layers, rnn_cell='gru' ): super(SeqDecoderRNN, self).__init__() self.num_steps = T_out if rnn_cell == 'gru': self.cell = nn.GRU(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=False) elif rnn_cell == 'lstm': self.cell = nn.LSTM(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=False) else: raise ValueError('invalid rnn cell:%s' % (rnn_cell)) self.fc = nn.Linear(hidden_size, C_out) def forward(self, hidden, frame_0): frame_0 = frame_0.permute(0, 2, 1) dec_input = frame_0 outputs = [] for i in range(self.num_steps): frame_out, hidden = self.cell(dec_input, hidden) frame_out = self.fc(frame_out) dec_input = frame_out outputs.append(frame_out) output = torch.cat(outputs, dim=1) return output.permute(0, 2, 1) class SeqTranslator2D(nn.Module): ''' (B, C, T)->(B, 1, C, T)-> ... -> (B, 1, C_out, T_out) ''' def __init__(self, C_in=64, C_out=108, T_in=75, T_out=25, residual=True ): super(SeqTranslator2D, self).__init__() print("Warning: hard coded") self.C_in = C_in self.C_out = C_out self.T_in = T_in self.T_out = T_out self.residual = residual self.conv_layers = nn.Sequential( ConvNormRelu(1, 32, '2d', kernel_size=5, stride=1), ConvNormRelu(32, 32, '2d', kernel_size=5, stride=1, residual=self.residual), ConvNormRelu(32, 32, '2d', kernel_size=5, stride=1, residual=self.residual), ConvNormRelu(32, 64, '2d', kernel_size=5, stride=(4, 3)), ConvNormRelu(64, 64, '2d', kernel_size=5, stride=1, residual=self.residual), ConvNormRelu(64, 64, '2d', kernel_size=5, stride=1, residual=self.residual), ConvNormRelu(64, 128, '2d', kernel_size=5, stride=(4, 1)), ConvNormRelu(128, 108, '2d', kernel_size=3, stride=(4, 1)), ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1, residual=self.residual), ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1, residual=self.residual), ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1), ) def forward(self, x): assert len(x.shape) == 3 and x.shape[1] == self.C_in and x.shape[2] == self.T_in x = x.view(x.shape[0], 1, x.shape[1], x.shape[2]) x = self.conv_layers(x) x = x.squeeze(2) return x class SeqTranslator1D(nn.Module): ''' (B, C, T)->(B, C_out, T) ''' def __init__(self, C_in, C_out, kernel_size=None, stride=None, min_layers_num=None, residual=True, norm='bn' ): super(SeqTranslator1D, self).__init__() conv_layers = nn.ModuleList([]) conv_layers.append(ConvNormRelu( in_channels=C_in, out_channels=C_out, type='1d', kernel_size=kernel_size, stride=stride, residual=residual, norm=norm )) self.num_layers = 1 if min_layers_num is not None and self.num_layers < min_layers_num: while self.num_layers < min_layers_num: conv_layers.append(ConvNormRelu( in_channels=C_out, out_channels=C_out, type='1d', kernel_size=kernel_size, stride=stride, residual=residual, norm=norm )) self.num_layers += 1 self.conv_layers = nn.Sequential(*conv_layers) def forward(self, x): return self.conv_layers(x) class SeqTranslatorRNN(nn.Module): ''' (B, C, T)->(B, C_out, T) LSTM-FC ''' def __init__(self, C_in, C_out, hidden_size, num_layers, rnn_cell='gru' ): super(SeqTranslatorRNN, self).__init__() if rnn_cell == 'gru': self.enc_cell = nn.GRU(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=False) self.dec_cell = nn.GRU(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=False) elif rnn_cell == 'lstm': self.enc_cell = nn.LSTM(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=False) self.dec_cell = nn.LSTM(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=False) else: raise ValueError('invalid rnn cell:%s' % (rnn_cell)) self.fc = nn.Linear(hidden_size, C_out) def forward(self, x, frame_0): num_steps = x.shape[-1] x = x.permute(0, 2, 1) frame_0 = frame_0.permute(0, 2, 1) _, hidden = self.enc_cell(x, None) outputs = [] for i in range(num_steps): inputs = frame_0 output_frame, hidden = self.dec_cell(inputs, hidden) output_frame = self.fc(output_frame) frame_0 = output_frame outputs.append(output_frame) outputs = torch.cat(outputs, dim=1) return outputs.permute(0, 2, 1) class ResBlock(nn.Module): def __init__(self, input_dim, fc_dim, afn, nfn ): ''' afn: activation fn nfn: normalization fn ''' super(ResBlock, self).__init__() self.input_dim = input_dim self.fc_dim = fc_dim self.afn = afn self.nfn = nfn if self.afn != 'relu': raise ValueError('Wrong') if self.nfn == 'layer_norm': raise ValueError('wrong') self.layers = nn.Sequential( nn.Linear(self.input_dim, self.fc_dim // 2), nn.ReLU(), nn.Linear(self.fc_dim // 2, self.fc_dim // 2), nn.ReLU(), nn.Linear(self.fc_dim // 2, self.fc_dim), nn.ReLU() ) self.shortcut_layer = nn.Sequential( nn.Linear(self.input_dim, self.fc_dim), nn.ReLU(), ) def forward(self, inputs): return self.layers(inputs) + self.shortcut_layer(inputs) class AudioEncoder(nn.Module): def __init__(self, channels, padding=3, kernel_size=8, conv_stride=2, conv_pool=None, augmentation=False): super(AudioEncoder, self).__init__() self.in_channels = channels[0] self.augmentation = augmentation model = [] acti = nn.LeakyReLU(0.2) nr_layer = len(channels) - 1 for i in range(nr_layer): if conv_pool is None: model.append(nn.ReflectionPad1d(padding)) model.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=kernel_size, stride=conv_stride)) model.append(acti) else: model.append(nn.ReflectionPad1d(padding)) model.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=kernel_size, stride=conv_stride)) model.append(acti) model.append(conv_pool(kernel_size=2, stride=2)) if self.augmentation: model.append( nn.Conv1d(channels[-1], channels[-1], kernel_size=kernel_size, stride=conv_stride) ) model.append(acti) self.model = nn.Sequential(*model) def forward(self, x): x = x[:, :self.in_channels, :] x = self.model(x) return x class AudioDecoder(nn.Module): def __init__(self, channels, kernel_size=7, ups=25): super(AudioDecoder, self).__init__() model = [] pad = (kernel_size - 1) // 2 acti = nn.LeakyReLU(0.2) for i in range(len(channels) - 2): model.append(nn.Upsample(scale_factor=2, mode='nearest')) model.append(nn.ReflectionPad1d(pad)) model.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=kernel_size, stride=1)) if i == 0 or i == 1: model.append(nn.Dropout(p=0.2)) if not i == len(channels) - 2: model.append(acti) model.append(nn.Upsample(size=ups, mode='nearest')) model.append(nn.ReflectionPad1d(pad)) model.append(nn.Conv1d(channels[-2], channels[-1], kernel_size=kernel_size, stride=1)) self.model = nn.Sequential(*model) def forward(self, x): return self.model(x) class Audio2Pose(nn.Module): def __init__(self, pose_dim, embed_size, augmentation, ups=25): super(Audio2Pose, self).__init__() self.pose_dim = pose_dim self.embed_size = embed_size self.augmentation = augmentation self.aud_enc = AudioEncoder(channels=[13, 64, 128, 256], padding=2, kernel_size=7, conv_stride=1, conv_pool=nn.AvgPool1d, augmentation=self.augmentation) if self.augmentation: self.aud_dec = AudioDecoder(channels=[512, 256, 128, pose_dim]) else: self.aud_dec = AudioDecoder(channels=[256, 256, 128, pose_dim], ups=ups) if self.augmentation: self.pose_enc = nn.Sequential( nn.Linear(self.embed_size // 2, 256), nn.LayerNorm(256) ) def forward(self, audio_feat, dec_input=None): B = audio_feat.shape[0] aud_embed = self.aud_enc.forward(audio_feat) if self.augmentation: dec_input = dec_input.squeeze(0) dec_embed = self.pose_enc(dec_input) dec_embed = dec_embed.unsqueeze(2) dec_embed = dec_embed.expand(dec_embed.shape[0], dec_embed.shape[1], aud_embed.shape[-1]) aud_embed = torch.cat([aud_embed, dec_embed], dim=1) out = self.aud_dec.forward(aud_embed) return out if __name__ == '__main__': import numpy as np import os import sys test_model = SeqEncoder2D( C_in=2, T_in=25, C_out=512, num_joints=54, ) print(test_model.num_layers) input = torch.randn((64, 108, 25)) output = test_model(input) print(output.shape)