Tmeena commited on
Commit
2109067
1 Parent(s): eec8d2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -1,20 +1,19 @@
1
- import os
2
  import base64
 
3
  from flask import Flask, request, jsonify
4
  from pydub import AudioSegment
5
  import whisper
6
  from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
7
 
8
- # Set environment variables for Hugging Face cache
9
- os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
10
- os.environ["HF_HOME"] = "/app/cache"
11
 
12
- # Load the Whisper model
13
- whisper_model = whisper.load_model("base")
14
 
15
  # Load the translation model and tokenizer
16
- tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
17
- translation_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
18
 
19
  def preprocess_audio(audio_path):
20
  """Convert audio to 16kHz mono WAV format."""
@@ -67,7 +66,7 @@ def translate():
67
  data = request.json
68
  if 'audio' not in data or 'source_lang' not in data or 'target_lang' not in data:
69
  return jsonify({"error": "Invalid request format"}), 400
70
-
71
  audio_base64 = data['audio']
72
  source_lang = data['source_lang']
73
  target_lang = data['target_lang']
 
 
1
  import base64
2
+ import os
3
  from flask import Flask, request, jsonify
4
  from pydub import AudioSegment
5
  import whisper
6
  from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
7
 
8
+ # Define cache directory
9
+ os.environ['HF_HOME'] = '/app/cache'
 
10
 
11
+ # Load the Whisper model with a specified cache directory
12
+ whisper_model = whisper.load_model("base", download_root="/app/cache")
13
 
14
  # Load the translation model and tokenizer
15
+ tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", cache_dir="/app/cache")
16
+ translation_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", cache_dir="/app/cache")
17
 
18
  def preprocess_audio(audio_path):
19
  """Convert audio to 16kHz mono WAV format."""
 
66
  data = request.json
67
  if 'audio' not in data or 'source_lang' not in data or 'target_lang' not in data:
68
  return jsonify({"error": "Invalid request format"}), 400
69
+
70
  audio_base64 = data['audio']
71
  source_lang = data['source_lang']
72
  target_lang = data['target_lang']