Spaces:
Runtime error
Runtime error
jiedong-yang
commited on
Commit
β’
dddc03b
1
Parent(s):
d27921b
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
import whisper
|
5 |
+
import validators
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
from wordcloud import WordCloud, STOPWORDS
|
9 |
+
|
10 |
+
from scipy.io.wavfile import write
|
11 |
+
from espnet2.bin.tts_inference import Text2Speech
|
12 |
+
|
13 |
+
from utils import *
|
14 |
+
|
15 |
+
# load whisper model for ASR and BART for summarization
|
16 |
+
asr_model = whisper.load_model('base.en')
|
17 |
+
summarizer = gr.Interface.load("facebook/bart-large-cnn", src='huggingface')
|
18 |
+
tts_model = Text2Speech.from_pretrained("espnet/kan-bayashi_ljspeech_joint_finetune_conformer_fastspeech2_hifigan")
|
19 |
+
|
20 |
+
|
21 |
+
def load_model(name: str):
|
22 |
+
"""
|
23 |
+
|
24 |
+
:param name: model options, tiny or base only, for quick inference
|
25 |
+
:return:
|
26 |
+
"""
|
27 |
+
global asr_model
|
28 |
+
asr_model = whisper.load_model(f"{name.lower()}")
|
29 |
+
return name
|
30 |
+
|
31 |
+
|
32 |
+
def audio_from_url(url, dst_dir='data', name=None, format='wav'):
|
33 |
+
""" Download video from url and save the audio from video
|
34 |
+
|
35 |
+
:param url: str, the video url
|
36 |
+
:param dst_dir: destination directory for save audio
|
37 |
+
:param name: audio file's name, if none, assign the name as the video's title
|
38 |
+
:param format: format type for audio file, such as 'wav', 'mp3'. WAV is preferred.
|
39 |
+
:return: path of audio
|
40 |
+
"""
|
41 |
+
|
42 |
+
if not validators.url(url):
|
43 |
+
return None
|
44 |
+
|
45 |
+
os.makedirs(dst_dir, exist_ok=True)
|
46 |
+
|
47 |
+
# download audio
|
48 |
+
path = os.path.join(dst_dir, f"audio.{format}")
|
49 |
+
if os.path.exists(path):
|
50 |
+
os.remove(path)
|
51 |
+
os.system(f"yt-dlp -f 'ba' -x --audio-format {format} {url} -o {path} --quiet")
|
52 |
+
|
53 |
+
return path
|
54 |
+
|
55 |
+
|
56 |
+
def speech_to_text(audio, beam_size=5, best_of=5, language='en'):
|
57 |
+
""" ASR inference with Whisper
|
58 |
+
|
59 |
+
:param audio: filepath
|
60 |
+
:param beam_size: beam search parameter
|
61 |
+
:param best_of: number of best results
|
62 |
+
:param language: Currently English only
|
63 |
+
:return: transcription
|
64 |
+
"""
|
65 |
+
|
66 |
+
result = asr_model.transcribe(audio, language=language, beam_size=beam_size, best_of=best_of, fp16=False)
|
67 |
+
|
68 |
+
return result['text']
|
69 |
+
|
70 |
+
|
71 |
+
def text_summarization(text):
|
72 |
+
return summarizer(text)
|
73 |
+
|
74 |
+
|
75 |
+
def wordcloud_func(text: str, out_path='data/wordcloud_output.png'):
|
76 |
+
""" generate wordcloud based on text
|
77 |
+
|
78 |
+
:param text: transcription
|
79 |
+
:param out_path: filepath
|
80 |
+
:return: filepath
|
81 |
+
"""
|
82 |
+
|
83 |
+
if len(text) == 0:
|
84 |
+
return None
|
85 |
+
|
86 |
+
stopwords = STOPWORDS
|
87 |
+
|
88 |
+
wc = WordCloud(
|
89 |
+
background_color='white',
|
90 |
+
stopwords=stopwords,
|
91 |
+
height=600,
|
92 |
+
width=600
|
93 |
+
)
|
94 |
+
|
95 |
+
wc.generate(text)
|
96 |
+
wc.to_file(out_path)
|
97 |
+
|
98 |
+
return out_path
|
99 |
+
|
100 |
+
|
101 |
+
def normalize_dollars(text):
|
102 |
+
""" text normalization for '$'
|
103 |
+
|
104 |
+
:param text:
|
105 |
+
:return:
|
106 |
+
"""
|
107 |
+
|
108 |
+
def expand_dollars(m):
|
109 |
+
match = m.group(1)
|
110 |
+
parts = match.split(' ')
|
111 |
+
parts.append('dollars')
|
112 |
+
return ' '.join(parts)
|
113 |
+
|
114 |
+
units = ['hundred', 'thousand', 'million', 'billion', 'trillion']
|
115 |
+
_dollars_re = re.compile(fr"\$([0-9\.\,]*[0-9]+ (?:{'|'.join(units)}))")
|
116 |
+
|
117 |
+
return re.sub(_dollars_re, expand_dollars, text)
|
118 |
+
|
119 |
+
|
120 |
+
def text_to_speech(text: str, out_path="data/short_speech.wav"):
|
121 |
+
|
122 |
+
# espnet tts model process '$1.4 trillion' as 'one point four dollar trillion'
|
123 |
+
# use this function to fix this issue
|
124 |
+
text = normalize_dollars(text)
|
125 |
+
|
126 |
+
output = tts_model(text)
|
127 |
+
write(out_path, 22050, output['wav'].numpy())
|
128 |
+
|
129 |
+
return out_path
|
130 |
+
|
131 |
+
|
132 |
+
demo = gr.Blocks(css=demo_css, title="Speech Summarization")
|
133 |
+
|
134 |
+
demo.encrypt = False
|
135 |
+
|
136 |
+
with demo:
|
137 |
+
# demo description
|
138 |
+
gr.Markdown("""
|
139 |
+
## Speech Summarization with Whisper
|
140 |
+
This space is intended to summarize a speech, a short one or long one, to save us sometime
|
141 |
+
(runs faster with GPU inference). Check the example links provided below:
|
142 |
+
[3 mins speech](https://www.youtube.com/watch?v=DuX4K4eeTz8),
|
143 |
+
[13 mins speech](https://www.youtube.com/watch?v=nepOSEGHHCQ)
|
144 |
+
|
145 |
+
1. Type in a youtube URL or upload an audio file
|
146 |
+
2. Generate transcription with Whisper (English Only)
|
147 |
+
3. Summarize the transcribed speech
|
148 |
+
4. Generate summary speech with the ESPNet model
|
149 |
+
""")
|
150 |
+
|
151 |
+
# data preparation
|
152 |
+
with gr.Row():
|
153 |
+
with gr.Column():
|
154 |
+
url = gr.Textbox(label="URL", placeholder="video url")
|
155 |
+
|
156 |
+
url_btn = gr.Button("clear")
|
157 |
+
url_btn.click(lambda x: '', inputs=url, outputs=url)
|
158 |
+
|
159 |
+
speech = gr.Audio(label="Speech", type="filepath")
|
160 |
+
|
161 |
+
url.change(audio_from_url, inputs=url, outputs=speech)
|
162 |
+
|
163 |
+
# ASR
|
164 |
+
text = gr.Textbox(label="Transcription", placeholder="transcription")
|
165 |
+
|
166 |
+
with gr.Row():
|
167 |
+
default_values = dict(model='Base.en', bs=5, bo=5) if torch.cuda.is_available() \
|
168 |
+
else dict(model='Tiny.en', bs=1, bo=1)
|
169 |
+
model_options = gr.Dropdown(['Tiny.en', 'Base.en'], value=default_values['model'], label="models")
|
170 |
+
model_options.change(load_model, inputs=model_options, outputs=model_options)
|
171 |
+
|
172 |
+
beam_size_slider = gr.Slider(1, 10, value=default_values['bs'], step=1, label="param: beam_size")
|
173 |
+
best_of_slider = gr.Slider(1, 10, value=default_values['bo'], step=1, label="param: best_of")
|
174 |
+
|
175 |
+
with gr.Row():
|
176 |
+
asr_clr_btn = gr.Button("clear")
|
177 |
+
asr_clr_btn.click(lambda x: '', inputs=text, outputs=text)
|
178 |
+
asr_btn = gr.Button("Recognize Speech")
|
179 |
+
asr_btn.click(speech_to_text, inputs=[speech, beam_size_slider, best_of_slider], outputs=text)
|
180 |
+
|
181 |
+
# summarization
|
182 |
+
summary = gr.Textbox(label="Summarization")
|
183 |
+
|
184 |
+
with gr.Row():
|
185 |
+
sum_clr_btn = gr.Button("clear")
|
186 |
+
sum_clr_btn.click(lambda x: '', inputs=summary, outputs=summary)
|
187 |
+
sum_btn = gr.Button("Summarize")
|
188 |
+
sum_btn.click(text_summarization, inputs=text, outputs=summary)
|
189 |
+
|
190 |
+
with gr.Row():
|
191 |
+
# wordcloud
|
192 |
+
image = gr.Image(label="wordcloud", show_label=False).style(height=400, width=400)
|
193 |
+
with gr.Column():
|
194 |
+
tts = gr.Audio(label="Short Speech", type="filepath")
|
195 |
+
tts_btn = gr.Button("Read Summary")
|
196 |
+
tts_btn.click(text_to_speech, inputs=summary, outputs=tts)
|
197 |
+
|
198 |
+
text.change(wordcloud_func, inputs=text, outputs=image)
|
199 |
+
|
200 |
+
examples = gr.Examples(examples=["https://www.youtube.com/watch?v=DuX4K4eeTz8",
|
201 |
+
"https://www.youtube.com/watch?v=nepOSEGHHCQ"],
|
202 |
+
fn=lambda x: speech_to_text(audio_from_url(x)),
|
203 |
+
inputs=url, outputs=text, cache_examples=True)
|
204 |
+
|
205 |
+
gr.HTML(footer_html)
|
206 |
+
|
207 |
+
|
208 |
+
if __name__ == '__main__':
|
209 |
+
demo.launch()
|