SalahZa commited on
Commit
0d1350d
1 Parent(s): a3b4990

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. EnglishCV/common_voice_prepare.py +410 -0
  3. EnglishCV/results/final_cs/hyperparams.yaml +144 -0
  4. EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/CKPT.yaml +4 -0
  5. EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/brain.ckpt +3 -0
  6. EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/counter.ckpt +3 -0
  7. EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/dataloader-TRAIN.ckpt +3 -0
  8. EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/model.ckpt +3 -0
  9. EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/modelopt.ckpt +3 -0
  10. EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/scheduler_encoder.ckpt +3 -0
  11. EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/scheduler_model.ckpt +3 -0
  12. EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/tokenizer.ckpt +3 -0
  13. EnglishCV/results/final_cs/save/label_encoder.txt +80 -0
  14. EnglishCV/results/final_cs/train_mixer.py +756 -0
  15. EnglishCV/results/wav2vec2_ctc_en/1234/hyperparams.yaml +190 -0
  16. EnglishCV/results/wav2vec2_ctc_en/1234/save/29_char.model +3 -0
  17. EnglishCV/results/wav2vec2_ctc_en/1234/save/29_char.vocab +28 -0
  18. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/CKPT.yaml +4 -0
  19. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/brain.ckpt +3 -0
  20. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/counter.ckpt +3 -0
  21. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/dataloader-TRAIN.ckpt +3 -0
  22. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/model.ckpt +3 -0
  23. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/modelopt.ckpt +3 -0
  24. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/scheduler_model.ckpt +3 -0
  25. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/scheduler_wav2vec.ckpt +3 -0
  26. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/wav2vec2.ckpt +3 -0
  27. EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/wav2vec_opt.ckpt +3 -0
  28. EnglishCV/results/wav2vec2_ctc_en/1234/train_with_wav2vec.py +388 -0
  29. EnglishCV/train_en_with_wav2vec.yaml +184 -0
  30. EnglishCV/train_with_wav2vec.py +388 -0
  31. README.md +18 -0
  32. TunisianASR/README.md +21 -0
  33. TunisianASR/outdomain.arpa +3 -0
  34. TunisianASR/semi_wavlm_large_tunisian_ctc/1234/hyperparams.yaml +194 -0
  35. TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/CKPT.yaml +4 -0
  36. TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/brain.ckpt +3 -0
  37. TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/counter.ckpt +3 -0
  38. TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/dataloader-TRAIN.ckpt +3 -0
  39. TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/model.ckpt +3 -0
  40. TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/modelopt.ckpt +3 -0
  41. TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/scheduler_model.ckpt +3 -0
  42. TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/scheduler_wav2vec.ckpt +3 -0
  43. TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/wav2vec2.ckpt +3 -0
  44. TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/wav2vec_opt.ckpt +3 -0
  45. TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/label_encoder.txt +44 -0
  46. TunisianASR/train_semi.yaml +175 -0
  47. TunisianASR/train_with_wavlm.py +399 -0
  48. arpas/everything.arpa +3 -0
  49. asr-wav2vec2-commonvoice-fr/README.md +130 -0
  50. asr-wav2vec2-commonvoice-fr/asr.ckpt +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.arpa filter=lfs diff=lfs merge=lfs -text
EnglishCV/common_voice_prepare.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data preparation.
3
+ Download: https://voice.mozilla.org/en/datasets
4
+ Author
5
+ ------
6
+ Titouan Parcollet
7
+ Luca Della Libera 2022
8
+ Pooneh Mousavi 2022
9
+ """
10
+
11
+ from dataclasses import dataclass
12
+ import os
13
+ import csv
14
+ import re
15
+ import logging
16
+ import torchaudio
17
+ from tqdm import tqdm
18
+ import unicodedata
19
+ import functools
20
+ torchaudio.set_audio_backend("soundfile")
21
+ from speechbrain.utils.parallel import parallel_map
22
+ from speechbrain.dataio.dataio import read_audio_info
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def prepare_common_voice(
28
+ data_folder,
29
+ save_folder,
30
+ train_tsv_file=None,
31
+ dev_tsv_file=None,
32
+ test_tsv_file=None,
33
+ accented_letters=False,
34
+ language="en",
35
+ skip_prep=False,
36
+ ):
37
+ """
38
+ Prepares the csv files for the Mozilla Common Voice dataset.
39
+ Download: https://voice.mozilla.org/en/datasets
40
+ Arguments
41
+ ---------
42
+ data_folder : str
43
+ Path to the folder where the original Common Voice dataset is stored.
44
+ This path should include the lang: /datasets/CommonVoice/<language>/
45
+ save_folder : str
46
+ The directory where to store the csv files.
47
+ train_tsv_file : str, optional
48
+ Path to the Train Common Voice .tsv file (cs)
49
+ dev_tsv_file : str, optional
50
+ Path to the Dev Common Voice .tsv file (cs)
51
+ test_tsv_file : str, optional
52
+ Path to the Test Common Voice .tsv file (cs)
53
+ accented_letters : bool, optional
54
+ Defines if accented letters will be kept as individual letters or
55
+ transformed to the closest non-accented letters.
56
+ language: str
57
+ Specify the language for text normalization.
58
+ skip_prep: bool
59
+ If True, skip data preparation.
60
+ Example
61
+ -------
62
+ >>> from recipes.CommonVoice.common_voice_prepare import prepare_common_voice
63
+ >>> data_folder = '/datasets/CommonVoice/en'
64
+ >>> save_folder = 'exp/CommonVoice_exp'
65
+ >>> train_tsv_file = '/datasets/CommonVoice/en/train.tsv'
66
+ >>> dev_tsv_file = '/datasets/CommonVoice/en/dev.tsv'
67
+ >>> test_tsv_file = '/datasets/CommonVoice/en/test.tsv'
68
+ >>> accented_letters = False
69
+ >>> duration_threshold = 10
70
+ >>> prepare_common_voice( \
71
+ data_folder, \
72
+ save_folder, \
73
+ train_tsv_file, \
74
+ dev_tsv_file, \
75
+ test_tsv_file, \
76
+ accented_letters, \
77
+ language="en" \
78
+ )
79
+ """
80
+
81
+ if skip_prep:
82
+ return
83
+
84
+ # If not specified point toward standard location w.r.t CommonVoice tree
85
+ if train_tsv_file is None:
86
+ train_tsv_file = data_folder + "/train.tsv"
87
+ else:
88
+ train_tsv_file = train_tsv_file
89
+
90
+ if dev_tsv_file is None:
91
+ dev_tsv_file = data_folder + "/dev.tsv"
92
+ else:
93
+ dev_tsv_file = dev_tsv_file
94
+
95
+ if test_tsv_file is None:
96
+ test_tsv_file = data_folder + "/test.tsv"
97
+ else:
98
+ test_tsv_file = test_tsv_file
99
+
100
+ # Setting the save folder
101
+ if not os.path.exists(save_folder):
102
+ os.makedirs(save_folder)
103
+
104
+ # Setting ouput files
105
+ save_csv_train = save_folder + "/train.csv"
106
+ save_csv_dev = save_folder + "/dev.csv"
107
+ save_csv_test = save_folder + "/test.csv"
108
+
109
+ # If csv already exists, we skip the data preparation
110
+ if skip(save_csv_train, save_csv_dev, save_csv_test):
111
+
112
+ msg = "%s already exists, skipping data preparation!" % (save_csv_train)
113
+ logger.info(msg)
114
+
115
+ msg = "%s already exists, skipping data preparation!" % (save_csv_dev)
116
+ logger.info(msg)
117
+
118
+ msg = "%s already exists, skipping data preparation!" % (save_csv_test)
119
+ logger.info(msg)
120
+
121
+ return
122
+
123
+ # Additional checks to make sure the data folder contains Common Voice
124
+ check_commonvoice_folders(data_folder)
125
+ # Creating csv files for {train, dev, test} data
126
+ file_pairs = zip(
127
+ [train_tsv_file, dev_tsv_file, test_tsv_file],
128
+ [save_csv_train, save_csv_dev, save_csv_test],
129
+ )
130
+ for tsv_file, save_csv in file_pairs:
131
+ create_csv(
132
+ tsv_file, save_csv, data_folder, accented_letters, language,
133
+ )
134
+
135
+
136
+ def skip(save_csv_train, save_csv_dev, save_csv_test):
137
+ """
138
+ Detects if the Common Voice data preparation has been already done.
139
+ If the preparation has been done, we can skip it.
140
+ Returns
141
+ -------
142
+ bool
143
+ if True, the preparation phase can be skipped.
144
+ if False, it must be done.
145
+ """
146
+
147
+ # Checking folders and save options
148
+ skip = False
149
+
150
+ if (
151
+ os.path.isfile(save_csv_train)
152
+ and os.path.isfile(save_csv_dev)
153
+ and os.path.isfile(save_csv_test)
154
+ ):
155
+ skip = True
156
+
157
+ return skip
158
+
159
+
160
+ @dataclass
161
+ class CVRow:
162
+ snt_id: str
163
+ duration: float
164
+ mp3_path: str
165
+ spk_id: str
166
+ words: str
167
+
168
+
169
+ def process_line(line, data_folder, language, accented_letters):
170
+ # Path is at indice 1 in Common Voice tsv files. And .mp3 files
171
+ # are located in datasets/lang/clips/
172
+ mp3_path = data_folder + "/clips/" + line.split("\t")[1]
173
+ file_name = mp3_path.split(".")[-2].split("/")[-1]
174
+ spk_id = line.split("\t")[0]
175
+ snt_id = file_name
176
+
177
+ # Setting torchaudio backend to sox-io (needed to read mp3 files)
178
+ """
179
+ if torchaudio.get_audio_backend() != "sox_io":
180
+ logger.warning("This recipe needs the sox-io backend of torchaudio")
181
+ logger.warning("The torchaudio backend is changed to sox_io")
182
+ torchaudio.set_audio_backend("sox_io")
183
+ """
184
+ # Reading the signal (to retrieve duration in seconds)
185
+ if os.path.isfile(mp3_path):
186
+ info = read_audio_info(mp3_path)
187
+ else:
188
+ msg = "\tError loading: %s" % (str(len(file_name)))
189
+ logger.info(msg)
190
+ return None
191
+
192
+ duration = info.num_frames / info.sample_rate
193
+
194
+ # Getting transcript
195
+ words = line.split("\t")[2]
196
+
197
+ # Unicode Normalization
198
+ words = unicode_normalisation(words)
199
+
200
+ # !! Language specific cleaning !!
201
+ words = language_specific_preprocess(language, words)
202
+
203
+ # Remove accents if specified
204
+ if not accented_letters:
205
+ words = strip_accents(words)
206
+ words = words.replace("'", " ")
207
+ words = words.replace("’", " ")
208
+
209
+ # Remove multiple spaces
210
+ words = re.sub(" +", " ", words)
211
+
212
+ # Remove spaces at the beginning and the end of the sentence
213
+ words = words.lstrip().rstrip()
214
+
215
+ # Getting chars
216
+ chars = words.replace(" ", "_")
217
+ chars = " ".join([char for char in chars][:])
218
+
219
+ # Remove too short sentences (or empty):
220
+ if language in ["ja", "ch"]:
221
+ if len(chars) < 3:
222
+ return None
223
+ else:
224
+ if len(words.split(" ")) < 3:
225
+ return None
226
+
227
+ # Composition of the csv_line
228
+ return CVRow(snt_id, duration, mp3_path, spk_id, words)
229
+
230
+
231
+ def create_csv(
232
+ orig_tsv_file, csv_file, data_folder, accented_letters=False, language="en"
233
+ ):
234
+ """
235
+ Creates the csv file given a list of wav files.
236
+ Arguments
237
+ ---------
238
+ orig_tsv_file : str
239
+ Path to the Common Voice tsv file (standard file).
240
+ data_folder : str
241
+ Path of the CommonVoice dataset.
242
+ accented_letters : bool, optional
243
+ Defines if accented letters will be kept as individual letters or
244
+ transformed to the closest non-accented letters.
245
+ Returns
246
+ -------
247
+ None
248
+ """
249
+
250
+ # Check if the given files exists
251
+ if not os.path.isfile(orig_tsv_file):
252
+ msg = "\t%s doesn't exist, verify your dataset!" % (orig_tsv_file)
253
+ logger.info(msg)
254
+ raise FileNotFoundError(msg)
255
+
256
+ # We load and skip the header
257
+ loaded_csv = open(orig_tsv_file, "r").readlines()[1:]
258
+ nb_samples = len(loaded_csv)
259
+
260
+ msg = "Preparing CSV files for %s samples ..." % (str(nb_samples))
261
+ logger.info(msg)
262
+
263
+ # Adding some Prints
264
+ msg = "Creating csv lists in %s ..." % (csv_file)
265
+ logger.info(msg)
266
+
267
+ # Process and write lines
268
+ total_duration = 0.0
269
+
270
+ line_processor = functools.partial(
271
+ process_line,
272
+ data_folder=data_folder,
273
+ language=language,
274
+ accented_letters=accented_letters,
275
+ )
276
+
277
+ # Stream into a .tmp file, and rename it to the real path at the end.
278
+ csv_file_tmp = csv_file + ".tmp"
279
+
280
+ with open(csv_file_tmp, mode="w", encoding="utf-8") as csv_f:
281
+ csv_writer = csv.writer(
282
+ csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL
283
+ )
284
+
285
+ csv_writer.writerow(["ID", "duration", "wav", "spk_id", "wrd"])
286
+ for line in tqdm(loaded_csv) :
287
+
288
+ row = line_processor(line)
289
+ if row is not None :
290
+ total_duration += row.duration
291
+ csv_writer.writerow(
292
+ [
293
+ row.snt_id,
294
+ str(row.duration),
295
+ row.mp3_path,
296
+ row.spk_id,
297
+ row.words,
298
+ ]
299
+ )
300
+
301
+ os.replace(csv_file_tmp, csv_file)
302
+
303
+ # Final prints
304
+ msg = "%s successfully created!" % (csv_file)
305
+ logger.info(msg)
306
+ msg = "Number of samples: %s " % (str(len(loaded_csv)))
307
+ logger.info(msg)
308
+ msg = "Total duration: %s Hours" % (str(round(total_duration / 3600, 2)))
309
+ logger.info(msg)
310
+
311
+
312
+ def language_specific_preprocess(language, words):
313
+ # !! Language specific cleaning !!
314
+ # Important: feel free to specify the text normalization
315
+ # corresponding to your alphabet.
316
+
317
+ if language in ["en", "fr", "it", "rw"]:
318
+ words = re.sub(
319
+ "[^’'A-Za-z0-9À-ÖØ-öø-ÿЀ-ӿéæœâçèàûî]+", " ", words
320
+ ).upper()
321
+
322
+ if language == "de":
323
+ # this replacement helps preserve the case of ß
324
+ # (and helps retain solitary occurrences of SS)
325
+ # since python's upper() converts ß to SS.
326
+ words = words.replace("ß", "0000ß0000")
327
+ words = re.sub("[^’'A-Za-z0-9öÖäÄüÜß]+", " ", words).upper()
328
+ words = words.replace("'", " ")
329
+ words = words.replace("’", " ")
330
+ words = words.replace(
331
+ "0000SS0000", "ß"
332
+ ) # replace 0000SS0000 back to ß as its initial presence in the corpus
333
+
334
+ if language == "fr":
335
+ # Replace J'y D'hui etc by J_ D_hui
336
+ words = words.replace("'", " ")
337
+ words = words.replace("’", " ")
338
+
339
+ elif language == "ar":
340
+ HAMZA = "\u0621"
341
+ ALEF_MADDA = "\u0622"
342
+ ALEF_HAMZA_ABOVE = "\u0623"
343
+ letters = (
344
+ "ابتةثجحخدذرزژشسصضطظعغفقكلمنهويىءآأؤإئ"
345
+ + HAMZA
346
+ + ALEF_MADDA
347
+ + ALEF_HAMZA_ABOVE
348
+ )
349
+ words = re.sub("[^" + letters + " ]+", "", words).upper()
350
+ elif language == "fa":
351
+ HAMZA = "\u0621"
352
+ ALEF_MADDA = "\u0622"
353
+ ALEF_HAMZA_ABOVE = "\u0623"
354
+ letters = (
355
+ "ابپتةثجحخچدذرزژسشصضطظعغفقگکلمنهویىءآأؤإئ"
356
+ + HAMZA
357
+ + ALEF_MADDA
358
+ + ALEF_HAMZA_ABOVE
359
+ )
360
+ words = re.sub("[^" + letters + " ]+", "", words).upper()
361
+ elif language == "ga-IE":
362
+ # Irish lower() is complicated, but upper() is nondeterministic, so use lowercase
363
+ def pfxuc(a):
364
+ return len(a) >= 2 and a[0] in "tn" and a[1] in "AEIOUÁÉÍÓÚ"
365
+
366
+ def galc(w):
367
+ return w.lower() if not pfxuc(w) else w[0] + "-" + w[1:].lower()
368
+
369
+ words = re.sub("[^-A-Za-z'ÁÉÍÓÚáéíóú]+", " ", words)
370
+ words = " ".join(map(galc, words.split(" ")))
371
+ elif language == "es":
372
+ # Fix the following error in dataset large:
373
+ # KeyError: 'The item En noviembre lanzaron Queen Elizabeth , coproducida por Foreign Noi$e . requires replacements which were not supplied.'
374
+ words = words.replace("$", "s")
375
+ return words
376
+
377
+
378
+ def check_commonvoice_folders(data_folder):
379
+ """
380
+ Check if the data folder actually contains the Common Voice dataset.
381
+ If not, raises an error.
382
+ Returns
383
+ -------
384
+ None
385
+ Raises
386
+ ------
387
+ FileNotFoundError
388
+ If data folder doesn't contain Common Voice dataset.
389
+ """
390
+ files_str = "/clips"
391
+ # Checking clips
392
+ if not os.path.exists(data_folder + files_str):
393
+ err_msg = (
394
+ "the folder %s does not exist (it is expected in "
395
+ "the Common Voice dataset)" % (data_folder + files_str)
396
+ )
397
+ raise FileNotFoundError(err_msg)
398
+
399
+
400
+ def unicode_normalisation(text):
401
+ return str(text)
402
+
403
+
404
+ def strip_accents(text):
405
+ text = (
406
+ unicodedata.normalize("NFD", text)
407
+ .encode("ascii", "ignore")
408
+ .decode("utf-8")
409
+ )
410
+ return str(text)
EnglishCV/results/final_cs/hyperparams.yaml ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated 2023-09-08 from:
2
+ # /gpfsssd/scratch/rech/nou/uzn19yk/switched_data/stac.yaml
3
+ # yamllint disable
4
+ # Generated 2023-08-03 from:
5
+ # /home/salah/new_tunisian_model/hparams/train_tunisian_withwavlm.yaml
6
+ # yamllint disable
7
+ # ################################
8
+ # Model: wav2vec2 + DNN + CTC
9
+ # Augmentation: SpecAugment
10
+ # Authors: Titouan Parcollet 2021
11
+ # ################################
12
+
13
+ seed: 1994
14
+ __set_seed: !!python/object/apply:torch.manual_seed [1234]
15
+ output_folder: results/non_semi_final_stac
16
+ wer_file: results/non_semi_final_stac/wer.txt
17
+ save_folder: results/non_semi_final_stac/save
18
+ train_log: results/non_semi_final_stac/train_log.txt
19
+
20
+
21
+
22
+ # Data files
23
+ data_folder: junk # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
24
+ train_tsv_file: junk/train.tsv # Standard CommonVoice .tsv files
25
+ dev_tsv_file: junk/dev.tsv # Standard CommonVoice .tsv files
26
+ test_tsv_file: junk/test.tsv # Standard CommonVoice .tsv files
27
+ accented_letters: true
28
+
29
+ csv_folder: /gpfsscratch/rech/nou/uzn19yk/switched_data/extended_clean/
30
+ train_csv: /gpfsscratch/rech/nou/uzn19yk/switched_data/extended_clean//train.csv
31
+ valid_csv: /gpfsscratch/rech/nou/uzn19yk/switched_data/extended_clean//dev.csv
32
+ test_csv:
33
+ - all_tests/cs_test.csv
34
+ - all_tests/stac_test.csv
35
+
36
+ # We remove utterance slonger than 10s in the train/dev/test sets as
37
+ # longer sentences certainly correspond to "open microphones".
38
+ avoid_if_longer_than: 13.0
39
+ avoid_if_shorter_than: 0.5
40
+
41
+ # Training parameters
42
+ number_of_epochs: 20
43
+ lr: 0.0002
44
+ lr_weights: 0.01
45
+ sorting: ascending
46
+ auto_mix_prec: false
47
+ sample_rate: 16000
48
+ language_modelling: true
49
+ ngram_lm_path:
50
+ /gpfsstore/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/arpas/pluslanguages_everything.arpa
51
+
52
+ # With data_parallel batch_size is split into N jobs
53
+ # With DDP batch_size is multiplied by N jobs
54
+ # Must be 3 per GPU to fit 32GB of VRAM
55
+ batch_size: 3
56
+ test_batch_size: 4
57
+
58
+ # Dataloader options
59
+ dataloader_options:
60
+ batch_size: 3
61
+ num_workers: 6
62
+
63
+ test_dataloader_options:
64
+ batch_size: 4
65
+ num_workers: 6
66
+
67
+ # Model parameters
68
+ activation: !name:torch.nn.Sigmoid
69
+ dnn_layers: 1
70
+ dnn_neurons: 768
71
+ freeze_encoder: true
72
+
73
+ # Outputs
74
+ output_neurons: 76 # BPE size, index(blank/eos/bos) = 0
75
+
76
+ # Functions and classes
77
+ #
78
+ epoch_counter: &id006 !new:speechbrain.utils.epoch_loop.EpochCounter
79
+ limit: 20
80
+
81
+ encoder_dim: 3217
82
+ enc: &id001 !new:speechbrain.nnet.RNN.LSTM
83
+ input_shape: [null, null, 3217]
84
+ num_layers: 2
85
+ bidirectional: true
86
+ dropout: 0.2
87
+ hidden_size: 1024
88
+
89
+ ctc_lin: &id002 !new:speechbrain.nnet.linear.Linear
90
+
91
+ input_size: 2048
92
+ n_neurons: 76
93
+
94
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
95
+ apply_log: true
96
+
97
+ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
98
+ blank_index: 0
99
+
100
+ modules:
101
+ enc: *id001
102
+ ctc_lin: *id002
103
+ model: &id003 !new:torch.nn.ModuleList
104
+ - [*id001, *id002]
105
+ model_opt_class: !name:torch.optim.Adam
106
+ lr: 0.0002
107
+
108
+ weights_opt_class: !name:torch.optim.Adam
109
+ lr: 0.01
110
+
111
+ lr_annealing_model: &id004 !new:speechbrain.nnet.schedulers.NewBobScheduler
112
+ initial_value: 0.0002
113
+ improvement_threshold: 0.0025
114
+ annealing_factor: 0.8
115
+ patient: 0
116
+
117
+ lr_annealing_weights: &id005 !new:speechbrain.nnet.schedulers.NewBobScheduler
118
+ initial_value: 0.01
119
+ improvement_threshold: 0.0025
120
+ annealing_factor: 0.9
121
+ patient: 0
122
+
123
+ label_encoder: &id007 !new:speechbrain.dataio.encoder.CTCTextEncoder
124
+
125
+
126
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
127
+ checkpoints_dir: results/non_semi_final_stac/save
128
+ recoverables:
129
+ model: *id003
130
+ scheduler_model: *id004
131
+ scheduler_encoder: *id005
132
+ counter: *id006
133
+ tokenizer: *id007
134
+ blank_index: 0
135
+ unk_index: 1
136
+
137
+
138
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
139
+ save_file: results/non_semi_final_stac/train_log.txt
140
+
141
+ error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
142
+
143
+ cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
144
+ split_tokens: true
EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/CKPT.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # yamllint disable
2
+ WER: 51.292116454039906
3
+ end-of-epoch: true
4
+ unixtime: 1694130018.9642384
EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/brain.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5c026fe6fa51700406bd476e131950c797b0b3bacb3daae0854e85689bb4cf9
3
+ size 50
EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/counter.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5ca38f748a1d6eaf726b8a42fb575c3c71f1864a8143301782de13da2d9202b
3
+ size 2
EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/dataloader-TRAIN.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7e1edcac43af8cea1439d222314af06354ae31da6a3d90b8cc6bcebc5c8e397
3
+ size 4
EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da683a8efa5709a06af9b258452c243da841780a0a7942c196c472a3e21e5010
3
+ size 240389017
EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/modelopt.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:416feb314443cf839f4425fc382e555dec90e3dea26fa52b75e4ac1b702c5078
3
+ size 480787579
EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/scheduler_encoder.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e2efd50f0cf28a080e2625fdd8a1852c669841537cdc0a57fce60bc6c1eec11
3
+ size 515
EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/scheduler_model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cec54cc9236fa7aa965b397675d24299b973675cc0c6345de038fc70e51629ab
3
+ size 703
EnglishCV/results/final_cs/save/CKPT+2023-09-08+01-40-18+00/tokenizer.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21080a140faeb4f39fad188aaf081914ec782be9c4320d6415e8822709e18017
3
+ size 39
EnglishCV/results/final_cs/save/label_encoder.txt ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'و' => 74
2
+ 'ي' => 1
3
+ 'ن' => 2
4
+ ' ' => 3
5
+ 'م' => 4
6
+ 'ش' => 5
7
+ 'ل' => 6
8
+ 'س' => 7
9
+ 'ت' => 8
10
+ 'ا' => 9
11
+ 'د' => 10
12
+ 'ر' => 11
13
+ 'ى' => 12
14
+ 'ب' => 13
15
+ 'ح' => 14
16
+ 'ط' => 15
17
+ 'ع' => 16
18
+ 'ك' => 17
19
+ 'ف' => 18
20
+ 'ق' => 19
21
+ 'ذ' => 20
22
+ 'ث' => 21
23
+ 'ج' => 22
24
+ 'ة' => 23
25
+ 'غ' => 24
26
+ 'o' => 25
27
+ 'k' => 26
28
+ 'b' => 27
29
+ 'n' => 28
30
+ 'خ' => 29
31
+ 'ه' => 30
32
+ 'v' => 31
33
+ 'i' => 32
34
+ 'l' => 33
35
+ 'à' => 34
36
+ 'ص' => 35
37
+ 'ض' => 36
38
+ 'a' => 37
39
+ 'u' => 38
40
+ 't' => 39
41
+ 'm' => 40
42
+ 'q' => 41
43
+ 'e' => 42
44
+ 'd' => 43
45
+ 'c' => 44
46
+ 'p' => 45
47
+ 'r' => 46
48
+ 'أ' => 47
49
+ 'إ' => 48
50
+ 's' => 49
51
+ 'j' => 50
52
+ 'ز' => 51
53
+ 'ء' => 52
54
+ 'h' => 53
55
+ 'f' => 54
56
+ 'آ' => 55
57
+ 'ئ' => 56
58
+ 'ؤ' => 57
59
+ 'ظ' => 58
60
+ 'y' => 59
61
+ 'é' => 60
62
+ "'" => 61
63
+ 'z' => 62
64
+ 'x' => 63
65
+ 'w' => 64
66
+ 'g' => 65
67
+ 'è' => 66
68
+ 'û' => 67
69
+ 'ç' => 68
70
+ 'ê' => 69
71
+ 'ô' => 70
72
+ 'ù' => 71
73
+ 'î' => 72
74
+ 'â' => 73
75
+ '<blank>' => 0
76
+ 1 => 75
77
+ ================
78
+ 'starting_index' => 0
79
+ 'unk_label' => 1
80
+ 'blank_label' => '<blank>'
EnglishCV/results/final_cs/train_mixer.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import sys
5
+ import torch
6
+ import logging
7
+ import speechbrain as sb
8
+ from speechbrain.utils.distributed import run_on_main
9
+ from hyperpyyaml import load_hyperpyyaml
10
+ from pathlib import Path
11
+ import torchaudio.transforms as T
12
+ from cv_train import ASRCV
13
+ import torchaudio
14
+ import numpy as np
15
+ import kenlm
16
+ from pyctcdecode import build_ctcdecoder
17
+ import re
18
+
19
+ # Commented out IPython magic to ensure Python compatibility.
20
+ # %cd /content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm
21
+ #hparams_file, run_opts, overrides = sb.parse_arguments(["/gpfsstore/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/hparams/train_semi.yaml"])
22
+ hparams_file, run_opts, overrides = sb.parse_arguments(["semi_supervised_test_tunisian.yaml"])
23
+
24
+ # If distributed_launch=True then
25
+ # create ddp_group with the right communication protocol
26
+ sb.utils.distributed.ddp_init_group(run_opts)
27
+
28
+ with open(hparams_file) as fin:
29
+ hparams = load_hyperpyyaml(fin, overrides)
30
+
31
+ # Create experiment directory
32
+ sb.create_experiment_directory(
33
+ experiment_directory=hparams["output_folder"],
34
+ hyperparams_to_save=hparams_file,
35
+ overrides=overrides,
36
+ )
37
+ # Dataset prep (parsing Librispeech)
38
+
39
+ def dataio_prepare(hparams):
40
+ """This function prepares the datasets to be used in the brain class.
41
+ It also defines the data processing pipeline through user-defined functions."""
42
+
43
+ # 1. Define datasets
44
+ data_folder = hparams["data_folder"]
45
+
46
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
47
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
48
+ )
49
+
50
+ if hparams["sorting"] == "ascending":
51
+ # we sort training data to speed up training and get better results.
52
+ train_data = train_data.filtered_sorted(
53
+ sort_key="duration",
54
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
55
+ )
56
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
57
+ hparams["dataloader_options"]["shuffle"] = False
58
+
59
+ elif hparams["sorting"] == "descending":
60
+ train_data = train_data.filtered_sorted(
61
+ sort_key="duration",
62
+ reverse=True,
63
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
64
+ )
65
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
66
+ hparams["dataloader_options"]["shuffle"] = False
67
+
68
+ elif hparams["sorting"] == "random":
69
+ pass
70
+
71
+ else:
72
+ raise NotImplementedError(
73
+ "sorting must be random, ascending or descending"
74
+ )
75
+
76
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
77
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
78
+ )
79
+ # We also sort the validation data so it is faster to validate
80
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
81
+ test_datasets = {}
82
+ for csv_file in hparams["test_csv"]:
83
+ name = Path(csv_file).stem
84
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
85
+ csv_path=csv_file, replacements={"data_root": data_folder}
86
+ )
87
+ test_datasets[name] = test_datasets[name].filtered_sorted(
88
+ sort_key="duration"
89
+ )
90
+
91
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
92
+
93
+
94
+ # 2. Define audio pipeline:
95
+ @sb.utils.data_pipeline.takes("wav")
96
+ @sb.utils.data_pipeline.provides("sig")
97
+ def audio_pipeline(wav):
98
+ info = torchaudio.info(wav)
99
+ sig = sb.dataio.dataio.read_audio(wav)
100
+ if len(sig.shape)>1 :
101
+ sig = torch.mean(sig, dim=1)
102
+ resampled = torchaudio.transforms.Resample(
103
+ info.sample_rate, hparams["sample_rate"],
104
+ )(sig)
105
+ return resampled
106
+
107
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
108
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
109
+
110
+ # 3. Define text pipeline:
111
+ @sb.utils.data_pipeline.takes("wrd")
112
+ @sb.utils.data_pipeline.provides(
113
+ "wrd", "char_list", "tokens_list", "tokens"
114
+ )
115
+ def text_pipeline(wrd):
116
+ yield wrd
117
+ char_list = list(wrd)
118
+ yield char_list
119
+ tokens_list = label_encoder.encode_sequence(char_list)
120
+ yield tokens_list
121
+ tokens = torch.LongTensor(tokens_list)
122
+ yield tokens
123
+
124
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
125
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
126
+ special_labels = {
127
+ "blank_label": hparams["blank_index"],
128
+ "unk_label": hparams["unk_index"]
129
+ }
130
+ label_encoder.load_or_create(
131
+ path=lab_enc_file,
132
+ from_didatasets=[train_data],
133
+ output_key="char_list",
134
+ special_labels=special_labels,
135
+ sequence_input=True,
136
+ )
137
+
138
+ # 4. Set output:
139
+ sb.dataio.dataset.set_output_keys(
140
+ datasets, ["id", "sig", "wrd", "char_list", "tokens"],
141
+ )
142
+ return train_data, valid_data,test_datasets, label_encoder
143
+
144
+ class ASR(sb.core.Brain):
145
+ def compute_forward(self, batch, stage):
146
+ """Forward computations from the waveform batches to the output probabilities."""
147
+
148
+ batch = batch.to(self.device)
149
+ wavs, wav_lens = batch.sig
150
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
151
+
152
+ if stage == sb.Stage.TRAIN:
153
+ if hasattr(self.hparams, "augmentation"):
154
+ wavs = self.hparams.augmentation(wavs, wav_lens)
155
+
156
+ # Forward pass
157
+ feats = self.modules.wav2vec2(wavs, wav_lens)
158
+ x = self.modules.enc(feats)
159
+ logits = self.modules.ctc_lin(x)
160
+ p_ctc = self.hparams.log_softmax(logits)
161
+
162
+ return p_ctc, wav_lens
163
+
164
+ def custom_encode(self,wavs,wav_lens) :
165
+ wavs = wavs.to(self.device)
166
+ if(wav_lens is not None): wav_lens.to(self.device)
167
+
168
+ feats = self.modules.wav2vec2(wavs, wav_lens)
169
+ x = self.modules.enc(feats)
170
+ logits = self.modules.ctc_lin(x)
171
+ p_ctc = self.hparams.log_softmax(logits)
172
+
173
+ return feats,p_ctc
174
+
175
+
176
+
177
+ def compute_objectives(self, predictions, batch, stage):
178
+ """Computes the loss (CTC) given predictions and targets."""
179
+
180
+ p_ctc, wav_lens = predictions
181
+
182
+ ids = batch.id
183
+ tokens, tokens_lens = batch.tokens
184
+
185
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
186
+
187
+ if stage != sb.Stage.TRAIN:
188
+ predicted_tokens = sb.decoders.ctc_greedy_decode(
189
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
190
+ )
191
+ # Decode token terms to words
192
+ if self.hparams.use_language_modelling:
193
+ predicted_words = []
194
+ for logs in p_ctc:
195
+ text = decoder.decode(logs.detach().cpu().numpy())
196
+ predicted_words.append(text.split(" "))
197
+ else:
198
+ predicted_words = [
199
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
200
+ for utt_seq in predicted_tokens
201
+ ]
202
+ # Convert indices to words
203
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
204
+
205
+ self.wer_metric.append(ids, predicted_words, target_words)
206
+ self.cer_metric.append(ids, predicted_words, target_words)
207
+
208
+ return loss
209
+
210
+ def fit_batch(self, batch):
211
+ """Train the parameters given a single batch in input"""
212
+ should_step = self.step % self.grad_accumulation_factor == 0
213
+ # Managing automatic mixed precision
214
+ # TOFIX: CTC fine-tuning currently is unstable
215
+ # This is certainly due to CTC being done in fp16 instead of fp32
216
+ if self.auto_mix_prec:
217
+ with torch.cuda.amp.autocast():
218
+ with self.no_sync():
219
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
220
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
221
+ with self.no_sync(not should_step):
222
+ self.scaler.scale(
223
+ loss / self.grad_accumulation_factor
224
+ ).backward()
225
+ if should_step:
226
+
227
+ if not self.hparams.wav2vec2.freeze:
228
+ self.scaler.unscale_(self.wav2vec_optimizer)
229
+ self.scaler.unscale_(self.model_optimizer)
230
+ if self.check_gradients(loss):
231
+ if not self.hparams.wav2vec2.freeze:
232
+ if self.optimizer_step >= self.hparams.warmup_steps:
233
+ self.scaler.step(self.wav2vec_optimizer)
234
+ self.scaler.step(self.model_optimizer)
235
+ self.scaler.update()
236
+ self.zero_grad()
237
+ self.optimizer_step += 1
238
+ else:
239
+ # This is mandatory because HF models have a weird behavior with DDP
240
+ # on the forward pass
241
+ with self.no_sync():
242
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
243
+
244
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
245
+
246
+ with self.no_sync(not should_step):
247
+ (loss / self.grad_accumulation_factor).backward()
248
+ if should_step:
249
+ if self.check_gradients(loss):
250
+ if not self.hparams.wav2vec2.freeze:
251
+ if self.optimizer_step >= self.hparams.warmup_steps:
252
+ self.wav2vec_optimizer.step()
253
+ self.model_optimizer.step()
254
+ self.zero_grad()
255
+ self.optimizer_step += 1
256
+
257
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
258
+ return loss.detach().cpu()
259
+
260
+ def evaluate_batch(self, batch, stage):
261
+ """Computations needed for validation/test batches"""
262
+ predictions = self.compute_forward(batch, stage=stage)
263
+ with torch.no_grad():
264
+ loss = self.compute_objectives(predictions, batch, stage=stage)
265
+ return loss.detach()
266
+
267
+ def on_stage_start(self, stage, epoch):
268
+ """Gets called at the beginning of each epoch"""
269
+ if stage != sb.Stage.TRAIN:
270
+ self.cer_metric = self.hparams.cer_computer()
271
+ self.wer_metric = self.hparams.error_rate_computer()
272
+
273
+ def on_stage_end(self, stage, stage_loss, epoch):
274
+ """Gets called at the end of an epoch."""
275
+ # Compute/store important stats
276
+ stage_stats = {"loss": stage_loss}
277
+ if stage == sb.Stage.TRAIN:
278
+ self.train_stats = stage_stats
279
+ else:
280
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
281
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
282
+
283
+ # Perform end-of-iteration things, like annealing, logging, etc.
284
+ if stage == sb.Stage.VALID:
285
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
286
+ stage_stats["loss"]
287
+ )
288
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
289
+ stage_stats["loss"]
290
+ )
291
+ sb.nnet.schedulers.update_learning_rate(
292
+ self.model_optimizer, new_lr_model
293
+ )
294
+ if not self.hparams.wav2vec2.freeze:
295
+ sb.nnet.schedulers.update_learning_rate(
296
+ self.wav2vec_optimizer, new_lr_wav2vec
297
+ )
298
+ self.hparams.train_logger.log_stats(
299
+ stats_meta={
300
+ "epoch": epoch,
301
+ "lr_model": old_lr_model,
302
+ "lr_wav2vec": old_lr_wav2vec,
303
+ },
304
+ train_stats=self.train_stats,
305
+ valid_stats=stage_stats,
306
+ )
307
+ self.checkpointer.save_and_keep_only(
308
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
309
+ )
310
+ elif stage == sb.Stage.TEST:
311
+ self.hparams.train_logger.log_stats(
312
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
313
+ test_stats=stage_stats,
314
+ )
315
+ with open(self.hparams.wer_file, "w") as w:
316
+ self.wer_metric.write_stats(w)
317
+
318
+ def init_optimizers(self):
319
+ "Initializes the wav2vec2 optimizer and model optimizer"
320
+
321
+ # If the wav2vec encoder is unfrozen, we create the optimizer
322
+ if not self.hparams.wav2vec2.freeze:
323
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
324
+ self.modules.wav2vec2.parameters()
325
+ )
326
+ if self.checkpointer is not None:
327
+ self.checkpointer.add_recoverable(
328
+ "wav2vec_opt", self.wav2vec_optimizer
329
+ )
330
+
331
+ self.model_optimizer = self.hparams.model_opt_class(
332
+ self.hparams.model.parameters()
333
+ )
334
+
335
+ if self.checkpointer is not None:
336
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
337
+
338
+ def zero_grad(self, set_to_none=False):
339
+ if not self.hparams.wav2vec2.freeze:
340
+ self.wav2vec_optimizer.zero_grad(set_to_none)
341
+ self.model_optimizer.zero_grad(set_to_none)
342
+
343
+
344
+ """
345
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
346
+
347
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
348
+ hparams
349
+ )
350
+
351
+
352
+ # We dynamicaly add the tokenizer to our brain class.
353
+ # NB: This tokenizer corresponds to the one used for the LM!!
354
+ """
355
+ from speechbrain.pretrained import EncoderASR,EncoderDecoderASR
356
+ french_asr_model = EncoderASR.from_hparams(source="speechbrain/asr-wav2vec2-commonvoice-fr", savedir="pretrained_models/asr-wav2vec2-commonvoice-fr").cuda()
357
+ #french_asr_model = "r"
358
+
359
+ cvhparams_file, cvrun_opts, cvoverrides = sb.parse_arguments(["en_cv.yaml"])
360
+ with open(cvhparams_file) as cvfin:
361
+ cvhparams = load_hyperpyyaml(cvfin, cvoverrides)
362
+ english_asr_model = ASRCV(
363
+ modules=cvhparams["modules"],
364
+ hparams=cvhparams,
365
+ run_opts=cvrun_opts,
366
+ checkpointer=cvhparams["checkpointer"],
367
+ )
368
+ english_asr_model.checkpointer.recover_if_possible()
369
+ asr_brain = ASR(
370
+ modules=hparams["modules"],
371
+ hparams=hparams,
372
+ run_opts=run_opts,
373
+ checkpointer=hparams["checkpointer"],
374
+ )
375
+ asr_brain.checkpointer.recover_if_possible()
376
+ asr_brain.modules.eval()
377
+ english_asr_model.modules.eval()
378
+ french_asr_model.mods.eval()
379
+ """
380
+ asr_brain.tokenizer = label_encoder
381
+
382
+ # Testing
383
+ real = True
384
+ if real :
385
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
386
+ asr_brain.hparams.wer_file = os.path.join(
387
+ hparams["output_folder"], "wer_{}.txt".format(k)
388
+ )
389
+ asr_brain.evaluate(
390
+ test_datasets[k], test_loader_kwargs=hparams["dataloader_options"]
391
+ )
392
+ """
393
+
394
+ """
395
+ from torch.nn.utils.rnn import pad_sequence
396
+ def load_paths(wavs_path):
397
+ waveforms = []
398
+ for path in wavs_path :
399
+ waveform, _ = torchaudio.load(path)
400
+ waveforms.append(waveform.squeeze(0))
401
+ # normalize array length to the bigger arrays by pading with 0's
402
+ padded_arrays = pad_sequence(waveforms, batch_first=True)
403
+ return torch.tensor(padded_arrays)
404
+
405
+ waveform = load_paths(["/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/samples/Salah10.wav","/content/drive/MyDrive/tunisian_corpora/tunisian_without_wavlm/samples/Salah10.wav"])
406
+ embeddings, posteriogram = asr_brain.custom_encode(waveform,None)
407
+ print(embeddings.shape)
408
+ print(posteriogram.shape)
409
+ """
410
+
411
+ from speechbrain.pretrained import EncoderASR,EncoderDecoderASR
412
+ import torchaudio
413
+ import speechbrain as sb
414
+ import torch
415
+ from torch.nn.utils.rnn import pad_sequence
416
+ import torch
417
+ import speechbrain as sb
418
+ import numpy as np
419
+ import torch.optim as optim
420
+ import torch.nn as nn
421
+
422
+ # Commented out IPython magic to ensure Python compatibility.
423
+ # %ls
424
+
425
+ #UTILS FUNCTIOJNS
426
+ def get_size_dimensions(arr):
427
+ size_dimensions = []
428
+ while isinstance(arr, list):
429
+ size_dimensions.append(len(arr))
430
+ arr = arr[0]
431
+ return size_dimensions
432
+
433
+ def scale_array(batch,n):
434
+ scaled_batch = []
435
+
436
+ for array in batch:
437
+ if(n < len(array)): raise ValueError("Cannot scale Array down")
438
+
439
+ repeat = round(n/len(array))+1
440
+ scaled_length_array= []
441
+
442
+ for i in array:
443
+ for j in range(repeat) :
444
+ if(len(scaled_length_array) == n): break
445
+ scaled_length_array.append(i)
446
+
447
+ scaled_batch.append(scaled_length_array)
448
+
449
+ return torch.tensor(scaled_batch)
450
+
451
+
452
+ def load_paths(wavs_path):
453
+ waveforms = []
454
+ for path in wavs_path :
455
+ waveform, _ = torchaudio.load(path)
456
+ waveforms.append(waveform.squeeze(0))
457
+ # normalize array length to the bigger arrays by pading with 0's
458
+ padded_arrays = pad_sequence(waveforms, batch_first=True)
459
+ return torch.tensor(padded_arrays)
460
+
461
+
462
+
463
+ def word_to_vec(input_string):
464
+ mapping= {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, 'ا': 27, 'ب': 28, 'ت': 29, 'ث': 30, 'ج': 31, 'ح': 32, 'خ': 33, 'د': 34, 'ذ': 35, 'ر': 36, 'ز': 37, 'س': 38, 'ش': 39, 'ص': 40, 'ض': 41, 'ط': 42, 'ظ': 43, 'ع': 44, 'غ': 45, 'ف': 46, 'ق': 47, 'ك': 48, 'ل': 49, 'م': 50, 'ن': 51, 'ه': 52, 'و': 53, 'ي': 54,' ':55}
465
+
466
+ numbers = [mapping[word] for word in input_string if word in mapping]
467
+ return numbers
468
+
469
+ device = 'cuda'
470
+ verbose = 0
471
+ #FLOW LEVEL FUNCTIONS
472
+ def merge_strategy(embeddings1, embeddings2, embeddings3,post1, post2,post3):
473
+
474
+
475
+ post1 = post1.to(device)
476
+ post2 = post2.to(device)
477
+ post3 = post3.to(device)
478
+ embeddings1 = embeddings1.to(device)
479
+ embeddings2 = embeddings2.to(device)
480
+ embeddings3 = embeddings3.to(device)
481
+
482
+ posteriograms_merged = torch.cat((post1,post2,post3),dim=2)
483
+ embeddings_merged = torch.cat((embeddings1,embeddings2,embeddings3),dim=2)
484
+
485
+ if(verbose !=0):
486
+ print('MERGED POST ',posteriograms_merged.shape)
487
+ print('MERGED emb ',embeddings_merged.shape)
488
+
489
+ return torch.cat((posteriograms_merged,embeddings_merged),dim=2).to(device)
490
+
491
+ def decode(model,wavs,wav_lens):
492
+
493
+ with torch.no_grad():
494
+ wav_lens = wav_lens.to(model.device)
495
+ encoder_out = model.encode_batch(wavs, wav_lens)
496
+ predictions = model.decoding_function(encoder_out, wav_lens)
497
+ return predictions
498
+
499
+ def middle_layer(batch, lens):
500
+
501
+ tn_embeddings, tn_posteriogram = asr_brain.custom_encode(batch,None)
502
+
503
+ fr_embeddings = french_asr_model.mods.encoder.wav2vec2(batch)
504
+ fr_posteriogram =french_asr_model.encode_batch(batch,lens)
505
+ en_embeddings = english_asr_model.modules.wav2vec2(batch, lens)
506
+ x = english_asr_model.modules.enc(en_embeddings)
507
+ en_posteriogram = english_asr_model.modules.ctc_lin(x)
508
+ #scores, en_posteriogram = english_asr_model.mods.decoder(en_embeddings ,lens)
509
+ if(verbose !=0):
510
+ print('[EMBEDDINGS] FR:',fr_embeddings.shape, "EN:",en_embeddings.shape, "TN:", tn_embeddings.shape)
511
+ print('[POSTERIOGRAM] FR:',fr_posteriogram.shape, "EN:",en_posteriogram.shape,"TN:",tn_posteriogram.shape)
512
+
513
+
514
+ bilangual_sample = merge_strategy(fr_embeddings,en_embeddings,tn_embeddings,fr_posteriogram,en_posteriogram,tn_posteriogram)
515
+ return bilangual_sample
516
+
517
+ class Mixer(sb.core.Brain):
518
+
519
+ def compute_forward(self, batch, stage):
520
+ """Forward computations from the waveform batches to the output probabilities."""
521
+ wavs, wav_lens = batch.sig
522
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
523
+
524
+ if stage == sb.Stage.TRAIN:
525
+ if hasattr(self.hparams, "augmentation"):
526
+ wavs = self.hparams.augmentation(wavs, wav_lens)
527
+
528
+ multi_langual_feats = middle_layer(wavs, wav_lens)
529
+ multi_langual_feats= multi_langual_feats.to(device)
530
+ feats, _ = self.modules.enc(multi_langual_feats)
531
+ logits = self.modules.ctc_lin(feats)
532
+ p_ctc = self.hparams.log_softmax(logits)
533
+
534
+ if stage!= sb.Stage.TRAIN:
535
+ p_tokens = sb.decoders.ctc_greedy_decode(
536
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
537
+ )
538
+ else :
539
+ p_tokens = None
540
+ return p_ctc, wav_lens, p_tokens
541
+
542
+ def compute_objectives(self, predictions, batch, stage):
543
+ """Computes the loss (CTC) given predictions and targets."""
544
+
545
+ p_ctc, wav_lens , predicted_tokens= predictions
546
+
547
+ ids = batch.id
548
+ tokens, tokens_lens = batch.tokens
549
+
550
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
551
+
552
+
553
+ if stage == sb.Stage.VALID:
554
+ predicted_words = [
555
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
556
+ for utt_seq in predicted_tokens
557
+ ]
558
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
559
+ self.wer_metric.append(ids, predicted_words, target_words)
560
+ self.cer_metric.append(ids, predicted_words, target_words)
561
+ if stage ==sb.Stage.TEST :
562
+ if self.hparams.language_modelling:
563
+ predicted_words = []
564
+ for logs in p_ctc:
565
+ text = decoder.decode(logs.detach().cpu().numpy())
566
+ predicted_words.append(text.split(" "))
567
+ else :
568
+ predicted_words = [
569
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
570
+ for utt_seq in predicted_tokens
571
+ ]
572
+
573
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
574
+ self.wer_metric.append(ids, predicted_words, target_words)
575
+ self.cer_metric.append(ids, predicted_words, target_words)
576
+
577
+ return loss
578
+
579
+ def fit_batch(self, batch):
580
+ """Train the parameters given a single batch in input"""
581
+ should_step = self.step % self.grad_accumulation_factor == 0
582
+ # Managing automatic mixed precision
583
+ # TOFIX: CTC fine-tuning currently is unstable
584
+ # This is certainly due to CTC being done in fp16 instead of fp32
585
+ if self.auto_mix_prec:
586
+ with torch.cuda.amp.autocast():
587
+ with self.no_sync():
588
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
589
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
590
+ with self.no_sync(not should_step):
591
+ self.scaler.scale(
592
+ loss / self.grad_accumulation_factor
593
+ ).backward()
594
+ if should_step:
595
+
596
+
597
+ self.scaler.unscale_(self.model_optimizer)
598
+ if self.check_gradients(loss):
599
+ self.scaler.step(self.model_optimizer)
600
+ self.scaler.update()
601
+ self.zero_grad()
602
+ self.optimizer_step += 1
603
+ else:
604
+ # This is mandatory because HF models have a weird behavior with DDP
605
+ # on the forward pass
606
+ with self.no_sync():
607
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
608
+
609
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
610
+
611
+ with self.no_sync(not should_step):
612
+ (loss / self.grad_accumulation_factor).backward()
613
+ if should_step:
614
+ if self.check_gradients(loss):
615
+ self.model_optimizer.step()
616
+ self.zero_grad()
617
+ self.optimizer_step += 1
618
+
619
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
620
+ return loss.detach().cpu()
621
+
622
+ def evaluate_batch(self, batch, stage):
623
+ """Computations needed for validation/test batches"""
624
+ predictions = self.compute_forward(batch, stage=stage)
625
+ with torch.no_grad():
626
+ loss = self.compute_objectives(predictions, batch, stage=stage)
627
+ return loss.detach()
628
+
629
+ def on_stage_start(self, stage, epoch):
630
+ """Gets called at the beginning of each epoch"""
631
+ if stage != sb.Stage.TRAIN:
632
+ self.cer_metric = self.hparams.cer_computer()
633
+ self.wer_metric = self.hparams.error_rate_computer()
634
+
635
+ def on_stage_end(self, stage, stage_loss, epoch):
636
+ """Gets called at the end of an epoch."""
637
+ # Compute/store important stats
638
+ stage_stats = {"loss": stage_loss}
639
+ if stage == sb.Stage.TRAIN:
640
+ self.train_stats = stage_stats
641
+ else:
642
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
643
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
644
+
645
+ # Perform end-of-iteration things, like annealing, logging, etc.
646
+ if stage == sb.Stage.VALID:
647
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
648
+ stage_stats["loss"]
649
+ )
650
+ sb.nnet.schedulers.update_learning_rate(
651
+ self.model_optimizer, new_lr_model
652
+ )
653
+ self.hparams.train_logger.log_stats(
654
+ stats_meta={
655
+ "epoch": epoch,
656
+ "lr_model": old_lr_model,
657
+ },
658
+ train_stats=self.train_stats,
659
+ valid_stats=stage_stats,
660
+ )
661
+ self.checkpointer.save_and_keep_only(
662
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
663
+ )
664
+ elif stage == sb.Stage.TEST:
665
+ self.hparams.train_logger.log_stats(
666
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
667
+ test_stats=stage_stats,
668
+ )
669
+ with open(self.hparams.wer_file, "w") as w:
670
+ self.wer_metric.write_stats(w)
671
+
672
+ def init_optimizers(self):
673
+
674
+ self.model_optimizer = self.hparams.model_opt_class(
675
+ self.hparams.model.parameters()
676
+ )
677
+
678
+ if self.checkpointer is not None:
679
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
680
+
681
+ def zero_grad(self, set_to_none=False):
682
+
683
+ self.model_optimizer.zero_grad(set_to_none)
684
+
685
+
686
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
687
+
688
+ # If distributed_launch=True then
689
+ # create ddp_group with the right communication protocol
690
+ sb.utils.distributed.ddp_init_group(run_opts)
691
+
692
+ with open(hparams_file) as fin:
693
+ hparams = load_hyperpyyaml(fin, overrides)
694
+
695
+ # Create experiment directory
696
+ sb.create_experiment_directory(
697
+ experiment_directory=hparams["output_folder"],
698
+ hyperparams_to_save=hparams_file,
699
+ overrides=overrides,
700
+ )
701
+ def read_labels_file(labels_file):
702
+ with open(labels_file, "r",encoding="utf-8") as lf:
703
+ lines = lf.read().splitlines()
704
+ division = "==="
705
+ numbers = {}
706
+ for line in lines :
707
+ if division in line :
708
+ break
709
+ string, number = line.split("=>")
710
+ number = int(number)
711
+ string = string[1:-2]
712
+ numbers[number] = string
713
+ return [numbers[x] for x in range(len(numbers))]
714
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
715
+ hparams
716
+ )
717
+
718
+
719
+ labels = read_labels_file(os.path.join(hparams["save_folder"], "label_encoder.txt"))
720
+ labels = [""] + labels[1:-1] + ["1"]
721
+ if hparams["language_modelling"]:
722
+ decoder = build_ctcdecoder(
723
+ labels,
724
+ kenlm_model_path=hparams["ngram_lm_path"], # either .arpa or .bin file
725
+ alpha=0.5, # tuned on a val set
726
+ beta=1, # tuned on a val set
727
+ )
728
+
729
+
730
+
731
+
732
+ mixer = Mixer(
733
+ modules=hparams["modules"],
734
+ hparams=hparams,
735
+ run_opts=run_opts,
736
+ checkpointer=hparams["checkpointer"],
737
+ )
738
+ mixer.tokenizer = label_encoder
739
+
740
+
741
+ mixer.fit(
742
+ mixer.hparams.epoch_counter,
743
+ train_data,
744
+ valid_data,
745
+ train_loader_kwargs=hparams["dataloader_options"],
746
+ valid_loader_kwargs=hparams["test_dataloader_options"],
747
+ )
748
+ print(test_datasets.keys())
749
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
750
+ mixer.hparams.wer_file = os.path.join(
751
+ hparams["output_folder"], "wer_{}.txt".format(k)
752
+ )
753
+ mixer.evaluate(
754
+ test_datasets[k], test_loader_kwargs=hparams["test_dataloader_options"]
755
+ )
756
+
EnglishCV/results/wav2vec2_ctc_en/1234/hyperparams.yaml ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated 2023-09-06 from:
2
+ # /gpfsdswork/projects/rech/nou/uzn19yk/final_forke/speechbrain-3/recipes/CommonVoice/ASR/CTC/hparams/train_en_with_wav2vec.yaml
3
+ # yamllint disable
4
+ # ################################
5
+ # Model: wav2vec2 + DNN + CTC
6
+ # Augmentation: SpecAugment
7
+ # Authors: Titouan Parcollet 2021
8
+ # ################################
9
+
10
+ # Seed needs to be set at top of yaml, before objects with parameters are made
11
+ seed: 1234
12
+ __set_seed: !!python/object/apply:torch.manual_seed [1234]
13
+ output_folder: results/wav2vec2_ctc_en/1234
14
+ wer_file: results/wav2vec2_ctc_en/1234/wer.txt
15
+ save_folder: results/wav2vec2_ctc_en/1234/save
16
+ train_log: results/wav2vec2_ctc_en/1234/train_log.txt
17
+
18
+ # URL for the biggest Fairseq english wav2vec2 model.
19
+ wav2vec2_hub: facebook/wav2vec2-large-lv60
20
+ wav2vec2_folder: results/wav2vec2_ctc_en/1234/save/wav2vec2_checkpoint
21
+
22
+ # Data files
23
+ data_folder:
24
+ /gpfsscratch/rech/nou/uzn19yk/download/cv-corpus-12.0-2022-12-07/en/cv-corpus-12.0-2022-12-07/en # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
25
+ train_tsv_file:
26
+ /gpfsscratch/rech/nou/uzn19yk/download/cv-corpus-12.0-2022-12-07/en/cv-corpus-12.0-2022-12-07/en/train.tsv # Standard CommonVoice .tsv files
27
+ dev_tsv_file:
28
+ /gpfsscratch/rech/nou/uzn19yk/download/cv-corpus-12.0-2022-12-07/en/cv-corpus-12.0-2022-12-07/en/dev.tsv # Standard CommonVoice .tsv files
29
+ test_tsv_file:
30
+ /gpfsscratch/rech/nou/uzn19yk/download/cv-corpus-12.0-2022-12-07/en/cv-corpus-12.0-2022-12-07/en/test.tsv # Standard CommonVoice .tsv files
31
+ accented_letters: false
32
+ language: en # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
33
+ train_csv: results/wav2vec2_ctc_en/1234/save/train.csv
34
+ valid_csv: results/wav2vec2_ctc_en/1234/save/dev.csv
35
+ test_csv: results/wav2vec2_ctc_en/1234/save/test.csv
36
+ skip_prep: false # Skip data preparation
37
+
38
+ # We remove utterance slonger than 10s in the train/dev/test sets as
39
+ # longer sentences certainly correspond to "open microphones".
40
+ avoid_if_longer_than: 10.0
41
+
42
+ # Training parameters
43
+ number_of_epochs: 10
44
+ lr: 1.0
45
+ lr_wav2vec: 0.0001
46
+ sorting: ascending
47
+ auto_mix_prec: false
48
+ sample_rate: 16000
49
+ ckpt_interval_minutes: 30 # save checkpoint every N min
50
+
51
+ # With data_parallel batch_size is split into N jobs
52
+ # With DDP batch_size is multiplied by N jobs
53
+ # Must be 8 per GPU to fit 32GB of VRAM
54
+ batch_size: 8
55
+ test_batch_size: 4
56
+
57
+ dataloader_options:
58
+ batch_size: 8
59
+ num_workers: 6
60
+ test_dataloader_options:
61
+ batch_size: 4
62
+ num_workers: 6
63
+
64
+ # BPE parameters
65
+ token_type: char # ["unigram", "bpe", "char"]
66
+ character_coverage: 1.0
67
+
68
+ # Model parameters
69
+ # activation: !name:torch.nn.LeakyReLU
70
+ wav2vec_output_dim: 1024
71
+ dnn_neurons: 1024
72
+ freeze_wav2vec: false
73
+ freeze_feature_extractor: true
74
+ dropout: 0.15
75
+ warmup_steps: 500
76
+
77
+ # Outputs
78
+ output_neurons: 29 # BPE size, index(blank/eos/bos) = 0
79
+
80
+ # Decoding parameters
81
+ # Be sure that the bos and eos index match with the BPEs ones
82
+ blank_index: 0
83
+ bos_index: 1
84
+ eos_index: 2
85
+
86
+ #
87
+ # Functions and classes
88
+ #
89
+ epoch_counter: &id007 !new:speechbrain.utils.epoch_loop.EpochCounter
90
+
91
+ limit: 10
92
+
93
+ augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
94
+ sample_rate: 16000
95
+ speeds: [95, 100, 105]
96
+
97
+ enc: &id002 !new:speechbrain.nnet.containers.Sequential
98
+ input_shape: [null, null, 1024]
99
+ linear1: !name:speechbrain.nnet.linear.Linear
100
+ n_neurons: 1024
101
+ bias: true
102
+ bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
103
+ activation: !new:torch.nn.LeakyReLU
104
+ drop: !new:torch.nn.Dropout
105
+ p: 0.15
106
+ linear2: !name:speechbrain.nnet.linear.Linear
107
+ n_neurons: 1024
108
+ bias: true
109
+ bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
110
+ activation2: !new:torch.nn.LeakyReLU
111
+ drop2: !new:torch.nn.Dropout
112
+ p: 0.15
113
+ linear3: !name:speechbrain.nnet.linear.Linear
114
+ n_neurons: 1024
115
+ bias: true
116
+ bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
117
+ activation3: !new:torch.nn.LeakyReLU
118
+
119
+ wav2vec2: &id001 !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
120
+ source: /gpfsscratch/rech/nou/uzn19yk/wav2vec2-large-lv60/
121
+ output_norm: true
122
+ freeze: false
123
+ freeze_feature_extractor: true
124
+ save_path: results/wav2vec2_ctc_en/1234/save/wav2vec2_checkpoint
125
+
126
+ #####
127
+ # Uncomment this block if you prefer to use a Fairseq pretrained model instead
128
+ # of a HuggingFace one. Here, we provide an URL that is obtained from the
129
+ # Fairseq github for the multilingual XLSR.
130
+ #
131
+ #wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt
132
+ #wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
133
+ # pretrained_path: !ref <wav2vec2_url>
134
+ # output_norm: True
135
+ # freeze: False
136
+ # save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
137
+ #####
138
+
139
+ ctc_lin: &id003 !new:speechbrain.nnet.linear.Linear
140
+
141
+ input_size: 1024
142
+ n_neurons: 29
143
+
144
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
145
+ apply_log: true
146
+
147
+ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
148
+ blank_index: 0
149
+
150
+ modules:
151
+ wav2vec2: *id001
152
+ enc: *id002
153
+ ctc_lin: *id003
154
+ model: &id004 !new:torch.nn.ModuleList
155
+ - [*id002, *id003]
156
+ model_opt_class: !name:torch.optim.Adadelta
157
+ lr: 1.0
158
+ rho: 0.95
159
+ eps: 1.e-8
160
+
161
+ wav2vec_opt_class: !name:torch.optim.Adam
162
+ lr: 0.0001
163
+
164
+ lr_annealing_model: &id005 !new:speechbrain.nnet.schedulers.NewBobScheduler
165
+ initial_value: 1.0
166
+ improvement_threshold: 0.0025
167
+ annealing_factor: 0.8
168
+ patient: 0
169
+
170
+ lr_annealing_wav2vec: &id006 !new:speechbrain.nnet.schedulers.NewBobScheduler
171
+ initial_value: 0.0001
172
+ improvement_threshold: 0.0025
173
+ annealing_factor: 0.9
174
+ patient: 0
175
+
176
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
177
+ checkpoints_dir: results/wav2vec2_ctc_en/1234/save
178
+ recoverables:
179
+ wav2vec2: *id001
180
+ model: *id004
181
+ scheduler_model: *id005
182
+ scheduler_wav2vec: *id006
183
+ counter: *id007
184
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
185
+ save_file: results/wav2vec2_ctc_en/1234/train_log.txt
186
+
187
+ error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
188
+
189
+ cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
190
+ split_tokens: true
EnglishCV/results/wav2vec2_ctc_en/1234/save/29_char.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee4214a3ebba9461ca02ca61220a2338412bbf9ef5a5982f2bc40740c4ab91a8
3
+ size 238011
EnglishCV/results/wav2vec2_ctc_en/1234/save/29_char.vocab ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <unk> 0
2
+ ▁ -1.786
3
+ E -2.27261
4
+ A -2.6326
5
+ T -2.64317
6
+ I -2.76341
7
+ S -2.81519
8
+ O -2.8189
9
+ N -2.83568
10
+ R -2.87568
11
+ H -3.22802
12
+ L -3.30075
13
+ D -3.43047
14
+ C -3.58554
15
+ U -3.84445
16
+ M -3.84732
17
+ F -4.07023
18
+ P -4.09107
19
+ G -4.16259
20
+ W -4.25412
21
+ Y -4.30147
22
+ B -4.36224
23
+ V -4.71267
24
+ K -5.1744
25
+ X -6.46672
26
+ J -6.5246
27
+ Z -6.95828
28
+ Q -7.12388
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/CKPT.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # yamllint disable
2
+ WER: 18.234978071545488
3
+ end-of-epoch: true
4
+ unixtime: 1694033791.9455216
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/brain.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06617abf655f8550362b963062fc2a57bd819826ab70e63701676ea09d23618d
3
+ size 51
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/counter.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4735e3a265e16eee03f59718b9b5d03019c07d8b6c51f90da3a666eec13ab35
3
+ size 1
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/dataloader-TRAIN.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f21c20a479fcc07663ec4255ad1c85466afb791f514f8f3baa174bd56edca2d4
3
+ size 6
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:422a1d7a30720e846d2cb79ff510832fe96c1495f559f08fb37bdd118269ea7b
3
+ size 12769326
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/modelopt.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65cda77e4403deb7c8cee3052ac687bfc3bf6e68264dcb0e297e8f88bccf0d66
3
+ size 25485359
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/scheduler_model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9c36e38dd81971c68387a9f921cf0d61adad21f5b3f6420b6f3015b0f9d20df
3
+ size 511
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/scheduler_wav2vec.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0293788921aad16c6e904d7ec0b7dba2dd4778fa3b7f1bfa04276b3965599999
3
+ size 515
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/wav2vec2.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f7073aa70c88927f11cff4f2ba63a026c8ff6c119837391d84013feb229ad3e
3
+ size 1261924189
EnglishCV/results/wav2vec2_ctc_en/1234/save/CKPT+2023-09-06+22-56-31+00/wav2vec_opt.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42691a96ebaba3dd3baf7e2521763db7f79b37a6bde9b0ea9d1adc2cac5bdf5e
3
+ size 2490156402
EnglishCV/results/wav2vec2_ctc_en/1234/train_with_wav2vec.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import torch
4
+ import logging
5
+ import speechbrain as sb
6
+ import torchaudio
7
+ from hyperpyyaml import load_hyperpyyaml
8
+ from speechbrain.tokenizers.SentencePiece import SentencePiece
9
+ from speechbrain.utils.data_utils import undo_padding
10
+ from speechbrain.utils.distributed import run_on_main
11
+
12
+ """Recipe for training a sequence-to-sequence ASR system with CommonVoice.
13
+ The system employs a wav2vec2 encoder and a CTC decoder.
14
+ Decoding is performed with greedy decoding (will be extended to beam search).
15
+
16
+ To run this recipe, do the following:
17
+ > python train_with_wav2vec2.py hparams/train_with_wav2vec2.yaml
18
+
19
+ With the default hyperparameters, the system employs a pretrained wav2vec2 encoder.
20
+ The wav2vec2 model is pretrained following the model given in the hprams file.
21
+ It may be dependent on the language.
22
+
23
+ The neural network is trained with CTC on sub-word units estimated with
24
+ Byte Pairwise Encoding (BPE).
25
+
26
+ The experiment file is flexible enough to support a large variety of
27
+ different systems. By properly changing the parameter files, you can try
28
+ different encoders, decoders, tokens (e.g, characters instead of BPE),
29
+ training languages (all CommonVoice languages), and many
30
+ other possible variations.
31
+
32
+ Authors
33
+ * Titouan Parcollet 2021
34
+ """
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ # Define training procedure
40
+ class ASR(sb.core.Brain):
41
+ def compute_forward(self, batch, stage):
42
+ """Forward computations from the waveform batches to the output probabilities."""
43
+
44
+ batch = batch.to(self.device)
45
+ wavs, wav_lens = batch.sig
46
+ tokens_bos, _ = batch.tokens_bos
47
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
48
+
49
+ if stage == sb.Stage.TRAIN:
50
+ if hasattr(self.hparams, "augmentation"):
51
+ wavs = self.hparams.augmentation(wavs, wav_lens)
52
+
53
+ # Forward pass
54
+ feats = self.modules.wav2vec2(wavs, wav_lens)
55
+ x = self.modules.enc(feats)
56
+ logits = self.modules.ctc_lin(x)
57
+ p_ctc = self.hparams.log_softmax(logits)
58
+
59
+ return p_ctc, wav_lens
60
+
61
+ def compute_objectives(self, predictions, batch, stage):
62
+ """Computes the loss (CTC) given predictions and targets."""
63
+
64
+ p_ctc, wav_lens = predictions
65
+
66
+ ids = batch.id
67
+ tokens_eos, tokens_eos_lens = batch.tokens_eos
68
+ tokens, tokens_lens = batch.tokens
69
+
70
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
71
+
72
+ if stage != sb.Stage.TRAIN:
73
+ # Decode token terms to words
74
+ sequence = sb.decoders.ctc_greedy_decode(
75
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
76
+ )
77
+
78
+ predicted_words = self.tokenizer(sequence, task="decode_from_list")
79
+
80
+ # Convert indices to words
81
+ target_words = undo_padding(tokens, tokens_lens)
82
+ target_words = self.tokenizer(target_words, task="decode_from_list")
83
+
84
+ self.wer_metric.append(ids, predicted_words, target_words)
85
+ self.cer_metric.append(ids, predicted_words, target_words)
86
+
87
+ return loss
88
+
89
+ def fit_batch(self, batch):
90
+ """Train the parameters given a single batch in input"""
91
+ should_step = self.step % self.grad_accumulation_factor == 0
92
+ # Managing automatic mixed precision
93
+ # TOFIX: CTC fine-tuning currently is unstable
94
+ # This is certainly due to CTC being done in fp16 instead of fp32
95
+ if self.auto_mix_prec:
96
+ with torch.cuda.amp.autocast():
97
+ with self.no_sync():
98
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
99
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
100
+ with self.no_sync(not should_step):
101
+ self.scaler.scale(
102
+ loss / self.grad_accumulation_factor
103
+ ).backward()
104
+ if should_step:
105
+
106
+ if not self.hparams.wav2vec2.freeze:
107
+ self.scaler.unscale_(self.wav2vec_optimizer)
108
+ self.scaler.unscale_(self.model_optimizer)
109
+ if self.check_gradients(loss):
110
+ if not self.hparams.wav2vec2.freeze:
111
+ if self.optimizer_step >= self.hparams.warmup_steps:
112
+ self.scaler.step(self.wav2vec_optimizer)
113
+ self.scaler.step(self.model_optimizer)
114
+ self.scaler.update()
115
+ self.zero_grad()
116
+ self.optimizer_step += 1
117
+ else:
118
+ # This is mandatory because HF models have a weird behavior with DDP
119
+ # on the forward pass
120
+ with self.no_sync():
121
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
122
+
123
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
124
+
125
+ with self.no_sync(not should_step):
126
+ (loss / self.grad_accumulation_factor).backward()
127
+ if should_step:
128
+ if self.check_gradients(loss):
129
+ if not self.hparams.wav2vec2.freeze:
130
+ if self.optimizer_step >= self.hparams.warmup_steps:
131
+ self.wav2vec_optimizer.step()
132
+ self.model_optimizer.step()
133
+ self.zero_grad()
134
+ self.optimizer_step += 1
135
+
136
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
137
+ return loss.detach().cpu()
138
+
139
+ def evaluate_batch(self, batch, stage):
140
+ """Computations needed for validation/test batches"""
141
+ predictions = self.compute_forward(batch, stage=stage)
142
+ with torch.no_grad():
143
+ loss = self.compute_objectives(predictions, batch, stage=stage)
144
+ return loss.detach()
145
+
146
+ def on_stage_start(self, stage, epoch):
147
+ """Gets called at the beginning of each epoch"""
148
+ if stage != sb.Stage.TRAIN:
149
+ self.cer_metric = self.hparams.cer_computer()
150
+ self.wer_metric = self.hparams.error_rate_computer()
151
+
152
+ def on_stage_end(self, stage, stage_loss, epoch):
153
+ """Gets called at the end of an epoch."""
154
+ # Compute/store important stats
155
+ stage_stats = {"loss": stage_loss}
156
+ if stage == sb.Stage.TRAIN:
157
+ self.train_stats = stage_stats
158
+ else:
159
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
160
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
161
+
162
+ # Perform end-of-iteration things, like annealing, logging, etc.
163
+ if stage == sb.Stage.VALID:
164
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
165
+ stage_stats["loss"]
166
+ )
167
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
168
+ stage_stats["loss"]
169
+ )
170
+ sb.nnet.schedulers.update_learning_rate(
171
+ self.model_optimizer, new_lr_model
172
+ )
173
+ if not self.hparams.wav2vec2.freeze:
174
+ sb.nnet.schedulers.update_learning_rate(
175
+ self.wav2vec_optimizer, new_lr_wav2vec
176
+ )
177
+ self.hparams.train_logger.log_stats(
178
+ stats_meta={
179
+ "epoch": epoch,
180
+ "lr_model": old_lr_model,
181
+ "lr_wav2vec": old_lr_wav2vec,
182
+ },
183
+ train_stats=self.train_stats,
184
+ valid_stats=stage_stats,
185
+ )
186
+ self.checkpointer.save_and_keep_only(
187
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
188
+ )
189
+ elif stage == sb.Stage.TEST:
190
+ self.hparams.train_logger.log_stats(
191
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
192
+ test_stats=stage_stats,
193
+ )
194
+ with open(self.hparams.wer_file, "w") as w:
195
+ self.wer_metric.write_stats(w)
196
+
197
+ def init_optimizers(self):
198
+ "Initializes the wav2vec2 optimizer and model optimizer"
199
+
200
+ # If the wav2vec encoder is unfrozen, we create the optimizer
201
+ if not self.hparams.wav2vec2.freeze:
202
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
203
+ self.modules.wav2vec2.parameters()
204
+ )
205
+ if self.checkpointer is not None:
206
+ self.checkpointer.add_recoverable(
207
+ "wav2vec_opt", self.wav2vec_optimizer
208
+ )
209
+
210
+ self.model_optimizer = self.hparams.model_opt_class(
211
+ self.hparams.model.parameters()
212
+ )
213
+
214
+ if self.checkpointer is not None:
215
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
216
+
217
+ def zero_grad(self, set_to_none=False):
218
+ if not self.hparams.wav2vec2.freeze:
219
+ self.wav2vec_optimizer.zero_grad(set_to_none)
220
+ self.model_optimizer.zero_grad(set_to_none)
221
+
222
+
223
+ # Define custom data procedure
224
+ def dataio_prepare(hparams, tokenizer):
225
+ """This function prepares the datasets to be used in the brain class.
226
+ It also defines the data processing pipeline through user-defined functions."""
227
+
228
+ # 1. Define datasets
229
+ data_folder = hparams["data_folder"]
230
+
231
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
232
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
233
+ )
234
+
235
+ if hparams["sorting"] == "ascending":
236
+ # we sort training data to speed up training and get better results.
237
+ train_data = train_data.filtered_sorted(
238
+ sort_key="duration",
239
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
240
+ )
241
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
242
+ hparams["dataloader_options"]["shuffle"] = False
243
+
244
+ elif hparams["sorting"] == "descending":
245
+ train_data = train_data.filtered_sorted(
246
+ sort_key="duration",
247
+ reverse=True,
248
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
249
+ )
250
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
251
+ hparams["dataloader_options"]["shuffle"] = False
252
+
253
+ elif hparams["sorting"] == "random":
254
+ pass
255
+
256
+ else:
257
+ raise NotImplementedError(
258
+ "sorting must be random, ascending or descending"
259
+ )
260
+
261
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
262
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
263
+ )
264
+ # We also sort the validation data so it is faster to validate
265
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
266
+
267
+ test_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
268
+ csv_path=hparams["test_csv"], replacements={"data_root": data_folder},
269
+ )
270
+
271
+ # We also sort the validation data so it is faster to validate
272
+ test_data = test_data.filtered_sorted(sort_key="duration")
273
+
274
+ datasets = [train_data, valid_data, test_data]
275
+
276
+ # 2. Define audio pipeline:
277
+ @sb.utils.data_pipeline.takes("wav")
278
+ @sb.utils.data_pipeline.provides("sig")
279
+ def audio_pipeline(wav):
280
+ info = torchaudio.info(wav)
281
+ sig = sb.dataio.dataio.read_audio(wav)
282
+ resampled = torchaudio.transforms.Resample(
283
+ info.sample_rate, hparams["sample_rate"],
284
+ )(sig)
285
+ return resampled
286
+
287
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
288
+
289
+ # 3. Define text pipeline:
290
+ @sb.utils.data_pipeline.takes("wrd")
291
+ @sb.utils.data_pipeline.provides(
292
+ "tokens_list", "tokens_bos", "tokens_eos", "tokens"
293
+ )
294
+ def text_pipeline(wrd):
295
+ tokens_list = tokenizer.sp.encode_as_ids(wrd)
296
+ yield tokens_list
297
+ tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
298
+ yield tokens_bos
299
+ tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
300
+ yield tokens_eos
301
+ tokens = torch.LongTensor(tokens_list)
302
+ yield tokens
303
+
304
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
305
+
306
+ # 4. Set output:
307
+ sb.dataio.dataset.set_output_keys(
308
+ datasets, ["id", "sig", "tokens_bos", "tokens_eos", "tokens"],
309
+ )
310
+ return train_data, valid_data, test_data
311
+
312
+
313
+ if __name__ == "__main__":
314
+
315
+ # Load hyperparameters file with command-line overrides
316
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
317
+ with open(hparams_file) as fin:
318
+ hparams = load_hyperpyyaml(fin, overrides)
319
+
320
+ # If --distributed_launch then
321
+ # create ddp_group with the right communication protocol
322
+ sb.utils.distributed.ddp_init_group(run_opts)
323
+
324
+ # Dataset preparation (parsing CommonVoice)
325
+ from common_voice_prepare import prepare_common_voice # noqa
326
+
327
+ # Create experiment directory
328
+ sb.create_experiment_directory(
329
+ experiment_directory=hparams["output_folder"],
330
+ hyperparams_to_save=hparams_file,
331
+ overrides=overrides,
332
+ )
333
+
334
+ # Due to DDP, we do the preparation ONLY on the main python process
335
+ run_on_main(
336
+ prepare_common_voice,
337
+ kwargs={
338
+ "data_folder": hparams["data_folder"],
339
+ "save_folder": hparams["save_folder"],
340
+ "train_tsv_file": hparams["train_tsv_file"],
341
+ "dev_tsv_file": hparams["dev_tsv_file"],
342
+ "test_tsv_file": hparams["test_tsv_file"],
343
+ "accented_letters": hparams["accented_letters"],
344
+ "language": hparams["language"],
345
+ "skip_prep": hparams["skip_prep"],
346
+ },
347
+ )
348
+
349
+ # Defining tokenizer and loading it
350
+ tokenizer = SentencePiece(
351
+ model_dir=hparams["save_folder"],
352
+ vocab_size=hparams["output_neurons"],
353
+ annotation_train=hparams["train_csv"],
354
+ annotation_read="wrd",
355
+ model_type=hparams["token_type"],
356
+ character_coverage=hparams["character_coverage"],
357
+ )
358
+
359
+ # Create the datasets objects as well as tokenization and encoding :-D
360
+ train_data, valid_data, test_data = dataio_prepare(hparams, tokenizer)
361
+
362
+ # Trainer initialization
363
+ asr_brain = ASR(
364
+ modules=hparams["modules"],
365
+ hparams=hparams,
366
+ run_opts=run_opts,
367
+ checkpointer=hparams["checkpointer"],
368
+ )
369
+
370
+ # Adding objects to trainer.
371
+ asr_brain.tokenizer = tokenizer
372
+
373
+ # Training
374
+ asr_brain.fit(
375
+ asr_brain.hparams.epoch_counter,
376
+ train_data,
377
+ valid_data,
378
+ train_loader_kwargs=hparams["dataloader_options"],
379
+ valid_loader_kwargs=hparams["test_dataloader_options"],
380
+ )
381
+
382
+ # Test
383
+ asr_brain.hparams.wer_file = hparams["output_folder"] + "/wer_test.txt"
384
+ asr_brain.evaluate(
385
+ test_data,
386
+ min_key="WER",
387
+ test_loader_kwargs=hparams["test_dataloader_options"],
388
+ )
EnglishCV/train_en_with_wav2vec.yaml ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ################################
2
+ # Model: wav2vec2 + DNN + CTC
3
+ # Augmentation: SpecAugment
4
+ # Authors: Titouan Parcollet 2021
5
+ # ################################
6
+
7
+ # Seed needs to be set at top of yaml, before objects with parameters are made
8
+ seed: 1234
9
+ __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
10
+ output_folder: !ref results/wav2vec2_ctc_en/<seed>
11
+ wer_file: !ref <output_folder>/wer.txt
12
+ save_folder: !ref <output_folder>/save
13
+ train_log: !ref <output_folder>/train_log.txt
14
+
15
+ # URL for the biggest Fairseq english wav2vec2 model.
16
+ wav2vec2_hub: facebook/wav2vec2-large-lv60
17
+ wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
18
+
19
+ # Data files
20
+ data_folder: /gpfsscratch/rech/nou/uzn19yk/download/cv-corpus-12.0-2022-12-07/en/cv-corpus-12.0-2022-12-07/en # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
21
+ train_tsv_file: !ref <data_folder>/train.tsv # Standard CommonVoice .tsv files
22
+ dev_tsv_file: !ref <data_folder>/dev.tsv # Standard CommonVoice .tsv files
23
+ test_tsv_file: !ref <data_folder>/test.tsv # Standard CommonVoice .tsv files
24
+ accented_letters: False
25
+ language: en # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
26
+ train_csv: !ref <save_folder>/train.csv
27
+ valid_csv: !ref <save_folder>/dev.csv
28
+ test_csv: !ref <save_folder>/test.csv
29
+ skip_prep: False # Skip data preparation
30
+
31
+ # We remove utterance slonger than 10s in the train/dev/test sets as
32
+ # longer sentences certainly correspond to "open microphones".
33
+ avoid_if_longer_than: 10.0
34
+
35
+ # Training parameters
36
+ number_of_epochs: 10
37
+ lr: 1.0
38
+ lr_wav2vec: 0.0001
39
+ sorting: ascending
40
+ auto_mix_prec: False
41
+ sample_rate: 16000
42
+ ckpt_interval_minutes: 30 # save checkpoint every N min
43
+
44
+ # With data_parallel batch_size is split into N jobs
45
+ # With DDP batch_size is multiplied by N jobs
46
+ # Must be 8 per GPU to fit 32GB of VRAM
47
+ batch_size: 8
48
+ test_batch_size: 4
49
+
50
+ dataloader_options:
51
+ batch_size: !ref <batch_size>
52
+ num_workers: 6
53
+ test_dataloader_options:
54
+ batch_size: !ref <test_batch_size>
55
+ num_workers: 6
56
+
57
+ # BPE parameters
58
+ token_type: char # ["unigram", "bpe", "char"]
59
+ character_coverage: 1.0
60
+
61
+ # Model parameters
62
+ # activation: !name:torch.nn.LeakyReLU
63
+ wav2vec_output_dim: 1024
64
+ dnn_neurons: 1024
65
+ freeze_wav2vec: False
66
+ freeze_feature_extractor: True
67
+ dropout: 0.15
68
+ warmup_steps: 500
69
+
70
+ # Outputs
71
+ output_neurons: 29 # BPE size, index(blank/eos/bos) = 0
72
+
73
+ # Decoding parameters
74
+ # Be sure that the bos and eos index match with the BPEs ones
75
+ blank_index: 0
76
+ bos_index: 1
77
+ eos_index: 2
78
+
79
+ #
80
+ # Functions and classes
81
+ #
82
+ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
83
+ limit: !ref <number_of_epochs>
84
+
85
+ augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
86
+ sample_rate: !ref <sample_rate>
87
+ speeds: [95, 100, 105]
88
+
89
+ enc: !new:speechbrain.nnet.containers.Sequential
90
+ input_shape: [null, null, !ref <wav2vec_output_dim>]
91
+ linear1: !name:speechbrain.nnet.linear.Linear
92
+ n_neurons: !ref <dnn_neurons>
93
+ bias: True
94
+ bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
95
+ activation: !new:torch.nn.LeakyReLU
96
+ drop: !new:torch.nn.Dropout
97
+ p: !ref <dropout>
98
+ linear2: !name:speechbrain.nnet.linear.Linear
99
+ n_neurons: !ref <dnn_neurons>
100
+ bias: True
101
+ bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
102
+ activation2: !new:torch.nn.LeakyReLU
103
+ drop2: !new:torch.nn.Dropout
104
+ p: !ref <dropout>
105
+ linear3: !name:speechbrain.nnet.linear.Linear
106
+ n_neurons: !ref <dnn_neurons>
107
+ bias: True
108
+ bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
109
+ activation3: !new:torch.nn.LeakyReLU
110
+
111
+ wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
112
+ source: /gpfsscratch/rech/nou/uzn19yk/wav2vec2-large-lv60/
113
+ output_norm: True
114
+ freeze: !ref <freeze_wav2vec>
115
+ freeze_feature_extractor: !ref <freeze_feature_extractor>
116
+ save_path: !ref <wav2vec2_folder>
117
+
118
+ #####
119
+ # Uncomment this block if you prefer to use a Fairseq pretrained model instead
120
+ # of a HuggingFace one. Here, we provide an URL that is obtained from the
121
+ # Fairseq github for the multilingual XLSR.
122
+ #
123
+ #wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt
124
+ #wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
125
+ # pretrained_path: !ref <wav2vec2_url>
126
+ # output_norm: True
127
+ # freeze: False
128
+ # save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
129
+ #####
130
+
131
+ ctc_lin: !new:speechbrain.nnet.linear.Linear
132
+ input_size: !ref <dnn_neurons>
133
+ n_neurons: !ref <output_neurons>
134
+
135
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
136
+ apply_log: True
137
+
138
+ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
139
+ blank_index: !ref <blank_index>
140
+
141
+ modules:
142
+ wav2vec2: !ref <wav2vec2>
143
+ enc: !ref <enc>
144
+ ctc_lin: !ref <ctc_lin>
145
+
146
+ model: !new:torch.nn.ModuleList
147
+ - [!ref <enc>, !ref <ctc_lin>]
148
+
149
+ model_opt_class: !name:torch.optim.Adadelta
150
+ lr: !ref <lr>
151
+ rho: 0.95
152
+ eps: 1.e-8
153
+
154
+ wav2vec_opt_class: !name:torch.optim.Adam
155
+ lr: !ref <lr_wav2vec>
156
+
157
+ lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
158
+ initial_value: !ref <lr>
159
+ improvement_threshold: 0.0025
160
+ annealing_factor: 0.8
161
+ patient: 0
162
+
163
+ lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
164
+ initial_value: !ref <lr_wav2vec>
165
+ improvement_threshold: 0.0025
166
+ annealing_factor: 0.9
167
+ patient: 0
168
+
169
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
170
+ checkpoints_dir: !ref <save_folder>
171
+ recoverables:
172
+ wav2vec2: !ref <wav2vec2>
173
+ model: !ref <model>
174
+ scheduler_model: !ref <lr_annealing_model>
175
+ scheduler_wav2vec: !ref <lr_annealing_wav2vec>
176
+ counter: !ref <epoch_counter>
177
+
178
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
179
+ save_file: !ref <train_log>
180
+
181
+ error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
182
+
183
+ cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
184
+ split_tokens: True
EnglishCV/train_with_wav2vec.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import torch
4
+ import logging
5
+ import speechbrain as sb
6
+ import torchaudio
7
+ from hyperpyyaml import load_hyperpyyaml
8
+ from speechbrain.tokenizers.SentencePiece import SentencePiece
9
+ from speechbrain.utils.data_utils import undo_padding
10
+ from speechbrain.utils.distributed import run_on_main
11
+
12
+ """Recipe for training a sequence-to-sequence ASR system with CommonVoice.
13
+ The system employs a wav2vec2 encoder and a CTC decoder.
14
+ Decoding is performed with greedy decoding (will be extended to beam search).
15
+
16
+ To run this recipe, do the following:
17
+ > python train_with_wav2vec2.py hparams/train_with_wav2vec2.yaml
18
+
19
+ With the default hyperparameters, the system employs a pretrained wav2vec2 encoder.
20
+ The wav2vec2 model is pretrained following the model given in the hprams file.
21
+ It may be dependent on the language.
22
+
23
+ The neural network is trained with CTC on sub-word units estimated with
24
+ Byte Pairwise Encoding (BPE).
25
+
26
+ The experiment file is flexible enough to support a large variety of
27
+ different systems. By properly changing the parameter files, you can try
28
+ different encoders, decoders, tokens (e.g, characters instead of BPE),
29
+ training languages (all CommonVoice languages), and many
30
+ other possible variations.
31
+
32
+ Authors
33
+ * Titouan Parcollet 2021
34
+ """
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ # Define training procedure
40
+ class ASR(sb.core.Brain):
41
+ def compute_forward(self, batch, stage):
42
+ """Forward computations from the waveform batches to the output probabilities."""
43
+
44
+ batch = batch.to(self.device)
45
+ wavs, wav_lens = batch.sig
46
+ tokens_bos, _ = batch.tokens_bos
47
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
48
+
49
+ if stage == sb.Stage.TRAIN:
50
+ if hasattr(self.hparams, "augmentation"):
51
+ wavs = self.hparams.augmentation(wavs, wav_lens)
52
+
53
+ # Forward pass
54
+ feats = self.modules.wav2vec2(wavs, wav_lens)
55
+ x = self.modules.enc(feats)
56
+ logits = self.modules.ctc_lin(x)
57
+ p_ctc = self.hparams.log_softmax(logits)
58
+
59
+ return p_ctc, wav_lens
60
+
61
+ def compute_objectives(self, predictions, batch, stage):
62
+ """Computes the loss (CTC) given predictions and targets."""
63
+
64
+ p_ctc, wav_lens = predictions
65
+
66
+ ids = batch.id
67
+ tokens_eos, tokens_eos_lens = batch.tokens_eos
68
+ tokens, tokens_lens = batch.tokens
69
+
70
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
71
+
72
+ if stage != sb.Stage.TRAIN:
73
+ # Decode token terms to words
74
+ sequence = sb.decoders.ctc_greedy_decode(
75
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
76
+ )
77
+
78
+ predicted_words = self.tokenizer(sequence, task="decode_from_list")
79
+
80
+ # Convert indices to words
81
+ target_words = undo_padding(tokens, tokens_lens)
82
+ target_words = self.tokenizer(target_words, task="decode_from_list")
83
+
84
+ self.wer_metric.append(ids, predicted_words, target_words)
85
+ self.cer_metric.append(ids, predicted_words, target_words)
86
+
87
+ return loss
88
+
89
+ def fit_batch(self, batch):
90
+ """Train the parameters given a single batch in input"""
91
+ should_step = self.step % self.grad_accumulation_factor == 0
92
+ # Managing automatic mixed precision
93
+ # TOFIX: CTC fine-tuning currently is unstable
94
+ # This is certainly due to CTC being done in fp16 instead of fp32
95
+ if self.auto_mix_prec:
96
+ with torch.cuda.amp.autocast():
97
+ with self.no_sync():
98
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
99
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
100
+ with self.no_sync(not should_step):
101
+ self.scaler.scale(
102
+ loss / self.grad_accumulation_factor
103
+ ).backward()
104
+ if should_step:
105
+
106
+ if not self.hparams.wav2vec2.freeze:
107
+ self.scaler.unscale_(self.wav2vec_optimizer)
108
+ self.scaler.unscale_(self.model_optimizer)
109
+ if self.check_gradients(loss):
110
+ if not self.hparams.wav2vec2.freeze:
111
+ if self.optimizer_step >= self.hparams.warmup_steps:
112
+ self.scaler.step(self.wav2vec_optimizer)
113
+ self.scaler.step(self.model_optimizer)
114
+ self.scaler.update()
115
+ self.zero_grad()
116
+ self.optimizer_step += 1
117
+ else:
118
+ # This is mandatory because HF models have a weird behavior with DDP
119
+ # on the forward pass
120
+ with self.no_sync():
121
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
122
+
123
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
124
+
125
+ with self.no_sync(not should_step):
126
+ (loss / self.grad_accumulation_factor).backward()
127
+ if should_step:
128
+ if self.check_gradients(loss):
129
+ if not self.hparams.wav2vec2.freeze:
130
+ if self.optimizer_step >= self.hparams.warmup_steps:
131
+ self.wav2vec_optimizer.step()
132
+ self.model_optimizer.step()
133
+ self.zero_grad()
134
+ self.optimizer_step += 1
135
+
136
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
137
+ return loss.detach().cpu()
138
+
139
+ def evaluate_batch(self, batch, stage):
140
+ """Computations needed for validation/test batches"""
141
+ predictions = self.compute_forward(batch, stage=stage)
142
+ with torch.no_grad():
143
+ loss = self.compute_objectives(predictions, batch, stage=stage)
144
+ return loss.detach()
145
+
146
+ def on_stage_start(self, stage, epoch):
147
+ """Gets called at the beginning of each epoch"""
148
+ if stage != sb.Stage.TRAIN:
149
+ self.cer_metric = self.hparams.cer_computer()
150
+ self.wer_metric = self.hparams.error_rate_computer()
151
+
152
+ def on_stage_end(self, stage, stage_loss, epoch):
153
+ """Gets called at the end of an epoch."""
154
+ # Compute/store important stats
155
+ stage_stats = {"loss": stage_loss}
156
+ if stage == sb.Stage.TRAIN:
157
+ self.train_stats = stage_stats
158
+ else:
159
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
160
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
161
+
162
+ # Perform end-of-iteration things, like annealing, logging, etc.
163
+ if stage == sb.Stage.VALID:
164
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
165
+ stage_stats["loss"]
166
+ )
167
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
168
+ stage_stats["loss"]
169
+ )
170
+ sb.nnet.schedulers.update_learning_rate(
171
+ self.model_optimizer, new_lr_model
172
+ )
173
+ if not self.hparams.wav2vec2.freeze:
174
+ sb.nnet.schedulers.update_learning_rate(
175
+ self.wav2vec_optimizer, new_lr_wav2vec
176
+ )
177
+ self.hparams.train_logger.log_stats(
178
+ stats_meta={
179
+ "epoch": epoch,
180
+ "lr_model": old_lr_model,
181
+ "lr_wav2vec": old_lr_wav2vec,
182
+ },
183
+ train_stats=self.train_stats,
184
+ valid_stats=stage_stats,
185
+ )
186
+ self.checkpointer.save_and_keep_only(
187
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
188
+ )
189
+ elif stage == sb.Stage.TEST:
190
+ self.hparams.train_logger.log_stats(
191
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
192
+ test_stats=stage_stats,
193
+ )
194
+ with open(self.hparams.wer_file, "w") as w:
195
+ self.wer_metric.write_stats(w)
196
+
197
+ def init_optimizers(self):
198
+ "Initializes the wav2vec2 optimizer and model optimizer"
199
+
200
+ # If the wav2vec encoder is unfrozen, we create the optimizer
201
+ if not self.hparams.wav2vec2.freeze:
202
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
203
+ self.modules.wav2vec2.parameters()
204
+ )
205
+ if self.checkpointer is not None:
206
+ self.checkpointer.add_recoverable(
207
+ "wav2vec_opt", self.wav2vec_optimizer
208
+ )
209
+
210
+ self.model_optimizer = self.hparams.model_opt_class(
211
+ self.hparams.model.parameters()
212
+ )
213
+
214
+ if self.checkpointer is not None:
215
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
216
+
217
+ def zero_grad(self, set_to_none=False):
218
+ if not self.hparams.wav2vec2.freeze:
219
+ self.wav2vec_optimizer.zero_grad(set_to_none)
220
+ self.model_optimizer.zero_grad(set_to_none)
221
+
222
+
223
+ # Define custom data procedure
224
+ def dataio_prepare(hparams, tokenizer):
225
+ """This function prepares the datasets to be used in the brain class.
226
+ It also defines the data processing pipeline through user-defined functions."""
227
+
228
+ # 1. Define datasets
229
+ data_folder = hparams["data_folder"]
230
+
231
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
232
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
233
+ )
234
+
235
+ if hparams["sorting"] == "ascending":
236
+ # we sort training data to speed up training and get better results.
237
+ train_data = train_data.filtered_sorted(
238
+ sort_key="duration",
239
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
240
+ )
241
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
242
+ hparams["dataloader_options"]["shuffle"] = False
243
+
244
+ elif hparams["sorting"] == "descending":
245
+ train_data = train_data.filtered_sorted(
246
+ sort_key="duration",
247
+ reverse=True,
248
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
249
+ )
250
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
251
+ hparams["dataloader_options"]["shuffle"] = False
252
+
253
+ elif hparams["sorting"] == "random":
254
+ pass
255
+
256
+ else:
257
+ raise NotImplementedError(
258
+ "sorting must be random, ascending or descending"
259
+ )
260
+
261
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
262
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
263
+ )
264
+ # We also sort the validation data so it is faster to validate
265
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
266
+
267
+ test_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
268
+ csv_path=hparams["test_csv"], replacements={"data_root": data_folder},
269
+ )
270
+
271
+ # We also sort the validation data so it is faster to validate
272
+ test_data = test_data.filtered_sorted(sort_key="duration")
273
+
274
+ datasets = [train_data, valid_data, test_data]
275
+
276
+ # 2. Define audio pipeline:
277
+ @sb.utils.data_pipeline.takes("wav")
278
+ @sb.utils.data_pipeline.provides("sig")
279
+ def audio_pipeline(wav):
280
+ info = torchaudio.info(wav)
281
+ sig = sb.dataio.dataio.read_audio(wav)
282
+ resampled = torchaudio.transforms.Resample(
283
+ info.sample_rate, hparams["sample_rate"],
284
+ )(sig)
285
+ return resampled
286
+
287
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
288
+
289
+ # 3. Define text pipeline:
290
+ @sb.utils.data_pipeline.takes("wrd")
291
+ @sb.utils.data_pipeline.provides(
292
+ "tokens_list", "tokens_bos", "tokens_eos", "tokens"
293
+ )
294
+ def text_pipeline(wrd):
295
+ tokens_list = tokenizer.sp.encode_as_ids(wrd)
296
+ yield tokens_list
297
+ tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
298
+ yield tokens_bos
299
+ tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
300
+ yield tokens_eos
301
+ tokens = torch.LongTensor(tokens_list)
302
+ yield tokens
303
+
304
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
305
+
306
+ # 4. Set output:
307
+ sb.dataio.dataset.set_output_keys(
308
+ datasets, ["id", "sig", "tokens_bos", "tokens_eos", "tokens"],
309
+ )
310
+ return train_data, valid_data, test_data
311
+
312
+
313
+ if __name__ == "__main__":
314
+
315
+ # Load hyperparameters file with command-line overrides
316
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
317
+ with open(hparams_file) as fin:
318
+ hparams = load_hyperpyyaml(fin, overrides)
319
+
320
+ # If --distributed_launch then
321
+ # create ddp_group with the right communication protocol
322
+ sb.utils.distributed.ddp_init_group(run_opts)
323
+
324
+ # Dataset preparation (parsing CommonVoice)
325
+ from common_voice_prepare import prepare_common_voice # noqa
326
+
327
+ # Create experiment directory
328
+ sb.create_experiment_directory(
329
+ experiment_directory=hparams["output_folder"],
330
+ hyperparams_to_save=hparams_file,
331
+ overrides=overrides,
332
+ )
333
+
334
+ # Due to DDP, we do the preparation ONLY on the main python process
335
+ run_on_main(
336
+ prepare_common_voice,
337
+ kwargs={
338
+ "data_folder": hparams["data_folder"],
339
+ "save_folder": hparams["save_folder"],
340
+ "train_tsv_file": hparams["train_tsv_file"],
341
+ "dev_tsv_file": hparams["dev_tsv_file"],
342
+ "test_tsv_file": hparams["test_tsv_file"],
343
+ "accented_letters": hparams["accented_letters"],
344
+ "language": hparams["language"],
345
+ "skip_prep": hparams["skip_prep"],
346
+ },
347
+ )
348
+
349
+ # Defining tokenizer and loading it
350
+ tokenizer = SentencePiece(
351
+ model_dir=hparams["save_folder"],
352
+ vocab_size=hparams["output_neurons"],
353
+ annotation_train=hparams["train_csv"],
354
+ annotation_read="wrd",
355
+ model_type=hparams["token_type"],
356
+ character_coverage=hparams["character_coverage"],
357
+ )
358
+
359
+ # Create the datasets objects as well as tokenization and encoding :-D
360
+ train_data, valid_data, test_data = dataio_prepare(hparams, tokenizer)
361
+
362
+ # Trainer initialization
363
+ asr_brain = ASR(
364
+ modules=hparams["modules"],
365
+ hparams=hparams,
366
+ run_opts=run_opts,
367
+ checkpointer=hparams["checkpointer"],
368
+ )
369
+
370
+ # Adding objects to trainer.
371
+ asr_brain.tokenizer = tokenizer
372
+
373
+ # Training
374
+ asr_brain.fit(
375
+ asr_brain.hparams.epoch_counter,
376
+ train_data,
377
+ valid_data,
378
+ train_loader_kwargs=hparams["dataloader_options"],
379
+ valid_loader_kwargs=hparams["test_dataloader_options"],
380
+ )
381
+
382
+ # Test
383
+ asr_brain.hparams.wer_file = hparams["output_folder"] + "/wer_test.txt"
384
+ asr_brain.evaluate(
385
+ test_data,
386
+ min_key="WER",
387
+ test_loader_kwargs=hparams["test_dataloader_options"],
388
+ )
README.md CHANGED
@@ -1,3 +1,21 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+ # Tunisian Arabic ASR Model with wav2vec2 and code switching
5
+ This repository provides all the necessary tools to perform automatic speech recognition from an end-to-end system pretrained on Tunisian arabic dialect. This model utilizes a code_switching approach and can process english , french and tunisian arabic
6
+ ## Performance
7
+ the performance of the mode is :
8
+ | Release Version |WER (%) | CER (%) |
9
+ |-----------------|---------|---------|
10
+ | v1.0 |29.47 | 12.44 |
11
+ ## Pipeline
12
+ The architecture comprises three components:
13
+ * French ASR pretrained with wav2vec2 on french corporas
14
+ * English ASR pretrained with wav2vec2 on english corporas
15
+ * Custom Tunisian ASR pretrained using wav2vec on a tunisian arabic corpora
16
+ All three models will process the audio data. Subsequently, the resulting posteriorgrams will be combined and utilized as input for the Mixer, which will produce the final posteriorgrams.
17
+ ## Install
18
+ ```python
19
+ pip install speechbrain transformers
20
+ ```
21
+
TunisianASR/README.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tunisian Arabic ASR Model with wav2vec2
2
+
3
+ This repository provides all the necessary tools to perform automatic speech recognition from an end-to-end system pretrained on Tunisian arabic dialect
4
+
5
+ ## Performance
6
+ the performance of the mode is :
7
+ | Release Version | |WER (%) | CER (%) |
8
+ |-----------------|----|---------|---------|
9
+ | v1.0 | Without LM |11.82 | 6.33 |
10
+ ## Dataset
11
+ This ASR model was trained on :
12
+ * TARIC : The corpus, named TARIC (Tunisian Arabic Railway Interaction Corpus) has a collection of audio recordings and transcriptions from dialogues in the Tunisian Railway Transport Network. - [Taric Corpus](https://aclanthology.org/L14-1385/) -
13
+ * STAC :A corpus of spoken Tunisian Arabic - [STAC Corpus](https://www.researchgate.net/publication/307583782_Spoken_Tunisian_Arabic_Corpus_STAC_Transcription_and_Annotation)
14
+ * IWSLT : A Tunisian conversational speech - [IWSLT Corpus](https://iwslt.org/2022/dialect)-
15
+ * Tunspeech : Our custom dataset
16
+
17
+ ## Install
18
+ ```python
19
+ pip install speechbrain transformers
20
+ ```
21
+
TunisianASR/outdomain.arpa ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24654c1d236bb1bd367125131c847c4a734e69914eda71a6786964c20440d8fe
3
+ size 324243244
TunisianASR/semi_wavlm_large_tunisian_ctc/1234/hyperparams.yaml ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated 2023-09-08 from:
2
+ # /gpfsdsstore/projects/rech/nou/uzn19yk/switched_code_tunisian/train/tunisian_asr/hparams/train_semi.yaml
3
+ # yamllint disable
4
+ # ################################
5
+ # Model: wav2vec2 + DNN + CTC
6
+ # Augmentation: SpecAugment
7
+ # Authors: Titouan Parcollet 2021
8
+ # ################################
9
+
10
+ # Seed needs to be set at top of yaml, before objects with parameters are made
11
+ seed: 1234
12
+ __set_seed: !!python/object/apply:torch.manual_seed [1234]
13
+ output_folder: results/semi_wavlm_large_tunisian_ctc/1234
14
+ wer_file: results/semi_wavlm_large_tunisian_ctc/1234/wer.txt
15
+ save_folder: results/semi_wavlm_large_tunisian_ctc/1234/save
16
+ train_log: results/semi_wavlm_large_tunisian_ctc/1234/train_log.txt
17
+
18
+ # URL for the biggest LeBenchmark wav2vec french.
19
+ wav2vec2_folder: results/semi_wavlm_large_tunisian_ctc/1234/save/wav2vec2_checkpoint
20
+
21
+ # Data files
22
+ data_folder: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
23
+ train_tsv_file: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk/train.tsv # Standard CommonVoice .tsv files
24
+ dev_tsv_file: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk/dev.tsv # Standard CommonVoice .tsv files
25
+ test_tsv_file: /gpfsscratch/rech/nou/uzn19yk/tunisian_junk/test.tsv # Standard CommonVoice .tsv files
26
+ accented_letters: true
27
+ language: fr # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
28
+ train_csv: /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/train_enhanced.csv
29
+ valid_csv: /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/dev.csv
30
+ test_csv:
31
+ - /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/full_annotation_test.csv
32
+ - /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/iwslt_test.csv
33
+ - /gpfsscratch/rech/nou/uzn19yk/tunisian_csvs/good_final/taric_test.csv
34
+
35
+ skip_prep: true # Skip data preparation
36
+
37
+ use_language_modelling: true
38
+ ngram_lm_path: arpas/outdomain.arpa
39
+
40
+ # We remove utterance slonger than 10s in the train/dev/test sets as
41
+ # longer sentences certainly correspond to "open microphones".
42
+ avoid_if_longer_than: 10.0
43
+ avoid_if_shorter_than: 1.2
44
+
45
+
46
+ # Training parameters
47
+ number_of_epochs: 12
48
+ lr: 1.0
49
+ lr_wav2vec: 0.0001
50
+ sorting: ascending
51
+ auto_mix_prec: false
52
+ sample_rate: 16000
53
+ ckpt_interval_minutes: 30 # save checkpoint every N min
54
+
55
+ # With data_parallel batch_size is split into N jobs
56
+ # With DDP batch_size is multiplied by N jobs
57
+ # Must be 6 per GPU to fit 16GB of VRAM
58
+ batch_size: 10
59
+ test_batch_size: 4
60
+
61
+ dataloader_options:
62
+ batch_size: 10
63
+ num_workers: 6
64
+ test_dataloader_options:
65
+ batch_size: 4
66
+ num_workers: 6
67
+
68
+ # BPE parameters
69
+ token_type: char # ["unigram", "bpe", "char"]
70
+ character_coverage: 1.0
71
+
72
+ # Model parameters
73
+ # activation: !name:torch.nn.LeakyReLU
74
+ wav2vec_output_dim: 1024
75
+ dnn_neurons: 1024
76
+ freeze_wav2vec: false
77
+ freeze_feature_extractor: true
78
+ dropout: 0.15
79
+ warmup_steps: 500 # The wav2vec 2 model isn't updated for this amount of steps
80
+
81
+ # Outputs
82
+ output_neurons: 40 # BPE size, index(blank/eos/bos) = 0
83
+
84
+ # Decoding parameters
85
+ # Be sure that the bos and eos index match with the BPEs ones
86
+ blank_index: 0
87
+ unk_index: 1
88
+
89
+ #
90
+ # Functions and classes
91
+ #
92
+ epoch_counter: &id007 !new:speechbrain.utils.epoch_loop.EpochCounter
93
+
94
+ limit: 12
95
+
96
+ augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
97
+ sample_rate: 16000
98
+ speeds: [95, 100, 105]
99
+
100
+ enc: &id002 !new:speechbrain.nnet.containers.Sequential
101
+ input_shape: [null, null, 1024]
102
+ linear1: !name:speechbrain.nnet.linear.Linear
103
+ n_neurons: 1024
104
+ bias: true
105
+ bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
106
+ activation: !new:torch.nn.LeakyReLU
107
+ drop: !new:torch.nn.Dropout
108
+ p: 0.15
109
+ linear2: !name:speechbrain.nnet.linear.Linear
110
+ n_neurons: 1024
111
+ bias: true
112
+ bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
113
+ activation2: !new:torch.nn.LeakyReLU
114
+ drop2: !new:torch.nn.Dropout
115
+ p: 0.15
116
+ linear3: !name:speechbrain.nnet.linear.Linear
117
+ n_neurons: 1024
118
+ bias: true
119
+ bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
120
+ activation3: !new:torch.nn.LeakyReLU
121
+
122
+ wav2vec2: &id001 !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
123
+ source: /gpfsstore/rech/nou/uzn19yk/wavlm/
124
+ output_norm: false
125
+ freeze: false
126
+ freeze_feature_extractor: true
127
+ save_path: results/semi_wavlm_large_tunisian_ctc/1234/save/wav2vec2_checkpoint
128
+
129
+ #####
130
+ # Uncomment this block if you prefer to use a Fairseq pretrained model instead
131
+ # of a HuggingFace one. Here, we provide an URL that is obtained from the
132
+ # Fairseq github for the multilingual XLSR.
133
+ #
134
+ #wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt
135
+ #wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
136
+ # pretrained_path: !ref <wav2vec2_url>
137
+ # output_norm: True
138
+ # freeze: False
139
+ # save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
140
+ #####
141
+
142
+
143
+ ctc_lin: &id003 !new:speechbrain.nnet.linear.Linear
144
+
145
+ input_size: 1024
146
+ n_neurons: 40
147
+
148
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
149
+ apply_log: true
150
+
151
+ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
152
+ blank_index: 0
153
+
154
+ modules:
155
+ wav2vec2: *id001
156
+ enc: *id002
157
+ ctc_lin: *id003
158
+ model: &id004 !new:torch.nn.ModuleList
159
+ - [*id002, *id003]
160
+ model_opt_class: !name:torch.optim.Adadelta
161
+ lr: 1.0
162
+ rho: 0.95
163
+ eps: 1.e-8
164
+
165
+ wav2vec_opt_class: !name:torch.optim.Adam
166
+ lr: 0.0001
167
+
168
+ lr_annealing_model: &id005 !new:speechbrain.nnet.schedulers.NewBobScheduler
169
+ initial_value: 1.0
170
+ improvement_threshold: 0.0025
171
+ annealing_factor: 0.8
172
+ patient: 0
173
+
174
+ lr_annealing_wav2vec: &id006 !new:speechbrain.nnet.schedulers.NewBobScheduler
175
+ initial_value: 0.0001
176
+ improvement_threshold: 0.0025
177
+ annealing_factor: 0.9
178
+ patient: 0
179
+
180
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
181
+ checkpoints_dir: results/semi_wavlm_large_tunisian_ctc/1234/save
182
+ recoverables:
183
+ wav2vec2: *id001
184
+ model: *id004
185
+ scheduler_model: *id005
186
+ scheduler_wav2vec: *id006
187
+ counter: *id007
188
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
189
+ save_file: results/semi_wavlm_large_tunisian_ctc/1234/train_log.txt
190
+
191
+ error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
192
+
193
+ cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
194
+ split_tokens: true
TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/CKPT.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # yamllint disable
2
+ WER: 27.83210816487267
3
+ end-of-epoch: true
4
+ unixtime: 1693868963.5220973
TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/brain.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3947a24e8dff5a14299b9cf2fe66ffb4d738cb88717de7f0cf7e8547a76e9776
3
+ size 51
TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/counter.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b51d431df5d7f141cbececcf79edf3dd861c3b4069f0b11661a3eefacbba918
3
+ size 2
TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/dataloader-TRAIN.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b363886c229e536bd3c84e0c3e89312d70e00422578e076a62df1b45c9390793
3
+ size 5
TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc1dbeca1e1f1340b08d8ebea6e492f474708dddbbe8cabbcdde5ee9660704f2
3
+ size 12814446
TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/modelopt.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3af1791eb9a5bfbfc087d2c10b94634df24cad3ac503ce9ba280a3ecc4737781
3
+ size 25575663
TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/scheduler_model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c275ab9245b440d1586f72058d9edaac1a2fb3e7a52712aa9a9ad022b99a1c0d
3
+ size 639
TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/scheduler_wav2vec.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a88187f7882dc3e10c108f1b7abfbd819285b34bded4e88e91c4ff699c1bb5d2
3
+ size 643
TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/wav2vec2.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:788267bd25ef37623715fa21a975090e5e316fff05971375cd3f62e5160f0743
3
+ size 1262005979
TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/CKPT+2023-09-05+01-09-23+00/wav2vec_opt.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efa967fdd8067be7d88c18cd197980c9c91f344a3dff2b2518b8381c49f28b1e
3
+ size 2490361859
TunisianASR/semi_wavlm_large_tunisian_ctc/1234/save/label_encoder.txt ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'ب' => 38
2
+ 'ا' => 1
3
+ 'ه' => 2
4
+ 'ي' => 3
5
+ 'و' => 4
6
+ 'ن' => 5
7
+ 'أ' => 6
8
+ ' ' => 7
9
+ 'م' => 8
10
+ 'ش' => 9
11
+ 'ل' => 10
12
+ 'س' => 11
13
+ 'ت' => 12
14
+ 'د' => 13
15
+ 'ر' => 14
16
+ 'ى' => 15
17
+ 'ح' => 16
18
+ 'ط' => 17
19
+ 'ع' => 18
20
+ 'ك' => 19
21
+ 'ف' => 20
22
+ 'ق' => 21
23
+ 'آ' => 22
24
+ 'ة' => 23
25
+ 'ج' => 24
26
+ 'ض' => 25
27
+ 'ز' => 26
28
+ 'ص' => 27
29
+ 'إ' => 28
30
+ 'ث' => 29
31
+ 'خ' => 30
32
+ 'ڨ' => 31
33
+ 'ذ' => 32
34
+ 'ظ' => 33
35
+ 'ء' => 34
36
+ 'غ' => 35
37
+ 'ئ' => 36
38
+ 'ؤ' => 37
39
+ '<blank>' => 0
40
+ 1 => 39
41
+ ================
42
+ 'starting_index' => 0
43
+ 'unk_label' => 1
44
+ 'blank_label' => '<blank>'
TunisianASR/train_semi.yaml ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ################################
2
+ # Model: wav2vec2 + DNN + CTC
3
+ # Augmentation: SpecAugment
4
+ # Authors: Titouan Parcollet 2021
5
+ # ################################
6
+
7
+ # Seed needs to be set at top of yaml, before objects with parameters are made
8
+ seed: 1234
9
+ __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
10
+ output_folder: !ref semi_wavlm_large_tunisian_ctc/<seed>
11
+ wer_file: !ref <output_folder>/wer.txt
12
+ save_folder: !ref <output_folder>/save
13
+ train_log: !ref <output_folder>/train_log.txt
14
+
15
+ # URL for the biggest LeBenchmark wav2vec french.
16
+ wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
17
+
18
+ # Data files
19
+ data_folder: /path/to/data # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
20
+ train_tsv_file: !ref <data_folder>/train.tsv # Standard CommonVoice .tsv files
21
+ dev_tsv_file: !ref <data_folder>/dev.tsv # Standard CommonVoice .tsv files
22
+ test_tsv_file: !ref <data_folder>/test.tsv # Standard CommonVoice .tsv files
23
+ accented_letters: True
24
+ language: fr # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
25
+ test_csv:
26
+ - /path/to/test_data
27
+
28
+ skip_prep: True # Skip data preparation
29
+
30
+ use_language_modelling: True
31
+ ngram_lm_path: outdomain.arpa
32
+
33
+ # We remove utterance slonger than 10s in the train/dev/test sets as
34
+ # longer sentences certainly correspond to "open microphones".
35
+ avoid_if_longer_than: 10.0
36
+ avoid_if_shorter_than: 1.2
37
+
38
+
39
+ # Training parameters
40
+ number_of_epochs: 12
41
+ lr: 1.0
42
+ lr_wav2vec: 0.0001
43
+ sorting: ascending
44
+ auto_mix_prec: False
45
+ sample_rate: 16000
46
+ ckpt_interval_minutes: 30 # save checkpoint every N min
47
+
48
+ # With data_parallel batch_size is split into N jobs
49
+ # With DDP batch_size is multiplied by N jobs
50
+ # Must be 6 per GPU to fit 16GB of VRAM
51
+ batch_size: 10
52
+ test_batch_size: 4
53
+
54
+ dataloader_options:
55
+ batch_size: !ref <batch_size>
56
+ num_workers: 6
57
+ test_dataloader_options:
58
+ batch_size: !ref <test_batch_size>
59
+ num_workers: 6
60
+
61
+ # BPE parameters
62
+ token_type: char # ["unigram", "bpe", "char"]
63
+ character_coverage: 1.0
64
+
65
+ # Model parameters
66
+ # activation: !name:torch.nn.LeakyReLU
67
+ wav2vec_output_dim: 1024
68
+ dnn_neurons: 1024
69
+ freeze_wav2vec: False
70
+ freeze_feature_extractor: True
71
+ dropout: 0.15
72
+ warmup_steps: 500 # The wav2vec 2 model isn't updated for this amount of steps
73
+
74
+ # Outputs
75
+ output_neurons: 40 # BPE size, index(blank/eos/bos) = 0
76
+
77
+ # Decoding parameters
78
+ # Be sure that the bos and eos index match with the BPEs ones
79
+ blank_index: 0
80
+ unk_index: 1
81
+
82
+ #
83
+ # Functions and classes
84
+ #
85
+ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
86
+ limit: !ref <number_of_epochs>
87
+
88
+ augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
89
+ sample_rate: !ref <sample_rate>
90
+ speeds: [95, 100, 105]
91
+
92
+ enc: !new:speechbrain.nnet.containers.Sequential
93
+ input_shape: [null, null, !ref <wav2vec_output_dim>]
94
+ linear1: !name:speechbrain.nnet.linear.Linear
95
+ n_neurons: !ref <dnn_neurons>
96
+ bias: True
97
+ bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
98
+ activation: !new:torch.nn.LeakyReLU
99
+ drop: !new:torch.nn.Dropout
100
+ p: !ref <dropout>
101
+ linear2: !name:speechbrain.nnet.linear.Linear
102
+ n_neurons: !ref <dnn_neurons>
103
+ bias: True
104
+ bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
105
+ activation2: !new:torch.nn.LeakyReLU
106
+ drop2: !new:torch.nn.Dropout
107
+ p: !ref <dropout>
108
+ linear3: !name:speechbrain.nnet.linear.Linear
109
+ n_neurons: !ref <dnn_neurons>
110
+ bias: True
111
+ bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
112
+ activation3: !new:torch.nn.LeakyReLU
113
+
114
+ wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
115
+ source: /gpfsstore/rech/nou/uzn19yk/wavlm/
116
+ output_norm: False
117
+ freeze: !ref <freeze_wav2vec>
118
+ freeze_feature_extractor: !ref <freeze_feature_extractor>
119
+ save_path: !ref <wav2vec2_folder>
120
+
121
+
122
+ ctc_lin: !new:speechbrain.nnet.linear.Linear
123
+ input_size: !ref <dnn_neurons>
124
+ n_neurons: !ref <output_neurons>
125
+
126
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
127
+ apply_log: True
128
+
129
+ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
130
+ blank_index: !ref <blank_index>
131
+
132
+ modules:
133
+ wav2vec2: !ref <wav2vec2>
134
+ enc: !ref <enc>
135
+ ctc_lin: !ref <ctc_lin>
136
+
137
+ model: !new:torch.nn.ModuleList
138
+ - [!ref <enc>, !ref <ctc_lin>]
139
+
140
+ model_opt_class: !name:torch.optim.Adadelta
141
+ lr: !ref <lr>
142
+ rho: 0.95
143
+ eps: 1.e-8
144
+
145
+ wav2vec_opt_class: !name:torch.optim.Adam
146
+ lr: !ref <lr_wav2vec>
147
+
148
+ lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
149
+ initial_value: !ref <lr>
150
+ improvement_threshold: 0.0025
151
+ annealing_factor: 0.8
152
+ patient: 0
153
+
154
+ lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
155
+ initial_value: !ref <lr_wav2vec>
156
+ improvement_threshold: 0.0025
157
+ annealing_factor: 0.9
158
+ patient: 0
159
+
160
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
161
+ checkpoints_dir: !ref <save_folder>
162
+ recoverables:
163
+ wav2vec2: !ref <wav2vec2>
164
+ model: !ref <model>
165
+ scheduler_model: !ref <lr_annealing_model>
166
+ scheduler_wav2vec: !ref <lr_annealing_wav2vec>
167
+ counter: !ref <epoch_counter>
168
+
169
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
170
+ save_file: !ref <train_log>
171
+
172
+ error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
173
+
174
+ cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
175
+ split_tokens: True
TunisianASR/train_with_wavlm.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import torch
4
+ import logging
5
+ import speechbrain as sb
6
+ from pathlib import Path
7
+ import os
8
+ import torchaudio
9
+ from hyperpyyaml import load_hyperpyyaml
10
+ from speechbrain.tokenizers.SentencePiece import SentencePiece
11
+ from speechbrain.utils.data_utils import undo_padding
12
+ from speechbrain.utils.distributed import run_on_main
13
+
14
+ """Recipe for training a sequence-to-sequence ASR system with CommonVoice.
15
+ The system employs a wav2vec2 encoder and a CTC decoder.
16
+ Decoding is performed with greedy decoding (will be extended to beam search).
17
+
18
+ To run this recipe, do the following:
19
+ > python train_with_wav2vec2.py hparams/train_with_wav2vec2.yaml
20
+
21
+ With the default hyperparameters, the system employs a pretrained wav2vec2 encoder.
22
+ The wav2vec2 model is pretrained following the model given in the hprams file.
23
+ It may be dependent on the language.
24
+
25
+ The neural network is trained with CTC on sub-word units estimated with
26
+ Byte Pairwise Encoding (BPE).
27
+
28
+ The experiment file is flexible enough to support a large variety of
29
+ different systems. By properly changing the parameter files, you can try
30
+ different encoders, decoders, tokens (e.g, characters instead of BPE),
31
+ training languages (all CommonVoice languages), and many
32
+ other possible variations.
33
+
34
+ Authors
35
+ * Titouan Parcollet 2021
36
+ """
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ # Define training procedure
42
+ class ASR(sb.core.Brain):
43
+ def compute_forward(self, batch, stage):
44
+ """Forward computations from the waveform batches to the output probabilities."""
45
+
46
+ batch = batch.to(self.device)
47
+ wavs, wav_lens = batch.sig
48
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
49
+ if stage == sb.Stage.TRAIN:
50
+ if hasattr(self.hparams, "augmentation"):
51
+ wavs = self.hparams.augmentation(wavs, wav_lens)
52
+
53
+ # Forward pass
54
+ feats = self.modules.wav2vec2(wavs, wav_lens)
55
+ x = self.modules.enc(feats)
56
+ logits = self.modules.ctc_lin(x)
57
+ p_ctc = self.hparams.log_softmax(logits)
58
+
59
+ return p_ctc, wav_lens
60
+
61
+ def compute_objectives(self, predictions, batch, stage):
62
+ """Computes the loss (CTC) given predictions and targets."""
63
+
64
+ p_ctc, wav_lens = predictions
65
+
66
+ ids = batch.id
67
+ tokens, tokens_lens = batch.tokens
68
+
69
+ loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
70
+
71
+ if stage != sb.Stage.TRAIN:
72
+ predicted_tokens = sb.decoders.ctc_greedy_decode(
73
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
74
+ )
75
+ # Decode token terms to words
76
+ if self.hparams.use_language_modelling:
77
+ predicted_words = []
78
+ for logs in p_ctc:
79
+ text = decoder.decode(logs.detach().cpu().numpy())
80
+ predicted_words.append(text.split(" "))
81
+ else:
82
+ predicted_words = [
83
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
84
+ for utt_seq in predicted_tokens
85
+ ]
86
+ # Convert indices to words
87
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
88
+
89
+ self.wer_metric.append(ids, predicted_words, target_words)
90
+ self.cer_metric.append(ids, predicted_words, target_words)
91
+
92
+ return loss
93
+
94
+ def fit_batch(self, batch):
95
+ """Train the parameters given a single batch in input"""
96
+ should_step = self.step % self.grad_accumulation_factor == 0
97
+ # Managing automatic mixed precision
98
+ # TOFIX: CTC fine-tuning currently is unstable
99
+ # This is certainly due to CTC being done in fp16 instead of fp32
100
+ if self.auto_mix_prec:
101
+ with torch.cuda.amp.autocast():
102
+ with self.no_sync():
103
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
104
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
105
+ with self.no_sync(not should_step):
106
+ self.scaler.scale(
107
+ loss / self.grad_accumulation_factor
108
+ ).backward()
109
+ if should_step:
110
+
111
+ if not self.hparams.wav2vec2.freeze:
112
+ self.scaler.unscale_(self.wav2vec_optimizer)
113
+ self.scaler.unscale_(self.model_optimizer)
114
+ if self.check_gradients(loss):
115
+ if not self.hparams.wav2vec2.freeze:
116
+ if self.optimizer_step >= self.hparams.warmup_steps:
117
+ self.scaler.step(self.wav2vec_optimizer)
118
+ self.scaler.step(self.model_optimizer)
119
+ self.scaler.update()
120
+ self.zero_grad()
121
+ self.optimizer_step += 1
122
+ else:
123
+ # This is mandatory because HF models have a weird behavior with DDP
124
+ # on the forward pass
125
+ with self.no_sync():
126
+ outputs = self.compute_forward(batch, sb.Stage.TRAIN)
127
+
128
+ loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
129
+
130
+ with self.no_sync(not should_step):
131
+ (loss / self.grad_accumulation_factor).backward()
132
+ if should_step:
133
+ if self.check_gradients(loss):
134
+ if not self.hparams.wav2vec2.freeze:
135
+ if self.optimizer_step >= self.hparams.warmup_steps:
136
+ self.wav2vec_optimizer.step()
137
+ self.model_optimizer.step()
138
+ self.zero_grad()
139
+ self.optimizer_step += 1
140
+
141
+ self.on_fit_batch_end(batch, outputs, loss, should_step)
142
+ return loss.detach().cpu()
143
+
144
+ def evaluate_batch(self, batch, stage):
145
+ """Computations needed for validation/test batches"""
146
+ predictions = self.compute_forward(batch, stage=stage)
147
+ with torch.no_grad():
148
+ loss = self.compute_objectives(predictions, batch, stage=stage)
149
+ return loss.detach()
150
+
151
+ def on_stage_start(self, stage, epoch):
152
+ """Gets called at the beginning of each epoch"""
153
+ if stage != sb.Stage.TRAIN:
154
+ self.cer_metric = self.hparams.cer_computer()
155
+ self.wer_metric = self.hparams.error_rate_computer()
156
+
157
+ def on_stage_end(self, stage, stage_loss, epoch):
158
+ """Gets called at the end of an epoch."""
159
+ # Compute/store important stats
160
+ stage_stats = {"loss": stage_loss}
161
+ if stage == sb.Stage.TRAIN:
162
+ self.train_stats = stage_stats
163
+ else:
164
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
165
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
166
+
167
+ # Perform end-of-iteration things, like annealing, logging, etc.
168
+ if stage == sb.Stage.VALID:
169
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
170
+ stage_stats["loss"]
171
+ )
172
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
173
+ stage_stats["loss"]
174
+ )
175
+ sb.nnet.schedulers.update_learning_rate(
176
+ self.model_optimizer, new_lr_model
177
+ )
178
+ if not self.hparams.wav2vec2.freeze:
179
+ sb.nnet.schedulers.update_learning_rate(
180
+ self.wav2vec_optimizer, new_lr_wav2vec
181
+ )
182
+ self.hparams.train_logger.log_stats(
183
+ stats_meta={
184
+ "epoch": epoch,
185
+ "lr_model": old_lr_model,
186
+ "lr_wav2vec": old_lr_wav2vec,
187
+ },
188
+ train_stats=self.train_stats,
189
+ valid_stats=stage_stats,
190
+ )
191
+ self.checkpointer.save_and_keep_only(
192
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
193
+ )
194
+ elif stage == sb.Stage.TEST:
195
+ self.hparams.train_logger.log_stats(
196
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
197
+ test_stats=stage_stats,
198
+ )
199
+ with open(self.hparams.wer_file, "w") as w:
200
+ self.wer_metric.write_stats(w)
201
+
202
+ def init_optimizers(self):
203
+ "Initializes the wav2vec2 optimizer and model optimizer"
204
+
205
+ # If the wav2vec encoder is unfrozen, we create the optimizer
206
+ if not self.hparams.wav2vec2.freeze:
207
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
208
+ self.modules.wav2vec2.parameters()
209
+ )
210
+ if self.checkpointer is not None:
211
+ self.checkpointer.add_recoverable(
212
+ "wav2vec_opt", self.wav2vec_optimizer
213
+ )
214
+
215
+ self.model_optimizer = self.hparams.model_opt_class(
216
+ self.hparams.model.parameters()
217
+ )
218
+
219
+ if self.checkpointer is not None:
220
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
221
+
222
+ def zero_grad(self, set_to_none=False):
223
+ if not self.hparams.wav2vec2.freeze:
224
+ self.wav2vec_optimizer.zero_grad(set_to_none)
225
+ self.model_optimizer.zero_grad(set_to_none)
226
+
227
+
228
+ # Define custom data procedure
229
+ def dataio_prepare(hparams):
230
+ """This function prepares the datasets to be used in the brain class.
231
+ It also defines the data processing pipeline through user-defined functions."""
232
+
233
+ # 1. Define datasets
234
+ data_folder = hparams["data_folder"]
235
+
236
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
237
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
238
+ )
239
+
240
+ if hparams["sorting"] == "ascending":
241
+ # we sort training data to speed up training and get better results.
242
+ train_data = train_data.filtered_sorted(
243
+ sort_key="duration",
244
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
245
+ )
246
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
247
+ hparams["dataloader_options"]["shuffle"] = False
248
+
249
+ elif hparams["sorting"] == "descending":
250
+ train_data = train_data.filtered_sorted(
251
+ sort_key="duration",
252
+ reverse=True,
253
+ key_max_value={"duration": hparams["avoid_if_longer_than"]},
254
+ )
255
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
256
+ hparams["dataloader_options"]["shuffle"] = False
257
+
258
+ elif hparams["sorting"] == "random":
259
+ pass
260
+
261
+ else:
262
+ raise NotImplementedError(
263
+ "sorting must be random, ascending or descending"
264
+ )
265
+
266
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
267
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
268
+ )
269
+ # We also sort the validation data so it is faster to validate
270
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
271
+ test_datasets = {}
272
+ for csv_file in hparams["test_csv"]:
273
+ name = Path(csv_file).stem
274
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
275
+ csv_path=csv_file, replacements={"data_root": data_folder}
276
+ )
277
+ test_datasets[name] = test_datasets[name].filtered_sorted(
278
+ sort_key="duration"
279
+ )
280
+
281
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
282
+
283
+
284
+ # 2. Define audio pipeline:
285
+ @sb.utils.data_pipeline.takes("wav")
286
+ @sb.utils.data_pipeline.provides("sig")
287
+ def audio_pipeline(wav):
288
+ info = torchaudio.info(wav)
289
+ sig = sb.dataio.dataio.read_audio(wav)
290
+ resampled = torchaudio.transforms.Resample(
291
+ info.sample_rate, hparams["sample_rate"],
292
+ )(sig)
293
+ return resampled
294
+
295
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
296
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
297
+
298
+ # 3. Define text pipeline:
299
+ @sb.utils.data_pipeline.takes("wrd")
300
+ @sb.utils.data_pipeline.provides(
301
+ "wrd", "char_list", "tokens_list", "tokens"
302
+ )
303
+ def text_pipeline(wrd):
304
+ yield wrd
305
+ char_list = list(wrd)
306
+ yield char_list
307
+ tokens_list = label_encoder.encode_sequence(char_list)
308
+ yield tokens_list
309
+ tokens = torch.LongTensor(tokens_list)
310
+ yield tokens
311
+
312
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
313
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
314
+ special_labels = {
315
+ "blank_label": hparams["blank_index"],
316
+ "unk_label": hparams["unk_index"]
317
+ }
318
+ label_encoder.load_or_create(
319
+ path=lab_enc_file,
320
+ from_didatasets=[train_data],
321
+ output_key="char_list",
322
+ special_labels=special_labels,
323
+ sequence_input=True,
324
+ )
325
+
326
+ # 4. Set output:
327
+ sb.dataio.dataset.set_output_keys(
328
+ datasets, ["id", "sig", "wrd", "char_list", "tokens"],
329
+ )
330
+ return train_data, valid_data,test_datasets, label_encoder
331
+
332
+
333
+ if __name__ == "__main__":
334
+
335
+ # Load hyperparameters file with command-line overrides
336
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
337
+ with open(hparams_file) as fin:
338
+ hparams = load_hyperpyyaml(fin, overrides)
339
+
340
+ # If --distributed_launch then
341
+ # create ddp_group with the right communication protocol
342
+ sb.utils.distributed.ddp_init_group(run_opts)
343
+
344
+
345
+ # Create experiment directory
346
+ sb.create_experiment_directory(
347
+ experiment_directory=hparams["output_folder"],
348
+ hyperparams_to_save=hparams_file,
349
+ overrides=overrides,
350
+ )
351
+
352
+ # Due to DDP, we do the preparation ONLY on the main python process
353
+ # Defining tokenizer and loading it
354
+ # Create the datasets objects as well as tokenization and encoding :-D
355
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(hparams)
356
+ if hparams["use_language_modelling"]:
357
+ print("using langauge_modeeling")
358
+ from pyctcdecode import build_ctcdecoder
359
+ ind2lab = label_encoder.ind2lab
360
+ print(ind2lab)
361
+ labels = [ind2lab[x] for x in range(len(ind2lab))]
362
+ labels = [""] + labels[1:-1] + ["1"]
363
+ # Replace the <blank> token with a blank character, needed for PyCTCdecode
364
+ print(labels)
365
+ decoder = build_ctcdecoder(
366
+ labels,
367
+ kenlm_model_path=hparams["ngram_lm_path"], # .arpa or .bin
368
+ alpha=0.5, # Default by KenLM
369
+ beta=1.0, # Default by KenLM
370
+ )
371
+ # Trainer initialization
372
+ asr_brain = ASR(
373
+ modules=hparams["modules"],
374
+ hparams=hparams,
375
+ run_opts=run_opts,
376
+ checkpointer=hparams["checkpointer"],
377
+ )
378
+
379
+ # Adding objects to trainer.
380
+ asr_brain.tokenizer = label_encoder
381
+
382
+ # Training
383
+ asr_brain.fit(
384
+ asr_brain.hparams.epoch_counter,
385
+ train_data,
386
+ valid_data,
387
+ train_loader_kwargs=hparams["dataloader_options"],
388
+ valid_loader_kwargs=hparams["test_dataloader_options"],
389
+ )
390
+
391
+ # Test
392
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
393
+ asr_brain.hparams.wer_file = os.path.join(
394
+ hparams["output_folder"], "wer_{}.txt".format(k)
395
+ )
396
+ asr_brain.evaluate(
397
+ test_datasets[k], test_loader_kwargs=hparams["test_dataloader_options"]
398
+ )
399
+
arpas/everything.arpa ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bada7d41f63b1e5fd661ba66bccdfa93c3e5c391038ac6e52615a42ec0e0174
3
+ size 345991397
asr-wav2vec2-commonvoice-fr/README.md ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - fr
4
+ thumbnail: null
5
+ pipeline_tag: automatic-speech-recognition
6
+ tags:
7
+ - CTC
8
+ - pytorch
9
+ - speechbrain
10
+ - Transformer
11
+ - hf-asr-leaderboard
12
+ license: apache-2.0
13
+ datasets:
14
+ - commonvoice
15
+ metrics:
16
+ - wer
17
+ - cer
18
+ model-index:
19
+ - name: asr-wav2vec2-commonvoice-fr
20
+ results:
21
+ - task:
22
+ name: Automatic Speech Recognition
23
+ type: automatic-speech-recognition
24
+ dataset:
25
+ name: CommonVoice 6.1 (French)
26
+ type: mozilla-foundation/common_voice_6_1
27
+ config: fr
28
+ split: test
29
+ args:
30
+ language: fr
31
+ metrics:
32
+ - name: Test WER
33
+ type: wer
34
+ value: '9.96'
35
+ ---
36
+
37
+ <iframe src="https://ghbtns.com/github-btn.html?user=speechbrain&repo=speechbrain&type=star&count=true&size=large&v=2" frameborder="0" scrolling="0" width="170" height="30" title="GitHub"></iframe>
38
+ <br/><br/>
39
+
40
+ # wav2vec 2.0 with CTC/Attention trained on CommonVoice French (No LM)
41
+
42
+ This repository provides all the necessary tools to perform automatic speech
43
+ recognition from an end-to-end system pretrained on CommonVoice (French Language) within
44
+ SpeechBrain. For a better experience, we encourage you to learn more about
45
+ [SpeechBrain](https://speechbrain.github.io).
46
+
47
+ The performance of the model is the following:
48
+
49
+ | Release | Test CER | Test WER | GPUs |
50
+ |:-------------:|:--------------:|:--------------:| :--------:|
51
+ | 24-08-21 | 3.19 | 9.96 | 2xV100 32GB |
52
+
53
+ ## Pipeline description
54
+
55
+ This ASR system is composed of 2 different but linked blocks:
56
+ - Tokenizer (unigram) that transforms words into subword units and trained with
57
+ the train transcriptions (train.tsv) of CommonVoice (FR).
58
+ - Acoustic model (wav2vec2.0 + CTC). A pretrained wav2vec 2.0 model ([LeBenchmark/wav2vec2-FR-7K-large](https://huggingface.co/LeBenchmark/wav2vec2-FR-7K-large)) is combined with two DNN layers and finetuned on CommonVoice FR.
59
+ The obtained final acoustic representation is given to the CTC greedy decoder.
60
+
61
+ The system is trained with recordings sampled at 16kHz (single channel).
62
+ The code will automatically normalize your audio (i.e., resampling + mono channel selection) when calling *transcribe_file* if needed.
63
+
64
+ ## Install SpeechBrain
65
+
66
+ First of all, please install tranformers and SpeechBrain with the following command:
67
+
68
+ ```
69
+ pip install speechbrain transformers
70
+ ```
71
+
72
+ Please notice that we encourage you to read our tutorials and learn more about
73
+ [SpeechBrain](https://speechbrain.github.io).
74
+
75
+ ### Transcribing your own audio files (in French)
76
+
77
+ ```python
78
+ from speechbrain.pretrained import EncoderASR
79
+
80
+ asr_model = EncoderASR.from_hparams(source="speechbrain/asr-wav2vec2-commonvoice-fr", savedir="pretrained_models/asr-wav2vec2-commonvoice-fr")
81
+ asr_model.transcribe_file('speechbrain/asr-wav2vec2-commonvoice-fr/example-fr.wav')
82
+
83
+ ```
84
+ ### Inference on GPU
85
+ To perform inference on the GPU, add `run_opts={"device":"cuda"}` when calling the `from_hparams` method.
86
+
87
+ ### Training
88
+ The model was trained with SpeechBrain.
89
+ To train it from scratch follow these steps:
90
+ 1. Clone SpeechBrain:
91
+ ```bash
92
+ git clone https://github.com/speechbrain/speechbrain/
93
+ ```
94
+ 2. Install it:
95
+ ```bash
96
+ cd speechbrain
97
+ pip install -r requirements.txt
98
+ pip install -e .
99
+ ```
100
+
101
+ 3. Run Training:
102
+ ```bash
103
+ cd recipes/CommonVoice/ASR/CTC/
104
+ python train_with_wav2vec.py hparams/train_fr_with_wav2vec.yaml --data_folder=your_data_folder
105
+ ```
106
+
107
+ You can find our training results (models, logs, etc) [here](https://drive.google.com/drive/folders/1T9DfdZwcNI9CURxhLCi8GA5JVz8adiY8?usp=sharing).
108
+
109
+ ### Limitations
110
+ The SpeechBrain team does not provide any warranty on the performance achieved by this model when used on other datasets.
111
+
112
+ #### Referencing SpeechBrain
113
+
114
+ ```
115
+ @misc{SB2021,
116
+ author = {Ravanelli, Mirco and Parcollet, Titouan and Rouhe, Aku and Plantinga, Peter and Rastorgueva, Elena and Lugosch, Loren and Dawalatabad, Nauman and Ju-Chieh, Chou and Heba, Abdel and Grondin, Francois and Aris, William and Liao, Chien-Feng and Cornell, Samuele and Yeh, Sung-Lin and Na, Hwidong and Gao, Yan and Fu, Szu-Wei and Subakan, Cem and De Mori, Renato and Bengio, Yoshua },
117
+ title = {SpeechBrain},
118
+ year = {2021},
119
+ publisher = {GitHub},
120
+ journal = {GitHub repository},
121
+ howpublished = {\\\\url{https://github.com/speechbrain/speechbrain}},
122
+ }
123
+ ```
124
+
125
+ #### About SpeechBrain
126
+ SpeechBrain is an open-source and all-in-one speech toolkit. It is designed to be simple, extremely flexible, and user-friendly. Competitive or state-of-the-art performance is obtained in various domains.
127
+
128
+ Website: https://speechbrain.github.io/
129
+
130
+ GitHub: https://github.com/speechbrain/speechbrain
asr-wav2vec2-commonvoice-fr/asr.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64ba475ed7be735d4ac054c2d537f22251b80f6ecb65cb04217eb0d1ed50a143
3
+ size 12963902