Alex Volkov commited on
Commit
7db5fdc
1 Parent(s): 09cee30

Added captions API, that receives a URL and both transcribes AND translates it.

Browse files
Files changed (6) hide show
  1. app.py +1 -1
  2. download.py +68 -15
  3. requirements.txt +2 -1
  4. static/css/main.css +1 -1
  5. utils/apis.py +32 -5
  6. utils/subs.py +33 -25
app.py CHANGED
@@ -137,7 +137,7 @@ with gr.Blocks(css='@import "file=static/css/main.css";', theme='darkpeach', tit
137
  init_video.change(fn=init_video_manual_upload, inputs=[url_input, init_video], outputs=[])
138
 
139
  # Render imported buttons for API bindings
140
- render_api_elements(url_input,download_status, output_text, sub_video)
141
 
142
  queue_placeholder = demo.queue()
143
 
 
137
  init_video.change(fn=init_video_manual_upload, inputs=[url_input, init_video], outputs=[])
138
 
139
  # Render imported buttons for API bindings
140
+ render_api_elements(url_input,download_status, output_text, sub_video, output_file)
141
 
142
  queue_placeholder = demo.queue()
143
 
download.py CHANGED
@@ -13,7 +13,7 @@ import argparse
13
  import whisper
14
  from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
15
  import ffmpeg
16
- from utils.subs import bake_subs
17
  from utils.utils import get_args
18
 
19
  original_dir = os.getcwd()
@@ -106,6 +106,54 @@ def download_generator(url, translate_action=True, source_language='Autodetect',
106
  yield {"message": f"{e}"}
107
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def progress_hook(d):
110
  if d['status'] == 'downloading':
111
  print("downloading " + str(round(float(d['downloaded_bytes']) / float(d['total_bytes']) * 100, 1)) + "%")
@@ -115,11 +163,11 @@ def progress_hook(d):
115
  print(filename)
116
  yield f"Downloaded {filename}"
117
 
118
- def download(url, tempdir):
119
  try:
120
  ydl_opts = {
121
- "format": "bestvideo[ext=mp4]+bestaudio/best",
122
- "keepvideo": True,
123
  'postprocessors': [{
124
  'key': 'FFmpegExtractAudio',
125
  'preferredcodec': 'mp3',
@@ -128,7 +176,7 @@ def download(url, tempdir):
128
  "skip_download": False,
129
  "outtmpl": f"{tempdir}/%(id)s.%(ext)s",
130
  "noplaylist": True,
131
- "verbose": False,
132
  "quiet": True,
133
  "progress_hooks": [progress_hook],
134
 
@@ -141,10 +189,13 @@ def download(url, tempdir):
141
  except DownloadError as e:
142
  raise e
143
  else:
144
- video = tempdir / f"{meta['id']}.{meta['ext']}"
145
  audio = tempdir / f"{meta['id']}.mp3"
146
- print(str(video.resolve()))
147
- return meta, str(video.resolve()), str(audio.resolve())
 
 
 
 
148
 
149
  def check_download(url):
150
  ydl_opts = {
@@ -164,22 +215,24 @@ def check_download(url):
164
  else:
165
  return meta
166
 
167
- def transcribe(audio, translate_action=True, language='Autodetect'):
168
  task = "translate" if translate_action else "transcribe"
169
- print(f'Starting {task} with whisper size {model_size}')
 
170
  global model
171
- if not preload_model:
172
- model = whisper.load_model(model_size)
 
173
  props = {
174
  "task": task,
175
  }
 
176
  if language != 'Autodetect':
177
  props["language"] = TO_LANGUAGE_CODE[language.lower()]
178
 
179
- output = model.transcribe(audio, task=task)
180
 
181
- output["language"] = LANGUAGES[output["language"]]
182
- output['segments'] = [{"id": 0, "seek": 0, "start": 0.0, "end": 3, "text": " [AI transcription]"}] + output['segments']
183
  print(f'Finished transcribe from {output["language"]}', output["text"])
184
  return output
185
 
 
13
  import whisper
14
  from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
15
  import ffmpeg
16
+ from utils.subs import bake_subs, get_srt
17
  from utils.utils import get_args
18
 
19
  original_dir = os.getcwd()
 
106
  yield {"message": f"{e}"}
107
 
108
 
109
+ def caption_generator(tweet_url, language="Autodetect", model_size=model_size):
110
+ # Download the file
111
+
112
+ try:
113
+ print(f"Downloading {tweet_url} ")
114
+ meta = check_download(tweet_url)
115
+ tempdir = output_dir / f"{meta['id']}"
116
+ print(f"Downloaded {meta['id']}.mp3 from {meta['uploader_id']} and url {meta['webpage_url']}")
117
+ except Exception as e:
118
+ print(f"Could not download file: {e}")
119
+ raise
120
+
121
+ try:
122
+ print(f"Starting audio only download with URL {tweet_url}, this may take a while")
123
+ meta, video, audio = download(tweet_url, tempdir, keepVideo=False)
124
+ print(f"Downloaded video and extracted audio")
125
+ except Exception as e:
126
+ print(f"Could not download file: {e}")
127
+ raise
128
+
129
+ # Run whisper on the audio with language unless auto
130
+ try:
131
+ print(f"Starting whisper transcribe with {meta['id']}.mp3")
132
+ transcribe_whisper_result = transcribe(audio, translate_action=False, language=language, override_model_size=model_size)
133
+ translate_whisper_result = transcribe(audio, translate_action=True, language=language, override_model_size=model_size)
134
+ srt = get_srt(transcribe_whisper_result["segments"])
135
+ en_srt = get_srt(translate_whisper_result["segments"])
136
+
137
+ print(f"Transcribe successful!")
138
+ except Exception as e:
139
+ print(f"Could not transcribe file: {e}")
140
+ return
141
+
142
+ return_dict = {
143
+ "detected_language": LANGUAGES[transcribe_whisper_result["language"]],
144
+ "requested_language": language,
145
+ "text": transcribe_whisper_result["text"],
146
+ "en_text": translate_whisper_result["text"],
147
+ "srt": srt,
148
+ "en_srt": en_srt,
149
+ "meta": meta,
150
+ }
151
+ return return_dict
152
+
153
+
154
+ # Run whisper with translation task enabled (and save to different srt file)
155
+ # Call anvil background task with both files, and both the plain texts
156
+
157
  def progress_hook(d):
158
  if d['status'] == 'downloading':
159
  print("downloading " + str(round(float(d['downloaded_bytes']) / float(d['total_bytes']) * 100, 1)) + "%")
 
163
  print(filename)
164
  yield f"Downloaded {filename}"
165
 
166
+ def download(url, tempdir, format="bestvideo[ext=mp4]+bestaudio/best", verbose=False, keepVideo=True):
167
  try:
168
  ydl_opts = {
169
+ "format": format,
170
+ "keepvideo": keepVideo,
171
  'postprocessors': [{
172
  'key': 'FFmpegExtractAudio',
173
  'preferredcodec': 'mp3',
 
176
  "skip_download": False,
177
  "outtmpl": f"{tempdir}/%(id)s.%(ext)s",
178
  "noplaylist": True,
179
+ "verbose": verbose,
180
  "quiet": True,
181
  "progress_hooks": [progress_hook],
182
 
 
189
  except DownloadError as e:
190
  raise e
191
  else:
 
192
  audio = tempdir / f"{meta['id']}.mp3"
193
+ if (keepVideo):
194
+ video = tempdir / f"{meta['id']}.{meta['ext']}"
195
+ return meta, str(video.resolve()), str(audio.resolve())
196
+ else:
197
+ return meta, None, str(audio.resolve())
198
+
199
 
200
  def check_download(url):
201
  ydl_opts = {
 
215
  else:
216
  return meta
217
 
218
+ def transcribe(audio, translate_action=True, language='Autodetect', override_model_size=''):
219
  task = "translate" if translate_action else "transcribe"
220
+ model_size_to_load = override_model_size if override_model_size else model_size
221
+ print(f'Starting {task} with whisper size {model_size_to_load} on {audio}')
222
  global model
223
+ if not preload_model or model_size != override_model_size:
224
+ model = whisper.load_model(model_size_to_load)
225
+
226
  props = {
227
  "task": task,
228
  }
229
+
230
  if language != 'Autodetect':
231
  props["language"] = TO_LANGUAGE_CODE[language.lower()]
232
 
233
+ output = model.transcribe(audio, verbose=True, **props)
234
 
235
+ output['segments'] = output['segments']
 
236
  print(f'Finished transcribe from {output["language"]}', output["text"])
237
  return output
238
 
requirements.txt CHANGED
@@ -4,4 +4,5 @@ anvil-uplink==0.4.0
4
  gradio==3.4.0
5
  python-dotenv==0.21.0
6
  aiohttp==3.8.3
7
- aiohttp-requests==0.1.3
 
 
4
  gradio==3.4.0
5
  python-dotenv==0.21.0
6
  aiohttp==3.8.3
7
+ aiohttp-requests==0.1.3
8
+ fsspec=2022.8.2
static/css/main.css CHANGED
@@ -93,5 +93,5 @@ background: transparent
93
  }
94
 
95
  footer{
96
- display: none !important;
97
  }
 
93
  }
94
 
95
  footer{
96
+ /*display: none !important;*/
97
  }
utils/apis.py CHANGED
@@ -11,10 +11,11 @@ import anvil.media
11
  import dotenv
12
  import gradio as gr
13
  import requests
14
- from download import download_generator
15
-
16
 
17
  dotenv.load_dotenv()
 
 
18
  @anvil.server.callable
19
  def call_gradio_api(api_name='test_api', data=()):
20
  port = os.environ.get('SERVER_PORT', 8111)
@@ -62,7 +63,19 @@ def test_api(url=''):
62
  # TODO: add an anvil server pingback to show we completed the queue operation
63
  return f"I've slept for 15 seconds and now I'm done. "
64
 
65
- def render_api_elements(url_input, download_status, output_text, sub_video):
 
 
 
 
 
 
 
 
 
 
 
 
66
  with gr.Group(elem_id='fake_ass_group') as api_buttons:
67
  # This is a hack to get APIs registered with the blocks interface
68
  translate_result = gr.Textbox(visible=False)
@@ -75,6 +88,21 @@ def render_api_elements(url_input, download_status, output_text, sub_video):
75
 
76
  gr.Button("remote_download", visible=False)\
77
  .click(api_name='remote_download', queue=True, fn=remote_download, inputs=[url_input], outputs=[download_status, output_text, translate_result, translate_language])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return api_buttons
79
 
80
 
@@ -87,5 +115,4 @@ def cleanup_output_dir():
87
  if path.is_file():
88
  path.unlink()
89
  elif path.is_dir():
90
- rmtree(path)
91
-
 
11
  import dotenv
12
  import gradio as gr
13
  import requests
14
+ from download import download_generator, caption_generator
 
15
 
16
  dotenv.load_dotenv()
17
+
18
+
19
  @anvil.server.callable
20
  def call_gradio_api(api_name='test_api', data=()):
21
  port = os.environ.get('SERVER_PORT', 8111)
 
63
  # TODO: add an anvil server pingback to show we completed the queue operation
64
  return f"I've slept for 15 seconds and now I'm done. "
65
 
66
+ #TODO: add telegram error handler here
67
+ def caption(tweet_url="", language="Autodetect", override_model_size=""):
68
+ """
69
+ :param media_id: The twitter media ID object
70
+ :param user_id_str: The twitter user ID string
71
+ :param tweet_url: tweet URL can potentially not exist in the future, so we can upload on behalf of the user
72
+ :return:
73
+ """
74
+ response = caption_generator(tweet_url, language, override_model_size)
75
+ return json.dumps(response)
76
+
77
+
78
+ def render_api_elements(url_input, download_status, output_text, sub_video, output_file):
79
  with gr.Group(elem_id='fake_ass_group') as api_buttons:
80
  # This is a hack to get APIs registered with the blocks interface
81
  translate_result = gr.Textbox(visible=False)
 
88
 
89
  gr.Button("remote_download", visible=False)\
90
  .click(api_name='remote_download', queue=True, fn=remote_download, inputs=[url_input], outputs=[download_status, output_text, translate_result, translate_language])
91
+
92
+ # creating fake elements just make gradio, cause I can't define an API signature like a sane person
93
+
94
+ gr.Button("caption", visible=False)\
95
+ .click(api_name='caption',
96
+ queue=True,
97
+ fn=caption,
98
+ inputs=[
99
+ gr.Text(label='tweet_url'),
100
+ gr.Text(label='language (optional)'),
101
+ gr.Dropdown(label='Model Size', choices=['base', 'tiny', 'small', 'medium', 'large']),
102
+ ],
103
+ outputs=[
104
+ gr.Text(label='response_json')
105
+ ])
106
  return api_buttons
107
 
108
 
 
115
  if path.is_file():
116
  path.unlink()
117
  elif path.is_dir():
118
+ rmtree(path)
 
utils/subs.py CHANGED
@@ -6,7 +6,6 @@ import os
6
  from typing import Iterator, TextIO
7
 
8
 
9
-
10
  def bake_subs(input_file, output_file, subs_file, fontsdir, translate_action):
11
  print(f"Baking {subs_file} into video... {input_file} -> {output_file}")
12
 
@@ -30,39 +29,39 @@ def bake_subs(input_file, output_file, subs_file, fontsdir, translate_action):
30
  fontstyle = f'Fontsize={sub_size},OutlineColour=&H40000000,BorderStyle=3,FontName={fontname},Bold=1'
31
  (
32
  ffmpeg.concat(
33
- video.filter('subtitles', subs_file, fontsdir=fontfile, force_style=fontstyle),
34
- audio, v=1, a=1
35
- )
36
- .overlay(watermark.filter('scale', iw / 3, -1), x='10', y='10')
37
- .output(filename=output_file)
38
- .run(quiet=True, overwrite_output=True)
39
  )
40
 
41
 
42
  def str2bool(string):
43
- str2val = {"True": True, "False": False}
44
- if string in str2val:
45
- return str2val[string]
46
- else:
47
- raise ValueError(
48
- f"Expected one of {set(str2val.keys())}, got {string}")
49
 
50
 
51
  def format_timestamp(seconds: float, always_include_hours: bool = False):
52
- assert seconds >= 0, "non-negative timestamp expected"
53
- milliseconds = round(seconds * 1000.0)
54
 
55
- hours = milliseconds // 3_600_000
56
- milliseconds -= hours * 3_600_000
57
 
58
- minutes = milliseconds // 60_000
59
- milliseconds -= minutes * 60_000
60
 
61
- seconds = milliseconds // 1_000
62
- milliseconds -= seconds * 1_000
63
 
64
- hours_marker = f"{hours}:" if always_include_hours or hours > 0 else ""
65
- return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
66
 
67
 
68
  def write_srt(transcript: Iterator[dict], file: TextIO):
@@ -77,8 +76,17 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
77
  )
78
 
79
 
 
 
 
 
 
 
 
 
 
80
  def filename(path):
81
- return os.path.splitext(os.path.basename(path))[0]
82
 
83
 
84
 
@@ -94,4 +102,4 @@ def filename(path):
94
  # os.chdir(tempdirname)
95
  # bake_subs(video_file_path, out_path, srt_path)
96
  # anvil_media = anvil.media.from_file(out_path, 'video/mp4')
97
- # print(anvil_media)
 
6
  from typing import Iterator, TextIO
7
 
8
 
 
9
  def bake_subs(input_file, output_file, subs_file, fontsdir, translate_action):
10
  print(f"Baking {subs_file} into video... {input_file} -> {output_file}")
11
 
 
29
  fontstyle = f'Fontsize={sub_size},OutlineColour=&H40000000,BorderStyle=3,FontName={fontname},Bold=1'
30
  (
31
  ffmpeg.concat(
32
+ video.filter('subtitles', subs_file, fontsdir=fontfile, force_style=fontstyle),
33
+ audio, v=1, a=1
34
+ )
35
+ .overlay(watermark.filter('scale', iw / 3, -1), x='10', y='10')
36
+ .output(filename=output_file)
37
+ .run(quiet=True, overwrite_output=True)
38
  )
39
 
40
 
41
  def str2bool(string):
42
+ str2val = {"True": True, "False": False}
43
+ if string in str2val:
44
+ return str2val[string]
45
+ else:
46
+ raise ValueError(
47
+ f"Expected one of {set(str2val.keys())}, got {string}")
48
 
49
 
50
  def format_timestamp(seconds: float, always_include_hours: bool = False):
51
+ assert seconds >= 0, "non-negative timestamp expected"
52
+ milliseconds = round(seconds * 1000.0)
53
 
54
+ hours = milliseconds // 3_600_000
55
+ milliseconds -= hours * 3_600_000
56
 
57
+ minutes = milliseconds // 60_000
58
+ milliseconds -= minutes * 60_000
59
 
60
+ seconds = milliseconds // 1_000
61
+ milliseconds -= seconds * 1_000
62
 
63
+ hours_marker = f"{hours}:" if always_include_hours or hours > 0 else ""
64
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
65
 
66
 
67
  def write_srt(transcript: Iterator[dict], file: TextIO):
 
76
  )
77
 
78
 
79
+ def get_srt(transcript: Iterator[dict]):
80
+ srt = ''
81
+ for i, segment in enumerate(transcript, start=1):
82
+ srt += f"{i}\n" \
83
+ f"{format_timestamp(segment['start'], always_include_hours=True)} --> " \
84
+ f"{format_timestamp(segment['end'], always_include_hours=True)}\n" \
85
+ f"{segment['text'].strip().replace('-->', '->')}\n"
86
+ return srt
87
+
88
  def filename(path):
89
+ return os.path.splitext(os.path.basename(path))[0]
90
 
91
 
92
 
 
102
  # os.chdir(tempdirname)
103
  # bake_subs(video_file_path, out_path, srt_path)
104
  # anvil_media = anvil.media.from_file(out_path, 'video/mp4')
105
+ # print(anvil_media)