frankenjoe
commited on
Commit
•
4ca0364
1
Parent(s):
5ad7ffe
Update README.md
Browse files
README.md
CHANGED
@@ -92,7 +92,7 @@ class AgeGenderModel(Wav2Vec2PreTrainedModel):
|
|
92 |
hidden_states = outputs[0]
|
93 |
hidden_states = torch.mean(hidden_states, dim=1)
|
94 |
logits_age = self.age(hidden_states)
|
95 |
-
logits_gender = self.gender(hidden_states)
|
96 |
|
97 |
return hidden_states, logits_age, logits_gender
|
98 |
|
@@ -140,10 +140,10 @@ def process_func(
|
|
140 |
|
141 |
print(process_func(signal, sampling_rate))
|
142 |
# Age child female male
|
143 |
-
# [[ 0.3079211
|
144 |
|
145 |
print(process_func(signal, sampling_rate, embeddings=True))
|
146 |
# Pooled hidden states of last transformer layer
|
147 |
-
# [[
|
148 |
-
# 0.
|
149 |
```
|
|
|
92 |
hidden_states = outputs[0]
|
93 |
hidden_states = torch.mean(hidden_states, dim=1)
|
94 |
logits_age = self.age(hidden_states)
|
95 |
+
logits_gender = torch.softmax(self.gender(hidden_states), dim=1)
|
96 |
|
97 |
return hidden_states, logits_age, logits_gender
|
98 |
|
|
|
140 |
|
141 |
print(process_func(signal, sampling_rate))
|
142 |
# Age child female male
|
143 |
+
# [[ 0.3079211 0.00848487 0.0051472 0.9863679 ]]
|
144 |
|
145 |
print(process_func(signal, sampling_rate, embeddings=True))
|
146 |
# Pooled hidden states of last transformer layer
|
147 |
+
# [[ 0.00409924 0.00365688 0.02392936 ... 0.02349018 -0.13294911
|
148 |
+
# 0.1538802 ]]
|
149 |
```
|