Kartikeyssj2 commited on
Commit
656776b
1 Parent(s): 10dd4bf

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +144 -150
main.py CHANGED
@@ -1,6 +1,97 @@
1
- import soundfile as sf
2
- import numpy as np
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  @app.post('/fluency_score')
5
  async def fluency_scoring(file: UploadFile = File(...)):
6
  with sf.SoundFile(file.file, 'r') as sound_file:
@@ -13,178 +104,81 @@ async def fluency_scoring(file: UploadFile = File(...)):
13
 
14
  print(audio_array)
15
  return audio_array[:5].tolist()
16
-
17
-
18
- # import re
19
- # import requests
20
- # import pyarrow as pa
21
- # import librosa
22
- # import torch
23
- # from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
24
- # from fastapi import FastAPI, File, UploadFile
25
- # import warnings
26
- # from starlette.formparsers import MultiPartParser
27
- # import io
28
- # import random
29
- # import tempfile
30
- # import os
31
- # import numba
32
- # import soundfile as sf
33
- # import asyncio
34
-
35
- # MultiPartParser.max_file_size = 200 * 1024 * 1024
36
-
37
- # # Initialize FastAPI app
38
- # app = FastAPI()
39
-
40
- # # Load Wav2Vec2 tokenizer and model
41
- # tokenizer = Wav2Vec2Tokenizer.from_pretrained("./models/tokenizer")
42
- # model = Wav2Vec2ForCTC.from_pretrained("./models/model")
43
-
44
-
45
- # # Function to download English word list
46
- # def download_word_list():
47
- # print("Downloading English word list...")
48
- # url = "https://raw.githubusercontent.com/dwyl/english-words/master/words_alpha.txt"
49
- # response = requests.get(url)
50
- # words = set(response.text.split())
51
- # print("Word list downloaded.")
52
- # return words
53
-
54
- # english_words = download_word_list()
55
-
56
- # # Function to count correctly spelled words in text
57
- # def count_spelled_words(text, word_list):
58
- # print("Counting spelled words...")
59
- # # Split the text into words
60
- # words = re.findall(r'\b\w+\b', text.lower())
61
-
62
- # correct = sum(1 for word in words if word in word_list)
63
- # incorrect = len(words) - correct
64
-
65
- # print("Spelling check complete.")
66
- # return incorrect, correct
67
-
68
- # # Function to apply spell check to an item (assuming it's a dictionary)
69
- # def apply_spell_check(item, word_list):
70
- # print("Applying spell check...")
71
- # if isinstance(item, dict):
72
- # # This is a single item
73
- # text = item['transcription']
74
- # incorrect, correct = count_spelled_words(text, word_list)
75
- # item['incorrect_words'] = incorrect
76
- # item['correct_words'] = correct
77
- # print("Spell check applied to single item.")
78
- # return item
79
- # else:
80
- # # This is likely a batch
81
- # texts = item['transcription']
82
- # results = [count_spelled_words(text, word_list) for text in texts]
83
-
84
- # incorrect_counts, correct_counts = zip(*results)
85
-
86
- # item = item.append_column('incorrect_words', pa.array(incorrect_counts))
87
- # item = item.append_column('correct_words', pa.array(correct_counts))
88
-
89
- # print("Spell check applied to batch of items.")
90
- # return item
91
-
92
- # # FastAPI routes
93
- # @app.get('/')
94
- # async def root():
95
- # return "Welcome to the pronunciation scoring API!"
96
-
97
- # @app.post('/check_post')
98
- # async def rnc(number):
99
- # return {
100
- # "your value:" , number
101
- # }
102
-
103
- # @app.get('/check_get')
104
- # async def get_rnc():
105
- # return random.randint(0 , 10)
106
-
107
-
108
- # @app.post('/fluency_score')
109
- # async def fluency_scoring(file: UploadFile = File(...)):
110
- # audio_array, sample_rate = librosa.load(file.file, sr=16000)
111
- # print(audio_array)
112
- # return audio_array[:5]
113
 
114
 
115
- # @app.post('/pronunciation_score')
116
- # async def pronunciation_scoring(file: UploadFile = File(...)):
117
- # print("loading the file")
118
- # url = "https://speech-processing-6.onrender.com/process_audio"
119
- # files = {'file': await file.read()}
120
 
121
- # print("file loaded")
122
 
123
- # # print(files)
124
 
125
- # print("making a POST request on speech processor")
126
 
127
- # # Make the POST request
128
- # response = requests.post(url, files=files)
129
 
130
- # audio = response.json().get('audio_array')
131
 
132
- # print("audio:" , audio[:5])
133
 
134
 
135
 
136
- # print("length of the audio array:" , len(audio))
137
 
138
- # print("*" * 100)
139
 
140
- # # Tokenization
141
- # print("Tokenizing audio...")
142
- # input_values = tokenizer(
143
- # audio,
144
- # return_tensors="pt",
145
- # padding="max_length",
146
- # max_length= 386380,
147
- # truncation=True
148
- # ).input_values
149
 
150
- # print(input_values.shape)
151
 
152
- # print("Tokenization complete. Shape of input_values:", input_values.shape)
153
 
154
- # return "tokenization successful"
155
 
156
- # # Perform inference
157
- # print("Performing inference with Wav2Vec2 model...")
158
 
159
- # logits = model(input_values).logits
160
 
161
- # print("Inference complete. Shape of logits:", logits.shape)
162
 
163
- # # Get predictions
164
- # print("Getting predictions...")
165
- # prediction = torch.argmax(logits, dim=-1)
166
- # print("Prediction shape:", prediction.shape)
167
 
168
- # # Decode predictions
169
- # print("Decoding predictions...")
170
- # transcription = tokenizer.batch_decode(prediction)[0]
171
 
172
- # # Convert transcription to lowercase
173
- # transcription = transcription.lower()
174
 
175
- # print("Decoded transcription:", transcription)
176
 
177
- # incorrect, correct = count_spelled_words(transcription, english_words)
178
- # print("Spelling check - Incorrect words:", incorrect, ", Correct words:", correct)
179
 
180
- # # Calculate pronunciation score
181
- # fraction = correct / (incorrect + correct)
182
- # score = round(fraction * 100, 2)
183
- # print("Pronunciation score for", transcription, ":", score)
184
 
185
- # print("Pronunciation scoring process complete.")
186
 
187
- # return {
188
- # "transcription": transcription,
189
- # "pronunciation_score": score
190
- # }
 
 
 
1
 
2
+
3
+
4
+ import re
5
+ import requests
6
+ import pyarrow as pa
7
+ import librosa
8
+ import torch
9
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
10
+ from fastapi import FastAPI, File, UploadFile
11
+ import warnings
12
+ from starlette.formparsers import MultiPartParser
13
+ import io
14
+ import random
15
+ import tempfile
16
+ import os
17
+ import numba
18
+ import soundfile as sf
19
+ import asyncio
20
+
21
+ MultiPartParser.max_file_size = 200 * 1024 * 1024
22
+
23
+ # Initialize FastAPI app
24
+ app = FastAPI()
25
+
26
+ # Load Wav2Vec2 tokenizer and model
27
+ tokenizer = Wav2Vec2Tokenizer.from_pretrained("./models/tokenizer")
28
+ model = Wav2Vec2ForCTC.from_pretrained("./models/model")
29
+
30
+
31
+ # Function to download English word list
32
+ def download_word_list():
33
+ print("Downloading English word list...")
34
+ url = "https://raw.githubusercontent.com/dwyl/english-words/master/words_alpha.txt"
35
+ response = requests.get(url)
36
+ words = set(response.text.split())
37
+ print("Word list downloaded.")
38
+ return words
39
+
40
+ english_words = download_word_list()
41
+
42
+ # Function to count correctly spelled words in text
43
+ def count_spelled_words(text, word_list):
44
+ print("Counting spelled words...")
45
+ # Split the text into words
46
+ words = re.findall(r'\b\w+\b', text.lower())
47
+
48
+ correct = sum(1 for word in words if word in word_list)
49
+ incorrect = len(words) - correct
50
+
51
+ print("Spelling check complete.")
52
+ return incorrect, correct
53
+
54
+ # Function to apply spell check to an item (assuming it's a dictionary)
55
+ def apply_spell_check(item, word_list):
56
+ print("Applying spell check...")
57
+ if isinstance(item, dict):
58
+ # This is a single item
59
+ text = item['transcription']
60
+ incorrect, correct = count_spelled_words(text, word_list)
61
+ item['incorrect_words'] = incorrect
62
+ item['correct_words'] = correct
63
+ print("Spell check applied to single item.")
64
+ return item
65
+ else:
66
+ # This is likely a batch
67
+ texts = item['transcription']
68
+ results = [count_spelled_words(text, word_list) for text in texts]
69
+
70
+ incorrect_counts, correct_counts = zip(*results)
71
+
72
+ item = item.append_column('incorrect_words', pa.array(incorrect_counts))
73
+ item = item.append_column('correct_words', pa.array(correct_counts))
74
+
75
+ print("Spell check applied to batch of items.")
76
+ return item
77
+
78
+ # FastAPI routes
79
+ @app.get('/')
80
+ async def root():
81
+ return "Welcome to the pronunciation scoring API!"
82
+
83
+ @app.post('/check_post')
84
+ async def rnc(number):
85
+ return {
86
+ "your value:" , number
87
+ }
88
+
89
+ @app.get('/check_get')
90
+ async def get_rnc():
91
+ return random.randint(0 , 10)
92
+
93
+
94
+
95
  @app.post('/fluency_score')
96
  async def fluency_scoring(file: UploadFile = File(...)):
97
  with sf.SoundFile(file.file, 'r') as sound_file:
 
104
 
105
  print(audio_array)
106
  return audio_array[:5].tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
 
109
+ @app.post('/pronunciation_score')
110
+ async def pronunciation_scoring(file: UploadFile = File(...)):
111
+ print("loading the file")
112
+ url = "https://speech-processing-6.onrender.com/process_audio"
113
+ files = {'file': await file.read()}
114
 
115
+ print("file loaded")
116
 
117
+ # print(files)
118
 
119
+ print("making a POST request on speech processor")
120
 
121
+ # Make the POST request
122
+ response = requests.post(url, files=files)
123
 
124
+ audio = response.json().get('audio_array')
125
 
126
+ print("audio:" , audio[:5])
127
 
128
 
129
 
130
+ print("length of the audio array:" , len(audio))
131
 
132
+ print("*" * 100)
133
 
134
+ # Tokenization
135
+ print("Tokenizing audio...")
136
+ input_values = tokenizer(
137
+ audio,
138
+ return_tensors="pt",
139
+ padding="max_length",
140
+ max_length= 386380,
141
+ truncation=True
142
+ ).input_values
143
 
144
+ print(input_values.shape)
145
 
146
+ print("Tokenization complete. Shape of input_values:", input_values.shape)
147
 
148
+ return "tokenization successful"
149
 
150
+ # Perform inference
151
+ print("Performing inference with Wav2Vec2 model...")
152
 
153
+ logits = model(input_values).logits
154
 
155
+ print("Inference complete. Shape of logits:", logits.shape)
156
 
157
+ # Get predictions
158
+ print("Getting predictions...")
159
+ prediction = torch.argmax(logits, dim=-1)
160
+ print("Prediction shape:", prediction.shape)
161
 
162
+ # Decode predictions
163
+ print("Decoding predictions...")
164
+ transcription = tokenizer.batch_decode(prediction)[0]
165
 
166
+ # Convert transcription to lowercase
167
+ transcription = transcription.lower()
168
 
169
+ print("Decoded transcription:", transcription)
170
 
171
+ incorrect, correct = count_spelled_words(transcription, english_words)
172
+ print("Spelling check - Incorrect words:", incorrect, ", Correct words:", correct)
173
 
174
+ # Calculate pronunciation score
175
+ fraction = correct / (incorrect + correct)
176
+ score = round(fraction * 100, 2)
177
+ print("Pronunciation score for", transcription, ":", score)
178
 
179
+ print("Pronunciation scoring process complete.")
180
 
181
+ return {
182
+ "transcription": transcription,
183
+ "pronunciation_score": score
184
+ }