Update app.py
Browse files
app.py
CHANGED
@@ -19,14 +19,15 @@ def ASR(audio):
|
|
19 |
# Load the audio file using torchaudio
|
20 |
waveform, sample_rate = torchaudio.load(temp_audio_file.name)
|
21 |
# Resample the audio to 16kHz
|
22 |
-
|
23 |
-
|
24 |
# Convert the PyTorch tensor to a NumPy ndarray
|
25 |
# Preprocess the audio file
|
26 |
input_values = processor(waveform.squeeze().numpy(),sampling_rate=16_000, return_tensors="pt").input_values
|
|
|
27 |
# Transcribe the audio file
|
28 |
with torch.no_grad():
|
29 |
-
logits = model(input_values).logits
|
30 |
# Decode the transcription
|
31 |
transcription = processor.decode(torch.argmax(logits, dim=-1))
|
32 |
return transcription
|
|
|
19 |
# Load the audio file using torchaudio
|
20 |
waveform, sample_rate = torchaudio.load(temp_audio_file.name)
|
21 |
# Resample the audio to 16kHz
|
22 |
+
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
|
23 |
+
waveform = resampler(waveform)
|
24 |
# Convert the PyTorch tensor to a NumPy ndarray
|
25 |
# Preprocess the audio file
|
26 |
input_values = processor(waveform.squeeze().numpy(),sampling_rate=16_000, return_tensors="pt").input_values
|
27 |
+
attention_mask = processor(waveform.squeeze().numpy(),sampling_rate=16_000, return_tensors="pt").attention_mask
|
28 |
# Transcribe the audio file
|
29 |
with torch.no_grad():
|
30 |
+
logits = model(input_values,attention_mask).logits
|
31 |
# Decode the transcription
|
32 |
transcription = processor.decode(torch.argmax(logits, dim=-1))
|
33 |
return transcription
|