import torch import torch.nn as nn from torch import Tensor, BoolTensor from typing import Optional, Tuple, Iterable from model.diffusion import SinusoidalPosEmb from torch.nn.functional import pad import math def silu(input): ''' Applies the Sigmoid Linear Unit (SiLU) function element-wise: SiLU(x) = x * sigmoid(x) ''' return input * torch.sigmoid(input) # use torch.sigmoid to make sure that we created the most efficient implemetation based on builtin PyTorch functions class RelPositionMultiHeadedAttention(nn.Module): """Multi-Head Self-Attention layer with relative position encoding. Paper: https://arxiv.org/abs/1901.02860 Args: n_head: The number of heads. d: The number of features. dropout: Dropout rate. zero_triu: Whether to zero the upper triangular part of attention matrix. """ def __init__( self, d: int, n_head: int, dropout: float ): super().__init__() assert d % n_head == 0 self.c = d // n_head self.h = n_head self.linear_q = nn.Linear(d, d) self.linear_k = nn.Linear(d, d) self.linear_v = nn.Linear(d, d) self.linear_out = nn.Linear(d, d) self.p_attn = None self.dropout = nn.Dropout(p=dropout) # linear transformation for positional encoding self.linear_pos = nn.Linear(d, d, bias=False) # these two learnable bias are used in matrix c and matrix d # as described in https://arxiv.org/abs/1901.02860 Section 3.3 self.u = nn.Parameter(torch.Tensor(self.h, self.c)) self.v = nn.Parameter(torch.Tensor(self.h, self.c)) # [H, C] torch.nn.init.xavier_uniform_(self.u) torch.nn.init.xavier_uniform_(self.v) def forward_qkv(self, query, key, value) -> Tuple[Tensor, ...]: """Transform query, key and value. Args: query (Tensor): [B, S, D]. key (Tensor): [B, T, D]. value (Tensor): [B, T, D]. Returns: q (Tensor): [B, H, S, C]. k (Tensor): [B, H, T, C]. v (Tensor): [B, H, T, C]. """ n_batch = query.size(0) q = self.linear_q(query).view(n_batch, -1, self.h, self.c) k = self.linear_k(key).view(n_batch, -1, self.h, self.c) v = self.linear_v(value).view(n_batch, -1, self.h, self.c) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) return q, k, v def forward_attention(self, v, scores, mask, causal=False) -> Tensor: """Compute attention context vector. Args: v (Tensor): [B, H, T, C]. scores (Tensor): [B, H, S, T]. mask (BoolTensor): [B, T], True values are masked from scores. Returns: result (Tensor): [B, S, D]. Attention result weighted by the score. """ n_batch, H, S, T = scores.shape if mask is not None: scores = scores.masked_fill( mask.unsqueeze(1).unsqueeze(2).to(bool), float("-inf"), # [B, H, S, T] ) if causal: k_grid = torch.arange(0, S, dtype=torch.int32, device=scores.device) v_grid = torch.arange(0, T, dtype=torch.int32, device=scores.device) kk, vv = torch.meshgrid(k_grid, v_grid, indexing="ij") causal_mask = vv > kk scores = scores.masked_fill( causal_mask.view(1, 1, S, T), float("-inf") ) p_attn = self.p_attn = torch.softmax(scores, dim=-1) # [B, H, S, T] p_attn = self.dropout(p_attn) # [B, H, S, T] x = torch.matmul(p_attn, v) # [B, H, S, C] x = ( x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.c) ) # [B, S, D] return self.linear_out(x) # [B, S, D] def rel_shift(self, x): """Converting (..., i, i - j) matrix into (..., i, j) matrix. Args: x (Tensor): [B, H, S, 2S-1]. Returns: x (Tensor): [B, H, S, S]. Example: Take S = 2 for example, larger values work similarly. x = [ [(0, -1), (0, 0), (0, 1)], [(1, 0), (1, 1), (1, 2)] ] x_padded = [ [(x, x), (0, -1), (0, 0), (0, 1)], [(x, x), (1, 0), (1, 1), (1, 2)]] ] x_padded = [ [(x, x), (0, -1)], [(0, 0), (0, 1)], [(x, x), (1, 0)], [(1, 1), (1, 2)] ] x = [ [(0, 0), (0, 1)], [(1, 0), (1, 1)] ] """ B, H, S, _ = x.shape zero_pad = torch.zeros((B, H, S, 1), device=x.device, dtype=x.dtype) # [B, H, S, 1] x_padded = torch.cat([zero_pad, x], dim=-1) # [B, H, S, 2S] x_padded = x_padded.view(B, H, 2 * S, S) # [B, H, 2S, S] x = x_padded[:, :, 1:].view_as(x)[:, :, :, :S] # only keep the positions from 0 to S # [B, H, 2S-1, S] [B, H, S, 2S - 1] [B, H, S, S] return x def forward( self, query, key, value, pos_emb, mask=None, causal=False): """Compute self-attention with relative positional embedding. Args: query (Tensor): [B, S, D]. key (Tensor): [B, S, D]. value (Tensor): [B, S, D]. pos_emb (Tensor): [1/B, 2S-1, D]. Positional embedding. mask (BoolTensor): [B, S], True for masked. causal (bool): True for applying causal mask. Returns: output (Tensor): [B, S, D]. """ # Splitting Q, K, V: q, k, v = self.forward_qkv(query, key, value) # [B, H, S, C], [B, H, S, C], [B, H, S, C] # Adding per head & channel biases to the query vectors: q_u = q + self.u.unsqueeze(1) q_v = q + self.v.unsqueeze(1) # [B, H, S, C] # Splitting relative positional coding: n_batch_pos = pos_emb.size(0) # [1/B, 2S-1, D] p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.c) # [1/B, 2S-1, H, C] p = p.transpose(1, 2) # [1/B, H, 2S-1, C]. # Compute query, key similarity: matrix_ac = torch.matmul(q_u, k.transpose(-2, -1)) # [B, H, S, C] x [B, H, C, S] -> [B, H, S, S] matrix_bd = torch.matmul(q_v, p.transpose(-2, -1)) # [B, H, S, C] x [1/B, H, C, 2S-1] -> [B, H, S, 2S-1] matrix_bd = self.rel_shift(matrix_bd) scores = (matrix_ac + matrix_bd) / math.sqrt(self.c) # [B, H, S, S] return self.forward_attention(v, scores, mask, causal) # [B, S, D] class ConditionalBiasScale(nn.Module): def __init__(self, channels: int, cond_channels: int): super().__init__() self.scale_transform = nn.Linear( cond_channels, channels, bias=True ) self.bias_transform = nn.Linear( cond_channels, channels, bias=True ) self.init_parameters() def init_parameters(self): torch.nn.init.constant_(self.scale_transform.weight, 0.0) torch.nn.init.constant_(self.scale_transform.bias, 1.0) torch.nn.init.constant_(self.bias_transform.weight, 0.0) torch.nn.init.constant_(self.bias_transform.bias, 0.0) def forward(self, x: Tensor, cond: Tensor) -> Tensor: """Applying conditional bias and scale. Args: x (Tensor): [..., channels]. cond (Tensor): [..., cond_channels]. Returns: y (Tensor): [..., channels]. """ a = self.scale_transform.forward(cond) b = self.bias_transform.forward(cond) return x * a + b class FeedForwardModule(torch.nn.Module): """Positionwise feed forward layer used in conformer""" def __init__( self, d_in: int, d_hidden: int, dropout: float, bias: bool = True, d_cond: int = 0 ): """ Args: d_in (int): Input feature dimension. d_hidden (int): Hidden unit dimension. dropout (float): dropout value for first Linear Layer. bias (bool): If linear layers should have bias. d_cond (int, optional): The channels of conditional tensor. """ super(FeedForwardModule, self).__init__() self.layer_norm = torch.nn.LayerNorm(d_in) if d_cond > 0: self.cond_layer = ConditionalBiasScale(d_in, d_cond) self.w_1 = torch.nn.Linear(d_in, d_hidden, bias=bias) self.w_2 = torch.nn.Linear(d_hidden, d_in, bias=bias) self.dropout = torch.nn.Dropout(dropout) def forward(self, x: Tensor, cond: Optional[Tensor] = None) -> Tensor: """ Args: x (Tensor): [..., D]. Returns: y (Tensor): [..., D]. cond (Tensor): [..., D_cond] """ x = self.layer_norm(x) if cond is not None: x = self.cond_layer.forward(x, cond) x = self.w_1(x) x = silu(x) x = self.dropout(x) x = self.w_2(x) return self.dropout(x) class RelPositionalEncoding(nn.Module): """Relative positional encoding cache. Args: d_model: Embedding dimension. dropout_rate: Dropout rate. max_len: Default maximum input length. """ def __init__(self, max_len: int, d_model: int): super().__init__() self.d_model = d_model self.cached_code = None self.l = 0 self.gen_code(torch.tensor(0.0).expand(1, max_len)) def gen_code(self, x: Tensor): """Generate positional encoding with a reference tensor x. Args: x (Tensor): [B, L, ...], we extract the device, length, and dtype from it. Effects: self.cached_code (Tensor): [1, >=(2L-1), D]. """ l = x.size(1) if self.l >= l: if self.cached_code.dtype != x.dtype or self.cached_code.device != x.device: self.cached_code = self.cached_code.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the # position of key vector. We use position relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i Tensor: """Get positional encoding of appropriate shape given a reference Tensor. Args: x (Tensor): [B, L, ...]. Returns: y (Tensor): [1, 2L-1, D]. """ self.gen_code(x) l = x.size(1) pos_emb = self.cached_code[ :, self.l - l: self.l + l - 1, ] return pos_emb class ConformerBlock(torch.nn.Module): """Conformer block based on https://arxiv.org/abs/2005.08100.""" def __init__( self, d: int, d_hidden: int, attention_heads: int, dropout: float, depthwise_conv_kernel_size: int = 7, causal: bool = False, d_cond: int = 0 ): """ Args: d (int): Block input output channel number. d_hidden (int): FFN layer dimension. attention_heads (int): Number of attention heads. dropout (float): dropout value. depthwise_conv_kernel_size (int): Size of kernel in depthwise conv. d_cond (int, optional): The channels of conditional tensor. """ super(ConformerBlock, self).__init__() self.causal = causal self.ffn1 = FeedForwardModule( d, d_hidden, dropout, bias=True, d_cond=d_cond ) self.self_attn_layer_norm = torch.nn.LayerNorm(d) if d_cond > 0: self.cond_layer = ConditionalBiasScale(d, d_cond) self.self_attn = RelPositionMultiHeadedAttention( d, attention_heads, dropout=dropout ) self.self_attn_dropout = torch.nn.Dropout(dropout) self.conv_module = ConvolutionModule( d_in=d, d_hidden=d, depthwise_kernel_size=depthwise_conv_kernel_size, dropout=dropout, d_cond=d_cond ) self.ffn2 = FeedForwardModule( d, d_hidden, dropout, bias=True, d_cond=d_cond ) self.final_layer_norm = torch.nn.LayerNorm(d) def forward( self, x: Tensor, mask: BoolTensor, pos_emb: Tensor, cond: Optional[Tensor] = None ) -> Tensor: """ Args: x (Tensor): [B, T, D_in]. mask (BoolTensor): [B, T], True for masked. pos_emb (Tensor): [1 or B, 2T-1, D]. cond (Tensor, optional): [B, ?, D_cond]. Returns: y (Tensor): [B, T, D_in]. """ y = x x = self.ffn1(x) * 0.5 + y y = x # [B, T, D_in] x = self.self_attn_layer_norm(x) if cond is not None: x = self.cond_layer.forward(x, cond) x = self.self_attn.forward( query=x, key=x, value=x, pos_emb=pos_emb, mask=mask, causal=self.causal ) x = self.self_attn_dropout(x) + y y = x # [B, T, D_in] x = self.conv_module.forward(x, mask) + y y = x # [B, T, D_in] x = self.ffn2(x) * 0.5 + y x = self.final_layer_norm(x) x.masked_fill(mask.unsqueeze(-1), 0.0) return x class ConvolutionModule(torch.nn.Module): """Convolution Block inside a Conformer Block.""" def __init__( self, d_in: int, d_hidden: int, depthwise_kernel_size: int, dropout: float, bias: bool = False, causal: bool = False, d_cond: int = 0 ): """ Args: d_in (int): Embedding dimension. d_hidden (int): Number of channels in depthwise conv layers. depthwise_kernel_size (int): Depthwise conv layer kernel size. dropout (float): dropout value. bias (bool): If bias should be added to conv layers. conditional (bool): Whether to use conditional LayerNorm. """ super(ConvolutionModule, self).__init__() assert (depthwise_kernel_size - 1) % 2 == 0, "kernel_size should be odd" self.causal = causal self.causal_padding = (depthwise_kernel_size - 1, 0) self.layer_norm = torch.nn.LayerNorm(d_in) # Optional conditional LayerNorm: self.d_cond = d_cond if d_cond > 0: self.cond_layer = ConditionalBiasScale(d_in, d_cond) self.pointwise_conv1 = torch.nn.Conv1d( d_in, 2 * d_hidden, kernel_size=1, stride=1, padding=0, bias=bias ) self.glu = torch.nn.GLU(dim=1) self.depthwise_conv = torch.nn.Conv1d( d_hidden, d_hidden, kernel_size=depthwise_kernel_size, stride=1, padding=(depthwise_kernel_size - 1) // 2 if not causal else 0, groups=d_hidden, bias=bias ) self.pointwise_conv2 = torch.nn.Conv1d( d_hidden, d_in, kernel_size=1, stride=1, padding=0, bias=bias, ) self.dropout = torch.nn.Dropout(dropout) def forward(self, x: Tensor, mask: BoolTensor, cond: Optional[Tensor] = None) -> Tensor: """ Args: x (Tensor): [B, T, D_in]. mask (BoolTensor): [B, T], True for masked. cond (Tensor): [B, T, D_cond]. Returns: y (Tensor): [B, T, D_in]. """ x = self.layer_norm(x) if cond is not None: x = self.cond_layer.forward(x, cond) x = x.transpose(-1, -2) # [B, D_in, T] x = self.pointwise_conv1(x) # [B, 2C, T] x = self.glu(x) # [B, C, T] # Take care of masking the input tensor: if mask is not None: x = x.masked_fill(mask.unsqueeze(1), 0.0) # 1D Depthwise Conv if self.causal: # Causal padding x = pad(x, self.causal_padding) x = self.depthwise_conv(x) # FIXME: BatchNorm should not be used in variable length training. x = silu(x) # [B, C, T] if mask is not None: x = x.masked_fill(mask.unsqueeze(1), 0.0) x = self.pointwise_conv2(x) x = self.dropout(x) return x.transpose(-1, -2) # [B, T, D_in] class Conformer(torch.nn.Module): def __init__( self, d: int, d_hidden: int, n_heads: int, n_layers: int, dropout: float, depthwise_conv_kernel_size: int, causal: bool = False, d_cond: int = 0 ): super().__init__() self.pos_encoding = RelPositionalEncoding(1024, d) self.causal = causal self.blocks = torch.nn.ModuleList( [ ConformerBlock( d=d, d_hidden=d_hidden, attention_heads=n_heads, dropout=dropout, depthwise_conv_kernel_size=depthwise_conv_kernel_size, causal=causal, d_cond=d_cond ) for _ in range(n_layers) ] ) # type: Iterable[ConformerBlock] def forward( self, x: Tensor, mask: BoolTensor, cond: Tensor = None ) -> Tensor: """Conformer forwarding. Args: x (Tensor): [B, T, D]. mask (BoolTensor): [B, T], with True for masked. cond (Tensor, optional): [B, T, D_cond]. Returns: y (Tensor): [B, T, D] """ pos_emb = self.pos_encoding(x) # [1, 2T-1, D] for block in self.blocks: x = block.forward(x, mask, pos_emb, cond) return x class CNNBlock(nn.Module): def __init__(self, in_dim, out_dim, dropout, cond_dim, kernel_size, stride): super(CNNBlock, self).__init__() self.layers = nn.Sequential( nn.Conv1d(in_dim, out_dim, kernel_size, stride), nn.ReLU(), nn.BatchNorm1d(out_dim,), nn.Dropout(p=dropout) ) def forward(self, inp): out = self.layers(inp) return out class CNNClassifier(nn.Module): def __init__(self, in_dim, d_decoder, decoder_dropout, cond_dim): super(CNNClassifier, self).__init__() self.cnn = nn.Sequential( CNNBlock(in_dim, d_decoder, decoder_dropout, cond_dim, 8, 4), CNNBlock(d_decoder, d_decoder, decoder_dropout, cond_dim, 8, 4), CNNBlock(d_decoder, d_decoder, decoder_dropout, cond_dim, 4, 2), CNNBlock(d_decoder, d_decoder, decoder_dropout, cond_dim, 4, 2), ) # receptive field is 180, frame shift is 64 self.cond_layer = nn.Sequential( nn.Linear(cond_dim, in_dim), nn.LeakyReLU(), nn.Linear(in_dim, in_dim) ) def forward(self, inp, mask, cond): inp = inp.transpose(-1, -2) cond = cond.transpose(-1, -2) inp.masked_fill_(mask.unsqueeze(1), 0.0) cond = self.cond_layer(cond.transpose(-1, -2)).transpose(-1, -2) cond.masked_fill_(mask.unsqueeze(1), 0.0) inp = inp + cond return self.cnn(inp) class CNNClassifierWithTime(nn.Module): def __init__(self, in_dim, d_decoder, decoder_dropout, cond_dim, time_emb_dim=512): super(CNNClassifierWithTime, self).__init__() self.cnn = nn.Sequential( CNNBlock(in_dim, d_decoder, decoder_dropout, cond_dim, 8, 4), CNNBlock(d_decoder, d_decoder, decoder_dropout, cond_dim, 8, 4), CNNBlock(d_decoder, d_decoder, decoder_dropout, cond_dim, 4, 2), CNNBlock(d_decoder, d_decoder, decoder_dropout, cond_dim, 4, 2), ) # receptive field is 180, frame shift is 64 self.cond_layer = nn.Sequential( nn.Linear(cond_dim, in_dim), nn.LeakyReLU(), nn.Linear(in_dim, in_dim) ) self.time_emb = SinusoidalPosEmb(time_emb_dim) self.time_layer = nn.Sequential( nn.Linear(time_emb_dim, in_dim), nn.LeakyReLU(), nn.Linear(in_dim, in_dim) ) def forward(self, inp, mask, cond, t): time_emb = self.time_emb(t) # [B, T] time_emb = self.time_layer(time_emb.unsqueeze(1)).transpose(-1, -2) inp = inp.transpose(-1, -2) cond = cond.transpose(-1, -2) inp.masked_fill_(mask.unsqueeze(1), 0.0) cond = self.cond_layer(cond.transpose(-1, -2)).transpose(-1, -2) cond.masked_fill_(mask.unsqueeze(1), 0.0) inp = inp + cond + time_emb return self.cnn(inp) class SpecClassifier(nn.Module): def __init__(self, in_dim, d_decoder, h_decoder, l_decoder, decoder_dropout, k_decoder, n_class, cond_dim, model_type='conformer'): super(SpecClassifier, self).__init__() self.model_type = model_type self.prenet = nn.Sequential( nn.Linear(in_features=in_dim, out_features=d_decoder) ) if model_type == 'conformer': self.conformer = Conformer(d=d_decoder, d_hidden=d_decoder, n_heads=h_decoder, n_layers=l_decoder, dropout=decoder_dropout, depthwise_conv_kernel_size=k_decoder, d_cond=cond_dim) elif model_type == 'CNN': self.conformer = CNNClassifier(in_dim=d_decoder, d_decoder=d_decoder, decoder_dropout=decoder_dropout, cond_dim=cond_dim) elif model_type == 'CNN-with-time': self.conformer = CNNClassifierWithTime(in_dim=d_decoder, d_decoder=d_decoder, decoder_dropout=decoder_dropout, cond_dim=cond_dim, time_emb_dim=256) self.classifier = nn.Linear(d_decoder, n_class) def forward(self, noisy_mel, condition, mask, **kwargs): """ Args: noisy_mel: [B, T, D] condition: [B, T, D] mask: [B, T] with True for un-masked (real-values) Returns: classification logits (un-softmaxed) """ # print(noisy_mel.shape) noisy_mel = noisy_mel.masked_fill(~mask.unsqueeze(-1), 0.0) # print(self.prenet, noisy_mel.shape) hiddens = self.prenet(noisy_mel) if self.model_type == 'CNN-with-time': hiddens = self.conformer.forward(hiddens, ~mask, condition, kwargs['t']) else: hiddens = self.conformer.forward(hiddens, ~mask, condition) # [B, T, D] if self.model_type == 'conformer': averaged_hiddens = torch.mean(hiddens, dim=1) # [B, D] logits = self.classifier(averaged_hiddens) return logits elif self.model_type == 'CNN' or self.model_type == 'CNN-with-time': hiddens = hiddens.transpose(-1, -2) return self.classifier(hiddens) # [B, T', C] @property def nparams(self): return sum([p.numel() for p in self.parameters()])