Update README.md
Browse files
README.md
CHANGED
@@ -117,3 +117,60 @@ print("Reference:", test_dataset["sentence"][:2])
|
|
117 |
```
|
118 |
|
119 |
### Evaluation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
```
|
118 |
|
119 |
### Evaluation
|
120 |
+
|
121 |
+
The model can be evaluated as follows on the Luganda test dataset.
|
122 |
+
|
123 |
+
```python
|
124 |
+
import torch
|
125 |
+
import torchaudio
|
126 |
+
from datasets import load_dataset, load_metric
|
127 |
+
from transformers import AutoModelForCTC, Wav2Vec2BertProcessor
|
128 |
+
import re
|
129 |
+
|
130 |
+
test_dataset = load_dataset("common_voice", "lg", split="test")
|
131 |
+
wer = load_metric("wer")
|
132 |
+
|
133 |
+
model = AutoModelForCTC.from_pretrained("dmusingu/w2v-bert-2.0-luganda-CV-train-validation-7.0").to('cuda')
|
134 |
+
processor = Wav2Vec2BertProcessor.from_pretrained("dmusingu/w2v-bert-2.0-luganda-CV-train-validation-7.0")
|
135 |
+
|
136 |
+
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\»\«]'
|
137 |
+
|
138 |
+
test_dataset = test_dataset.cast_column("audio", Audio(sampling_rate=16_000))
|
139 |
+
|
140 |
+
def remove_special_characters(batch):
|
141 |
+
# remove special characters
|
142 |
+
batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()
|
143 |
+
|
144 |
+
return batch
|
145 |
+
|
146 |
+
test_dataset = test_dataset.map(remove_special_characters)
|
147 |
+
|
148 |
+
def prepare_dataset(batch):
|
149 |
+
audio = batch["audio"]
|
150 |
+
batch["input_features"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
|
151 |
+
batch["input_length"] = len(batch["input_features"])
|
152 |
+
|
153 |
+
batch["labels"] = processor(text=batch["sentence"]).input_ids
|
154 |
+
return batch
|
155 |
+
|
156 |
+
test_dataset = test_dataset.map(prepare_dataset, remove_columns=test_dataset.column_names)
|
157 |
+
|
158 |
+
# Evaluation is carried out with a batch size of 1
|
159 |
+
def map_to_result(batch):
|
160 |
+
with torch.no_grad():
|
161 |
+
input_values = torch.tensor(batch["input_features"], device="cuda").unsqueeze(0)
|
162 |
+
logits = model(input_values).logits
|
163 |
+
|
164 |
+
pred_ids = torch.argmax(logits, dim=-1)
|
165 |
+
batch["pred_str"] = processor.batch_decode(pred_ids)[0]
|
166 |
+
batch["text"] = processor.decode(batch["labels"], group_tokens=False)
|
167 |
+
|
168 |
+
return batch
|
169 |
+
|
170 |
+
results = test_dataset.map(map_to_result)
|
171 |
+
|
172 |
+
print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))
|
173 |
+
```
|
174 |
+
|
175 |
+
### Test Result: 19.4%
|
176 |
+
|