do not apply clip()
Browse files
README.md
CHANGED
@@ -41,22 +41,23 @@ Florian Eyben, Felix Burkhardt, Björn Schuller.
|
|
41 |
|
42 |
|
43 |
|
44 |
-
#
|
45 |
```python
|
46 |
-
|
47 |
-
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
48 |
-
Wav2Vec2Model,
|
49 |
-
Wav2Vec2PreTrainedModel
|
50 |
-
)
|
51 |
import torch
|
52 |
import types
|
53 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
54 |
|
55 |
-
|
|
|
|
|
56 |
device = 'cpu'
|
57 |
|
58 |
-
class
|
59 |
-
r"""A/D/V"""
|
60 |
|
61 |
def __init__(self, config):
|
62 |
|
@@ -81,14 +82,15 @@ class Dawn(Wav2Vec2PreTrainedModel):
|
|
81 |
super().__init__(config)
|
82 |
|
83 |
self.wav2vec2 = Wav2Vec2Model(config)
|
84 |
-
self.classifier =
|
85 |
|
86 |
def forward(self, x):
|
87 |
'''x: (batch, audio-samples-16KHz)'''
|
88 |
-
x
|
89 |
variance = (x * x).mean(1, keepdim=True) + 1e-7
|
90 |
-
|
91 |
-
|
|
|
92 |
|
93 |
|
94 |
def _infer(self, x):
|
@@ -125,9 +127,7 @@ dawn = Dawn.from_pretrained(
|
|
125 |
def wav2small(x):
|
126 |
return .5 * dawn(x) + .5 * base(x)
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
print(f'\nArousal = {pred[0, 0]} Dominance = {pred[0, 1]}',
|
132 |
-
f' Valence = {pred[0, 2]}')
|
133 |
```
|
|
|
41 |
|
42 |
|
43 |
|
44 |
+
# How To
|
45 |
```python
|
46 |
+
import librosa
|
|
|
|
|
|
|
|
|
47 |
import torch
|
48 |
import types
|
49 |
import torch.nn as nn
|
50 |
+
from transformers import AutoModelForAudioClassification
|
51 |
+
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
52 |
+
Wav2Vec2Model,
|
53 |
+
Wav2Vec2PreTrainedModel)
|
54 |
|
55 |
+
|
56 |
+
signal = torch.from_numpy(
|
57 |
+
librosa.load('test.wav', sr=16000)[0])[None, :]
|
58 |
device = 'cpu'
|
59 |
|
60 |
+
class ADV(nn.Module):
|
|
|
61 |
|
62 |
def __init__(self, config):
|
63 |
|
|
|
82 |
super().__init__(config)
|
83 |
|
84 |
self.wav2vec2 = Wav2Vec2Model(config)
|
85 |
+
self.classifier = ADV(config)
|
86 |
|
87 |
def forward(self, x):
|
88 |
'''x: (batch, audio-samples-16KHz)'''
|
89 |
+
x -= x.mean(1, keepdim=True)
|
90 |
variance = (x * x).mean(1, keepdim=True) + 1e-7
|
91 |
+
x = self.wav2vec2(x / variance.sqrt()
|
92 |
+
).last_hidden_state
|
93 |
+
return self.classifier(x.mean(1))
|
94 |
|
95 |
|
96 |
def _infer(self, x):
|
|
|
127 |
def wav2small(x):
|
128 |
return .5 * dawn(x) + .5 * base(x)
|
129 |
|
130 |
+
pred = wav2small(signal.to(device))
|
131 |
+
print(f'\nArousal = {pred[:, 0]} Dominance = {pred[:, 1]}',
|
132 |
+
f' Valence = {pred[:, 2]}')
|
|
|
|
|
133 |
```
|