Spaces:
Build error
Build error
Kartikeyssj2
commited on
Commit
•
5efbe11
1
Parent(s):
813004d
Update main.py
Browse files
main.py
CHANGED
@@ -105,19 +105,24 @@ async def upload_audio(file: UploadFile = File(...)):
|
|
105 |
print("length of the audio array:" , len(audio))
|
106 |
|
107 |
print("*" * 100)
|
108 |
-
|
109 |
-
#
|
110 |
print("Tokenizing audio...")
|
111 |
-
input_values = tokenizer
|
112 |
-
|
|
|
|
|
113 |
# Perform inference
|
114 |
print("Performing inference with Wav2Vec2 model...")
|
115 |
-
|
116 |
-
|
|
|
|
|
117 |
# Get predictions
|
118 |
print("Getting predictions...")
|
119 |
prediction = torch.argmax(logits, dim=-1)
|
120 |
-
|
|
|
121 |
# Decode predictions
|
122 |
print("Decoding predictions...")
|
123 |
transcription = tokenizer.batch_decode(prediction)[0]
|
@@ -125,8 +130,8 @@ async def upload_audio(file: UploadFile = File(...)):
|
|
125 |
# Convert transcription to lowercase
|
126 |
transcription = transcription.lower()
|
127 |
|
128 |
-
# Print transcription and word counts
|
129 |
print("Decoded transcription:", transcription)
|
|
|
130 |
incorrect, correct = count_spelled_words(transcription, english_words)
|
131 |
print("Spelling check - Incorrect words:", incorrect, ", Correct words:", correct)
|
132 |
|
|
|
105 |
print("length of the audio array:" , len(audio))
|
106 |
|
107 |
print("*" * 100)
|
108 |
+
|
109 |
+
# Tokenization
|
110 |
print("Tokenizing audio...")
|
111 |
+
input_values = await asyncio.to_thread(tokenizer, audio, return_tensors="pt")
|
112 |
+
input_values = input_values.input_values
|
113 |
+
print("Tokenization complete. Shape of input_values:", input_values.shape)
|
114 |
+
|
115 |
# Perform inference
|
116 |
print("Performing inference with Wav2Vec2 model...")
|
117 |
+
output = await asyncio.to_thread(model, input_values)
|
118 |
+
logits = output.logits
|
119 |
+
print("Inference complete. Shape of logits:", logits.shape)
|
120 |
+
|
121 |
# Get predictions
|
122 |
print("Getting predictions...")
|
123 |
prediction = torch.argmax(logits, dim=-1)
|
124 |
+
print("Prediction shape:", prediction.shape)
|
125 |
+
|
126 |
# Decode predictions
|
127 |
print("Decoding predictions...")
|
128 |
transcription = tokenizer.batch_decode(prediction)[0]
|
|
|
130 |
# Convert transcription to lowercase
|
131 |
transcription = transcription.lower()
|
132 |
|
|
|
133 |
print("Decoded transcription:", transcription)
|
134 |
+
|
135 |
incorrect, correct = count_spelled_words(transcription, english_words)
|
136 |
print("Spelling check - Incorrect words:", incorrect, ", Correct words:", correct)
|
137 |
|