Spaces:
Build error
Build error
import re | |
import requests | |
import pyarrow as pa | |
import librosa | |
import torch | |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer | |
from fastapi import FastAPI, File, UploadFile | |
import warnings | |
from starlette.formparsers import MultiPartParser | |
import io | |
import random | |
import tempfile | |
import os | |
import numba | |
import soundfile as sf | |
import asyncio | |
MultiPartParser.max_file_size = 200 * 1024 * 1024 | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Load Wav2Vec2 tokenizer and model | |
tokenizer = Wav2Vec2Tokenizer.from_pretrained("./models/tokenizer") | |
model = Wav2Vec2ForCTC.from_pretrained("./models/model") | |
# Function to download English word list | |
def download_word_list(): | |
print("Downloading English word list...") | |
url = "https://raw.githubusercontent.com/dwyl/english-words/master/words_alpha.txt" | |
response = requests.get(url) | |
words = set(response.text.split()) | |
print("Word list downloaded.") | |
return words | |
english_words = download_word_list() | |
# Function to count correctly spelled words in text | |
def count_spelled_words(text, word_list): | |
print("Counting spelled words...") | |
# Split the text into words | |
words = re.findall(r'\b\w+\b', text.lower()) | |
correct = sum(1 for word in words if word in word_list) | |
incorrect = len(words) - correct | |
print("Spelling check complete.") | |
return incorrect, correct | |
# Function to apply spell check to an item (assuming it's a dictionary) | |
def apply_spell_check(item, word_list): | |
print("Applying spell check...") | |
if isinstance(item, dict): | |
# This is a single item | |
text = item['transcription'] | |
incorrect, correct = count_spelled_words(text, word_list) | |
item['incorrect_words'] = incorrect | |
item['correct_words'] = correct | |
print("Spell check applied to single item.") | |
return item | |
else: | |
# This is likely a batch | |
texts = item['transcription'] | |
results = [count_spelled_words(text, word_list) for text in texts] | |
incorrect_counts, correct_counts = zip(*results) | |
item = item.append_column('incorrect_words', pa.array(incorrect_counts)) | |
item = item.append_column('correct_words', pa.array(correct_counts)) | |
print("Spell check applied to batch of items.") | |
return item | |
# FastAPI routes | |
async def root(): | |
return "Welcome to the pronunciation scoring API!" | |
async def rnc(number): | |
return { | |
"your value:" , number | |
} | |
async def get_rnc(): | |
return random.randint(0 , 10) | |
async def fluency_scoring(file: UploadFile = File(...)): | |
audio_array, sample_rate = librosa.load(file.file, sr=16000) | |
print(audio_array) | |
return audio_array[:5] | |
async def pronunciation_scoring(file: UploadFile = File(...)): | |
print("loading the file") | |
url = "https://speech-processing-6.onrender.com/process_audio" | |
files = {'file': await file.read()} | |
print("file loaded") | |
# print(files) | |
print("making a POST request on speech processor") | |
# Make the POST request | |
response = requests.post(url, files=files) | |
audio = response.json().get('audio_array') | |
print("audio:" , audio[:5]) | |
print("length of the audio array:" , len(audio)) | |
print("*" * 100) | |
# Tokenization | |
print("Tokenizing audio...") | |
input_values = tokenizer( | |
audio, | |
return_tensors="pt", | |
padding="max_length", | |
max_length= 386380, | |
truncation=True | |
).input_values | |
print(input_values.shape) | |
print("Tokenization complete. Shape of input_values:", input_values.shape) | |
return "tokenization successful" | |
# Perform inference | |
print("Performing inference with Wav2Vec2 model...") | |
logits = model(input_values).logits | |
print("Inference complete. Shape of logits:", logits.shape) | |
# Get predictions | |
print("Getting predictions...") | |
prediction = torch.argmax(logits, dim=-1) | |
print("Prediction shape:", prediction.shape) | |
# Decode predictions | |
print("Decoding predictions...") | |
transcription = tokenizer.batch_decode(prediction)[0] | |
# Convert transcription to lowercase | |
transcription = transcription.lower() | |
print("Decoded transcription:", transcription) | |
incorrect, correct = count_spelled_words(transcription, english_words) | |
print("Spelling check - Incorrect words:", incorrect, ", Correct words:", correct) | |
# Calculate pronunciation score | |
fraction = correct / (incorrect + correct) | |
score = round(fraction * 100, 2) | |
print("Pronunciation score for", transcription, ":", score) | |
print("Pronunciation scoring process complete.") | |
return { | |
"transcription": transcription, | |
"pronunciation_score": score | |
} |