Set dawn to eval, verfy CCC
Browse files
README.md
CHANGED
@@ -16,8 +16,8 @@ tags:
|
|
16 |
- speech-emotion-recognition
|
17 |
- dkounadis
|
18 |
---
|
19 |
-
|
20 |
-
Achieves `0.6760566` valence CCC on [MSP-Podcast](https://ecs.utdallas.edu/research/researchlabs/msp-lab/MSP-Podcast.html) Test1.
|
21 |
|
22 |
|
23 |
|
@@ -26,7 +26,7 @@ Achieves `0.6760566` valence CCC on [MSP-Podcast](https://ecs.utdallas.edu/resea
|
|
26 |
<tr><th colspan=6 align="center" >CCC MSP Podcast v1.7</th></tr>
|
27 |
<tr><th colspan=3 align="center">Test 1</th><th colspan=3 align="center">Test 2</th></tr>
|
28 |
<tr> <td>Val</td> <td>Dom</td> <td>Aro</td> <td>Val</td> <td>Dom</td> <td>Aro</td> </tr>
|
29 |
-
<tr> <td> 0.6760566 </td> <td>0.
|
30 |
</table>
|
31 |
|
32 |
|
@@ -43,7 +43,7 @@ from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
|
43 |
Wav2Vec2PreTrainedModel,
|
44 |
)
|
45 |
|
46 |
-
device = '
|
47 |
|
48 |
|
49 |
class RegressionHead(nn.Module):
|
@@ -74,67 +74,54 @@ class Dawn(Wav2Vec2PreTrainedModel):
|
|
74 |
self.wav2vec2 = Wav2Vec2Model(config)
|
75 |
self.classifier = RegressionHead(config)
|
76 |
|
77 |
-
def forward(
|
78 |
-
|
79 |
-
x,
|
80 |
-
):
|
81 |
-
|
82 |
x = x - x.mean(1, keepdim=True)
|
83 |
variance = (x * x).mean(1, keepdim=True) + 1e-7
|
84 |
out = self.wav2vec2(x / variance.sqrt())
|
85 |
-
|
86 |
return self.classifier(out[0].mean(1)).clip(0, 1)
|
87 |
|
88 |
|
89 |
def _infer(self, x):
|
90 |
-
'''
|
91 |
-
|
92 |
-
x = (x + self.config.mean) / self.config.std
|
93 |
x = self.ssl_model(x, attention_mask=None).last_hidden_state
|
94 |
# pool
|
95 |
h = self.pool_model.sap_linear(x).tanh()
|
96 |
w = torch.matmul(h, self.pool_model.attention)
|
97 |
w = w.softmax(1)
|
98 |
-
mu =
|
99 |
x = torch.cat(
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
return self.ser_model(x)
|
|
|
105 |
|
106 |
# WavLM
|
107 |
|
108 |
-
# https://lab-msp.com/MSP-Podcast_Competition/leaderboard.php
|
109 |
base = AutoModelForAudioClassification.from_pretrained(
|
110 |
'3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes',
|
111 |
-
trust_remote_code=True #
|
112 |
).to(device).eval()
|
113 |
base.forward = types.MethodType(_infer, base)
|
114 |
|
115 |
-
# Wav2Vec2
|
116 |
|
117 |
dawn = Dawn.from_pretrained(
|
118 |
'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
|
119 |
-
).to(device)
|
120 |
|
121 |
-
# Teacher
|
122 |
|
123 |
def wav2small(x):
|
124 |
-
'''average predctions'''
|
125 |
return .5 * dawn(x) + .5 * base(x)
|
126 |
|
127 |
|
128 |
x, _ = librosa.load('test.wav', sr=base.config.sampling_rate)
|
129 |
|
130 |
with torch.no_grad():
|
131 |
-
pred = wav2small(
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
print(f'\narousal = {pred[0, 0]}',
|
137 |
-
f'\ndominance= {pred[0, 1]}',
|
138 |
-
f'\nvalence = {pred[0, 2]}')
|
139 |
-
|
140 |
```
|
|
|
16 |
- speech-emotion-recognition
|
17 |
- dkounadis
|
18 |
---
|
19 |
+
Model based on [Wavlm](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes) and [wav2vec2](https://hf.rst.im/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim) for arousal/dominance/valence prediction.
|
20 |
+
Achieves `0.6760566` valence CCC on [MSP-Podcast](https://ecs.utdallas.edu/research/researchlabs/msp-lab/MSP-Podcast.html) Test1. Used as teacher for [wav2small]().
|
21 |
|
22 |
|
23 |
|
|
|
26 |
<tr><th colspan=6 align="center" >CCC MSP Podcast v1.7</th></tr>
|
27 |
<tr><th colspan=3 align="center">Test 1</th><th colspan=3 align="center">Test 2</th></tr>
|
28 |
<tr> <td>Val</td> <td>Dom</td> <td>Aro</td> <td>Val</td> <td>Dom</td> <td>Aro</td> </tr>
|
29 |
+
<tr> <td> 0.6760566 </td> <td>0.6840044</td> <td>0.7620181</td> <td>0.4229267</td> <td>0.4684658</td> <td>0.4857733</td> </tr>
|
30 |
</table>
|
31 |
|
32 |
|
|
|
43 |
Wav2Vec2PreTrainedModel,
|
44 |
)
|
45 |
|
46 |
+
device = 'cpu'
|
47 |
|
48 |
|
49 |
class RegressionHead(nn.Module):
|
|
|
74 |
self.wav2vec2 = Wav2Vec2Model(config)
|
75 |
self.classifier = RegressionHead(config)
|
76 |
|
77 |
+
def forward(self, x):
|
78 |
+
'''x: (batch, audio-samples-16KHz)'''
|
|
|
|
|
|
|
79 |
x = x - x.mean(1, keepdim=True)
|
80 |
variance = (x * x).mean(1, keepdim=True) + 1e-7
|
81 |
out = self.wav2vec2(x / variance.sqrt())
|
|
|
82 |
return self.classifier(out[0].mean(1)).clip(0, 1)
|
83 |
|
84 |
|
85 |
def _infer(self, x):
|
86 |
+
'''x: (batch, audio-samples-16KHz)'''
|
87 |
+
x = (x + self.config.mean) / self.config.std # plus
|
|
|
88 |
x = self.ssl_model(x, attention_mask=None).last_hidden_state
|
89 |
# pool
|
90 |
h = self.pool_model.sap_linear(x).tanh()
|
91 |
w = torch.matmul(h, self.pool_model.attention)
|
92 |
w = w.softmax(1)
|
93 |
+
mu = (x * w).sum(1)
|
94 |
x = torch.cat(
|
95 |
+
[
|
96 |
+
mu,
|
97 |
+
((x * x * w).sum(1) - mu * mu).clamp(min=1e-7).sqrt()
|
98 |
+
], 1)
|
99 |
+
return self.ser_model(x)
|
100 |
+
|
101 |
|
102 |
# WavLM
|
103 |
|
|
|
104 |
base = AutoModelForAudioClassification.from_pretrained(
|
105 |
'3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes',
|
106 |
+
trust_remote_code=True # fun definitions see 3loi/SER-.. repo
|
107 |
).to(device).eval()
|
108 |
base.forward = types.MethodType(_infer, base)
|
109 |
|
110 |
+
# Wav2Vec2
|
111 |
|
112 |
dawn = Dawn.from_pretrained(
|
113 |
'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
|
114 |
+
).to(device).eval()
|
115 |
|
|
|
116 |
|
117 |
def wav2small(x):
|
|
|
118 |
return .5 * dawn(x) + .5 * base(x)
|
119 |
|
120 |
|
121 |
x, _ = librosa.load('test.wav', sr=base.config.sampling_rate)
|
122 |
|
123 |
with torch.no_grad():
|
124 |
+
pred = wav2small(torch.from_numpy(x[None, :]).to(device))
|
125 |
+
print(f'\nArousal = {pred[0, 0]} Dominance= {pred[0, 1]}',
|
126 |
+
f' Valence = {pred[0, 2]}')
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
```
|