SLPL
/

sadrasabouri commited on
Commit
b8eb4fa
1 Parent(s): 177ebd7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +17 -8
README.md CHANGED
@@ -101,7 +101,10 @@ processor = Wav2Vec2ProcessorWithLM.from_pretrained("SLPL/Sharif-wav2vec2")
101
  def speech_file_to_array_fn(batch):
102
  speech_array, sampling_rate = torchaudio.load(batch["path"])
103
  speech_array = speech_array.squeeze().numpy()
104
- speech_array = librosa.resample(np.asarray(speech_array), sampling_rate, processor.feature_extractor.sampling_rate)
 
 
 
105
  batch["speech"] = speech_array
106
  return batch
107
 
@@ -112,24 +115,30 @@ def predict(batch):
112
  return_tensors="pt",
113
  padding=True
114
  )
115
-
116
- input_values = features.input_values
117
- attention_mask = features.attention_mask
118
 
119
  with torch.no_grad():
120
- logits = model(input_values, attention_mask=attention_mask).logits #when we are trying to load model with LM we have to use logits instead of argmax(logits)
 
 
121
  batch["prediction"] = processor.batch_decode(logits.numpy()).text
122
  return batch
123
 
124
- dataset = load_dataset("csv", data_files={"test":"path/to/your.csv"}, delimiter=",")["test"]
 
 
 
125
  dataset = dataset.map(speech_file_to_array_fn)
126
 
127
  result = dataset.map(predict, batched=True, batch_size=4)
128
  wer = load_metric("wer")
129
  cer = load_metric("cer")
130
 
131
- print("WER: {:.2f}".format(100 * wer.compute(predictions=result["prediction"], references=result["reference"])))
132
- print("CER: {:.2f}".format(100 * cer.compute(predictions=result["prediction"], references=result["reference"])))
 
 
 
 
133
  ```
134
 
135
  *Result (WER) on common-voice 6.1*:
 
101
  def speech_file_to_array_fn(batch):
102
  speech_array, sampling_rate = torchaudio.load(batch["path"])
103
  speech_array = speech_array.squeeze().numpy()
104
+ speech_array = librosa.resample(
105
+ np.asarray(speech_array),
106
+ sampling_rate,
107
+ processor.feature_extractor.sampling_rate)
108
  batch["speech"] = speech_array
109
  return batch
110
 
 
115
  return_tensors="pt",
116
  padding=True
117
  )
 
 
 
118
 
119
  with torch.no_grad():
120
+ logits = model(
121
+ features.input_values,
122
+ attention_mask=features.attention_mask).logits
123
  batch["prediction"] = processor.batch_decode(logits.numpy()).text
124
  return batch
125
 
126
+ dataset = load_dataset(
127
+ "csv",
128
+ ata_files={"test":"dataset.eval.csv"},
129
+ delimiter=",")["test"]
130
  dataset = dataset.map(speech_file_to_array_fn)
131
 
132
  result = dataset.map(predict, batched=True, batch_size=4)
133
  wer = load_metric("wer")
134
  cer = load_metric("cer")
135
 
136
+ print("WER: {:.2f}".format(wer.compute(
137
+ predictions=result["prediction"],
138
+ references=result["reference"])))
139
+ print("CER: {:.2f}".format(cer.compute(
140
+ predictions=result["prediction"],
141
+ references=result["reference"])))
142
  ```
143
 
144
  *Result (WER) on common-voice 6.1*: