Spaces:
Sleeping
Added support for translation models (NLLB, NLLB-CT2, MT5)
Browse filesto provide full translation capabilities for Whisper.
The interface now includes optional selection of NLLB Model (for translate) and NLLB Language. If not selected, the translation feature will not be activated.
__________________
Whisper’s Task ‘translate’ only implements the functionality of translating other languages into English. OpenAI does not guarantee translations between arbitrary languages. In such cases, you can choose to use the NLLB Model to implement the translation task. However, it’s important to note that the NLLB Model runs slowly, and the completion time may be twice as long as usual.
The larger the parameters of the NLLB model, the better its performance is expected to be. However, it also requires higher computational resources, making it slower to operate. On the other hand, the version converted from ct2 (CTranslate2) requires lower resources and operates at a faster speed.
Currently, enabling word-level timestamps cannot be used in conjunction with NLLB Model translation because Word Timestamps will split the source text, and after translation, it becomes a non-word-level string.
The ‘mt5-zh-ja-en-trimmed’ model is finetuned from Google’s ‘mt5-base’ model. This model has a relatively good translation speed, but it only supports three languages: Chinese, Japanese, and English.
- README.md +1 -1
- app.py +115 -56
- config.json5 +96 -0
- requirements-fasterWhisper.txt +3 -2
- requirements-whisper.txt +2 -1
- requirements.txt +3 -2
- src/config.py +11 -4
- src/nllb/nllbLangs.py +251 -0
- src/nllb/nllbModel.py +221 -0
- src/vadParallel.py +1 -1
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title: Faster Whisper Webui
|
3 |
emoji: ✨
|
4 |
colorFrom: blue
|
5 |
colorTo: purple
|
|
|
1 |
---
|
2 |
+
title: Faster Whisper Webui with translate
|
3 |
emoji: ✨
|
4 |
colorFrom: blue
|
5 |
colorTo: purple
|
@@ -5,8 +5,8 @@ from typing import Iterator, Union
|
|
5 |
import argparse
|
6 |
|
7 |
from io import StringIO
|
|
|
8 |
import os
|
9 |
-
import pathlib
|
10 |
import tempfile
|
11 |
import zipfile
|
12 |
import numpy as np
|
@@ -37,9 +37,14 @@ from src.utils import optional_int, slugify, write_srt, write_vtt
|
|
37 |
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
38 |
from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
|
39 |
from src.whisper.whisperFactory import create_whisper_container
|
|
|
|
|
|
|
|
|
40 |
|
41 |
import shutil
|
42 |
import zhconv
|
|
|
43 |
|
44 |
# Configure more application defaults in config.json5
|
45 |
|
@@ -92,26 +97,26 @@ class WhisperTranscriber:
|
|
92 |
print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
|
93 |
|
94 |
# Entry function for the simple tab
|
95 |
-
def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
96 |
vad, vadMergeWindow, vadMaxMergeSize,
|
97 |
word_timestamps: bool = False, highlight_words: bool = False):
|
98 |
-
return self.transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
99 |
vad, vadMergeWindow, vadMaxMergeSize,
|
100 |
word_timestamps, highlight_words)
|
101 |
|
102 |
# Entry function for the simple tab progress
|
103 |
-
def transcribe_webui_simple_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
104 |
vad, vadMergeWindow, vadMaxMergeSize,
|
105 |
word_timestamps: bool = False, highlight_words: bool = False,
|
106 |
progress=gr.Progress()):
|
107 |
|
108 |
vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, self.app_config.vad_padding, self.app_config.vad_prompt_window, self.app_config.vad_initial_prompt_mode)
|
109 |
|
110 |
-
return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
|
111 |
word_timestamps=word_timestamps, highlight_words=highlight_words, progress=progress)
|
112 |
|
113 |
# Entry function for the full tab
|
114 |
-
def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
115 |
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
116 |
# Word timestamps
|
117 |
word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
|
@@ -119,7 +124,7 @@ class WhisperTranscriber:
|
|
119 |
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
120 |
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float):
|
121 |
|
122 |
-
return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
123 |
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
124 |
word_timestamps, highlight_words, prepend_punctuations, append_punctuations,
|
125 |
initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
|
@@ -127,7 +132,7 @@ class WhisperTranscriber:
|
|
127 |
compression_ratio_threshold, logprob_threshold, no_speech_threshold)
|
128 |
|
129 |
# Entry function for the full tab with progress
|
130 |
-
def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
131 |
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
132 |
# Word timestamps
|
133 |
word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
|
@@ -144,21 +149,21 @@ class WhisperTranscriber:
|
|
144 |
|
145 |
vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
|
146 |
|
147 |
-
return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
|
148 |
initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
|
149 |
condition_on_previous_text=condition_on_previous_text, fp16=fp16,
|
150 |
compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
|
151 |
word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
|
152 |
progress=progress)
|
153 |
|
154 |
-
def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
155 |
vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
|
156 |
**decodeOptions: dict):
|
157 |
try:
|
158 |
sources = self.__get_source(urlData, multipleFiles, microphoneData)
|
159 |
|
160 |
try:
|
161 |
-
|
162 |
selectedLanguage = languageName.lower() if languageName is not None and len(languageName) > 0 else None
|
163 |
selectedModel = modelName if modelName is not None else "base"
|
164 |
|
@@ -166,6 +171,12 @@ class WhisperTranscriber:
|
|
166 |
model_name=selectedModel, compute_type=self.app_config.compute_type,
|
167 |
cache=self.model_cache, models=self.app_config.models)
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
# Result
|
170 |
download = []
|
171 |
zip_file_lookup = {}
|
@@ -208,7 +219,7 @@ class WhisperTranscriber:
|
|
208 |
# Update progress
|
209 |
current_progress += source_audio_duration
|
210 |
|
211 |
-
source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory, highlight_words)
|
212 |
|
213 |
if len(sources) > 1:
|
214 |
# Add new line separators
|
@@ -252,30 +263,19 @@ class WhisperTranscriber:
|
|
252 |
return download, text, vtt
|
253 |
|
254 |
finally:
|
255 |
-
if languageName == "Chinese":
|
256 |
-
for file_path in source_download:
|
257 |
-
try:
|
258 |
-
with open(file_path, "r+", encoding="utf-8") as source:
|
259 |
-
content = source.read()
|
260 |
-
content = zhconv.convert(content, "zh-tw")
|
261 |
-
source.seek(0)
|
262 |
-
source.write(content)
|
263 |
-
except Exception as e:
|
264 |
-
# Ignore error - it's just a cleanup
|
265 |
-
print("Error converting Traditional Chinese with download source file: \n" + file_path + ", \n" + str(e))
|
266 |
-
|
267 |
# Cleanup source
|
268 |
if self.deleteUploadedFiles:
|
269 |
for source in sources:
|
270 |
if self.app_config.merge_subtitle_with_sources and self.app_config.output_dir is not None and len(source_download) > 0:
|
271 |
-
print("
|
272 |
outRsult = ""
|
273 |
try:
|
274 |
srt_path = source_download[0]
|
275 |
save_path = os.path.join(self.app_config.output_dir, source.source_name)
|
276 |
save_without_ext, ext = os.path.splitext(save_path)
|
277 |
-
|
278 |
-
|
|
|
279 |
|
280 |
#ffmpeg -i "input.mp4" -i "input.srt" -c copy -c:s mov_text output.mp4
|
281 |
input_file = ffmpeg.input(source.source_path)
|
@@ -435,20 +435,41 @@ class WhisperTranscriber:
|
|
435 |
|
436 |
return config
|
437 |
|
438 |
-
def write_result(self, result: dict, source_name: str, output_dir: str, highlight_words: bool = False):
|
439 |
if not os.path.exists(output_dir):
|
440 |
os.makedirs(output_dir)
|
441 |
|
442 |
text = result["text"]
|
|
|
443 |
language = result["language"]
|
444 |
languageMaxLineWidth = self.__get_max_line_width(language)
|
445 |
|
446 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
447 |
vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth, highlight_words=highlight_words)
|
448 |
srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth, highlight_words=highlight_words)
|
449 |
json_result = json.dumps(result, indent=4, ensure_ascii=False)
|
450 |
|
451 |
-
if language == "zh":
|
452 |
vtt = zhconv.convert(vtt, "zh-tw")
|
453 |
srt = zhconv.convert(srt, "zh-tw")
|
454 |
text = zhconv.convert(text, "zh-tw")
|
@@ -541,12 +562,29 @@ def create_ui(app_config: ApplicationConfig):
|
|
541 |
ui_description += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
|
542 |
|
543 |
ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
544 |
|
545 |
whisper_models = app_config.get_model_names()
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
gr.Dropdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
550 |
gr.Text(label="URL (YouTube, etc.)"),
|
551 |
gr.File(label="Upload Files", file_count="multiple"),
|
552 |
gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
|
@@ -579,7 +617,13 @@ def create_ui(app_config: ApplicationConfig):
|
|
579 |
with gr.Row():
|
580 |
with gr.Column():
|
581 |
simple_submit = gr.Button("Submit", variant="primary")
|
582 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
583 |
with gr.Column():
|
584 |
simple_output = common_output()
|
585 |
simple_flag = gr.Button("Flag")
|
@@ -602,27 +646,33 @@ def create_ui(app_config: ApplicationConfig):
|
|
602 |
with gr.Row():
|
603 |
with gr.Column():
|
604 |
full_submit = gr.Button("Submit", variant="primary")
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
626 |
|
627 |
with gr.Column():
|
628 |
full_output = common_output()
|
@@ -654,6 +704,7 @@ def create_ui(app_config: ApplicationConfig):
|
|
654 |
if __name__ == '__main__':
|
655 |
default_app_config = ApplicationConfig.create_default()
|
656 |
whisper_models = default_app_config.get_model_names()
|
|
|
657 |
|
658 |
# Environment variable overrides
|
659 |
default_whisper_implementation = os.environ.get("WHISPER_IMPLEMENTATION", default_app_config.whisper_implementation)
|
@@ -707,6 +758,14 @@ if __name__ == '__main__':
|
|
707 |
|
708 |
updated_config = default_app_config.update(**args)
|
709 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
710 |
if (threads := args.pop("threads")) > 0:
|
711 |
torch.set_num_threads(threads)
|
712 |
|
|
|
5 |
import argparse
|
6 |
|
7 |
from io import StringIO
|
8 |
+
import time
|
9 |
import os
|
|
|
10 |
import tempfile
|
11 |
import zipfile
|
12 |
import numpy as np
|
|
|
37 |
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
38 |
from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
|
39 |
from src.whisper.whisperFactory import create_whisper_container
|
40 |
+
from src.nllb.nllbModel import NllbModel
|
41 |
+
from src.nllb.nllbLangs import _TO_NLLB_LANG_CODE
|
42 |
+
from src.nllb.nllbLangs import get_nllb_lang_names
|
43 |
+
from src.nllb.nllbLangs import get_nllb_lang_from_name
|
44 |
|
45 |
import shutil
|
46 |
import zhconv
|
47 |
+
import tqdm
|
48 |
|
49 |
# Configure more application defaults in config.json5
|
50 |
|
|
|
97 |
print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
|
98 |
|
99 |
# Entry function for the simple tab
|
100 |
+
def transcribe_webui_simple(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
|
101 |
vad, vadMergeWindow, vadMaxMergeSize,
|
102 |
word_timestamps: bool = False, highlight_words: bool = False):
|
103 |
+
return self.transcribe_webui_simple_progress(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
|
104 |
vad, vadMergeWindow, vadMaxMergeSize,
|
105 |
word_timestamps, highlight_words)
|
106 |
|
107 |
# Entry function for the simple tab progress
|
108 |
+
def transcribe_webui_simple_progress(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
|
109 |
vad, vadMergeWindow, vadMaxMergeSize,
|
110 |
word_timestamps: bool = False, highlight_words: bool = False,
|
111 |
progress=gr.Progress()):
|
112 |
|
113 |
vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, self.app_config.vad_padding, self.app_config.vad_prompt_window, self.app_config.vad_initial_prompt_mode)
|
114 |
|
115 |
+
return self.transcribe_webui(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task, vadOptions,
|
116 |
word_timestamps=word_timestamps, highlight_words=highlight_words, progress=progress)
|
117 |
|
118 |
# Entry function for the full tab
|
119 |
+
def transcribe_webui_full(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
|
120 |
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
121 |
# Word timestamps
|
122 |
word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
|
|
|
124 |
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
125 |
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float):
|
126 |
|
127 |
+
return self.transcribe_webui_full_progress(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
|
128 |
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
129 |
word_timestamps, highlight_words, prepend_punctuations, append_punctuations,
|
130 |
initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
|
|
|
132 |
compression_ratio_threshold, logprob_threshold, no_speech_threshold)
|
133 |
|
134 |
# Entry function for the full tab with progress
|
135 |
+
def transcribe_webui_full_progress(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
|
136 |
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
137 |
# Word timestamps
|
138 |
word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
|
|
|
149 |
|
150 |
vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
|
151 |
|
152 |
+
return self.transcribe_webui(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task, vadOptions,
|
153 |
initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
|
154 |
condition_on_previous_text=condition_on_previous_text, fp16=fp16,
|
155 |
compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
|
156 |
word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
|
157 |
progress=progress)
|
158 |
|
159 |
+
def transcribe_webui(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
|
160 |
vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
|
161 |
**decodeOptions: dict):
|
162 |
try:
|
163 |
sources = self.__get_source(urlData, multipleFiles, microphoneData)
|
164 |
|
165 |
try:
|
166 |
+
whisper_lang = get_language_from_name(languageName)
|
167 |
selectedLanguage = languageName.lower() if languageName is not None and len(languageName) > 0 else None
|
168 |
selectedModel = modelName if modelName is not None else "base"
|
169 |
|
|
|
171 |
model_name=selectedModel, compute_type=self.app_config.compute_type,
|
172 |
cache=self.model_cache, models=self.app_config.models)
|
173 |
|
174 |
+
nllb_lang = get_nllb_lang_from_name(nllbLangName)
|
175 |
+
selectedNllbModelName = nllbModelName if nllbModelName is not None and len(nllbModelName) > 0 else "nllb-200-distilled-600M/facebook"
|
176 |
+
selectedNllbModel = next((modelConfig for modelConfig in self.app_config.nllb_models if modelConfig.name == selectedNllbModelName), None)
|
177 |
+
|
178 |
+
nllb_model = NllbModel(model_config=selectedNllbModel, whisper_lang=whisper_lang, nllb_lang=nllb_lang) # load_model=True
|
179 |
+
|
180 |
# Result
|
181 |
download = []
|
182 |
zip_file_lookup = {}
|
|
|
219 |
# Update progress
|
220 |
current_progress += source_audio_duration
|
221 |
|
222 |
+
source_download, source_text, source_vtt = self.write_result(result, nllb_model, filePrefix, outputDirectory, highlight_words)
|
223 |
|
224 |
if len(sources) > 1:
|
225 |
# Add new line separators
|
|
|
263 |
return download, text, vtt
|
264 |
|
265 |
finally:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
# Cleanup source
|
267 |
if self.deleteUploadedFiles:
|
268 |
for source in sources:
|
269 |
if self.app_config.merge_subtitle_with_sources and self.app_config.output_dir is not None and len(source_download) > 0:
|
270 |
+
print("\nmerge subtitle(srt) with source file [" + source.source_name + "]\n")
|
271 |
outRsult = ""
|
272 |
try:
|
273 |
srt_path = source_download[0]
|
274 |
save_path = os.path.join(self.app_config.output_dir, source.source_name)
|
275 |
save_without_ext, ext = os.path.splitext(save_path)
|
276 |
+
source_lang = "." + whisper_lang.code if whisper_lang is not None else ""
|
277 |
+
translate_lang = "." + nllb_lang.code if nllb_lang is not None else ""
|
278 |
+
output_with_srt = save_without_ext + source_lang + translate_lang + ext
|
279 |
|
280 |
#ffmpeg -i "input.mp4" -i "input.srt" -c copy -c:s mov_text output.mp4
|
281 |
input_file = ffmpeg.input(source.source_path)
|
|
|
435 |
|
436 |
return config
|
437 |
|
438 |
+
def write_result(self, result: dict, nllb_model: NllbModel, source_name: str, output_dir: str, highlight_words: bool = False):
|
439 |
if not os.path.exists(output_dir):
|
440 |
os.makedirs(output_dir)
|
441 |
|
442 |
text = result["text"]
|
443 |
+
segments = result["segments"]
|
444 |
language = result["language"]
|
445 |
languageMaxLineWidth = self.__get_max_line_width(language)
|
446 |
|
447 |
+
if nllb_model.nllb_lang is not None:
|
448 |
+
try:
|
449 |
+
pbar = tqdm.tqdm(total=len(segments))
|
450 |
+
perf_start_time = time.perf_counter()
|
451 |
+
nllb_model.load_model()
|
452 |
+
for idx, segment in enumerate(segments):
|
453 |
+
seg_text = segment["text"]
|
454 |
+
if language == "zh":
|
455 |
+
segment["text"] = zhconv.convert(seg_text, "zh-tw")
|
456 |
+
if nllb_model.nllb_lang is not None:
|
457 |
+
segment["text"] = nllb_model.translation(seg_text)
|
458 |
+
pbar.update(1)
|
459 |
+
|
460 |
+
nllb_model.release_vram()
|
461 |
+
perf_end_time = time.perf_counter()
|
462 |
+
print("\n\nprocess segments took {} seconds.\n\n".format(perf_end_time - perf_start_time))
|
463 |
+
except Exception as e:
|
464 |
+
# Ignore error - it's just a cleanup
|
465 |
+
print("Error process segments: " + str(e))
|
466 |
+
|
467 |
+
print("Max line width " + str(languageMaxLineWidth) + " for language:" + language)
|
468 |
vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth, highlight_words=highlight_words)
|
469 |
srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth, highlight_words=highlight_words)
|
470 |
json_result = json.dumps(result, indent=4, ensure_ascii=False)
|
471 |
|
472 |
+
if language == "zh" or (nllb_model.nllb_lang is not None and nllb_model.nllb_lang.code == "zho_Hant"):
|
473 |
vtt = zhconv.convert(vtt, "zh-tw")
|
474 |
srt = zhconv.convert(srt, "zh-tw")
|
475 |
text = zhconv.convert(text, "zh-tw")
|
|
|
562 |
ui_description += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
|
563 |
|
564 |
ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
|
565 |
+
ui_article += "\n\nWhisper's Task 'translate' only implements the functionality of translating other languages into English. "
|
566 |
+
ui_article += "OpenAI does not guarantee translations between arbitrary languages. In such cases, you can choose to use the NLLB Model to implement the translation task. "
|
567 |
+
ui_article += "However, it's important to note that the NLLB Model runs slowly, and the completion time may be twice as long as usual. "
|
568 |
+
ui_article += "\n\nThe larger the parameters of the NLLB model, the better its performance is expected to be. "
|
569 |
+
ui_article += "However, it also requires higher computational resources, making it slower to operate. "
|
570 |
+
ui_article += "On the other hand, the version converted from ct2 (CTranslate2) requires lower resources and operates at a faster speed."
|
571 |
+
ui_article += "\n\nCurrently, enabling word-level timestamps cannot be used in conjunction with NLLB Model translation "
|
572 |
+
ui_article += "because Word Timestamps will split the source text, and after translation, it becomes a non-word-level string. "
|
573 |
+
ui_article += "\n\nThe 'mt5-zh-ja-en-trimmed' model is finetuned from Google's 'mt5-base' model. "
|
574 |
+
ui_article += "This model has a relatively good translation speed, but it only supports three languages: Chinese, Japanese, and English. "
|
575 |
|
576 |
whisper_models = app_config.get_model_names()
|
577 |
+
nllb_models = app_config.get_nllb_model_names()
|
578 |
+
|
579 |
+
common_whisper_inputs = lambda : [
|
580 |
+
gr.Dropdown(label="Whisper Model (for audio)", choices=whisper_models, value=app_config.default_model_name),
|
581 |
+
gr.Dropdown(label="Whisper Language", choices=sorted(get_language_names()), value=app_config.language),
|
582 |
+
]
|
583 |
+
common_nllb_inputs = lambda : [
|
584 |
+
gr.Dropdown(label="NLLB Model (for translate)", choices=nllb_models),
|
585 |
+
gr.Dropdown(label="NLLB Language", choices=sorted(get_nllb_lang_names())),
|
586 |
+
]
|
587 |
+
common_audio_inputs = lambda : [
|
588 |
gr.Text(label="URL (YouTube, etc.)"),
|
589 |
gr.File(label="Upload Files", file_count="multiple"),
|
590 |
gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
|
|
|
617 |
with gr.Row():
|
618 |
with gr.Column():
|
619 |
simple_submit = gr.Button("Submit", variant="primary")
|
620 |
+
with gr.Column():
|
621 |
+
with gr.Row():
|
622 |
+
simple_input = common_whisper_inputs()
|
623 |
+
with gr.Row():
|
624 |
+
simple_input += common_nllb_inputs()
|
625 |
+
with gr.Column():
|
626 |
+
simple_input += common_audio_inputs() + common_vad_inputs() + common_word_timestamps_inputs()
|
627 |
with gr.Column():
|
628 |
simple_output = common_output()
|
629 |
simple_flag = gr.Button("Flag")
|
|
|
646 |
with gr.Row():
|
647 |
with gr.Column():
|
648 |
full_submit = gr.Button("Submit", variant="primary")
|
649 |
+
with gr.Column():
|
650 |
+
with gr.Row():
|
651 |
+
full_input1 = common_whisper_inputs()
|
652 |
+
with gr.Row():
|
653 |
+
full_input1 += common_nllb_inputs()
|
654 |
+
with gr.Column():
|
655 |
+
full_input1 += common_audio_inputs() + common_vad_inputs() + [
|
656 |
+
gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
|
657 |
+
gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
|
658 |
+
gr.Dropdown(choices=VAD_INITIAL_PROMPT_MODE_VALUES, label="VAD - Initial Prompt Mode")]
|
659 |
+
|
660 |
+
full_input2 = common_word_timestamps_inputs() + [
|
661 |
+
gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations),
|
662 |
+
gr.Text(label="Word Timestamps - Append Punctuations", value=app_config.append_punctuations),
|
663 |
+
gr.TextArea(label="Initial Prompt"),
|
664 |
+
gr.Number(label="Temperature", value=app_config.temperature),
|
665 |
+
gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0),
|
666 |
+
gr.Number(label="Beam Size - Zero temperature", value=app_config.beam_size, precision=0),
|
667 |
+
gr.Number(label="Patience - Zero temperature", value=app_config.patience),
|
668 |
+
gr.Number(label="Length Penalty - Any temperature", value=app_config.length_penalty),
|
669 |
+
gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value=app_config.suppress_tokens),
|
670 |
+
gr.Checkbox(label="Condition on previous text", value=app_config.condition_on_previous_text),
|
671 |
+
gr.Checkbox(label="FP16", value=app_config.fp16),
|
672 |
+
gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
|
673 |
+
gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
|
674 |
+
gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
|
675 |
+
gr.Number(label="No speech threshold", value=app_config.no_speech_threshold)]
|
676 |
|
677 |
with gr.Column():
|
678 |
full_output = common_output()
|
|
|
704 |
if __name__ == '__main__':
|
705 |
default_app_config = ApplicationConfig.create_default()
|
706 |
whisper_models = default_app_config.get_model_names()
|
707 |
+
nllb_models = default_app_config.get_nllb_model_names()
|
708 |
|
709 |
# Environment variable overrides
|
710 |
default_whisper_implementation = os.environ.get("WHISPER_IMPLEMENTATION", default_app_config.whisper_implementation)
|
|
|
758 |
|
759 |
updated_config = default_app_config.update(**args)
|
760 |
|
761 |
+
#updated_config.whisper_implementation = "faster-whisper"
|
762 |
+
#updated_config.input_audio_max_duration = -1
|
763 |
+
#updated_config.default_model_name = "large-v2"
|
764 |
+
#updated_config.output_dir = "output"
|
765 |
+
#updated_config.vad_max_merge_size = 90
|
766 |
+
#updated_config.merge_subtitle_with_sources = True
|
767 |
+
#updated_config.autolaunch = True
|
768 |
+
|
769 |
if (threads := args.pop("threads")) > 0:
|
770 |
torch.set_num_threads(threads)
|
771 |
|
@@ -43,6 +43,102 @@
|
|
43 |
// "url": "https://example.com/path/to/model",
|
44 |
//}
|
45 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
// Configuration options that will be used if they are not specified in the command line arguments.
|
47 |
|
48 |
// * WEBUI options *
|
|
|
43 |
// "url": "https://example.com/path/to/model",
|
44 |
//}
|
45 |
],
|
46 |
+
"nllb_models": [
|
47 |
+
{
|
48 |
+
"name": "nllb-200-distilled-1.3B-ct2fast:int8_float16/michaelfeil",
|
49 |
+
"url": "michaelfeil/ct2fast-nllb-200-distilled-1.3B",
|
50 |
+
"type": "huggingface"
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"name": "nllb-200-3.3B-ct2fast:int8_float16/michaelfeil",
|
54 |
+
"url": "michaelfeil/ct2fast-nllb-200-3.3B",
|
55 |
+
"type": "huggingface"
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"name": "nllb-200-1.3B-ct2:float16/JustFrederik",
|
59 |
+
"url": "JustFrederik/nllb-200-1.3B-ct2-float16",
|
60 |
+
"type": "huggingface"
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"name": "nllb-200-distilled-1.3B-ct2:float16/JustFrederik",
|
64 |
+
"url": "JustFrederik/nllb-200-distilled-1.3B-ct2-float16",
|
65 |
+
"type": "huggingface"
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"name": "nllb-200-1.3B-ct2:int8/JustFrederik",
|
69 |
+
"url": "JustFrederik/nllb-200-1.3B-ct2-int8",
|
70 |
+
"type": "huggingface"
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"name": "nllb-200-distilled-1.3B-ct2:int8/JustFrederik",
|
74 |
+
"url": "JustFrederik/nllb-200-distilled-1.3B-ct2-int8",
|
75 |
+
"type": "huggingface"
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"name": "mt5-zh-ja-en-trimmed/K024",
|
79 |
+
"url": "K024/mt5-zh-ja-en-trimmed",
|
80 |
+
"type": "huggingface"
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"name": "mt5-zh-ja-en-trimmed-fine-tuned-v1/engmatic-earth",
|
84 |
+
"url": "engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1",
|
85 |
+
"type": "huggingface"
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"name": "nllb-200-distilled-600M/facebook",
|
89 |
+
"url": "facebook/nllb-200-distilled-600M",
|
90 |
+
"type": "huggingface"
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"name": "nllb-200-distilled-600M-ct2/JustFrederik",
|
94 |
+
"url": "JustFrederik/nllb-200-distilled-600M-ct2",
|
95 |
+
"type": "huggingface"
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"name": "nllb-200-distilled-600M-ct2:float16/JustFrederik",
|
99 |
+
"url": "JustFrederik/nllb-200-distilled-600M-ct2-float16",
|
100 |
+
"type": "huggingface"
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"name": "nllb-200-distilled-600M-ct2:int8/JustFrederik",
|
104 |
+
"url": "JustFrederik/nllb-200-distilled-600M-ct2-int8",
|
105 |
+
"type": "huggingface"
|
106 |
+
},
|
107 |
+
// Uncomment to add official Facebook 1.3B and 3.3B model
|
108 |
+
// The official Facebook 1.3B and 3.3B model files are too large,
|
109 |
+
// and to avoid occupying too much disk space on Hugging Face's free spaces,
|
110 |
+
// these models are not included in the config.
|
111 |
+
//{
|
112 |
+
// "name": "nllb-200-distilled-1.3B/facebook",
|
113 |
+
// "url": "facebook/nllb-200-distilled-1.3B",
|
114 |
+
// "type": "huggingface"
|
115 |
+
//},
|
116 |
+
//{
|
117 |
+
// "name": "nllb-200-1.3B/facebook",
|
118 |
+
// "url": "facebook/nllb-200-1.3B",
|
119 |
+
// "type": "huggingface"
|
120 |
+
//},
|
121 |
+
//{
|
122 |
+
// "name": "nllb-200-3.3B/facebook",
|
123 |
+
// "url": "facebook/nllb-200-3.3B",
|
124 |
+
// "type": "huggingface"
|
125 |
+
//},
|
126 |
+
//{
|
127 |
+
// "name": "nllb-200-distilled-1.3B-ct2/JustFrederik",
|
128 |
+
// "url": "JustFrederik/nllb-200-distilled-1.3B-ct2",
|
129 |
+
// "type": "huggingface"
|
130 |
+
//},
|
131 |
+
//{
|
132 |
+
// "name": "nllb-200-1.3B-ct2/JustFrederik",
|
133 |
+
// "url": "JustFrederik/nllb-200-1.3B-ct2",
|
134 |
+
// "type": "huggingface"
|
135 |
+
//},
|
136 |
+
//{
|
137 |
+
// "name": "nllb-200-3.3B-ct2:float16/JustFrederik",
|
138 |
+
// "url": "JustFrederik/nllb-200-3.3B-ct2-float16",
|
139 |
+
// "type": "huggingface"
|
140 |
+
//},
|
141 |
+
],
|
142 |
// Configuration options that will be used if they are not specified in the command line arguments.
|
143 |
|
144 |
// * WEBUI options *
|
@@ -1,4 +1,4 @@
|
|
1 |
-
ctranslate2
|
2 |
faster-whisper
|
3 |
ffmpeg-python==0.2.0
|
4 |
gradio==3.36.0
|
@@ -7,4 +7,5 @@ json5
|
|
7 |
torch
|
8 |
torchaudio
|
9 |
more_itertools
|
10 |
-
zhconv
|
|
|
|
1 |
+
ctranslate2>=3.16.0
|
2 |
faster-whisper
|
3 |
ffmpeg-python==0.2.0
|
4 |
gradio==3.36.0
|
|
|
7 |
torch
|
8 |
torchaudio
|
9 |
more_itertools
|
10 |
+
zhconv
|
11 |
+
sentencepiece
|
@@ -7,4 +7,5 @@ yt-dlp
|
|
7 |
torchaudio
|
8 |
altair
|
9 |
json5
|
10 |
-
zhconv
|
|
|
|
7 |
torchaudio
|
8 |
altair
|
9 |
json5
|
10 |
+
zhconv
|
11 |
+
sentencepiece
|
@@ -1,4 +1,4 @@
|
|
1 |
-
ctranslate2
|
2 |
faster-whisper
|
3 |
ffmpeg-python==0.2.0
|
4 |
gradio==3.36.0
|
@@ -7,4 +7,5 @@ json5
|
|
7 |
torch
|
8 |
torchaudio
|
9 |
more_itertools
|
10 |
-
zhconv
|
|
|
|
1 |
+
ctranslate2>=3.16.0
|
2 |
faster-whisper
|
3 |
ffmpeg-python==0.2.0
|
4 |
gradio==3.36.0
|
|
|
7 |
torch
|
8 |
torchaudio
|
9 |
more_itertools
|
10 |
+
zhconv
|
11 |
+
sentencepiece
|
@@ -47,11 +47,11 @@ class VadInitialPromptMode(Enum):
|
|
47 |
return None
|
48 |
|
49 |
class ApplicationConfig:
|
50 |
-
def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
|
51 |
share: bool = False, server_name: str = None, server_port: int = 7860,
|
52 |
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
|
53 |
whisper_implementation: str = "whisper",
|
54 |
-
default_model_name: str = "medium", default_vad: str = "silero-vad",
|
55 |
vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
|
56 |
auto_parallel: bool = False, output_dir: str = None,
|
57 |
model_dir: str = None, device: str = None,
|
@@ -72,6 +72,7 @@ class ApplicationConfig:
|
|
72 |
highlight_words: bool = False):
|
73 |
|
74 |
self.models = models
|
|
|
75 |
|
76 |
# WebUI settings
|
77 |
self.input_audio_max_duration = input_audio_max_duration
|
@@ -83,6 +84,7 @@ class ApplicationConfig:
|
|
83 |
|
84 |
self.whisper_implementation = whisper_implementation
|
85 |
self.default_model_name = default_model_name
|
|
|
86 |
self.default_vad = default_vad
|
87 |
self.vad_parallel_devices = vad_parallel_devices
|
88 |
self.vad_cpu_cores = vad_cpu_cores
|
@@ -124,6 +126,9 @@ class ApplicationConfig:
|
|
124 |
def get_model_names(self):
|
125 |
return [ x.name for x in self.models ]
|
126 |
|
|
|
|
|
|
|
127 |
def update(self, **new_values):
|
128 |
result = ApplicationConfig(**self.__dict__)
|
129 |
|
@@ -148,7 +153,9 @@ class ApplicationConfig:
|
|
148 |
# Load using json5
|
149 |
data = json5.load(f)
|
150 |
data_models = data.pop("models", [])
|
151 |
-
|
|
|
152 |
models = [ ModelConfig(**x) for x in data_models ]
|
|
|
153 |
|
154 |
-
return ApplicationConfig(models, **data)
|
|
|
47 |
return None
|
48 |
|
49 |
class ApplicationConfig:
|
50 |
+
def __init__(self, models: List[ModelConfig] = [], nllb_models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
|
51 |
share: bool = False, server_name: str = None, server_port: int = 7860,
|
52 |
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
|
53 |
whisper_implementation: str = "whisper",
|
54 |
+
default_model_name: str = "medium", default_nllb_model_name: str = "distilled-600M", default_vad: str = "silero-vad",
|
55 |
vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
|
56 |
auto_parallel: bool = False, output_dir: str = None,
|
57 |
model_dir: str = None, device: str = None,
|
|
|
72 |
highlight_words: bool = False):
|
73 |
|
74 |
self.models = models
|
75 |
+
self.nllb_models = nllb_models
|
76 |
|
77 |
# WebUI settings
|
78 |
self.input_audio_max_duration = input_audio_max_duration
|
|
|
84 |
|
85 |
self.whisper_implementation = whisper_implementation
|
86 |
self.default_model_name = default_model_name
|
87 |
+
self.default_nllb_model_name = default_nllb_model_name
|
88 |
self.default_vad = default_vad
|
89 |
self.vad_parallel_devices = vad_parallel_devices
|
90 |
self.vad_cpu_cores = vad_cpu_cores
|
|
|
126 |
def get_model_names(self):
|
127 |
return [ x.name for x in self.models ]
|
128 |
|
129 |
+
def get_nllb_model_names(self):
|
130 |
+
return [ x.name for x in self.nllb_models ]
|
131 |
+
|
132 |
def update(self, **new_values):
|
133 |
result = ApplicationConfig(**self.__dict__)
|
134 |
|
|
|
153 |
# Load using json5
|
154 |
data = json5.load(f)
|
155 |
data_models = data.pop("models", [])
|
156 |
+
data_nllb_models = data.pop("nllb_models", [])
|
157 |
+
|
158 |
models = [ ModelConfig(**x) for x in data_models ]
|
159 |
+
nllb_models = [ ModelConfig(**x) for x in data_nllb_models ]
|
160 |
|
161 |
+
return ApplicationConfig(models, nllb_models, **data)
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class NllbLang():
|
2 |
+
def __init__(self, code, name, code_whisper=None, name_whisper=None):
|
3 |
+
self.code = code
|
4 |
+
self.name = name
|
5 |
+
self.code_whisper = code_whisper
|
6 |
+
self.name_whisper = name_whisper
|
7 |
+
|
8 |
+
def __str__(self):
|
9 |
+
return "Language(code={}, name={})".format(self.code, self.name)
|
10 |
+
|
11 |
+
NLLB_LANGS = [
|
12 |
+
NllbLang('ace_Arab', 'Acehnese (Arabic script)'),
|
13 |
+
NllbLang('ace_Latn', 'Acehnese (Latin script)'),
|
14 |
+
NllbLang('acm_Arab', 'Mesopotamian Arabic', 'ar', 'Arabic'),
|
15 |
+
NllbLang('acq_Arab', 'Ta’izzi-Adeni Arabic', 'ar', 'Arabic'),
|
16 |
+
NllbLang('aeb_Arab', 'Tunisian Arabic'),
|
17 |
+
NllbLang('afr_Latn', 'Afrikaans', 'am', 'Amharic'),
|
18 |
+
NllbLang('ajp_Arab', 'South Levantine Arabic', 'ar', 'Arabic'),
|
19 |
+
NllbLang('aka_Latn', 'Akan'),
|
20 |
+
NllbLang('amh_Ethi', 'Amharic'),
|
21 |
+
NllbLang('apc_Arab', 'North Levantine Arabic', 'ar', 'Arabic'),
|
22 |
+
NllbLang('arb_Arab', 'Modern Standard Arabic', 'ar', 'Arabic'),
|
23 |
+
NllbLang('arb_Latn', 'Modern Standard Arabic (Romanized)'),
|
24 |
+
NllbLang('ars_Arab', 'Najdi Arabic', 'ar', 'Arabic'),
|
25 |
+
NllbLang('ary_Arab', 'Moroccan Arabic', 'ar', 'Arabic'),
|
26 |
+
NllbLang('arz_Arab', 'Egyptian Arabic', 'ar', 'Arabic'),
|
27 |
+
NllbLang('asm_Beng', 'Assamese', 'as', 'Assamese'),
|
28 |
+
NllbLang('ast_Latn', 'Asturian'),
|
29 |
+
NllbLang('awa_Deva', 'Awadhi'),
|
30 |
+
NllbLang('ayr_Latn', 'Central Aymara'),
|
31 |
+
NllbLang('azb_Arab', 'South Azerbaijani', 'az', 'Azerbaijani'),
|
32 |
+
NllbLang('azj_Latn', 'North Azerbaijani', 'az', 'Azerbaijani'),
|
33 |
+
NllbLang('bak_Cyrl', 'Bashkir', 'ba', 'Bashkir'),
|
34 |
+
NllbLang('bam_Latn', 'Bambara'),
|
35 |
+
NllbLang('ban_Latn', 'Balinese'),
|
36 |
+
NllbLang('bel_Cyrl', 'Belarusian', 'be', 'Belarusian'),
|
37 |
+
NllbLang('bem_Latn', 'Bemba'),
|
38 |
+
NllbLang('ben_Beng', 'Bengali', 'bn', 'Bengali'),
|
39 |
+
NllbLang('bho_Deva', 'Bhojpuri'),
|
40 |
+
NllbLang('bjn_Arab', 'Banjar (Arabic script)'),
|
41 |
+
NllbLang('bjn_Latn', 'Banjar (Latin script)'),
|
42 |
+
NllbLang('bod_Tibt', 'Standard Tibetan', 'bo', 'Tibetan'),
|
43 |
+
NllbLang('bos_Latn', 'Bosnian', 'bs', 'Bosnian'),
|
44 |
+
NllbLang('bug_Latn', 'Buginese'),
|
45 |
+
NllbLang('bul_Cyrl', 'Bulgarian', 'bg', 'Bulgarian'),
|
46 |
+
NllbLang('cat_Latn', 'Catalan', 'ca', 'Catalan'),
|
47 |
+
NllbLang('ceb_Latn', 'Cebuano'),
|
48 |
+
NllbLang('ces_Latn', 'Czech', 'cs', 'Czech'),
|
49 |
+
NllbLang('cjk_Latn', 'Chokwe'),
|
50 |
+
NllbLang('ckb_Arab', 'Central Kurdish'),
|
51 |
+
NllbLang('crh_Latn', 'Crimean Tatar'),
|
52 |
+
NllbLang('cym_Latn', 'Welsh', 'cy', 'Welsh'),
|
53 |
+
NllbLang('dan_Latn', 'Danish', 'da', 'Danish'),
|
54 |
+
NllbLang('deu_Latn', 'German', 'de', 'German'),
|
55 |
+
NllbLang('dik_Latn', 'Southwestern Dinka'),
|
56 |
+
NllbLang('dyu_Latn', 'Dyula'),
|
57 |
+
NllbLang('dzo_Tibt', 'Dzongkha'),
|
58 |
+
NllbLang('ell_Grek', 'Greek', 'el', 'Greek'),
|
59 |
+
NllbLang('eng_Latn', 'English', 'en', 'English'),
|
60 |
+
NllbLang('epo_Latn', 'Esperanto'),
|
61 |
+
NllbLang('est_Latn', 'Estonian', 'et', 'Estonian'),
|
62 |
+
NllbLang('eus_Latn', 'Basque', 'eu', 'Basque'),
|
63 |
+
NllbLang('ewe_Latn', 'Ewe'),
|
64 |
+
NllbLang('fao_Latn', 'Faroese', 'fo', 'Faroese'),
|
65 |
+
NllbLang('fij_Latn', 'Fijian'),
|
66 |
+
NllbLang('fin_Latn', 'Finnish', 'fi', 'Finnish'),
|
67 |
+
NllbLang('fon_Latn', 'Fon'),
|
68 |
+
NllbLang('fra_Latn', 'French', 'fr', 'French'),
|
69 |
+
NllbLang('fur_Latn', 'Friulian'),
|
70 |
+
NllbLang('fuv_Latn', 'Nigerian Fulfulde'),
|
71 |
+
NllbLang('gla_Latn', 'Scottish Gaelic'),
|
72 |
+
NllbLang('gle_Latn', 'Irish'),
|
73 |
+
NllbLang('glg_Latn', 'Galician', 'gl', 'Galician'),
|
74 |
+
NllbLang('grn_Latn', 'Guarani'),
|
75 |
+
NllbLang('guj_Gujr', 'Gujarati', 'gu', 'Gujarati'),
|
76 |
+
NllbLang('hat_Latn', 'Haitian Creole', 'ht', 'Haitian creole'),
|
77 |
+
NllbLang('hau_Latn', 'Hausa', 'ha', 'Hausa'),
|
78 |
+
NllbLang('heb_Hebr', 'Hebrew', 'he', 'Hebrew'),
|
79 |
+
NllbLang('hin_Deva', 'Hindi', 'hi', 'Hindi'),
|
80 |
+
NllbLang('hne_Deva', 'Chhattisgarhi'),
|
81 |
+
NllbLang('hrv_Latn', 'Croatian', 'hr', 'Croatian'),
|
82 |
+
NllbLang('hun_Latn', 'Hungarian', 'hu', 'Hungarian'),
|
83 |
+
NllbLang('hye_Armn', 'Armenian', 'hy', 'Armenian'),
|
84 |
+
NllbLang('ibo_Latn', 'Igbo'),
|
85 |
+
NllbLang('ilo_Latn', 'Ilocano'),
|
86 |
+
NllbLang('ind_Latn', 'Indonesian', 'id', 'Indonesian'),
|
87 |
+
NllbLang('isl_Latn', 'Icelandic', 'is', 'Icelandic'),
|
88 |
+
NllbLang('ita_Latn', 'Italian', 'it', 'Italian'),
|
89 |
+
NllbLang('jav_Latn', 'Javanese', 'jw', 'Javanese'),
|
90 |
+
NllbLang('jpn_Jpan', 'Japanese', 'ja', 'Japanese'),
|
91 |
+
NllbLang('kab_Latn', 'Kabyle'),
|
92 |
+
NllbLang('kac_Latn', 'Jingpho'),
|
93 |
+
NllbLang('kam_Latn', 'Kamba'),
|
94 |
+
NllbLang('kan_Knda', 'Kannada', 'kn', 'Kannada'),
|
95 |
+
NllbLang('kas_Arab', 'Kashmiri (Arabic script)'),
|
96 |
+
NllbLang('kas_Deva', 'Kashmiri (Devanagari script)'),
|
97 |
+
NllbLang('kat_Geor', 'Georgian', 'ka', 'Georgian'),
|
98 |
+
NllbLang('knc_Arab', 'Central Kanuri (Arabic script)'),
|
99 |
+
NllbLang('knc_Latn', 'Central Kanuri (Latin script)'),
|
100 |
+
NllbLang('kaz_Cyrl', 'Kazakh', 'kk', 'Kazakh'),
|
101 |
+
NllbLang('kbp_Latn', 'Kabiyè'),
|
102 |
+
NllbLang('kea_Latn', 'Kabuverdianu'),
|
103 |
+
NllbLang('khm_Khmr', 'Khmer', 'km', 'Khmer'),
|
104 |
+
NllbLang('kik_Latn', 'Kikuyu'),
|
105 |
+
NllbLang('kin_Latn', 'Kinyarwanda'),
|
106 |
+
NllbLang('kir_Cyrl', 'Kyrgyz'),
|
107 |
+
NllbLang('kmb_Latn', 'Kimbundu'),
|
108 |
+
NllbLang('kmr_Latn', 'Northern Kurdish'),
|
109 |
+
NllbLang('kon_Latn', 'Kikongo'),
|
110 |
+
NllbLang('kor_Hang', 'Korean', 'ko', 'Korean'),
|
111 |
+
NllbLang('lao_Laoo', 'Lao', 'lo', 'Lao'),
|
112 |
+
NllbLang('lij_Latn', 'Ligurian'),
|
113 |
+
NllbLang('lim_Latn', 'Limburgish'),
|
114 |
+
NllbLang('lin_Latn', 'Lingala', 'ln', 'Lingala'),
|
115 |
+
NllbLang('lit_Latn', 'Lithuanian', 'lt', 'Lithuanian'),
|
116 |
+
NllbLang('lmo_Latn', 'Lombard'),
|
117 |
+
NllbLang('ltg_Latn', 'Latgalian'),
|
118 |
+
NllbLang('ltz_Latn', 'Luxembourgish', 'lb', 'Luxembourgish'),
|
119 |
+
NllbLang('lua_Latn', 'Luba-Kasai'),
|
120 |
+
NllbLang('lug_Latn', 'Ganda'),
|
121 |
+
NllbLang('luo_Latn', 'Luo'),
|
122 |
+
NllbLang('lus_Latn', 'Mizo'),
|
123 |
+
NllbLang('lvs_Latn', 'Standard Latvian', 'lv', 'Latvian'),
|
124 |
+
NllbLang('mag_Deva', 'Magahi'),
|
125 |
+
NllbLang('mai_Deva', 'Maithili'),
|
126 |
+
NllbLang('mal_Mlym', 'Malayalam', 'ml', 'Malayalam'),
|
127 |
+
NllbLang('mar_Deva', 'Marathi', 'mr', 'Marathi'),
|
128 |
+
NllbLang('min_Arab', 'Minangkabau (Arabic script)'),
|
129 |
+
NllbLang('min_Latn', 'Minangkabau (Latin script)'),
|
130 |
+
NllbLang('mkd_Cyrl', 'Macedonian', 'mk', 'Macedonian'),
|
131 |
+
NllbLang('plt_Latn', 'Plateau Malagasy', 'mg', 'Malagasy'),
|
132 |
+
NllbLang('mlt_Latn', 'Maltese', 'mt', 'Maltese'),
|
133 |
+
NllbLang('mni_Beng', 'Meitei (Bengali script)'),
|
134 |
+
NllbLang('khk_Cyrl', 'Halh Mongolian', 'mn', 'Mongolian'),
|
135 |
+
NllbLang('mos_Latn', 'Mossi'),
|
136 |
+
NllbLang('mri_Latn', 'Maori', 'mi', 'Maori'),
|
137 |
+
NllbLang('mya_Mymr', 'Burmese', 'my', 'Myanmar'),
|
138 |
+
NllbLang('nld_Latn', 'Dutch', 'nl', 'Dutch'),
|
139 |
+
NllbLang('nno_Latn', 'Norwegian Nynorsk', 'nn', 'Nynorsk'),
|
140 |
+
NllbLang('nob_Latn', 'Norwegian Bokmål', 'no', 'Norwegian'),
|
141 |
+
NllbLang('npi_Deva', 'Nepali', 'ne', 'Nepali'),
|
142 |
+
NllbLang('nso_Latn', 'Northern Sotho'),
|
143 |
+
NllbLang('nus_Latn', 'Nuer'),
|
144 |
+
NllbLang('nya_Latn', 'Nyanja'),
|
145 |
+
NllbLang('oci_Latn', 'Occitan', 'oc', 'Occitan'),
|
146 |
+
NllbLang('gaz_Latn', 'West Central Oromo'),
|
147 |
+
NllbLang('ory_Orya', 'Odia'),
|
148 |
+
NllbLang('pag_Latn', 'Pangasinan'),
|
149 |
+
NllbLang('pan_Guru', 'Eastern Panjabi', 'pa', 'Punjabi'),
|
150 |
+
NllbLang('pap_Latn', 'Papiamento'),
|
151 |
+
NllbLang('pes_Arab', 'Western Persian', 'fa', 'Persian'),
|
152 |
+
NllbLang('pol_Latn', 'Polish', 'pl', 'Polish'),
|
153 |
+
NllbLang('por_Latn', 'Portuguese', 'pt', 'Portuguese'),
|
154 |
+
NllbLang('prs_Arab', 'Dari'),
|
155 |
+
NllbLang('pbt_Arab', 'Southern Pashto', 'ps', 'Pashto'),
|
156 |
+
NllbLang('quy_Latn', 'Ayacucho Quechua'),
|
157 |
+
NllbLang('ron_Latn', 'Romanian', 'ro', 'Romanian'),
|
158 |
+
NllbLang('run_Latn', 'Rundi'),
|
159 |
+
NllbLang('rus_Cyrl', 'Russian', 'ru', 'Russian'),
|
160 |
+
NllbLang('sag_Latn', 'Sango'),
|
161 |
+
NllbLang('san_Deva', 'Sanskrit', 'sa', 'Sanskrit'),
|
162 |
+
NllbLang('sat_Olck', 'Santali'),
|
163 |
+
NllbLang('scn_Latn', 'Sicilian'),
|
164 |
+
NllbLang('shn_Mymr', 'Shan'),
|
165 |
+
NllbLang('sin_Sinh', 'Sinhala', 'si', 'Sinhala'),
|
166 |
+
NllbLang('slk_Latn', 'Slovak', 'sk', 'Slovak'),
|
167 |
+
NllbLang('slv_Latn', 'Slovenian', 'sl', 'Slovenian'),
|
168 |
+
NllbLang('smo_Latn', 'Samoan'),
|
169 |
+
NllbLang('sna_Latn', 'Shona', 'sn', 'Shona'),
|
170 |
+
NllbLang('snd_Arab', 'Sindhi', 'sd', 'Sindhi'),
|
171 |
+
NllbLang('som_Latn', 'Somali', 'so', 'Somali'),
|
172 |
+
NllbLang('sot_Latn', 'Southern Sotho'),
|
173 |
+
NllbLang('spa_Latn', 'Spanish', 'es', 'Spanish'),
|
174 |
+
NllbLang('als_Latn', 'Tosk Albanian', 'sq', 'Albanian'),
|
175 |
+
NllbLang('srd_Latn', 'Sardinian'),
|
176 |
+
NllbLang('srp_Cyrl', 'Serbian', 'sr', 'Serbian'),
|
177 |
+
NllbLang('ssw_Latn', 'Swati'),
|
178 |
+
NllbLang('sun_Latn', 'Sundanese', 'su', 'Sundanese'),
|
179 |
+
NllbLang('swe_Latn', 'Swedish', 'sv', 'Swedish'),
|
180 |
+
NllbLang('swh_Latn', 'Swahili', 'sw', 'Swahili'),
|
181 |
+
NllbLang('szl_Latn', 'Silesian'),
|
182 |
+
NllbLang('tam_Taml', 'Tamil', 'ta', 'Tamil'),
|
183 |
+
NllbLang('tat_Cyrl', 'Tatar', 'tt', 'Tatar'),
|
184 |
+
NllbLang('tel_Telu', 'Telugu', 'te', 'Telugu'),
|
185 |
+
NllbLang('tgk_Cyrl', 'Tajik', 'tg', 'Tajik'),
|
186 |
+
NllbLang('tgl_Latn', 'Tagalog', 'tl', 'Tagalog'),
|
187 |
+
NllbLang('tha_Thai', 'Thai', 'th', 'Thai'),
|
188 |
+
NllbLang('tir_Ethi', 'Tigrinya'),
|
189 |
+
NllbLang('taq_Latn', 'Tamasheq (Latin script)'),
|
190 |
+
NllbLang('taq_Tfng', 'Tamasheq (Tifinagh script)'),
|
191 |
+
NllbLang('tpi_Latn', 'Tok Pisin'),
|
192 |
+
NllbLang('tsn_Latn', 'Tswana'),
|
193 |
+
NllbLang('tso_Latn', 'Tsonga'),
|
194 |
+
NllbLang('tuk_Latn', 'Turkmen', 'tk', 'Turkmen'),
|
195 |
+
NllbLang('tum_Latn', 'Tumbuka'),
|
196 |
+
NllbLang('tur_Latn', 'Turkish', 'tr', 'Turkish'),
|
197 |
+
NllbLang('twi_Latn', 'Twi'),
|
198 |
+
NllbLang('tzm_Tfng', 'Central Atlas Tamazight'),
|
199 |
+
NllbLang('uig_Arab', 'Uyghur'),
|
200 |
+
NllbLang('ukr_Cyrl', 'Ukrainian', 'uk', 'Ukrainian'),
|
201 |
+
NllbLang('umb_Latn', 'Umbundu'),
|
202 |
+
NllbLang('urd_Arab', 'Urdu', 'ur', 'Urdu'),
|
203 |
+
NllbLang('uzn_Latn', 'Northern Uzbek', 'uz', 'Uzbek'),
|
204 |
+
NllbLang('vec_Latn', 'Venetian'),
|
205 |
+
NllbLang('vie_Latn', 'Vietnamese', 'vi', 'Vietnamese'),
|
206 |
+
NllbLang('war_Latn', 'Waray'),
|
207 |
+
NllbLang('wol_Latn', 'Wolof'),
|
208 |
+
NllbLang('xho_Latn', 'Xhosa'),
|
209 |
+
NllbLang('ydd_Hebr', 'Eastern Yiddish', 'yi', 'Yiddish'),
|
210 |
+
NllbLang('yor_Latn', 'Yoruba', 'yo', 'Yoruba'),
|
211 |
+
NllbLang('yue_Hant', 'Yue Chinese', 'zh', 'Chinese'),
|
212 |
+
NllbLang('zho_Hans', 'Chinese (Simplified)', 'zh', 'Chinese'),
|
213 |
+
NllbLang('zho_Hant', 'Chinese (Traditional)', 'zh', 'Chinese'),
|
214 |
+
NllbLang('zsm_Latn', 'Standard Malay', 'ms', 'Malay'),
|
215 |
+
NllbLang('zul_Latn', 'Zulu'),
|
216 |
+
]
|
217 |
+
|
218 |
+
_TO_NLLB_LANG_CODE = {language.code.lower(): language for language in NLLB_LANGS if language.code is not None}
|
219 |
+
|
220 |
+
_TO_NLLB_LANG_NAME = {language.name.lower(): language for language in NLLB_LANGS if language.name is not None}
|
221 |
+
|
222 |
+
_TO_NLLB_LANG_WHISPER_CODE = {language.code_whisper.lower(): language for language in NLLB_LANGS if language.code_whisper is not None}
|
223 |
+
|
224 |
+
_TO_NLLB_LANG_WHISPER_NAME = {language.name_whisper.lower(): language for language in NLLB_LANGS if language.name_whisper is not None}
|
225 |
+
|
226 |
+
def get_nllb_lang_from_code(lang_code, default=None) -> NllbLang:
|
227 |
+
"""Return the language from the language code."""
|
228 |
+
return _TO_NLLB_LANG_CODE.get(lang_code, default)
|
229 |
+
|
230 |
+
def get_nllb_lang_from_name(lang_name, default=None) -> NllbLang:
|
231 |
+
"""Return the language from the language name."""
|
232 |
+
return _TO_NLLB_LANG_NAME.get(lang_name.lower() if lang_name else None, default)
|
233 |
+
|
234 |
+
def get_nllb_lang_from_code_whisper(lang_code_whisper, default=None) -> NllbLang:
|
235 |
+
"""Return the language from the language code."""
|
236 |
+
return _TO_NLLB_LANG_WHISPER_CODE.get(lang_code_whisper, default)
|
237 |
+
|
238 |
+
def get_nllb_lang_from_name_whisper(lang_name_whisper, default=None) -> NllbLang:
|
239 |
+
"""Return the language from the language name."""
|
240 |
+
return _TO_NLLB_LANG_WHISPER_NAME.get(lang_name_whisper.lower() if lang_name_whisper else None, default)
|
241 |
+
|
242 |
+
def get_nllb_lang_names():
|
243 |
+
"""Return a list of language names."""
|
244 |
+
return [language.name for language in NLLB_LANGS]
|
245 |
+
|
246 |
+
if __name__ == "__main__":
|
247 |
+
# Test lookup
|
248 |
+
print(get_nllb_lang_from_code('eng_Latn'))
|
249 |
+
print(get_nllb_lang_from_name('English'))
|
250 |
+
|
251 |
+
print(get_nllb_lang_names())
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import warnings
|
3 |
+
import huggingface_hub
|
4 |
+
import requests
|
5 |
+
import torch
|
6 |
+
|
7 |
+
import ctranslate2
|
8 |
+
import transformers
|
9 |
+
|
10 |
+
from typing import Optional
|
11 |
+
from src.config import ModelConfig
|
12 |
+
from src.languages import Language
|
13 |
+
from src.nllb.nllbLangs import NllbLang, get_nllb_lang_from_code_whisper
|
14 |
+
|
15 |
+
class NllbModel:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
model_config: ModelConfig,
|
19 |
+
device: str = None,
|
20 |
+
whisper_lang: Language = None,
|
21 |
+
nllb_lang: NllbLang = None,
|
22 |
+
download_root: Optional[str] = None,
|
23 |
+
local_files_only: bool = False,
|
24 |
+
load_model: bool = False,
|
25 |
+
):
|
26 |
+
"""Initializes the Nllb-200 model.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
model_config: Config of the model to use (distilled-600M, distilled-1.3B,
|
30 |
+
1.3B, 3.3B...) or a path to a converted
|
31 |
+
model directory. When a size is configured, the converted model is downloaded
|
32 |
+
from the Hugging Face Hub.
|
33 |
+
device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl,
|
34 |
+
ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia).
|
35 |
+
device_index: Device ID to use.
|
36 |
+
The model can also be loaded on multiple GPUs by passing a list of IDs
|
37 |
+
(e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel
|
38 |
+
when transcribe() is called from multiple Python threads (see also num_workers).
|
39 |
+
compute_type: Type to use for computation.
|
40 |
+
See https://opennmt.net/CTranslate2/quantization.html.
|
41 |
+
cpu_threads: Number of threads to use when running on CPU (4 by default).
|
42 |
+
A non zero value overrides the OMP_NUM_THREADS environment variable.
|
43 |
+
num_workers: When transcribe() is called from multiple Python threads,
|
44 |
+
having multiple workers enables true parallelism when running the model
|
45 |
+
(concurrent calls to self.model.generate() will run in parallel).
|
46 |
+
This can improve the global throughput at the cost of increased memory usage.
|
47 |
+
download_root: Directory where the models should be saved. If not set, the models
|
48 |
+
are saved in the standard Hugging Face cache directory.
|
49 |
+
local_files_only: If True, avoid downloading the file and return the path to the
|
50 |
+
local cached file if it exists.
|
51 |
+
"""
|
52 |
+
self.whisper_lang = whisper_lang
|
53 |
+
self.nllb_whisper_lang = get_nllb_lang_from_code_whisper(whisper_lang.code.lower() if whisper_lang is not None else "en")
|
54 |
+
self.nllb_lang = nllb_lang
|
55 |
+
self.model_config = model_config
|
56 |
+
|
57 |
+
if os.path.isdir(model_config.url):
|
58 |
+
self.model_path = model_config.url
|
59 |
+
else:
|
60 |
+
self.model_path = download_model(
|
61 |
+
model_config,
|
62 |
+
local_files_only=local_files_only,
|
63 |
+
cache_dir=download_root,
|
64 |
+
)
|
65 |
+
|
66 |
+
if device is None:
|
67 |
+
if torch.cuda.is_available():
|
68 |
+
device = "cuda" if "ct2" in self.model_path else "cuda:0"
|
69 |
+
else:
|
70 |
+
device = "cpu"
|
71 |
+
|
72 |
+
self.device = device
|
73 |
+
|
74 |
+
if load_model:
|
75 |
+
self.load_model()
|
76 |
+
|
77 |
+
def load_model(self):
|
78 |
+
print('\n\nLoading model: %s\n\n' % self.model_path)
|
79 |
+
if "ct2" in self.model_path:
|
80 |
+
self.target_prefix = [self.nllb_lang.code]
|
81 |
+
self.trans_tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_path, src_lang=self.nllb_whisper_lang.code)
|
82 |
+
self.trans_model = ctranslate2.Translator(self.model_path, compute_type="auto", device=self.device)
|
83 |
+
elif "mt5" in self.model_path:
|
84 |
+
self.mt5_prefix = self.whisper_lang.code + "2" + self.nllb_lang.code_whisper + ": "
|
85 |
+
self.trans_tokenizer = transformers.T5Tokenizer.from_pretrained(self.model_path) #requires spiece.model
|
86 |
+
self.trans_model = transformers.MT5ForConditionalGeneration.from_pretrained(self.model_path)
|
87 |
+
self.trans_translator = transformers.pipeline('text2text-generation', model=self.trans_model, device=self.device, tokenizer=self.trans_tokenizer)
|
88 |
+
else: #NLLB
|
89 |
+
self.trans_tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_path)
|
90 |
+
self.trans_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.model_path)
|
91 |
+
self.trans_translator = transformers.pipeline('translation', model=self.trans_model, device=self.device, tokenizer=self.trans_tokenizer, src_lang=self.nllb_whisper_lang.code, tgt_lang=self.nllb_lang.code)
|
92 |
+
|
93 |
+
def release_vram(self):
|
94 |
+
try:
|
95 |
+
if torch.cuda.is_available():
|
96 |
+
if "ct2" not in self.model_path:
|
97 |
+
device = torch.device("cpu")
|
98 |
+
self.trans_model.to(device)
|
99 |
+
del self.trans_model
|
100 |
+
torch.cuda.empty_cache()
|
101 |
+
print("release vram end.")
|
102 |
+
except Exception as e:
|
103 |
+
print("Error release vram: " + str(e))
|
104 |
+
|
105 |
+
|
106 |
+
def translation(self, text: str, max_length: int = 400):
|
107 |
+
output = None
|
108 |
+
result = None
|
109 |
+
try:
|
110 |
+
if "ct2" in self.model_path:
|
111 |
+
source = self.trans_tokenizer.convert_ids_to_tokens(self.trans_tokenizer.encode(text))
|
112 |
+
output = self.trans_model.translate_batch([source], target_prefix=[self.target_prefix])
|
113 |
+
target = output[0].hypotheses[0][1:]
|
114 |
+
result = self.trans_tokenizer.decode(self.trans_tokenizer.convert_tokens_to_ids(target))
|
115 |
+
elif "mt5" in self.model_path:
|
116 |
+
output = self.trans_translator(self.mt5_prefix + text, max_length=max_length, num_beams=4)
|
117 |
+
result = output[0]['generated_text']
|
118 |
+
else: #NLLB
|
119 |
+
output = self.trans_translator(text, max_length=max_length)
|
120 |
+
result = output[0]['translation_text']
|
121 |
+
except Exception as e:
|
122 |
+
print("Error translation text: " + str(e))
|
123 |
+
|
124 |
+
return result
|
125 |
+
|
126 |
+
|
127 |
+
_MODELS = ["distilled-600M", "distilled-1.3B", "1.3B", "3.3B",
|
128 |
+
"ct2fast-nllb-200-distilled-1.3B-int8_float16",
|
129 |
+
"ct2fast-nllb-200-3.3B-int8_float16",
|
130 |
+
"nllb-200-3.3B-ct2-float16", "nllb-200-1.3B-ct2", "nllb-200-1.3B-ct2-int8", "nllb-200-1.3B-ct2-float16",
|
131 |
+
"nllb-200-distilled-1.3B-ct2", "nllb-200-distilled-1.3B-ct2-int8", "nllb-200-distilled-1.3B-ct2-float16",
|
132 |
+
"nllb-200-distilled-600M-ct2", "nllb-200-distilled-600M-ct2-int8", "nllb-200-distilled-600M-ct2-float16",
|
133 |
+
"mt5-zh-ja-en-trimmed",
|
134 |
+
"mt5-zh-ja-en-trimmed-fine-tuned-v1"]
|
135 |
+
|
136 |
+
def check_model_name(name):
|
137 |
+
return any(allowed_name in name for allowed_name in _MODELS)
|
138 |
+
|
139 |
+
def download_model(
|
140 |
+
model_config: ModelConfig,
|
141 |
+
output_dir: Optional[str] = None,
|
142 |
+
local_files_only: bool = False,
|
143 |
+
cache_dir: Optional[str] = None,
|
144 |
+
):
|
145 |
+
""""download_model" is referenced from the "utils.py" script
|
146 |
+
of the "faster_whisper" project, authored by guillaumekln.
|
147 |
+
|
148 |
+
Downloads a nllb-200 model from the Hugging Face Hub.
|
149 |
+
|
150 |
+
The model is downloaded from https://huggingface.co/facebook.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
model_config: config of the model to download (facebook/nllb-distilled-600M,
|
154 |
+
facebook/nllb-distilled-1.3B, facebook/nllb-1.3B, facebook/nllb-3.3B...).
|
155 |
+
output_dir: Directory where the model should be saved. If not set, the model is saved in
|
156 |
+
the cache directory.
|
157 |
+
local_files_only: If True, avoid downloading the file and return the path to the local
|
158 |
+
cached file if it exists.
|
159 |
+
cache_dir: Path to the folder where cached files are stored.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
The path to the downloaded model.
|
163 |
+
|
164 |
+
Raises:
|
165 |
+
ValueError: if the model size is invalid.
|
166 |
+
"""
|
167 |
+
if not check_model_name(model_config.name):
|
168 |
+
raise ValueError(
|
169 |
+
"Invalid model name '%s', expected one of: %s" % (model_config.name, ", ".join(_MODELS))
|
170 |
+
)
|
171 |
+
|
172 |
+
repo_id = model_config.url #"facebook/nllb-200-%s" %
|
173 |
+
|
174 |
+
allow_patterns = [
|
175 |
+
"config.json",
|
176 |
+
"generation_config.json",
|
177 |
+
"model.bin",
|
178 |
+
"pytorch_model.bin",
|
179 |
+
"pytorch_model.bin.index.json",
|
180 |
+
"pytorch_model-00001-of-00003.bin",
|
181 |
+
"pytorch_model-00002-of-00003.bin",
|
182 |
+
"pytorch_model-00003-of-00003.bin",
|
183 |
+
"sentencepiece.bpe.model",
|
184 |
+
"tokenizer.json",
|
185 |
+
"tokenizer_config.json",
|
186 |
+
"shared_vocabulary.txt",
|
187 |
+
"shared_vocabulary.json",
|
188 |
+
"special_tokens_map.json",
|
189 |
+
"spiece.model",
|
190 |
+
]
|
191 |
+
|
192 |
+
kwargs = {
|
193 |
+
"local_files_only": local_files_only,
|
194 |
+
"allow_patterns": allow_patterns,
|
195 |
+
#"tqdm_class": disabled_tqdm,
|
196 |
+
}
|
197 |
+
|
198 |
+
if output_dir is not None:
|
199 |
+
kwargs["local_dir"] = output_dir
|
200 |
+
kwargs["local_dir_use_symlinks"] = False
|
201 |
+
|
202 |
+
if cache_dir is not None:
|
203 |
+
kwargs["cache_dir"] = cache_dir
|
204 |
+
|
205 |
+
try:
|
206 |
+
return huggingface_hub.snapshot_download(repo_id, **kwargs)
|
207 |
+
except (
|
208 |
+
huggingface_hub.utils.HfHubHTTPError,
|
209 |
+
requests.exceptions.ConnectionError,
|
210 |
+
) as exception:
|
211 |
+
warnings.warn(
|
212 |
+
"An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
|
213 |
+
repo_id,
|
214 |
+
exception,
|
215 |
+
)
|
216 |
+
warnings.warn(
|
217 |
+
"Trying to load the model directly from the local cache, if it exists."
|
218 |
+
)
|
219 |
+
|
220 |
+
kwargs["local_files_only"] = True
|
221 |
+
return huggingface_hub.snapshot_download(repo_id, **kwargs)
|
@@ -204,7 +204,7 @@ class ParallelTranscription(AbstractTranscription):
|
|
204 |
gpu_parallel_context.close()
|
205 |
|
206 |
perf_end_gpu = time.perf_counter()
|
207 |
-
print("
|
208 |
|
209 |
return merged
|
210 |
|
|
|
204 |
gpu_parallel_context.close()
|
205 |
|
206 |
perf_end_gpu = time.perf_counter()
|
207 |
+
print("\nParallel transcription took " + str(perf_end_gpu - perf_start_gpu) + " seconds")
|
208 |
|
209 |
return merged
|
210 |
|