File size: 4,634 Bytes
427bb16
 
 
 
 
 
 
 
 
 
7dde9f0
4006ad9
 
c3b89d9
 
c34b637
c3b89d9
427bb16
 
 
 
 
 
7365efc
 
 
427bb16
 
 
 
 
 
 
 
 
 
c14d52f
427bb16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
882da5c
 
 
 
 
 
02e008b
882da5c
 
 
f2601c1
 
4006ad9
2e81661
bf1ce78
f89dd94
bf1ce78
b32a33b
2e81661
d7b0c66
 
91d4d7a
bf1ce78
990903e
91d4d7a
ee147ec
 
b32a33b
d466784
813004d
0cf878c
d466784
5efbe11
 
427bb16
5f74423
 
 
 
 
 
 
 
 
 
5efbe11
 
427bb16
 
3cc7981
 
 
5efbe11
 
427bb16
 
 
5efbe11
 
427bb16
 
 
 
 
 
 
 
5efbe11
427bb16
 
 
 
 
7365efc
427bb16
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import re
import requests
import pyarrow as pa
import librosa
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
from fastapi import FastAPI, File, UploadFile
import warnings
from starlette.formparsers import MultiPartParser
import io
import random
import tempfile
import os
import numba
import soundfile as sf 
import asyncio

MultiPartParser.max_file_size = 200 * 1024 * 1024

# Initialize FastAPI app
app = FastAPI()

# Load Wav2Vec2 tokenizer and model
tokenizer = Wav2Vec2Tokenizer.from_pretrained("./models/tokenizer")
model = Wav2Vec2ForCTC.from_pretrained("./models/model")


# Function to download English word list
def download_word_list():
    print("Downloading English word list...")
    url = "https://raw.githubusercontent.com/dwyl/english-words/master/words_alpha.txt"
    response = requests.get(url)
    words = set(response.text.split())
    print("Word list downloaded.")
    return words

english_words = download_word_list()

# Function to count correctly spelled words in text
def count_spelled_words(text, word_list):
    print("Counting spelled words...")
    # Split the text into words
    words = re.findall(r'\b\w+\b', text.lower())
    
    correct = sum(1 for word in words if word in word_list)
    incorrect = len(words) - correct
    
    print("Spelling check complete.")
    return incorrect, correct

# Function to apply spell check to an item (assuming it's a dictionary)
def apply_spell_check(item, word_list):
    print("Applying spell check...")
    if isinstance(item, dict):
        # This is a single item
        text = item['transcription']
        incorrect, correct = count_spelled_words(text, word_list)
        item['incorrect_words'] = incorrect
        item['correct_words'] = correct
        print("Spell check applied to single item.")
        return item
    else:
        # This is likely a batch
        texts = item['transcription']
        results = [count_spelled_words(text, word_list) for text in texts]
        
        incorrect_counts, correct_counts = zip(*results)
        
        item = item.append_column('incorrect_words', pa.array(incorrect_counts))
        item = item.append_column('correct_words', pa.array(correct_counts))
        
        print("Spell check applied to batch of items.")
        return item

# FastAPI routes
@app.get('/')
async def root():
    return "Welcome to the pronunciation scoring API!"

@app.post('/check_post')
async def rnc(number):
    return {
        "your value:" , number
    }

@app.get('/check_get')
async def get_rnc():
    return random.randint(0 , 10)


@app.post('/pronunciation_score')
async def upload_audio(file: UploadFile = File(...)):
    print("loading the file")
    url = "https://speech-processing-6.onrender.com/process_audio"
    files = {'file': await file.read()}

    # print(files)
    print("file loaded")

    print("making a POST request on speech processor")
    
    # Make the POST request
    response = requests.post(url, files=files)

    audio = response.json().get('audio_array')
    
    print("audio:" , audio[:5])

    print("length of the audio array:" , len(audio))

    print("*" * 100)

    # Tokenization
    print("Tokenizing audio...")
    input_values = tokenizer(
        audio, 
        return_tensors="pt", 
        padding="max_length", 
        max_length= 1000000, 
        truncation=True
    ).input_values

    print(input_values.shape)
    
    print("Tokenization complete. Shape of input_values:", input_values.shape)

    # Perform inference
    print("Performing inference with Wav2Vec2 model...")
    
    logits = model(input_values).logits
    
    print("Inference complete. Shape of logits:", logits.shape)

    # Get predictions
    print("Getting predictions...")
    prediction = torch.argmax(logits, dim=-1)
    print("Prediction shape:", prediction.shape)

    # Decode predictions
    print("Decoding predictions...")
    transcription = tokenizer.batch_decode(prediction)[0]
    
    # Convert transcription to lowercase
    transcription = transcription.lower()
    
    print("Decoded transcription:", transcription)
    
    incorrect, correct = count_spelled_words(transcription, english_words)
    print("Spelling check - Incorrect words:", incorrect, ", Correct words:", correct)
    
    # Calculate pronunciation score
    fraction = correct / (incorrect + correct)
    score = round(fraction * 100, 2)
    print("Pronunciation score for", transcription, ":", score)
    
    print("Pronunciation scoring process complete.")
    
    return {
        "transcription": transcription,
        "pronunciation_score": score
    }