Kartikeyssj2 commited on
Commit
7365efc
1 Parent(s): 427bb16
Files changed (3) hide show
  1. Dockerfile +5 -3
  2. download_models.py +19 -0
  3. main.py +5 -4
Dockerfile CHANGED
@@ -15,11 +15,13 @@ RUN apt-get update && apt-get install -y \
15
  # Install any needed packages specified in requirements.txt
16
  RUN pip install --no-cache-dir -r requirements.txt
17
 
18
- # Make port 80 available to the world outside this container
19
- EXPOSE 80
 
 
20
 
21
  # Define environment variable
22
- ENV NAME World
23
 
24
  # Run app.py when the container launches
25
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
 
15
  # Install any needed packages specified in requirements.txt
16
  RUN pip install --no-cache-dir -r requirements.txt
17
 
18
+ RUN python download_models.py
19
+
20
+ # Make port 7860 available to the world outside this container
21
+ EXPOSE 7860
22
 
23
  # Define environment variable
24
+ ENV TRANSFORMERS_CACHE=/tmp/.cache
25
 
26
  # Run app.py when the container launches
27
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
download_models.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
3
+
4
+ # Create the models directory if it doesn't exist
5
+ os.makedirs("./models", exist_ok=True)
6
+ os.makedirs("./models/tokenizer", exist_ok=True)
7
+ os.makedirs("./models/model", exist_ok=True)
8
+
9
+ print("Downloading and saving tokenizer...")
10
+ tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
11
+ tokenizer.save_pretrained("./models/tokenizer")
12
+ print("Tokenizer saved successfully.")
13
+
14
+ print("Downloading and saving model...")
15
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
16
+ model.save_pretrained("./models/model")
17
+ print("Model saved successfully.")
18
+
19
+ print("Download and save process completed.")
main.py CHANGED
@@ -16,8 +16,10 @@ MultiPartParser.max_file_size = 200 * 1024 * 1024
16
  app = FastAPI()
17
 
18
  # Load Wav2Vec2 tokenizer and model
19
- tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
20
- model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
 
 
21
 
22
  # Function to download English word list
23
  def download_word_list():
@@ -28,7 +30,6 @@ def download_word_list():
28
  print("Word list downloaded.")
29
  return words
30
 
31
- english_words = download_word_list()
32
 
33
  # Function to count correctly spelled words in text
34
  def count_spelled_words(text, word_list):
@@ -114,7 +115,7 @@ async def unscripted_root(audio_file: UploadFile):
114
 
115
  # Calculate pronunciation score
116
  fraction = correct / (incorrect + correct)
117
- score = round(fraction * 10, 2)
118
  print("Pronunciation score for", transcription, ":", score)
119
 
120
  print("Pronunciation scoring process complete.")
 
16
  app = FastAPI()
17
 
18
  # Load Wav2Vec2 tokenizer and model
19
+ tokenizer = Wav2Vec2Tokenizer.from_pretrained("./models/tokenizer")
20
+ model = Wav2Vec2ForCTC.from_pretrained("./models/model")
21
+ english_words = download_word_list()
22
+
23
 
24
  # Function to download English word list
25
  def download_word_list():
 
30
  print("Word list downloaded.")
31
  return words
32
 
 
33
 
34
  # Function to count correctly spelled words in text
35
  def count_spelled_words(text, word_list):
 
115
 
116
  # Calculate pronunciation score
117
  fraction = correct / (incorrect + correct)
118
+ score = round(fraction * 100, 2)
119
  print("Pronunciation score for", transcription, ":", score)
120
 
121
  print("Pronunciation scoring process complete.")