tjysdsg commited on
Commit
ed95978
1 Parent(s): f5a907f

Move model downloading to initialization stage

Browse files
Files changed (2) hide show
  1. app.py +78 -79
  2. utils.py +4 -0
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import os
2
  import gradio as gr
3
- import numpy as np
4
- import torch
5
  import torchaudio
6
  from typing import Tuple, Optional
7
  import soundfile as sf
8
  from s2st_inference import s2st_inference
 
9
 
10
  SAMPLE_RATE = 16000
11
  MAX_INPUT_LENGTH = 60 # seconds
@@ -19,85 +18,83 @@ NGPU = 0
19
  BEAM_SIZE = 1
20
 
21
 
22
- def download_model(tag: str, out_dir: str):
23
- from huggingface_hub import snapshot_download
24
-
25
- return snapshot_download(repo_id=tag, local_dir=out_dir)
26
-
27
-
28
- def s2st(
29
- audio_source: str,
30
- input_audio_mic: Optional[str],
31
- input_audio_file: Optional[str],
32
- ):
33
- if audio_source == 'file':
34
- input_path = input_audio_file
35
- else:
36
- input_path = input_audio_mic
37
-
38
- if input_path is None:
39
- gr.Error(f"Input audio is too long. Truncated to {MAX_INPUT_LENGTH} seconds.")
40
- return (None, None), None
41
-
42
- orig_wav, orig_sr = torchaudio.load(input_path)
43
- wav = torchaudio.functional.resample(orig_wav, orig_freq=orig_sr, new_freq=SAMPLE_RATE)
44
- max_length = int(MAX_INPUT_LENGTH * SAMPLE_RATE)
45
- if wav.shape[1] > max_length:
46
- wav = wav[:, :max_length]
47
- gr.Warning(f"Input audio is too long. Truncated to {MAX_INPUT_LENGTH} seconds.")
48
-
49
- wav = wav[0] # mono
50
-
51
- # Download models
52
- os.makedirs(S2UT_DIR, exist_ok=True)
53
- os.makedirs(VOCODER_DIR, exist_ok=True)
54
- s2ut_path = download_model(S2UT_TAG, S2UT_DIR)
55
- vocoder_path = download_model(VOCODER_TAG, VOCODER_DIR)
56
-
57
- # Temporary change cwd to model dir so that it loads correctly
58
- cwd = os.getcwd()
59
- os.chdir(s2ut_path)
60
-
61
- # Translate wav
62
- out_wav = s2st_inference(
63
- wav,
64
- train_config=os.path.join(
65
- s2ut_path,
66
- 'exp',
67
- 's2st_train_s2st_discrete_unit_raw_fbank_es_en',
68
- 'config.yaml',
69
- ),
70
- model_file=os.path.join(
71
- s2ut_path,
72
- 'exp',
73
- 's2st_train_s2st_discrete_unit_raw_fbank_es_en',
74
- '500epoch.pth',
75
- ),
76
- vocoder_file=os.path.join(
77
- vocoder_path,
78
- 'checkpoint-450000steps.pkl',
79
- ),
80
- vocoder_config=os.path.join(
81
- vocoder_path,
82
- 'config.yml',
83
- ),
84
- ngpu=NGPU,
85
- beam_size=BEAM_SIZE,
86
- )
87
 
88
- # Restore working directory
89
- os.chdir(cwd)
90
 
91
- # Save result
92
- output_path = 'output.wav'
93
- sf.write(
94
- output_path,
95
- out_wav,
96
- 16000,
97
- "PCM_16",
98
- )
99
 
100
- return output_path, f'Source: {audio_source}'
101
 
102
 
103
  def update_audio_ui(audio_source: str) -> Tuple[dict, dict]:
@@ -109,6 +106,8 @@ def update_audio_ui(audio_source: str) -> Tuple[dict, dict]:
109
 
110
 
111
  def main():
 
 
112
  with gr.Blocks() as demo:
113
  with gr.Group():
114
  with gr.Row() as audio_box:
@@ -153,7 +152,7 @@ def main():
153
  )
154
 
155
  btn.click(
156
- fn=s2st,
157
  inputs=[
158
  audio_source,
159
  input_audio_mic,
 
1
  import os
2
  import gradio as gr
 
 
3
  import torchaudio
4
  from typing import Tuple, Optional
5
  import soundfile as sf
6
  from s2st_inference import s2st_inference
7
+ from utils import download_model
8
 
9
  SAMPLE_RATE = 16000
10
  MAX_INPUT_LENGTH = 60 # seconds
 
18
  BEAM_SIZE = 1
19
 
20
 
21
+ class App:
22
+ def __init__(self):
23
+ # Download models
24
+ os.makedirs(S2UT_DIR, exist_ok=True)
25
+ os.makedirs(VOCODER_DIR, exist_ok=True)
26
+
27
+ self.s2ut_path = download_model(S2UT_TAG, S2UT_DIR)
28
+ self.vocoder_path = download_model(VOCODER_TAG, VOCODER_DIR)
29
+
30
+ def s2st(
31
+ self,
32
+ audio_source: str,
33
+ input_audio_mic: Optional[str],
34
+ input_audio_file: Optional[str],
35
+ ):
36
+ if audio_source == 'file':
37
+ input_path = input_audio_file
38
+ else:
39
+ input_path = input_audio_mic
40
+
41
+ if input_path is None:
42
+ gr.Error(f"Input audio is too long. Truncated to {MAX_INPUT_LENGTH} seconds.")
43
+ return (None, None), None
44
+
45
+ orig_wav, orig_sr = torchaudio.load(input_path)
46
+ wav = torchaudio.functional.resample(orig_wav, orig_freq=orig_sr, new_freq=SAMPLE_RATE)
47
+ max_length = int(MAX_INPUT_LENGTH * SAMPLE_RATE)
48
+ if wav.shape[1] > max_length:
49
+ wav = wav[:, :max_length]
50
+ gr.Warning(f"Input audio is too long. Truncated to {MAX_INPUT_LENGTH} seconds.")
51
+
52
+ wav = wav[0] # mono
53
+
54
+ # Temporary change cwd to model dir so that it loads correctly
55
+ cwd = os.getcwd()
56
+ os.chdir(self.s2ut_path)
57
+
58
+ # Translate wav
59
+ out_wav = s2st_inference(
60
+ wav,
61
+ train_config=os.path.join(
62
+ self.s2ut_path,
63
+ 'exp',
64
+ 's2st_train_s2st_discrete_unit_raw_fbank_es_en',
65
+ 'config.yaml',
66
+ ),
67
+ model_file=os.path.join(
68
+ self.s2ut_path,
69
+ 'exp',
70
+ 's2st_train_s2st_discrete_unit_raw_fbank_es_en',
71
+ '500epoch.pth',
72
+ ),
73
+ vocoder_file=os.path.join(
74
+ self.vocoder_path,
75
+ 'checkpoint-450000steps.pkl',
76
+ ),
77
+ vocoder_config=os.path.join(
78
+ self.vocoder_path,
79
+ 'config.yml',
80
+ ),
81
+ ngpu=NGPU,
82
+ beam_size=BEAM_SIZE,
83
+ )
 
 
84
 
85
+ # Restore working directory
86
+ os.chdir(cwd)
87
 
88
+ # Save result
89
+ output_path = 'output.wav'
90
+ sf.write(
91
+ output_path,
92
+ out_wav,
93
+ 16000,
94
+ "PCM_16",
95
+ )
96
 
97
+ return output_path, f'Source: {audio_source}'
98
 
99
 
100
  def update_audio_ui(audio_source: str) -> Tuple[dict, dict]:
 
106
 
107
 
108
  def main():
109
+ app = App()
110
+
111
  with gr.Blocks() as demo:
112
  with gr.Group():
113
  with gr.Row() as audio_box:
 
152
  )
153
 
154
  btn.click(
155
+ fn=app.s2st,
156
  inputs=[
157
  audio_source,
158
  input_audio_mic,
utils.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ def download_model(tag: str, out_dir: str):
2
+ from huggingface_hub import snapshot_download
3
+
4
+ return snapshot_download(repo_id=tag, local_dir=out_dir)