wav2small / README.md
dkounadis's picture
fx
9e7614c verified
|
raw
history blame
3.62 kB
---
license: cc-by-nc-sa-4.0
language:
- en
pipeline_tag: audio-classification
tags:
- wavlm
- wav2vec2
- msp-podcast
- emotion-recognition
- speech
- valence
- arousal
- dominance
- speech-emotion-recognition
- dkounadis
---
# Arousal - Dominance - Valence
Dimensional Speech Emotion Recognition model of simultaneous use of [wavlm](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes) / [wav2vec2.0](https://github.com/audeering/w2v2-how-to).
Achieves `0.6760566` valence CCC on [MSP Podcast Test 1](https://paperswithcode.com/sota/speech-emotion-recognition-on-msp-podcast). Used as teacher for [wav2small ..]().
**[PapersWithCode](https://paperswithcode.com/dataset/msp-podcast) / [arXiv](https://arxiv.org/abs/2408.13920)**
```
Wav2Small: Distilling Wav2Vec2 to 72K parameters for low-resource
speech emotion recognition.
D. Kounadis-Bastian, O. Schrüfer, A. Derington, H. Wierstorf,
F. Eyben, F. Burkhardt, B.W. Schuller. 2024, arXiV Preprint
```
<table style="width:500px">
<tr><th colspan=6 align="center" >CCC MSP Podcast v1.7</th></tr>
<tr><th colspan=3 align="center">Test 1</th><th colspan=3 align="center">Test 2</th></tr>
<tr> <td>Val</td> <td>Dom</td> <td>Aro</td> <td>Val</td> <td>Dom</td> <td>Aro</td> </tr>
<tr> <td> 0.6760566 </td> <td>0.6840044</td> <td>0.7620181</td> <td>0.4229267</td> <td>0.4684658</td> <td>0.4857733</td> </tr>
</table>
# HowTo
```python
import librosa
import torch
import types
import torch.nn as nn
from transformers import AutoModelForAudioClassification
from transformers.models.wav2vec2.modeling_wav2vec2 import (Wav2Vec2Model,
Wav2Vec2PreTrainedModel)
signal = torch.from_numpy(
librosa.load('test.wav', sr=16000)[0])[None, :]
device = 'cpu'
class ADV(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, x):
x = self.dense(x)
x = torch.tanh(x)
return self.out_proj(x)
class Dawn(Wav2Vec2PreTrainedModel):
r"""https://arxiv.org/abs/2203.07378"""
def __init__(self, config):
super().__init__(config)
self.wav2vec2 = Wav2Vec2Model(config)
self.classifier = ADV(config)
def forward(self, x):
x -= x.mean(1, keepdim=True)
variance = (x * x).mean(1, keepdim=True) + 1e-7
x = self.wav2vec2(x / variance.sqrt())
return self.classifier(x.last_hidden_state.mean(1))
def _forward(self, x):
'''x: (batch, audio-samples-16KHz)'''
x = (x + self.config.mean) / self.config.std # sgn
x = self.ssl_model(x, attention_mask=None).last_hidden_state
# pool
h = self.pool_model.sap_linear(x).tanh()
w = torch.matmul(h, self.pool_model.attention).softmax(1)
mu = (x * w).sum(1)
x = torch.cat(
[
mu,
((x * x * w).sum(1) - mu * mu).clamp(min=1e-7).sqrt()
], 1)
return self.ser_model(x)
# WavLM
base = AutoModelForAudioClassification.from_pretrained(
'3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes',
trust_remote_code=True).to(device).eval()
base.forward = types.MethodType(_forward, base)
# Wav2Vec2
dawn = Dawn.from_pretrained(
'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
).to(device).eval()
def wav2small(x):
return .5 * dawn(x) + .5 * base(x)
pred = wav2small(signal.to(device))
print(f'Arousal={pred[0, 0]} '
f'Dominance={pred[0, 1]} ',
f'Valence={pred[0, 2]}')
```