jiedong-yang commited on
Commit
dddc03b
β€’
1 Parent(s): d27921b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -0
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import torch
4
+ import whisper
5
+ import validators
6
+ import gradio as gr
7
+
8
+ from wordcloud import WordCloud, STOPWORDS
9
+
10
+ from scipy.io.wavfile import write
11
+ from espnet2.bin.tts_inference import Text2Speech
12
+
13
+ from utils import *
14
+
15
+ # load whisper model for ASR and BART for summarization
16
+ asr_model = whisper.load_model('base.en')
17
+ summarizer = gr.Interface.load("facebook/bart-large-cnn", src='huggingface')
18
+ tts_model = Text2Speech.from_pretrained("espnet/kan-bayashi_ljspeech_joint_finetune_conformer_fastspeech2_hifigan")
19
+
20
+
21
+ def load_model(name: str):
22
+ """
23
+
24
+ :param name: model options, tiny or base only, for quick inference
25
+ :return:
26
+ """
27
+ global asr_model
28
+ asr_model = whisper.load_model(f"{name.lower()}")
29
+ return name
30
+
31
+
32
+ def audio_from_url(url, dst_dir='data', name=None, format='wav'):
33
+ """ Download video from url and save the audio from video
34
+
35
+ :param url: str, the video url
36
+ :param dst_dir: destination directory for save audio
37
+ :param name: audio file's name, if none, assign the name as the video's title
38
+ :param format: format type for audio file, such as 'wav', 'mp3'. WAV is preferred.
39
+ :return: path of audio
40
+ """
41
+
42
+ if not validators.url(url):
43
+ return None
44
+
45
+ os.makedirs(dst_dir, exist_ok=True)
46
+
47
+ # download audio
48
+ path = os.path.join(dst_dir, f"audio.{format}")
49
+ if os.path.exists(path):
50
+ os.remove(path)
51
+ os.system(f"yt-dlp -f 'ba' -x --audio-format {format} {url} -o {path} --quiet")
52
+
53
+ return path
54
+
55
+
56
+ def speech_to_text(audio, beam_size=5, best_of=5, language='en'):
57
+ """ ASR inference with Whisper
58
+
59
+ :param audio: filepath
60
+ :param beam_size: beam search parameter
61
+ :param best_of: number of best results
62
+ :param language: Currently English only
63
+ :return: transcription
64
+ """
65
+
66
+ result = asr_model.transcribe(audio, language=language, beam_size=beam_size, best_of=best_of, fp16=False)
67
+
68
+ return result['text']
69
+
70
+
71
+ def text_summarization(text):
72
+ return summarizer(text)
73
+
74
+
75
+ def wordcloud_func(text: str, out_path='data/wordcloud_output.png'):
76
+ """ generate wordcloud based on text
77
+
78
+ :param text: transcription
79
+ :param out_path: filepath
80
+ :return: filepath
81
+ """
82
+
83
+ if len(text) == 0:
84
+ return None
85
+
86
+ stopwords = STOPWORDS
87
+
88
+ wc = WordCloud(
89
+ background_color='white',
90
+ stopwords=stopwords,
91
+ height=600,
92
+ width=600
93
+ )
94
+
95
+ wc.generate(text)
96
+ wc.to_file(out_path)
97
+
98
+ return out_path
99
+
100
+
101
+ def normalize_dollars(text):
102
+ """ text normalization for '$'
103
+
104
+ :param text:
105
+ :return:
106
+ """
107
+
108
+ def expand_dollars(m):
109
+ match = m.group(1)
110
+ parts = match.split(' ')
111
+ parts.append('dollars')
112
+ return ' '.join(parts)
113
+
114
+ units = ['hundred', 'thousand', 'million', 'billion', 'trillion']
115
+ _dollars_re = re.compile(fr"\$([0-9\.\,]*[0-9]+ (?:{'|'.join(units)}))")
116
+
117
+ return re.sub(_dollars_re, expand_dollars, text)
118
+
119
+
120
+ def text_to_speech(text: str, out_path="data/short_speech.wav"):
121
+
122
+ # espnet tts model process '$1.4 trillion' as 'one point four dollar trillion'
123
+ # use this function to fix this issue
124
+ text = normalize_dollars(text)
125
+
126
+ output = tts_model(text)
127
+ write(out_path, 22050, output['wav'].numpy())
128
+
129
+ return out_path
130
+
131
+
132
+ demo = gr.Blocks(css=demo_css, title="Speech Summarization")
133
+
134
+ demo.encrypt = False
135
+
136
+ with demo:
137
+ # demo description
138
+ gr.Markdown("""
139
+ ## Speech Summarization with Whisper
140
+ This space is intended to summarize a speech, a short one or long one, to save us sometime
141
+ (runs faster with GPU inference). Check the example links provided below:
142
+ [3 mins speech](https://www.youtube.com/watch?v=DuX4K4eeTz8),
143
+ [13 mins speech](https://www.youtube.com/watch?v=nepOSEGHHCQ)
144
+
145
+ 1. Type in a youtube URL or upload an audio file
146
+ 2. Generate transcription with Whisper (English Only)
147
+ 3. Summarize the transcribed speech
148
+ 4. Generate summary speech with the ESPNet model
149
+ """)
150
+
151
+ # data preparation
152
+ with gr.Row():
153
+ with gr.Column():
154
+ url = gr.Textbox(label="URL", placeholder="video url")
155
+
156
+ url_btn = gr.Button("clear")
157
+ url_btn.click(lambda x: '', inputs=url, outputs=url)
158
+
159
+ speech = gr.Audio(label="Speech", type="filepath")
160
+
161
+ url.change(audio_from_url, inputs=url, outputs=speech)
162
+
163
+ # ASR
164
+ text = gr.Textbox(label="Transcription", placeholder="transcription")
165
+
166
+ with gr.Row():
167
+ default_values = dict(model='Base.en', bs=5, bo=5) if torch.cuda.is_available() \
168
+ else dict(model='Tiny.en', bs=1, bo=1)
169
+ model_options = gr.Dropdown(['Tiny.en', 'Base.en'], value=default_values['model'], label="models")
170
+ model_options.change(load_model, inputs=model_options, outputs=model_options)
171
+
172
+ beam_size_slider = gr.Slider(1, 10, value=default_values['bs'], step=1, label="param: beam_size")
173
+ best_of_slider = gr.Slider(1, 10, value=default_values['bo'], step=1, label="param: best_of")
174
+
175
+ with gr.Row():
176
+ asr_clr_btn = gr.Button("clear")
177
+ asr_clr_btn.click(lambda x: '', inputs=text, outputs=text)
178
+ asr_btn = gr.Button("Recognize Speech")
179
+ asr_btn.click(speech_to_text, inputs=[speech, beam_size_slider, best_of_slider], outputs=text)
180
+
181
+ # summarization
182
+ summary = gr.Textbox(label="Summarization")
183
+
184
+ with gr.Row():
185
+ sum_clr_btn = gr.Button("clear")
186
+ sum_clr_btn.click(lambda x: '', inputs=summary, outputs=summary)
187
+ sum_btn = gr.Button("Summarize")
188
+ sum_btn.click(text_summarization, inputs=text, outputs=summary)
189
+
190
+ with gr.Row():
191
+ # wordcloud
192
+ image = gr.Image(label="wordcloud", show_label=False).style(height=400, width=400)
193
+ with gr.Column():
194
+ tts = gr.Audio(label="Short Speech", type="filepath")
195
+ tts_btn = gr.Button("Read Summary")
196
+ tts_btn.click(text_to_speech, inputs=summary, outputs=tts)
197
+
198
+ text.change(wordcloud_func, inputs=text, outputs=image)
199
+
200
+ examples = gr.Examples(examples=["https://www.youtube.com/watch?v=DuX4K4eeTz8",
201
+ "https://www.youtube.com/watch?v=nepOSEGHHCQ"],
202
+ fn=lambda x: speech_to_text(audio_from_url(x)),
203
+ inputs=url, outputs=text, cache_examples=True)
204
+
205
+ gr.HTML(footer_html)
206
+
207
+
208
+ if __name__ == '__main__':
209
+ demo.launch()