mskov commited on
Commit
68ed0e8
β€’
1 Parent(s): 6c847d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -2
app.py CHANGED
@@ -32,8 +32,23 @@ model = WhisperForConditionalGeneration.from_pretrained("mskov/whisper-small-esc
32
 
33
 
34
  # Remove brackets and extra spaces
 
 
 
 
 
 
 
 
 
 
35
 
36
- '''
 
 
 
 
 
37
  def map_to_pred(batch):
38
  cleaned_transcription = re.sub(r'\[[^\]]+\]', '', batch['category']).strip()
39
  print("cleaned transcript", cleaned_transcription)
@@ -57,6 +72,7 @@ result = dataset.map(map_to_pred)
57
  wer = load("wer")
58
  print(100 * wer.compute(references=result["reference"], predictions=result["prediction"]))
59
  '''
 
60
  with torch.no_grad():
61
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
62
  print("outputs ", outputs)
@@ -74,7 +90,7 @@ wer_score = wer(labels, predicted_text)
74
 
75
  # Print or return WER score
76
  print(f"Word Error Rate (WER): {wer_score}")
77
-
78
 
79
  def transcribe(audio):
80
  text = pipe(audio)["text"]
 
32
 
33
 
34
  # Remove brackets and extra spaces
35
+ def map_to_pred(batch):
36
+ audio = batch["audio"]
37
+ input_features = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
38
+ batch["reference"] = processor.tokenizer._normalize(batch['category'])
39
+
40
+ with torch.no_grad():
41
+ predicted_ids = model.generate(input_features.to("cuda"))[0]
42
+ transcription = processor.decode(predicted_ids)
43
+ batch["prediction"] = processor.tokenizer._normalize(transcription)
44
+ return batch
45
 
46
+ result = dataset.map(map_to_pred)
47
+
48
+ wer = load("wer")
49
+ print(100 * wer.compute(references=result["reference"], predictions=result["prediction"]))
50
+
51
+ '''
52
  def map_to_pred(batch):
53
  cleaned_transcription = re.sub(r'\[[^\]]+\]', '', batch['category']).strip()
54
  print("cleaned transcript", cleaned_transcription)
 
72
  wer = load("wer")
73
  print(100 * wer.compute(references=result["reference"], predictions=result["prediction"]))
74
  '''
75
+ '''
76
  with torch.no_grad():
77
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
78
  print("outputs ", outputs)
 
90
 
91
  # Print or return WER score
92
  print(f"Word Error Rate (WER): {wer_score}")
93
+ '''
94
 
95
  def transcribe(audio):
96
  text = pipe(audio)["text"]