dkounadis commited on
Commit
8d889ff
1 Parent(s): 94f36bf

Set dawn to eval, verfy CCC

Browse files
Files changed (1) hide show
  1. README.md +21 -34
README.md CHANGED
@@ -16,8 +16,8 @@ tags:
16
  - speech-emotion-recognition
17
  - dkounadis
18
  ---
19
- Tecaher model based on [Wavlm](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes) and [wav2vec2](https://hf.rst.im/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim) for arousal/dominance/valence prediction.
20
- Achieves `0.6760566` valence CCC on [MSP-Podcast](https://ecs.utdallas.edu/research/researchlabs/msp-lab/MSP-Podcast.html) Test1.
21
 
22
 
23
 
@@ -26,7 +26,7 @@ Achieves `0.6760566` valence CCC on [MSP-Podcast](https://ecs.utdallas.edu/resea
26
  <tr><th colspan=6 align="center" >CCC MSP Podcast v1.7</th></tr>
27
  <tr><th colspan=3 align="center">Test 1</th><th colspan=3 align="center">Test 2</th></tr>
28
  <tr> <td>Val</td> <td>Dom</td> <td>Aro</td> <td>Val</td> <td>Dom</td> <td>Aro</td> </tr>
29
- <tr> <td> 0.6760566 </td> <td>0.6840190</td> <td>0.7620374</td> <td>0.4229267</td> <td>0.4684658</td> <td>0.4857733</td> </tr>
30
  </table>
31
 
32
 
@@ -43,7 +43,7 @@ from transformers.models.wav2vec2.modeling_wav2vec2 import (
43
  Wav2Vec2PreTrainedModel,
44
  )
45
 
46
- device = 'cuda:0'
47
 
48
 
49
  class RegressionHead(nn.Module):
@@ -74,67 +74,54 @@ class Dawn(Wav2Vec2PreTrainedModel):
74
  self.wav2vec2 = Wav2Vec2Model(config)
75
  self.classifier = RegressionHead(config)
76
 
77
- def forward(
78
- self,
79
- x,
80
- ):
81
-
82
  x = x - x.mean(1, keepdim=True)
83
  variance = (x * x).mean(1, keepdim=True) + 1e-7
84
  out = self.wav2vec2(x / variance.sqrt())
85
-
86
  return self.classifier(out[0].mean(1)).clip(0, 1)
87
 
88
 
89
  def _infer(self, x):
90
- '''re-definition for less cpu'''
91
- # x = (x + 8.278621631819787e-05) / 0.08485610250851999
92
- x = (x + self.config.mean) / self.config.std
93
  x = self.ssl_model(x, attention_mask=None).last_hidden_state
94
  # pool
95
  h = self.pool_model.sap_linear(x).tanh()
96
  w = torch.matmul(h, self.pool_model.attention)
97
  w = w.softmax(1)
98
- mu = torch.sum(x * w, 1)
99
  x = torch.cat(
100
- [
101
- mu,
102
- ((x * x * w).sum(1) - mu * mu).clamp(min=1e-5).sqrt()
103
- ], 1)
104
- return self.ser_model(x).clip(0, 1)
 
105
 
106
  # WavLM
107
 
108
- # https://lab-msp.com/MSP-Podcast_Competition/leaderboard.php
109
  base = AutoModelForAudioClassification.from_pretrained(
110
  '3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes',
111
- trust_remote_code=True # extra definitions see above repository
112
  ).to(device).eval()
113
  base.forward = types.MethodType(_infer, base)
114
 
115
- # Wav2Vec2.0
116
 
117
  dawn = Dawn.from_pretrained(
118
  'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
119
- ).to(device)
120
 
121
- # Teacher
122
 
123
  def wav2small(x):
124
- '''average predctions'''
125
  return .5 * dawn(x) + .5 * base(x)
126
 
127
 
128
  x, _ = librosa.load('test.wav', sr=base.config.sampling_rate)
129
 
130
  with torch.no_grad():
131
- pred = wav2small(
132
- torch.from_numpy(x[None, :]
133
- ).to(device))
134
-
135
-
136
- print(f'\narousal = {pred[0, 0]}',
137
- f'\ndominance= {pred[0, 1]}',
138
- f'\nvalence = {pred[0, 2]}')
139
-
140
  ```
 
16
  - speech-emotion-recognition
17
  - dkounadis
18
  ---
19
+ Model based on [Wavlm](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes) and [wav2vec2](https://hf.rst.im/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim) for arousal/dominance/valence prediction.
20
+ Achieves `0.6760566` valence CCC on [MSP-Podcast](https://ecs.utdallas.edu/research/researchlabs/msp-lab/MSP-Podcast.html) Test1. Used as teacher for [wav2small]().
21
 
22
 
23
 
 
26
  <tr><th colspan=6 align="center" >CCC MSP Podcast v1.7</th></tr>
27
  <tr><th colspan=3 align="center">Test 1</th><th colspan=3 align="center">Test 2</th></tr>
28
  <tr> <td>Val</td> <td>Dom</td> <td>Aro</td> <td>Val</td> <td>Dom</td> <td>Aro</td> </tr>
29
+ <tr> <td> 0.6760566 </td> <td>0.6840044</td> <td>0.7620181</td> <td>0.4229267</td> <td>0.4684658</td> <td>0.4857733</td> </tr>
30
  </table>
31
 
32
 
 
43
  Wav2Vec2PreTrainedModel,
44
  )
45
 
46
+ device = 'cpu'
47
 
48
 
49
  class RegressionHead(nn.Module):
 
74
  self.wav2vec2 = Wav2Vec2Model(config)
75
  self.classifier = RegressionHead(config)
76
 
77
+ def forward(self, x):
78
+ '''x: (batch, audio-samples-16KHz)'''
 
 
 
79
  x = x - x.mean(1, keepdim=True)
80
  variance = (x * x).mean(1, keepdim=True) + 1e-7
81
  out = self.wav2vec2(x / variance.sqrt())
 
82
  return self.classifier(out[0].mean(1)).clip(0, 1)
83
 
84
 
85
  def _infer(self, x):
86
+ '''x: (batch, audio-samples-16KHz)'''
87
+ x = (x + self.config.mean) / self.config.std # plus
 
88
  x = self.ssl_model(x, attention_mask=None).last_hidden_state
89
  # pool
90
  h = self.pool_model.sap_linear(x).tanh()
91
  w = torch.matmul(h, self.pool_model.attention)
92
  w = w.softmax(1)
93
+ mu = (x * w).sum(1)
94
  x = torch.cat(
95
+ [
96
+ mu,
97
+ ((x * x * w).sum(1) - mu * mu).clamp(min=1e-7).sqrt()
98
+ ], 1)
99
+ return self.ser_model(x)
100
+
101
 
102
  # WavLM
103
 
 
104
  base = AutoModelForAudioClassification.from_pretrained(
105
  '3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes',
106
+ trust_remote_code=True # fun definitions see 3loi/SER-.. repo
107
  ).to(device).eval()
108
  base.forward = types.MethodType(_infer, base)
109
 
110
+ # Wav2Vec2
111
 
112
  dawn = Dawn.from_pretrained(
113
  'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
114
+ ).to(device).eval()
115
 
 
116
 
117
  def wav2small(x):
 
118
  return .5 * dawn(x) + .5 * base(x)
119
 
120
 
121
  x, _ = librosa.load('test.wav', sr=base.config.sampling_rate)
122
 
123
  with torch.no_grad():
124
+ pred = wav2small(torch.from_numpy(x[None, :]).to(device))
125
+ print(f'\nArousal = {pred[0, 0]} Dominance= {pred[0, 1]}',
126
+ f' Valence = {pred[0, 2]}')
 
 
 
 
 
 
127
  ```