File size: 8,638 Bytes
c7ca823 e183b88 3155938 e183b88 00a4e4b f231c44 e9a4535 f231c44 7f5d9d9 f231c44 7f5d9d9 f231c44 7f5d9d9 c7ca823 cbc9d72 93fbdd3 cbc9d72 71ded5c cbc9d72 fc2a52a cbc9d72 f82bd96 cbc9d72 27ad5f2 dfc08ec 9833d26 f199e21 f82bd96 bd9af21 f82bd96 cbc9d72 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
---
language: th
datasets:
- common_voice
tags:
- audio
- automatic-speech-recognition
- hf-asr-leaderboard
- robust-speech-event
- speech
- xlsr-fine-tuning
license: cc-by-sa-4.0
model-index:
- name: XLS-R-53 - Thai
results:
- task:
name: Automatic Speech Recognition
type: automatic-speech-recognition
dataset:
name: Common Voice 7
type: mozilla-foundation/common_voice_7_0
args: th
metrics:
- name: Test WER
type: wer
value: 0.9524
- name: Test SER
type: ser
value: 1.2346
- name: Test CER
type: cer
value: 0.1623
- task:
name: Automatic Speech Recognition
type: automatic-speech-recognition
dataset:
name: Robust Speech Event - Dev Data
type: speech-recognition-community-v2/dev_data
args: sv
metrics:
- name: Test WER
type: wer
value: null
- name: Test SER
type: ser
value: null
- name: Test CER
type: cer
value: null
---
# `wav2vec2-large-xlsr-53-th`
Finetuning `wav2vec2-large-xlsr-53` on Thai [Common Voice 7.0](https://commonvoice.mozilla.org/en/datasets)
[Read more on our blog](https://medium.com/airesearch-in-th/airesearch-in-th-3c1019a99cd)
We finetune [wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) based on [Fine-tuning Wav2Vec2 for English ASR](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Fine_tuning_Wav2Vec2_for_English_ASR.ipynb) using Thai examples of [Common Voice Corpus 7.0](https://commonvoice.mozilla.org/en/datasets). The notebooks and scripts can be found in [vistec-ai/wav2vec2-large-xlsr-53-th](https://github.com/vistec-ai/wav2vec2-large-xlsr-53-th). The pretrained model and processor can be found at [airesearch/wav2vec2-large-xlsr-53-th](https://huggingface.co/airesearch/wav2vec2-large-xlsr-53-th).
## `robust-speech-event`
Add `syllable_tokenize`, `word_tokenize` ([PyThaiNLP](https://github.com/PyThaiNLP/pythainlp)) and [deepcut](https://github.com/rkcosmos/deepcut) tokenizers to `eval.py` from [robust-speech-event](https://github.com/huggingface/transformers/tree/master/examples/research_projects/robust-speech-event#evaluation)
```
> python eval.py --model_id ./ --dataset mozilla-foundation/common_voice_7_0 --config th --split test --log_outputs --thai_tokenizer newmm/syllable/deepcut/cer
```
### Eval results on Common Voice 7 "test":
| | WER PyThaiNLP 2.3.1 | WER deepcut | SER | CER |
|---------------------------------|---------------------|-------------|---------|---------|
| Only Tokenization | 0.9524% | 2.5316% | 1.2346% | 0.1623% |
| Cleaning rules and Tokenization | TBD | TBD | TBD | TBD |
## Usage
```
#load pretrained processor and model
processor = Wav2Vec2Processor.from_pretrained("airesearch/wav2vec2-large-xlsr-53-th")
model = Wav2Vec2ForCTC.from_pretrained("airesearch/wav2vec2-large-xlsr-53-th")
#function to resample to 16_000
def speech_file_to_array_fn(batch,
text_col="sentence",
fname_col="path",
resampling_to=16000):
speech_array, sampling_rate = torchaudio.load(batch[fname_col])
resampler=torchaudio.transforms.Resample(sampling_rate, resampling_to)
batch["speech"] = resampler(speech_array)[0].numpy()
batch["sampling_rate"] = resampling_to
batch["target_text"] = batch[text_col]
return batch
#get 2 examples as sample input
test_dataset = test_dataset.map(speech_file_to_array_fn)
inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
#infer
with torch.no_grad():
logits = model(inputs.input_values,).logits
predicted_ids = torch.argmax(logits, dim=-1)
print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset["sentence"][:2])
>> Prediction: ['และ เขา ก็ สัมผัส ดีบุก', 'คุณ สามารถ รับทราบ เมื่อ ข้อความ นี้ ถูก อ่าน แล้ว']
>> Reference: ['และเขาก็สัมผัสดีบุก', 'คุณสามารถรับทราบเมื่อข้อความนี้ถูกอ่านแล้ว']
```
## Datasets
Common Voice Corpus 7.0](https://commonvoice.mozilla.org/en/datasets) contains 133 validated hours of Thai (255 total hours) at 5GB. We pre-tokenize with `pythainlp.tokenize.word_tokenize`. We preprocess the dataset using cleaning rules described in `notebooks/cv-preprocess.ipynb` by [@tann9949](https://github.com/tann9949). We then deduplicate and split as described in [ekapolc/Thai_commonvoice_split](https://github.com/ekapolc/Thai_commonvoice_split) in order to 1) avoid data leakage due to random splits after cleaning in [Common Voice Corpus 7.0](https://commonvoice.mozilla.org/en/datasets) and 2) preserve the majority of the data for the training set. The dataset loading script is `scripts/th_common_voice_70.py`. You can use this scripts together with `train_cleand.tsv`, `validation_cleaned.tsv` and `test_cleaned.tsv` to have the same splits as we do. The resulting dataset is as follows:
```
DatasetDict({
train: Dataset({
features: ['path', 'sentence'],
num_rows: 86586
})
test: Dataset({
features: ['path', 'sentence'],
num_rows: 2502
})
validation: Dataset({
features: ['path', 'sentence'],
num_rows: 3027
})
})
```
## Training
We fintuned using the following configuration on a single V100 GPU and chose the checkpoint with the lowest validation loss. The finetuning script is `scripts/wav2vec2_finetune.py`
```
# create model
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-large-xlsr-53",
attention_dropout=0.1,
hidden_dropout=0.1,
feat_proj_dropout=0.0,
mask_time_prob=0.05,
layerdrop=0.1,
gradient_checkpointing=True,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer)
)
model.freeze_feature_extractor()
training_args = TrainingArguments(
output_dir="../data/wav2vec2-large-xlsr-53-thai",
group_by_length=True,
per_device_train_batch_size=32,
gradient_accumulation_steps=1,
per_device_eval_batch_size=16,
metric_for_best_model='wer',
evaluation_strategy="steps",
eval_steps=1000,
logging_strategy="steps",
logging_steps=1000,
save_strategy="steps",
save_steps=1000,
num_train_epochs=100,
fp16=True,
learning_rate=1e-4,
warmup_steps=1000,
save_total_limit=3,
report_to="tensorboard"
)
```
## Evaluation
We benchmark on the test set using WER with words tokenized by [PyThaiNLP](https://github.com/PyThaiNLP/pythainlp) 2.3.1 and [deepcut](https://github.com/rkcosmos/deepcut), and CER. We also measure performance when spell correction using [TNC](http://www.arts.chula.ac.th/ling/tnc/) ngrams is applied. Evaluation codes can be found in `notebooks/wav2vec2_finetuning_tutorial.ipynb`. Benchmark is performed on `test-unique` split.
| | WER PyThaiNLP 2.3.1 | WER deepcut | CER |
|--------------------------------|---------------------|----------------|----------------|
| [Kaldi from scratch](https://github.com/vistec-AI/commonvoice-th) | 23.04 | | 7.57 |
| Ours without spell correction | 13.634024 | **8.152052** | **2.813019** |
| Ours with spell correction | 17.996397 | 14.167975 | 5.225761 |
| Google Web Speech API※ | 13.711234 | 10.860058 | 7.357340 |
| Microsoft Bing Speech API※ | **12.578819** | 9.620991 | 5.016620 |
| Amazon Transcribe※ | 21.86334 | 14.487553 | 7.077562 |
| NECTEC AI for Thai Partii API※ | 20.105887 | 15.515631 | 9.551027 |
※ APIs are not finetuned with Common Voice 7.0 data
## LICENSE
[cc-by-sa 4.0](https://github.com/vistec-AI/wav2vec2-large-xlsr-53-th/blob/main/LICENSE)
## Ackowledgements
* model training and validation notebooks/scripts [@cstorm125](https://github.com/cstorm125/)
* dataset cleaning scripts [@tann9949](https://github.com/tann9949)
* dataset splits [@ekapolc](https://github.com/ekapolc/) and [@14mss](https://github.com/14mss)
* running the training [@mrpeerat](https://github.com/mrpeerat)
* spell correction [@wannaphong](https://github.com/wannaphong)
|