Kartikeyssj2 commited on
Commit
5efbe11
1 Parent(s): 813004d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +13 -8
main.py CHANGED
@@ -105,19 +105,24 @@ async def upload_audio(file: UploadFile = File(...)):
105
  print("length of the audio array:" , len(audio))
106
 
107
  print("*" * 100)
108
-
109
- # Tokenize audio
110
  print("Tokenizing audio...")
111
- input_values = tokenizer(audio, return_tensors="pt").input_values
112
-
 
 
113
  # Perform inference
114
  print("Performing inference with Wav2Vec2 model...")
115
- logits = model(input_values).logits
116
-
 
 
117
  # Get predictions
118
  print("Getting predictions...")
119
  prediction = torch.argmax(logits, dim=-1)
120
-
 
121
  # Decode predictions
122
  print("Decoding predictions...")
123
  transcription = tokenizer.batch_decode(prediction)[0]
@@ -125,8 +130,8 @@ async def upload_audio(file: UploadFile = File(...)):
125
  # Convert transcription to lowercase
126
  transcription = transcription.lower()
127
 
128
- # Print transcription and word counts
129
  print("Decoded transcription:", transcription)
 
130
  incorrect, correct = count_spelled_words(transcription, english_words)
131
  print("Spelling check - Incorrect words:", incorrect, ", Correct words:", correct)
132
 
 
105
  print("length of the audio array:" , len(audio))
106
 
107
  print("*" * 100)
108
+
109
+ # Tokenization
110
  print("Tokenizing audio...")
111
+ input_values = await asyncio.to_thread(tokenizer, audio, return_tensors="pt")
112
+ input_values = input_values.input_values
113
+ print("Tokenization complete. Shape of input_values:", input_values.shape)
114
+
115
  # Perform inference
116
  print("Performing inference with Wav2Vec2 model...")
117
+ output = await asyncio.to_thread(model, input_values)
118
+ logits = output.logits
119
+ print("Inference complete. Shape of logits:", logits.shape)
120
+
121
  # Get predictions
122
  print("Getting predictions...")
123
  prediction = torch.argmax(logits, dim=-1)
124
+ print("Prediction shape:", prediction.shape)
125
+
126
  # Decode predictions
127
  print("Decoding predictions...")
128
  transcription = tokenizer.batch_decode(prediction)[0]
 
130
  # Convert transcription to lowercase
131
  transcription = transcription.lower()
132
 
 
133
  print("Decoded transcription:", transcription)
134
+
135
  incorrect, correct = count_spelled_words(transcription, english_words)
136
  print("Spelling check - Incorrect words:", incorrect, ", Correct words:", correct)
137