rosyvs commited on
Commit
6d504a5
1 Parent(s): 563b9ce

Upload folder using huggingface_hub

Browse files
LEVI_whisper_benchmark.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%% imports
2
+ import os
3
+ from benchmark_utils import ASRmanifest, wer_from_csv
4
+
5
+
6
+ #%% setup paths
7
+ corpora_root = '/shared/corpora/forSAGA/' # root path where audio files are, inserted in palce of $DATAROOT in manifest
8
+ manif_root = '/shared/corpora/forSAGA/data_manifests/' # path to dir containing data manifest csvs
9
+ output_dir = './ASR_output/' # where to save ASR output
10
+ manifest='LEVI_LoFi_v2_TEST_norm_wer_isat' # name of test manifest
11
+ model_name= 'LEVI_whisper_medium.en' # name of save directory of model you want to evaluate
12
+ hf_org = 'levicu'
13
+ model_path = f'{hf_org}/{model_name}'
14
+
15
+ #%% setup paths for Rosy TESTING:
16
+ corpora_root = '/shared/corpora/' # root path where audio files are, inserted in palce of $DATAROOT in manifest
17
+ manif_root = '/shared/corpora/data_manifests/ASR/' # path to dir containing data manifest csvs
18
+ output_dir = '/home/rosy/whisat-output/' # where to save ASR output
19
+ manifest= 'LEVI_LoFi_v2_TEST_punc+cased' # name of test manifest
20
+ model_name= 'LEVI_LoFi_v2_MediumEN_Lora_Int8' # name of save directory of model you want to evaluate
21
+ model_path='/shared/models/LEVI_LoFi_v2_MediumEN_Lora_Int8/final/'
22
+ model_path='openai/whisper_medium.en'
23
+ #%%
24
+ # generate paths
25
+ manifest_csv=os.path.join(manif_root, f'{manifest}.csv')
26
+ out_csv=os.path.join(output_dir,f'{model_name}_on_{manifest}.csv')
27
+
28
+ #%% Inference
29
+ ASRmanifest(
30
+ manifest_csv=manifest_csv,
31
+ out_csv=out_csv,
32
+ corpora_root=corpora_root,
33
+ model_path=model_path,
34
+ )
35
+
36
+ #%% Evaluation
37
+ print(f'reading results from {out_csv}')
38
+ print(f'{model_name} on {manifest}')
39
+ wer_meas=wer_from_csv(
40
+ out_csv,
41
+ refcol='transcript',
42
+ hypcol='asr',
43
+ printout=True,
44
+ text_norm_method='levi'
45
+ )
46
+
47
+
__init__.py ADDED
File without changes
__pycache__/__init__.cpython-310.pyc ADDED
Binary file (128 Bytes). View file
 
__pycache__/benchmark_utils.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
__pycache__/converters.cpython-310.pyc ADDED
Binary file (5.32 kB). View file
 
__pycache__/renamers.cpython-310.pyc ADDED
Binary file (1.61 kB). View file
 
__pycache__/trimmers.cpython-310.pyc ADDED
Binary file (2.77 kB). View file
 
benchmark_utils.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%% imports
2
+ import pandas as pd
3
+ import time
4
+ from tqdm import tqdm
5
+ import torch
6
+ from torch.cuda.amp import autocast
7
+ import transformers
8
+ from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig
9
+ from transformers import pipeline, AutomaticSpeechRecognitionPipeline
10
+ from peft import PeftModel, PeftConfig
11
+ import warnings
12
+ import jiwer
13
+ from jiwer.process import WordOutput
14
+ import pandas as pd
15
+ import numpy as np
16
+ from pathlib import Path
17
+ import os
18
+ import math
19
+ from decimal import InvalidOperation
20
+ import contractions
21
+ from whisper.normalizers.english import EnglishTextNormalizer
22
+ from num2words import num2words
23
+ import csv
24
+ import re
25
+ import string
26
+
27
+ #%% define functions
28
+ def ASRmanifest(
29
+ manifest_csv: str,
30
+ out_csv: str,
31
+ corpora_root: str,
32
+ model_path:str,
33
+ ):
34
+ """Run Whisper ASR on a dataset specified in a manifest
35
+ Args:
36
+ manifest_csv (str): path to manifest csv listing files to transcribe
37
+ out_csv (str):path to write output csv
38
+ corpora_root (str): root path where audio files are, inserted in place of $DATAROOT in manifest
39
+ model_path (str): path to model directory / huggingface model name
40
+ """
41
+
42
+ df = pd.read_csv(manifest_csv,keep_default_na=False)
43
+ fieldnames = list(df.columns) + ['asr']
44
+
45
+ asr_pipeline=prepare_pipeline(
46
+ model_path=model_path,
47
+ generate_opts={'max_new_tokens':448,
48
+ 'num_beams':1,#greedy
49
+ 'repetition_penalty':1,
50
+ 'do_sample':False
51
+ }
52
+ )
53
+
54
+ message = "This may take a while on CPU." if asr_pipeline.device.type=="cpu" else "Using GPU"
55
+ print(f'Running ASR for {len(df)} files. {message} ...')
56
+ compute_time=0
57
+ total_audio_dur=0
58
+ # get the start time
59
+ st = time.time()
60
+
61
+ with open(out_csv, 'w', newline='') as csvfile:
62
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames,delimiter=',')
63
+ writer.writeheader()
64
+ for i,row in tqdm(df.iterrows(), total=df.shape[0]):
65
+ audiofile=row['wav'].replace('$DATAROOT',corpora_root)
66
+ with torch.no_grad():
67
+ with autocast():
68
+ try:
69
+ result = asr_pipeline(audiofile)
70
+ asrtext = result['text']
71
+ except (FileNotFoundError, ValueError) as e:
72
+ print(f'SKIPPED: {audiofile}')
73
+ continue
74
+ row['asr']=asrtext
75
+ writer.writerow( row.to_dict())
76
+ et = time.time()
77
+ compute_time = (et-st)
78
+ print(f'...transcription complete in {compute_time:.1f} sec')
79
+
80
+ def load_model(
81
+ model_path:str,
82
+ language='english',
83
+ use_int8 = False,
84
+ device_map='auto'):
85
+
86
+ warnings.filterwarnings("ignore")
87
+ transformers.utils.logging.set_verbosity_error()
88
+
89
+ try:
90
+ model = WhisperForConditionalGeneration.from_pretrained(
91
+ model_path,
92
+ load_in_8bit=use_int8,
93
+ device_map=device_map,
94
+ use_cache=False,
95
+ )
96
+ try:
97
+ processor=WhisperProcessor.from_pretrained(model_path, language=language, task="transcribe")
98
+ except OSError:
99
+ print('missing tokenizer and preprocessor config files in save dir, checking directory above...')
100
+ processor=WhisperProcessor.from_pretrained(os.path.join(model_path,'..'), language=language, task="transcribe")
101
+
102
+ except OSError as e:
103
+ print(f'{e}: possibly missing model or config file in model path. Will check for adapter...')
104
+ # check if PEFT
105
+ if os.path.isdir(os.path.join(model_path , "adapter_model")):
106
+ print('found adapter...loading PEFT model')
107
+ # checkpoint dir needs adapter model subdir with adapter_model.bin and adapter_confg.json
108
+ peft_config = PeftConfig.from_pretrained(os.path.join(model_path , "adapter_model"))
109
+ print(f'...loading and merging LORA weights to base model {peft_config.base_model_name_or_path}')
110
+ model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path,
111
+ load_in_8bit=use_int8,
112
+ device_map=device_map,
113
+ use_cache=False,
114
+ )
115
+ model = PeftModel.from_pretrained(model, os.path.join(model_path,"adapter_model"))
116
+ model = model.merge_and_unload()
117
+ processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task="transcribe")
118
+ else:
119
+ raise e
120
+ model.eval()
121
+ return(model, processor)
122
+
123
+ def prepare_pipeline(model_path, generate_opts):
124
+ """Prepare a pipeline for ASR inference
125
+ Args:
126
+ model_path (str): path to model directory / huggingface model name
127
+ generate_opts (dict): options to pass to pipeline
128
+ Returns:
129
+ pipeline: ASR pipeline
130
+ """
131
+ model, processor = load_model(
132
+ model_path=model_path)
133
+
134
+ asr_pipeline = pipeline(
135
+ "automatic-speech-recognition",
136
+ model=model,
137
+ tokenizer=processor.tokenizer,
138
+ feature_extractor=processor.feature_extractor,
139
+ generate_kwargs=generate_opts,
140
+ )
141
+ return asr_pipeline
142
+
143
+ #%% WER evaluation functions
144
+ def get_normalizer(text_norm_method='isat'):
145
+ if text_norm_method=='whisper':
146
+ normalizer=whisper_norm_text_for_wer
147
+ elif text_norm_method=='whisper_keep_tags':
148
+ normalizer=EnglishTextNormalizer()
149
+ elif text_norm_method=='isat':
150
+ normalizer = norm_text_for_wer
151
+ elif text_norm_method=='levi':
152
+ normalizer = levi_norm_text_for_wer
153
+ else:
154
+ raise NotImplementedError(f'unrecognized normalizer method: {text_norm_method}')
155
+ return normalizer
156
+
157
+ def strip_punct(instr, keep_math=False):
158
+ newstr = ''
159
+ for word in instr.split():
160
+ if keep_math:
161
+ word=word.strip('!"#$&\',.:;<=>?@[\\]^_`{|}~')
162
+ else:
163
+ # delete punct from start and end of word
164
+ word = word.strip(string.punctuation)
165
+ # delete commas inside numbers
166
+ m = re.match(r'(\d*),(\d)', word)
167
+ if m != None:
168
+ word = word.replace(',', '')
169
+ # commas inside words become space
170
+ word = re.sub(",", " ", word)
171
+ # hyphens inside words become space
172
+ if keep_math:
173
+ pass
174
+ else:
175
+ word = re.sub("-", " ", word)
176
+ word = word.strip()
177
+ newstr += ' ' + word
178
+ newstr = newstr.strip()
179
+ return newstr
180
+
181
+ def remove_in_brackets(text):
182
+ # removes any clause in brackets or parens, and the brackets themselves
183
+ return re.sub("[\(\[\<].*?[\)\]\>]+", " ", text)
184
+
185
+ def caught_num2words(text):
186
+ # first do currency replacements #TODO: plurals vs singular
187
+ if '$' in text:
188
+ text = re.sub('\$([0-9]+)', '\g<1> dollars', text)
189
+ if '€' in text:
190
+ text = re.sub('\$([0-9]+)', '\g<1> euro', text)
191
+ if '£' in text:
192
+ text = re.sub('\$([0-9]+)', '\g<1> pounds', text)
193
+ if '%' in text:
194
+ text = re.sub('([0-9]+)\%', '\g<1> percent', text)
195
+
196
+ # strip punctuation
197
+ text=strip_punct(text, keep_math=True)
198
+ text=text.strip('*=/')
199
+ # catch strings that might be converted to infinity or NaN and return as is...
200
+ naughty_words = ['INF','Inf','inf','NAN','NaN', 'nan', 'NONE','None','none','Infinity','infinity']
201
+ if text in naughty_words:
202
+ return text
203
+ try:
204
+ if len(text.split()) > 1:
205
+ return ' '.join([caught_num2words(word) for word in text.split()])
206
+ else:
207
+ return num2words(text)
208
+ except (InvalidOperation, ValueError) as error:
209
+ return text
210
+
211
+ def spell_math(text):
212
+ # spell out mathematical expressions
213
+ # numerals preceded by hyphen become negative
214
+ text = re.sub('\-(\d+)', 'minus \g<1>', text)
215
+ text = re.sub('(\d+\s?)\-(\s?\d?)', '\g<1> minus \g<2>', text)
216
+ text = re.sub('(\w+\s+)\-(\s?\w+)', '\g<1> minus \g<2>', text) # need to be more careful with - as this could be a hyphenated word not minus
217
+ text = re.sub('(\w+\s?)\+(\s?\w+)', '\g<1> plus \g<2>', text)
218
+ text = re.sub('(\w+\s?)\*(\s?\w+)', '\g<1> times \g<2>', text)
219
+ text = re.sub('(\d+\s?)x(\s?\d)', '\g<1> times \g<2>', text) # need to be more careful with x as this could be a variable not times
220
+ text = re.sub('(\w+\s?)\/(\s?\w+)', '\g<1> divided by \g<2>', text)
221
+ text = re.sub('(\w+\s?)\=(\s?\w+)', '\g<1> equals \g<2>', text)
222
+ return text
223
+
224
+ def expand_contractions(str):
225
+ expanded_words = []
226
+ for wrd in str.split():
227
+ expanded_words.append(contractions.fix(wrd))
228
+ str = ' '.join(expanded_words)
229
+ return str
230
+
231
+ def norm_text_for_wer(text):
232
+ # function to format text or lists of text (e.g. asr, transcript) for wer computation.
233
+ # Converts from list to a single string and apply some text normalization operations
234
+ # note that the clean_REV_transcript function should be applied first to remove REV-specific keywords
235
+ # and extract text from docx format tables
236
+
237
+ if isinstance(text,list):
238
+ text = ' '.join(text)
239
+ text=str(text)
240
+ text = text.replace('\n',' ') # replace newline with space
241
+ text = remove_in_brackets(text) # removes non-spoken annotations such as [inaudible]
242
+ text = re.sub('%\w+','', text) # remove %HESITATION etc
243
+ text = ' '.join([caught_num2words(str) for str in text.split(' ')]) # spell out numbers
244
+ text = expand_contractions(text)
245
+ text = strip_punct(text)
246
+ text = text.lower()
247
+ text = re.sub('\s+',' ',text) # replace multiple space with single
248
+ return text
249
+
250
+ def levi_norm_text_for_wer(text):
251
+ # function to format text or lists of text (e.g. asr, transcript) for wer computation.
252
+ # specialized for math language
253
+
254
+ if isinstance(text,list):
255
+ text = ' '.join(text)
256
+ text=str(text)
257
+ text = text.replace('\n',' ') # replace newline with space
258
+ text = remove_in_brackets(text) # removes non-spoken annotations such as [inaudible]
259
+ text = re.sub('%\w+','', text) # remove %HESITATION etc
260
+ text = spell_math(text)
261
+ text = ' '.join([caught_num2words(str) for str in text.split(' ')]) # spell out numbers
262
+ text = expand_contractions(text)
263
+ text = strip_punct(text, keep_math=True)
264
+ text = text.lower()
265
+ text = re.sub('\s+',' ',text) # replace multiple space with single
266
+ return text
267
+
268
+ def whisper_norm_text_for_wer(text):
269
+ # function to format text for wer computation.
270
+ # uses Whisper normalizer after stripping corpus-specific special tags
271
+
272
+ if isinstance(text,list):
273
+ text = ' '.join(text)
274
+ text=str(text)
275
+ text = text.replace('\n',' ') # replace newline with space
276
+ text = re.sub('%\w+','', text) # remove %HESITATION etc
277
+ text = remove_in_brackets(text) # removes non-spoken annotations such as [inaudible]
278
+ normalizer = EnglishTextNormalizer()
279
+ text = normalizer(text)
280
+ return text
281
+
282
+ def wer_from_df(
283
+ df,
284
+ refcol='ref',
285
+ hypcol='hyp',
286
+ return_alignments=False,
287
+ normalise = True,
288
+ text_norm_method='isat',
289
+ printout=True):
290
+ """Compute WER from a dataframe containing a ref col and a hyp col
291
+ WER is computed on the edit operation counts over the whole df,
292
+ not averaged over single utterances.
293
+
294
+ Args:
295
+ df (pandas DataFrame): containing rows per utterance
296
+ refcol (str, optional): column name containing reference transcript. Defaults to 'ref'.
297
+ hypcol (str, optional): column name containing hypothesis transcript. Defaults to 'hyp'.
298
+ return_alignments (bool, optional): Return full word-level alignments. Defaults to False.
299
+ normalise (bool, optional): Apply text normalisatin to ref and hyp (see norm_text_for_wer). Defaults to True.
300
+ printout (bool, optional): Print WER metrics. Defaults to True.
301
+ """
302
+ normalizer=get_normalizer(text_norm_method)
303
+
304
+ refs=df[refcol].astype(str)
305
+ hyps = df[hypcol].astype(str)
306
+ if normalise:
307
+ refs=refs.apply(normalizer)
308
+ hyps=hyps.apply(normalizer)
309
+
310
+ #ID,ref,hyp,ref_norm,hyp_norm
311
+ if any(s == '' for s in list(refs)):
312
+ nonempty=refs.str.len()>0
313
+ refs=refs[nonempty]
314
+ hyps=hyps[nonempty]
315
+ # print(f'{sum(~nonempty)} empty references removed (after normalisation if applied)')
316
+ wer_meas = jiwer.compute_measures(list(refs), list(hyps))
317
+
318
+ if not return_alignments:
319
+ # remove alignments
320
+ del wer_meas['ops']
321
+ del wer_meas['truth']
322
+ del wer_meas['hypothesis']
323
+ wer_meas['word_count'] = wer_meas['substitutions']+wer_meas['deletions']+wer_meas['hits']
324
+ wer_meas['sub_rate'] = wer_meas['substitutions']/wer_meas['word_count']
325
+ wer_meas['del_rate'] = wer_meas['deletions']/wer_meas['word_count']
326
+ wer_meas['ins_rate'] = wer_meas['insertions']/wer_meas['word_count']
327
+
328
+ if printout:
329
+ for key in ['wer','sub_rate','del_rate','ins_rate']:
330
+ print((f"{key}={100*wer_meas[key]:.1f}" ))
331
+ print(f"word_count={int(wer_meas['word_count'])}")
332
+ return wer_meas
333
+
334
+
335
+ def wer_from_csv(
336
+ csv_path,
337
+ refcol='ref',
338
+ hypcol='hyp',
339
+ return_alignments=False,
340
+ normalise = True,
341
+ text_norm_method='isat' ,
342
+ printout=True):
343
+
344
+ res = pd.read_csv(csv_path).astype(str)
345
+
346
+ wer_meas=wer_from_df(res,
347
+ refcol=refcol,
348
+ hypcol=hypcol,
349
+ return_alignments=return_alignments,
350
+ normalise = normalise,
351
+ text_norm_method=text_norm_method,
352
+ printout=printout)
353
+ return wer_meas
converters.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import re
4
+ import pandas as pd
5
+ from pathlib import Path
6
+ import numpy as np
7
+ # functions to convert between different transcript/annotation formats
8
+
9
+ #######
10
+ # "table" refers to a pd.Dataframe w the following cols
11
+ # [uttID, speaker, transcript, start_sec, end_sec]
12
+ #########
13
+
14
+ # separate function to write to csv, tsv or ELAN compatible (ELAN interprets ALL commas as delimiter so we need to use tab instead)
15
+
16
+ def HHMMSS_to_sec(time_str):
17
+ """Get Seconds from timestamp string with milliseconds."""
18
+ if not time_str:
19
+ return None
20
+ if time_str.count(':')==2:
21
+ h, m, s = time_str.split(':')
22
+ elif time_str.count(':')==3:
23
+ # weird timestamps where there is a field followign seconds delimited by colon
24
+ h, m, s, u = time_str.split(':')
25
+ # determine whether ms field is in tenths or hundredths or thousandths by countng how many digits
26
+ if len(u)==1:
27
+ print('Weird time format detected - HH:MM:SS:tenths - please verify this is how you want the time interpreted')
28
+ ms = float(u)/10
29
+ elif len(u)==2: # hundredths
30
+ ms = float(u)/100
31
+ elif len(u)==3: # hundredths
32
+ ms = float(u)/1000
33
+ else:
34
+ print(f'input string format not supported: {time_str}')
35
+ return None
36
+ s = int(s)+ms
37
+ elif time_str.count(':')==1:
38
+ # print('missing HH from timestamp, assuming MM:SS')
39
+ m, s = time_str.split(':')
40
+ h=0
41
+ elif time_str.count(':')==0 and time_str.count('.')==1:
42
+ # print('missing HH:MM from timestamp, assuming SS.ms')
43
+ s = float(time_str)
44
+ h=0
45
+ m=0
46
+ else:
47
+ print(f'input string format not supported: {time_str}')
48
+ return None
49
+ return int(h) * 3600 + int(m) * 60 + float(s)
50
+
51
+ def sec_to_timecode(time_sec):
52
+ # convert seconds to HH:MM:SS:hundredths as used in .xlsx transcripts
53
+ h=int(time_sec//3600)
54
+ m=int((time_sec-3600*h)//60)
55
+ s=int(time_sec-3600*h-60*m)
56
+ u=round(100*(time_sec-3600*h-60*m-s))
57
+ timecode=f'{h}:{m:02}:{s:02}:{u:02}'
58
+ return(timecode)
59
+
60
+ def docx_scraped_tsv_to_table(ooona_file):
61
+ # ooona output is a table in a word docx,
62
+ # for now manually copying this out and saving as tsv
63
+ # but the timestamp format is wrong
64
+ # input cols are SHOT START END SPEAKER DIALOGUE
65
+
66
+ with open(ooona_file) as in_file:
67
+ reader = csv.reader(in_file, delimiter="\t")
68
+ next(reader) # skip header
69
+ rows=[]
70
+ for i,line in enumerate(reader):
71
+ utt_ix, start_time, end_time, speaker, transcript = line
72
+ start_sec = HHMMSS_to_sec(start_time)
73
+ end_sec = HHMMSS_to_sec(end_time)
74
+ rows.append([utt_ix,speaker,transcript,start_sec,end_sec])
75
+ utt_table = pd.DataFrame(rows, columns=['uttID','speaker','transcript','start_sec','end_sec'])
76
+ return(utt_table)
77
+ # table = pd.read_csv(ooona_file, sep='\t')
78
+
79
+ def molly_xlsx_to_table(xl_file):
80
+ # contractor transcribers provide an xlsx with the following columns
81
+ # utt_ix: int
82
+ # Timecode: "HH:MM:SS:ss - HH:MM:SS:ss"
83
+ # Duration: HH:MM:SS:ss
84
+ # Speaker: str
85
+ # Dialogue: str
86
+ # Annotations: blank
87
+ # Error Type: blank
88
+ with pd.ExcelFile(xl_file) as xls:
89
+ sheetname = xls.sheet_names
90
+ table = pd.DataFrame(pd.read_excel(xls, sheetname[0]))
91
+ table.columns=table.columns.str.lower()
92
+ table[['start_time','end_time']] = table['timecode'].str.split('-',expand=True)
93
+ table['start_sec'] = table['start_time'].str.strip().apply(HHMMSS_to_sec)
94
+ table['end_sec'] = table['end_time'].str.strip().apply(HHMMSS_to_sec)
95
+ table.drop(labels=['annotations','error type','duration'], axis=1, inplace=True)
96
+ table=table[['#','speaker','dialogue','start_sec','end_sec']]
97
+ table.rename(columns={'#':'uttID', 'dialogue':'transcript'}, inplace=True)
98
+ table.reset_index(inplace=True,drop=True)
99
+ table=table.replace('', np.nan).dropna(subset=['speaker','dialogue'], how='all') # drop rows with missing values in speaker and utterance
100
+ return table
101
+
102
+ def LoFi_xlsx_to_table(xl_file):
103
+ # LoFi transcripts have the following columns:
104
+ # # utt_ix: int
105
+ # Timecode: "HH:MM:SS:ss - HH:MM:SS:ss"
106
+ # Duration: HH:MM:SS:ss
107
+ # Speaker: str
108
+ # Dialogue: str
109
+ # Annotations: blank
110
+ # Error Type: blank
111
+ with pd.ExcelFile(xl_file) as xls:
112
+ sheetname = xls.sheet_names
113
+ table = pd.DataFrame(pd.read_excel(xls, sheetname[0]))
114
+ table[['start_time','end_time']] = table['Timecode'].str.split('-',expand=True)
115
+ table['start_sec'] = table['start_time'].str.strip().apply(HHMMSS_to_sec)
116
+ table['end_sec'] = table['end_time'].str.strip().apply(HHMMSS_to_sec)
117
+ table.drop(labels=['Annotations','Error Type','Duration'], axis=1, inplace=True)
118
+ table=table[['#','Speaker','Dialogue','start_sec','end_sec']]
119
+ table.rename(columns={'#':'uttID','Speaker':'speaker', 'Dialogue':'transcript'}, inplace=True)
120
+
121
+ return table
122
+
123
+ def saga_to_table(saga_txt):
124
+ # saga's own transcripts are txt given in the following format
125
+ #
126
+ # speaker (start time MM:SS)
127
+ # utterance
128
+ # <blank line>
129
+ # TODO: make more robust by pattern matching instead of modulo
130
+ with open(saga_txt) as in_file:
131
+ reader = csv.reader(in_file, delimiter="\n")
132
+ count = 0
133
+ rows=[]
134
+ for i,line in enumerate(reader):
135
+ print((count,line))
136
+ if count%3 == 0:
137
+ # utt = utt.split('\n') # now speaker (time) , transcript
138
+ # transcript = utt[1]
139
+ spk_time = line[0].split('(')
140
+ if len(spk_time)<2:
141
+ # print('!!!speaker not changed')
142
+ # print(line)
143
+ timestamp = spk_time[0].strip('):( ')
144
+ speaker=rows[-1][0] # prev speaker
145
+
146
+ else:
147
+ speaker = spk_time[0]
148
+ timestamp = spk_time[1].replace('):','')
149
+ # print(timestamp)
150
+ start_sec = HHMMSS_to_sec(timestamp)
151
+
152
+ if count%3 == 1:
153
+ transcript = line[0]
154
+ if count%3 == 2:
155
+ rows.append([i,speaker,transcript,start_sec,None])
156
+ #print([speaker,transcript,timestamp])
157
+ count+=1
158
+ utt_table = pd.DataFrame(rows, columns=['uttID','speaker','transcript','start_sec','end_sec'])
159
+ return(utt_table)
160
+
161
+ def table_to_ELAN_tsv(table:pd.DataFrame, path:str):
162
+ # write table to tsv compatible with ELAN import
163
+ table.to_csv(path, index=False, float_format='%.3f',sep='\t')
164
+
165
+ def table_to_standard_csv(table:pd.DataFrame, path:str):
166
+ # write table to standard csv format agreed upon by whole team
167
+
168
+ # TODO: convert times in seconds back to HH:MM:SS?
169
+ # TODO: split utterances into sentences?
170
+ table.to_csv(path,index=False, float_format='%.3f')
171
+
172
+ def table_to_utt_labels_csv(table:pd.DataFrame, path:str):
173
+ # write table to utt_labels csv format comaptable w rosy's isatasr lib
174
+ table.rename(columns={'transcript':'utterance', 'uttID':'seg'}, inplace=True)
175
+ table=table.replace('', np.nan).dropna(subset=['speaker','utterance'], how='all') # drop rows with missing values in speaker and utterance
176
+ table.to_csv(path,index=False, float_format='%.3f')
177
+
178
+ def table_to_molly_xlsx(tbl:pd.DataFrame,path:str):
179
+ tblx = tbl
180
+ tblx.rename(columns={'uttID':'#', 'speaker':'Speaker','transcript':'Dialogue'}, inplace=True)
181
+ tblx['dur_s'] = tblx['end_sec']-tblx['start_sec']
182
+ tblx['start_timecode']=tblx['start_sec'].apply(sec_to_timecode)
183
+ tblx['end_timecode']=tblx['end_sec'].apply(sec_to_timecode)
184
+ tblx['Duration'] = tblx['dur_s'].apply(sec_to_timecode)
185
+ tblx['Timecode'] = [' - '.join(i) for i in zip(tblx['start_timecode'], tblx['end_timecode'])]
186
+ tblx['Annotations'] = ''
187
+ tblx['Error Type'] = ''
188
+ tblx=tblx[['#','Timecode','Duration','Speaker','Dialogue','Annotations','Error Type']]
189
+ tblx.to_excel(path,sheet_name=Path(path).stem, index=False)
190
+
191
+ def utt_labels_csv_to_table(label_csv:str):
192
+ # utt_labels_csv is the usual format used for diarized, timed transcripts in this repo
193
+ # There are several versions with differnt columns (with/without segment &/ utterance index)
194
+ # table:
195
+ # [uttID, speaker, transcript, start_sec, end_sec]
196
+
197
+ table = pd.read_csv(label_csv,keep_default_na=False)
198
+ # choose which column to use for uttID in table
199
+ if 'utt' in table.columns:
200
+ table=table.rename(columns={"utt":"uttID"}).drop('seg', axis=1)
201
+ elif 'seg' in table.columns:
202
+ table=table.rename(columns={"seg":"uttID"})
203
+ else:
204
+ table=table.reset_index().rename(columns={"index":"uttID"})
205
+
206
+ return table
renamers.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ import glob
4
+ import shutil
5
+ import re
6
+
7
+
8
+ # rename files from original filename (hexadecimal salad) to Session_ID (human readable) and back
9
+ global DEFAULT_MAP_PATH
10
+ DEFAULT_MAP_PATH = '../../SessionIDs_from_catalog.csv'
11
+
12
+ def make_SessionID_map(path=DEFAULT_MAP_PATH):
13
+ """generate dictionary from csv file with columns for File_Name and Session_ID -
14
+ copied from columsn 1 & 2 of the Catalog on OneDrive
15
+ """
16
+ SID_to_FN={}
17
+ FN_to_SID={}
18
+ with open(path,encoding='utf-8-sig') as f:
19
+ reader = csv.reader(f)
20
+ headers = next(reader)
21
+ assert (headers[0]=='File_Name' or headers[0]=='Conference_ID') & (headers[1]=='Session_ID'), "Headers are wrong, expected ('File_Name' or 'Conference_ID') and 'Session_ID'"
22
+
23
+ for line in reader:
24
+ filename,sessionID=line
25
+ filename=filename.split('.')[0] # remove extensions
26
+ if (len(filename.strip())>0 and len(sessionID.strip())>0):
27
+ SID_to_FN[sessionID]=filename
28
+ FN_to_SID[filename]=sessionID
29
+ return(SID_to_FN, FN_to_SID)
30
+
31
+
32
+ def rename_files_SID_to_FN(path, recursive=True, overwrite=False):
33
+ SID_to_FN, _=make_SessionID_map()
34
+ #TODO: deal with matching nested sIDs, see commented code below
35
+ newpaths=[]
36
+ for sID in SID_to_FN.keys():
37
+ srclist = glob.glob(os.path.join(path,'**', f'*{sID}.*'), recursive=recursive)
38
+ # print(f'siD: {sID}')
39
+ # print(srclist)
40
+ for srcpath in srclist:
41
+ newpath = srcpath.replace(sID, SID_to_FN[sID])
42
+ print(newpath)
43
+ if overwrite==True:
44
+ shutil.move(srcpath, newpath)
45
+ else:
46
+ shutil.copy(srcpath, newpath)
47
+ newpaths.append(newpath)
48
+ return newpaths
49
+
50
+
51
+ # # get sessnames
52
+ # sesslist = [s for s in os.listdir(path) ]
53
+ # srclist = [os.path.join(src_dir, filename) for filename in os.listdir(src_dir) if os.path.isfile(os.path.join(src_dir, filename))]
54
+ # for src in srclist:
55
+ # sessname_matches = [sessname in src for sessname in sesslist]
56
+ # if sum(sessname_matches)>1:
57
+ # print('!!!! multiple matches, will take longest match. TODO: implement this you dope')
58
+ # elif not any(sessname_matches):
59
+ # print(f'!!!! no sessname matches for file {src}')
60
+ # else:
61
+ # sessname = sesslist[sessname_matches.index(True)]
62
+ # print(f'...copying to {sessname}')
63
+ # shutil.copy(src, os.path.join(dest_dir,sessname))
64
+
65
+ def rename_files_FN_to_SID(path, recursive=True):
66
+ _, FN_to_SID=make_SessionID_map()
67
+
68
+ def extract_conferenceID_from_filename(filename):
69
+ """extract conferenceID from filename
70
+ """
71
+ conferenceID=filename.split(' ')[0]
72
+ conferenceID = re.sub('_?[a-zA-Z]*(\.*[a-zA-Z]*).xlsx','', conferenceID)
73
+ conferenceID=re.sub('TMcoded|Transcript','', conferenceID)
74
+ conferenceID=re.sub('_start\d+_end\d+_?','', conferenceID)
75
+ conferenceID=re.sub(
76
+ '\d{5}_\d{4}-\d{2}-\d{2}_','', conferenceID)
77
+ return conferenceID
trimmers.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+ import csv
4
+ import subprocess
5
+ import pandas as pd
6
+ import sys
7
+ sys.path.append('..')
8
+
9
+ from levi.converters import HHMMSS_to_sec
10
+
11
+ def trim_media(media_in,
12
+ media_out,
13
+ start,
14
+ end):
15
+
16
+ # options for writing out audio if converting
17
+ WAV_CHANNELS = 1
18
+ WAV_SAMPLE_RATE = 16000
19
+
20
+ media_type = Path(media_in).suffix
21
+ ext = Path(media_out).suffix
22
+
23
+ if isinstance(start, str):
24
+ start_sec = HHMMSS_to_sec(start)
25
+ else:
26
+ start_sec = float(start)
27
+ if isinstance(end, str):
28
+ end_sec = HHMMSS_to_sec(end)
29
+ else:
30
+ end_sec = float(end)
31
+
32
+ if ext == '.wav':
33
+ # convert to wav with standard format for audio models
34
+ print(f'...Using ffmpeg to trim video from {start} to {end} \n and convert to {WAV_SAMPLE_RATE}Hz WAV with {WAV_CHANNELS} channels...')
35
+ print(f'...generating {media_out}...')
36
+
37
+ subprocess.call(['ffmpeg',
38
+ '-y',
39
+ '-i',
40
+ media_in,
41
+ '-ss',
42
+ f'{start_sec}',
43
+ '-to',
44
+ f'{end_sec}',
45
+ '-acodec',
46
+ 'pcm_s16le',
47
+ '-ac',
48
+ WAV_CHANNELS,
49
+ '-ar',
50
+ WAV_SAMPLE_RATE,
51
+ media_out,
52
+ '-hide_banner',
53
+ '-loglevel',
54
+ 'warning'
55
+ ],shell=False)
56
+
57
+ else:
58
+
59
+ print(f'...Using ffmpeg to trim video from {start_sec} to {end_sec}...')
60
+ print(f'...generating {media_out}...')
61
+
62
+ subprocess.call(['ffmpeg',
63
+ '-y',
64
+ '-i',
65
+ media_in,
66
+ '-ss',
67
+ f'{start_sec}',
68
+ '-to',
69
+ f'{end_sec}',
70
+ '-c',
71
+ 'copy',
72
+ media_out,
73
+ '-hide_banner',
74
+ '-loglevel',
75
+ 'warning'
76
+ ],shell=False)
77
+
78
+ def trim_media_batch(extract_timings_csv,
79
+ outpath,
80
+ suffix='',
81
+ convert_to=False):
82
+ """trim a batch of media files given a csv of timings
83
+
84
+ Args:
85
+ extract_timings_csv (str): path to csv with columns:
86
+ filepath, start (HH:MM:SS), end (HH:MM:SS)
87
+ outpath (str): output path
88
+ suffix (str, optional): save output trimmed files with this suffix. Defaults to ''.
89
+ convert_to (bool, optional): [None, 'wav','mp4']. Defaults to False.
90
+ Returns:
91
+ outfiles (list): list of file paths created
92
+ """
93
+
94
+
95
+
96
+ os.makedirs(outpath, exist_ok=True)
97
+
98
+ samples_df = pd.read_csv(
99
+ extract_timings_csv,
100
+ skip_blank_lines=True,
101
+ index_col=False,
102
+ names=['media_in','startHMS','endHMS'],
103
+ header=0
104
+ ).dropna().sort_values(
105
+ by='media_in',ignore_index=True).reset_index(drop=True)
106
+
107
+ print(f'TRIMMING {len(samples_df.index)} FILES...')
108
+
109
+ # enumerate samples by session and check if there are multiple samples from a given session
110
+ samples_df['count'] = samples_df.groupby('media_in').cumcount()
111
+ if not os.path.exists(outpath):
112
+ os.makedirs(outpath)
113
+
114
+ outfiles=[]
115
+ for i, rec in samples_df.iterrows():
116
+ media_in,startHMS,endHMS, count = rec.values
117
+ suffix_use = f'{suffix}{count}' if count > 0 else suffix # if multiple samples per recording, give a diffrent name
118
+
119
+ if not os.path.exists(media_in):
120
+ print(f'!!!WARNING: media not found: {media_in}')
121
+ continue
122
+
123
+ media_type = Path(media_in).suffix
124
+ sessname = Path(media_in).stem
125
+ print(f'...Input media: {media_in}')
126
+
127
+ if convert_to=='wav':
128
+ ext = '.wav'
129
+ elif convert_to=='mp4':
130
+ ext = '.mp4'
131
+ else:
132
+ ext = media_type
133
+
134
+ outfile = os.path.expanduser(os.path.join(outpath,f'{sessname}{suffix_use}{ext}'))
135
+
136
+ trim_media(media_in, outfile, HHMMSS_to_sec(startHMS), HHMMSS_to_sec(endHMS))
137
+
138
+ outfiles.append(outfile)
139
+ return(outfiles)