versae commited on
Commit
8c2ab3f
1 Parent(s): 81e9cf2

Add eval script

Browse files
Files changed (1) hide show
  1. eval.py +19 -4
eval.py CHANGED
@@ -5,6 +5,7 @@ from typing import Dict
5
 
6
  import torch
7
  from datasets import Audio, Dataset, load_dataset, load_metric
 
8
 
9
  from transformers import AutoFeatureExtractor, AutoModelForCTC, pipeline, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, Wav2Vec2FeatureExtractor
10
  # from pyctcdecode import BeamSearchDecoderCTC
@@ -57,7 +58,7 @@ def normalize_text(text: str, dataset: str) -> str:
57
 
58
  if dataset.lower().endswith("nst"):
59
  text = text.lower()
60
- text = text.replace("(...Vær stille under dette opptaket...)", "")
61
  text = re.sub('[áàâ]', 'a', text)
62
  text = re.sub('[ä]', 'æ', text)
63
  text = re.sub('[éèëê]', 'e', text)
@@ -77,7 +78,18 @@ def normalize_text(text: str, dataset: str) -> str:
77
  text = re.sub('[ö]', 'ø', text)
78
  text = re.sub('[ç]', 'c', text)
79
  text = re.sub('[úùüû]', 'u', text)
80
- text = re.sub('\s', ' ', text)
 
 
 
 
 
 
 
 
 
 
 
81
  text = re.sub("<ee(eh)?>", "e", text)
82
  text = re.sub("<mmm?>", "m", text)
83
  text = re.sub("<qq>", "q", text)
@@ -140,8 +152,8 @@ def main(args):
140
  batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s
141
  )
142
 
143
- batch["prediction"] = prediction["text"]
144
- batch["target"] = normalize_text(batch["text"], args.dataset)
145
  return batch
146
 
147
  # run inference on all examples
@@ -168,6 +180,9 @@ if __name__ == "__main__":
168
  "--config", type=str, required=True, help="Config of the dataset. *E.g.* `'en'` for Common Voice"
169
  )
170
  parser.add_argument("--split", type=str, required=True, help="Split of the dataset. *E.g.* `'test'`")
 
 
 
171
  parser.add_argument(
172
  "--chunk_length_s", type=float, default=None, help="Chunk length in seconds. Defaults to 5 seconds."
173
  )
 
5
 
6
  import torch
7
  from datasets import Audio, Dataset, load_dataset, load_metric
8
+ from num2words import num2words as n2w
9
 
10
  from transformers import AutoFeatureExtractor, AutoModelForCTC, pipeline, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, Wav2Vec2FeatureExtractor
11
  # from pyctcdecode import BeamSearchDecoderCTC
 
58
 
59
  if dataset.lower().endswith("nst"):
60
  text = text.lower()
61
+ text = text.replace("(...vær stille under dette opptaket...)", "")
62
  text = re.sub('[áàâ]', 'a', text)
63
  text = re.sub('[ä]', 'æ', text)
64
  text = re.sub('[éèëê]', 'e', text)
 
78
  text = re.sub('[ö]', 'ø', text)
79
  text = re.sub('[ç]', 'c', text)
80
  text = re.sub('[úùüû]', 'u', text)
81
+ text = re.sub('\s+', ' ', text)
82
+ elif dataset.lower().endswith("fleurs"):
83
+ text = re.sub('[áàâ]', 'a', text)
84
+ text = re.sub('[ä]', 'æ', text)
85
+ text = re.sub('[éèëê]', 'e', text)
86
+ text = re.sub('[íìïî]', 'i', text)
87
+ text = re.sub('[óòöô]', 'o', text)
88
+ text = re.sub('[ö]', 'ø', text)
89
+ text = re.sub('[ç]', 'c', text)
90
+ text = re.sub('[úùüû]', 'u', text)
91
+ text = re.compile(r"-?[1-9][\d.]*").sub(lambda x: n2w(x.group(0), lang="no"), text)
92
+ text = re.sub('\s+', ' ', text)
93
  text = re.sub("<ee(eh)?>", "e", text)
94
  text = re.sub("<mmm?>", "m", text)
95
  text = re.sub("<qq>", "q", text)
 
152
  batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s
153
  )
154
 
155
+ batch["prediction"] = prediction[args.text_column]
156
+ batch["target"] = normalize_text(args.text_column, args.dataset)
157
  return batch
158
 
159
  # run inference on all examples
 
180
  "--config", type=str, required=True, help="Config of the dataset. *E.g.* `'en'` for Common Voice"
181
  )
182
  parser.add_argument("--split", type=str, required=True, help="Split of the dataset. *E.g.* `'test'`")
183
+ parser.add_argument(
184
+ "--text_column", type=str, default="text", help="Column name containing the transcription."
185
+ )
186
  parser.add_argument(
187
  "--chunk_length_s", type=float, default=None, help="Chunk length in seconds. Defaults to 5 seconds."
188
  )