|
""" |
|
Common pooling methods |
|
|
|
Authors: |
|
* Leo 2022 |
|
* Haibin Wu 2022 |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import AutoModel |
|
from transformers.modeling_utils import PreTrainedModel ,PretrainedConfig |
|
|
|
|
|
__all__ = [ |
|
"MeanPooling", |
|
"AttentiveStatisticsPooling" |
|
] |
|
|
|
|
|
class Pooling(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
def compute_length_from_mask(self, mask): |
|
""" |
|
mask: (batch_size, T) |
|
Assuming that the sampling rate is 16kHz, the frame shift is 20ms |
|
""" |
|
wav_lens = torch.sum(mask, dim=1) |
|
feat_lens = torch.div(wav_lens-1, 16000*0.02, rounding_mode="floor") + 1 |
|
feat_lens = feat_lens.int().tolist() |
|
return feat_lens |
|
|
|
def forward(self, x, mask): |
|
raise NotImplementedError |
|
|
|
class MeanPooling(Pooling): |
|
def __init__(self): |
|
super().__init__() |
|
def forward(self, xs, mask): |
|
""" |
|
xs: (batch_size, T, feat_dim) |
|
mask: (batch_size, T) |
|
|
|
=> output: (batch_size, feat_dim) |
|
""" |
|
feat_lens = self.compute_length_from_mask(mask) |
|
pooled_list = [] |
|
for x, feat_len in zip(xs, feat_lens): |
|
pooled = torch.mean(x[:feat_len], dim=0) |
|
pooled_list.append(pooled) |
|
pooled = torch.stack(pooled_list, dim=0) |
|
return pooled |
|
|
|
|
|
class AttentiveStatisticsPooling(Pooling): |
|
""" |
|
AttentiveStatisticsPooling |
|
Paper: Attentive Statistics Pooling for Deep Speaker Embedding |
|
Link: https://arxiv.org/pdf/1803.10963.pdf |
|
""" |
|
def __init__(self, input_size): |
|
super().__init__() |
|
self._indim = input_size |
|
self.sap_linear = nn.Linear(input_size, input_size) |
|
self.attention = nn.Parameter(torch.FloatTensor(input_size, 1)) |
|
torch.nn.init.normal_(self.attention, mean=0, std=1) |
|
|
|
def forward(self, xs, mask): |
|
""" |
|
xs: (batch_size, T, feat_dim) |
|
mask: (batch_size, T) |
|
|
|
=> output: (batch_size, feat_dim*2) |
|
""" |
|
feat_lens = self.compute_length_from_mask(mask) |
|
pooled_list = [] |
|
for x, feat_len in zip(xs, feat_lens): |
|
x = x[:feat_len].unsqueeze(0) |
|
h = torch.tanh(self.sap_linear(x)) |
|
w = torch.matmul(h, self.attention).squeeze(dim=2) |
|
w = F.softmax(w, dim=1).view(x.size(0), x.size(1), 1) |
|
mu = torch.sum(x * w, dim=1) |
|
rh = torch.sqrt((torch.sum((x**2) * w, dim=1) - mu**2).clamp(min=1e-5)) |
|
x = torch.cat((mu, rh), 1).squeeze(0) |
|
pooled_list.append(x) |
|
return torch.stack(pooled_list) |
|
|
|
|
|
|
|
|
|
class EmotionRegression(nn.Module): |
|
def __init__(self, *args, **kwargs): |
|
super(EmotionRegression, self).__init__() |
|
input_dim = args[0] |
|
hidden_dim = args[1] |
|
num_layers = args[2] |
|
output_dim = args[3] |
|
p = kwargs.get("dropout", 0.5) |
|
|
|
self.fc=nn.ModuleList([ |
|
nn.Sequential( |
|
nn.Linear(input_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(p) |
|
) |
|
]) |
|
for lidx in range(num_layers-1): |
|
self.fc.append( |
|
nn.Sequential( |
|
nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(p) |
|
) |
|
) |
|
self.out = nn.Sequential( |
|
nn.Linear(hidden_dim, output_dim) |
|
) |
|
|
|
self.inp_drop = nn.Dropout(p) |
|
def get_repr(self, x): |
|
h = self.inp_drop(x) |
|
for lidx, fc in enumerate(self.fc): |
|
h=fc(h) |
|
return h |
|
|
|
def forward(self, x): |
|
h=self.get_repr(x) |
|
result = self.out(h) |
|
return result |
|
|
|
class SERConfig(PretrainedConfig): |
|
model_type = "ser" |
|
|
|
def __init__( |
|
self, |
|
num_classes: int = 3, |
|
num_attention_heads = 16, |
|
num_hidden_layers = 24, |
|
hidden_size = 1024, |
|
classifier_hidden_layers = 1, |
|
classifier_dropout_prob = 0.5, |
|
ssl_type= "microsoft/wavlm-large", |
|
torch_dtype= "float32", |
|
**kwargs, |
|
): |
|
self.num_classes = num_classes |
|
self.num_attention_heads = num_attention_heads |
|
self.num_hidden_layers = num_hidden_layers |
|
self.hidden_size = hidden_size |
|
self.classifier_hidden_layers = classifier_hidden_layers |
|
self.classifier_dropout_prob = classifier_dropout_prob |
|
self.ssl_type = ssl_type |
|
self.torch_dtype = torch_dtype |
|
super().__init__(**kwargs) |
|
|
|
class SERModel(PreTrainedModel): |
|
config_class = SERConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.ssl_model = AutoModel.from_pretrained(config.ssl_type) |
|
self.ssl_model.freeze_feature_encoder() |
|
|
|
self.pool_model = AttentiveStatisticsPooling(config.hidden_size) |
|
|
|
self.ser_model = EmotionRegression(config.hidden_size*2, |
|
config.hidden_size, |
|
config.classifier_hidden_layers, |
|
config.num_classes, |
|
dropout=config.classifier_dropout_prob) |
|
|
|
|
|
def forward(self, x, mask): |
|
ssl = self.ssl_model(x, attention_mask=mask).last_hidden_state |
|
|
|
ssl = self.pool_model(ssl, mask) |
|
|
|
pred = self.ser_model(ssl) |
|
|
|
return pred |
|
|
|
|