Spaces:
Running
Enhanced Translation Model translation capabilities and optimized the Web UI interface.
Browse files1.Translation Model Enhancements:
* Added support for the M2M100 model.
* Added three new options for the translation model: Batch Size, No Repeat Ngram Size, Num Beams.
* When using the Translation Model for translation, it will now generate additional subtitle (srt) files for the original language (*-original.srt) and bilingual (*-bilingual.srt).
* In response to adjustments in the Translation Model functionality, nllbLangs has been renamed to translationLangs, and nllbModel has been renamed to translationModel.
2.Web UI Enhancements:
* Placed the translation model under tabs, with tabs for M2M100, NLLB, MT5.
* Organized the audio input under tabs for URL, Upload, Microphone.
* Categorized VAD options under tabs for VAD, Merge Window, Max Merge Size, Padding, Prompt Window, Initial Prompt Mode.
* Grouped Word Timestamps options under tabs for Word Timestamps, Highlight Words, Prepend Punctuations, Append Punctuations.
* On the Full page, the Whisper Advanced options have been organized into tabs, including Initial Prompt, Temperature, Best Of, Beam Size, Patience, Length Penalty, Suppress Tokens, Condition on previous text, FP16, Temperature increment on fallback, Compression ratio threshold, Logprob threshold, and No speech threshold.
3.New advanced options and program adjustments for Whisper:
* In the Whisper Advanced options on the Full page, Repetition Penalty and No Repeat Ngram Size options have been added for use with faster-whisper.
* Merged languages into translationLangs.
- app.py +485 -252
- cli.py +2 -2
- config.json5 +289 -243
- requirements-whisper.txt +0 -1
- src/config.py +28 -24
- src/languages.py +0 -147
- src/nllb/nllbLangs.py +0 -251
- src/translation/translationLangs.py +303 -0
- src/{nllb/nllbModel.py → translation/translationModel.py} +88 -72
- src/utils.py +79 -31
- src/vad.py +2 -3
- src/whisper/abstractWhisperContainer.py +3 -3
- src/whisper/fasterWhisperContainer.py +7 -17
- src/whisper/whisperContainer.py +7 -7
@@ -1,7 +1,7 @@
|
|
1 |
from datetime import datetime
|
2 |
import json
|
3 |
import math
|
4 |
-
from typing import Iterator, Union
|
5 |
import argparse
|
6 |
|
7 |
from io import StringIO
|
@@ -20,7 +20,6 @@ from src.diarization.diarizationContainer import DiarizationContainer
|
|
20 |
from src.hooks.progressListener import ProgressListener
|
21 |
from src.hooks.subTaskProgressListener import SubTaskProgressListener
|
22 |
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
23 |
-
from src.languages import _TO_LANGUAGE_CODE, get_language_names, get_language_from_name, get_language_from_code
|
24 |
from src.modelCache import ModelCache
|
25 |
from src.prompts.jsonPromptStrategy import JsonPromptStrategy
|
26 |
from src.prompts.prependPromptStrategy import PrependPromptStrategy
|
@@ -34,18 +33,18 @@ import ffmpeg
|
|
34 |
import gradio as gr
|
35 |
|
36 |
from src.download import ExceededMaximumDuration, download_url
|
37 |
-
from src.utils import optional_int, slugify, str2bool, write_srt, write_vtt
|
38 |
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
39 |
from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
|
40 |
from src.whisper.whisperFactory import create_whisper_container
|
41 |
-
from src.
|
42 |
-
from src.
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
import shutil
|
47 |
import zhconv
|
48 |
import tqdm
|
|
|
49 |
|
50 |
# Configure more application defaults in config.json5
|
51 |
|
@@ -114,120 +113,231 @@ class WhisperTranscriber:
|
|
114 |
self.diarization.cleanup()
|
115 |
self.diarization_kwargs = None
|
116 |
|
117 |
-
# Entry function for the simple tab
|
118 |
-
def transcribe_webui_simple(self,
|
119 |
-
vad, vadMergeWindow, vadMaxMergeSize,
|
120 |
-
word_timestamps: bool = False, highlight_words: bool = False,
|
121 |
-
diarization: bool = False, diarization_speakers: int = 2,
|
122 |
-
diarization_min_speakers = 1, diarization_max_speakers = 8):
|
123 |
-
return self.transcribe_webui_simple_progress(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
|
124 |
-
vad, vadMergeWindow, vadMaxMergeSize,
|
125 |
-
word_timestamps, highlight_words,
|
126 |
-
diarization, diarization_speakers,
|
127 |
-
diarization_min_speakers, diarization_max_speakers)
|
128 |
|
129 |
-
# Entry function for the simple tab progress
|
130 |
-
def transcribe_webui_simple_progress(self,
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
else:
|
143 |
-
self.
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
153 |
-
# Word timestamps
|
154 |
-
word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
|
155 |
-
initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
|
156 |
-
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
157 |
-
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
|
158 |
-
diarization: bool = False, diarization_speakers: int = 2,
|
159 |
-
diarization_min_speakers = 1, diarization_max_speakers = 8):
|
160 |
-
|
161 |
-
return self.transcribe_webui_full_progress(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
|
162 |
-
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
163 |
-
word_timestamps, highlight_words, prepend_punctuations, append_punctuations,
|
164 |
-
initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
|
165 |
-
condition_on_previous_text, fp16, temperature_increment_on_fallback,
|
166 |
-
compression_ratio_threshold, logprob_threshold, no_speech_threshold,
|
167 |
-
diarization, diarization_speakers,
|
168 |
-
diarization_min_speakers, diarization_max_speakers)
|
169 |
-
|
170 |
-
# Entry function for the full tab with progress
|
171 |
-
def transcribe_webui_full_progress(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
|
172 |
-
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
173 |
-
# Word timestamps
|
174 |
-
word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
|
175 |
-
initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
|
176 |
-
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
177 |
-
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
|
178 |
-
diarization: bool = False, diarization_speakers: int = 2,
|
179 |
-
diarization_min_speakers = 1, diarization_max_speakers = 8,
|
180 |
-
progress=gr.Progress()):
|
181 |
-
|
182 |
-
# Handle temperature_increment_on_fallback
|
183 |
-
if temperature_increment_on_fallback is not None:
|
184 |
-
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
|
185 |
-
else:
|
186 |
-
temperature = [temperature]
|
187 |
-
|
188 |
-
vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
|
189 |
|
190 |
-
# Set diarization
|
191 |
-
if diarization:
|
192 |
-
if diarization_speakers is not None and diarization_speakers < 1:
|
193 |
-
self.set_diarization(auth_token=self.app_config.auth_token, min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
|
194 |
-
else:
|
195 |
-
self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers, min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
|
196 |
-
else:
|
197 |
-
self.unset_diarization()
|
198 |
-
|
199 |
-
return self.transcribe_webui(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task, vadOptions,
|
200 |
-
initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
|
201 |
-
condition_on_previous_text=condition_on_previous_text, fp16=fp16,
|
202 |
-
compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
|
203 |
-
word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
|
204 |
-
progress=progress)
|
205 |
-
|
206 |
-
def transcribe_webui(self, modelName: str, languageName: str, nllbModelName: str, nllbLangName: str, urlData: str, multipleFiles, microphoneData: str, task: str,
|
207 |
-
vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
|
208 |
-
**decodeOptions: dict):
|
209 |
-
try:
|
210 |
progress(0, desc="init audio sources")
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
if (len(sources) == 0):
|
213 |
raise Exception("init audio sources failed...")
|
|
|
214 |
try:
|
215 |
progress(0, desc="init whisper model")
|
216 |
-
|
217 |
-
|
218 |
-
selectedModel =
|
219 |
|
220 |
model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation,
|
221 |
model_name=selectedModel, compute_type=self.app_config.compute_type,
|
222 |
-
cache=self.model_cache, models=self.app_config.models)
|
223 |
|
224 |
progress(0, desc="init translate model")
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
progress(0, desc="init transcribe")
|
232 |
# Result
|
233 |
download = []
|
@@ -238,7 +348,7 @@ class WhisperTranscriber:
|
|
238 |
# Write result
|
239 |
downloadDirectory = tempfile.mkdtemp()
|
240 |
source_index = 0
|
241 |
-
extra_tasks_count = 1 if
|
242 |
|
243 |
outputDirectory = self.output_dir if self.output_dir is not None else downloadDirectory
|
244 |
|
@@ -267,10 +377,10 @@ class WhisperTranscriber:
|
|
267 |
sub_task_total=sub_task_total)
|
268 |
|
269 |
# Transcribe
|
270 |
-
result = self.transcribe_file(model, source.source_path,
|
271 |
-
if
|
272 |
-
|
273 |
-
|
274 |
|
275 |
short_name, suffix = source.get_short_name_suffix(max_length=self.app_config.input_max_file_name_length)
|
276 |
filePrefix = slugify(source_prefix + short_name, allow_unicode=True)
|
@@ -278,7 +388,7 @@ class WhisperTranscriber:
|
|
278 |
# Update progress
|
279 |
current_progress += source_audio_duration
|
280 |
|
281 |
-
source_download, source_text, source_vtt = self.write_result(result,
|
282 |
|
283 |
if self.app_config.merge_subtitle_with_sources and self.app_config.output_dir is not None:
|
284 |
print("\nmerge subtitle(srt) with source file [" + source.source_name + "]\n")
|
@@ -287,8 +397,8 @@ class WhisperTranscriber:
|
|
287 |
srt_path = source_download[0]
|
288 |
save_path = os.path.join(self.app_config.output_dir, filePrefix)
|
289 |
# save_without_ext, ext = os.path.splitext(save_path)
|
290 |
-
source_lang = "." +
|
291 |
-
translate_lang = "." +
|
292 |
output_with_srt = save_path + source_lang + translate_lang + suffix
|
293 |
|
294 |
#ffmpeg -i "input.mp4" -i "input.srt" -c copy -c:s mov_text output.mp4
|
@@ -363,12 +473,11 @@ class WhisperTranscriber:
|
|
363 |
except ExceededMaximumDuration as e:
|
364 |
return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
|
365 |
except Exception as e:
|
366 |
-
import traceback
|
367 |
print(traceback.format_exc())
|
368 |
-
return [], ("Error occurred during transcribe: " + str(e)),
|
369 |
|
370 |
|
371 |
-
def transcribe_file(self, model: AbstractWhisperContainer, audio_path: str,
|
372 |
vadOptions: VadOptions = VadOptions(),
|
373 |
progressListener: ProgressListener = None, **decodeOptions: dict):
|
374 |
|
@@ -398,7 +507,7 @@ class WhisperTranscriber:
|
|
398 |
raise ValueError("Invalid vadInitialPromptMode: " + initial_prompt_mode)
|
399 |
|
400 |
# Callable for processing an audio file
|
401 |
-
whisperCallable = model.create_callback(
|
402 |
|
403 |
# The results
|
404 |
if (vadOptions.vad == 'silero-vad'):
|
@@ -513,7 +622,7 @@ class WhisperTranscriber:
|
|
513 |
|
514 |
return config
|
515 |
|
516 |
-
def write_result(self, result: dict,
|
517 |
if not os.path.exists(output_dir):
|
518 |
os.makedirs(output_dir)
|
519 |
|
@@ -522,7 +631,7 @@ class WhisperTranscriber:
|
|
522 |
language = result["language"]
|
523 |
languageMaxLineWidth = self.__get_max_line_width(language)
|
524 |
|
525 |
-
if
|
526 |
try:
|
527 |
segments_progress_listener = SubTaskProgressListener(progressListener,
|
528 |
base_task_total=progressListener.sub_task_total,
|
@@ -530,17 +639,15 @@ class WhisperTranscriber:
|
|
530 |
sub_task_total=1)
|
531 |
pbar = tqdm.tqdm(total=len(segments))
|
532 |
perf_start_time = time.perf_counter()
|
533 |
-
|
534 |
for idx, segment in enumerate(segments):
|
535 |
seg_text = segment["text"]
|
536 |
-
|
537 |
-
|
538 |
-
if nllb_model.nllb_lang is not None:
|
539 |
-
segment["text"] = nllb_model.translation(seg_text)
|
540 |
pbar.update(1)
|
541 |
segments_progress_listener.on_progress(idx+1, len(segments), desc=f"Process segments: {idx}/{len(segments)}")
|
542 |
|
543 |
-
|
544 |
perf_end_time = time.perf_counter()
|
545 |
# Call the finished callback
|
546 |
if segments_progress_listener is not None:
|
@@ -549,24 +656,57 @@ class WhisperTranscriber:
|
|
549 |
print("\n\nprocess segments took {} seconds.\n\n".format(perf_end_time - perf_start_time))
|
550 |
except Exception as e:
|
551 |
# Ignore error - it's just a cleanup
|
|
|
552 |
print("Error process segments: " + str(e))
|
553 |
|
554 |
print("Max line width " + str(languageMaxLineWidth) + " for language:" + language)
|
555 |
vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth, highlight_words=highlight_words)
|
556 |
srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth, highlight_words=highlight_words)
|
557 |
json_result = json.dumps(result, indent=4, ensure_ascii=False)
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
564 |
|
565 |
output_files = []
|
566 |
output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
|
567 |
output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
|
568 |
output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
|
569 |
output_files.append(self.__create_file(json_result, output_dir, source_name + "-result.json"));
|
|
|
|
|
|
|
|
|
570 |
|
571 |
return output_files, text, vtt
|
572 |
|
@@ -593,6 +733,10 @@ class WhisperTranscriber:
|
|
593 |
write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
|
594 |
elif format == 'srt':
|
595 |
write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
|
|
|
|
|
|
|
|
|
596 |
else:
|
597 |
raise Exception("Unknown format " + format)
|
598 |
|
@@ -621,6 +765,16 @@ class WhisperTranscriber:
|
|
621 |
self.diarization = None
|
622 |
|
623 |
def create_ui(app_config: ApplicationConfig):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
624 |
ui = WhisperTranscriber(app_config.input_audio_max_duration, app_config.vad_process_timeout, app_config.vad_cpu_cores,
|
625 |
app_config.delete_uploaded_files, app_config.output_dir, app_config)
|
626 |
|
@@ -639,59 +793,69 @@ def create_ui(app_config: ApplicationConfig):
|
|
639 |
# Try to convert from camel-case to title-case
|
640 |
implementation_name = app_config.whisper_implementation.title().replace("_", " ").replace("-", " ")
|
641 |
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
|
646 |
-
|
647 |
|
648 |
# Recommend faster-whisper
|
649 |
if is_whisper:
|
650 |
-
|
651 |
|
652 |
if app_config.input_audio_max_duration > 0:
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
whisper_models = app_config.get_model_names()
|
668 |
-
nllb_models = app_config.
|
|
|
|
|
669 |
|
670 |
-
common_whisper_inputs = lambda :
|
671 |
-
gr.Dropdown(label="Whisper Model (for audio)", choices=whisper_models, value=app_config.default_model_name),
|
672 |
-
gr.Dropdown(label="Whisper Language", choices=sorted(
|
673 |
-
|
674 |
-
|
675 |
-
gr.Dropdown(label="
|
676 |
-
gr.Dropdown(label="
|
677 |
-
|
678 |
-
|
679 |
-
gr.
|
680 |
-
gr.
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD"),
|
687 |
-
gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window),
|
688 |
-
gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size),
|
689 |
-
]
|
690 |
|
691 |
-
|
692 |
-
gr.
|
693 |
-
gr.
|
694 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
695 |
|
696 |
has_diarization_libs = Diarization.has_libraries()
|
697 |
|
@@ -699,12 +863,12 @@ def create_ui(app_config: ApplicationConfig):
|
|
699 |
print("Diarization libraries not found - disabling diarization")
|
700 |
app_config.diarization = False
|
701 |
|
702 |
-
common_diarization_inputs = lambda :
|
703 |
-
gr.Checkbox(label="Diarization", value=app_config.diarization, interactive=has_diarization_libs),
|
704 |
-
gr.Number(label="Diarization - Speakers", precision=0, value=app_config.diarization_speakers, interactive=has_diarization_libs),
|
705 |
-
gr.Number(label="Diarization - Min Speakers", precision=0, value=app_config.diarization_min_speakers, interactive=has_diarization_libs),
|
706 |
-
gr.Number(label="Diarization - Max Speakers", precision=0, value=app_config.diarization_max_speakers, interactive=has_diarization_libs)
|
707 |
-
|
708 |
|
709 |
common_output = lambda : [
|
710 |
gr.File(label="Download"),
|
@@ -714,84 +878,152 @@ def create_ui(app_config: ApplicationConfig):
|
|
714 |
|
715 |
is_queue_mode = app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0
|
716 |
|
717 |
-
|
718 |
-
|
719 |
-
with gr.Blocks() as
|
720 |
-
gr.
|
|
|
|
|
721 |
with gr.Row():
|
722 |
with gr.Column():
|
723 |
-
|
724 |
with gr.Column():
|
725 |
with gr.Row():
|
726 |
-
|
727 |
-
with gr.
|
728 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
729 |
with gr.Column():
|
730 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
731 |
with gr.Column():
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
-
|
738 |
-
|
739 |
-
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
-
|
745 |
|
746 |
-
|
|
|
747 |
|
748 |
-
with gr.Blocks() as
|
749 |
-
gr.
|
|
|
|
|
750 |
with gr.Row():
|
751 |
with gr.Column():
|
752 |
-
|
753 |
with gr.Column():
|
754 |
with gr.Row():
|
755 |
-
|
756 |
-
with gr.
|
757 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
758 |
with gr.Column():
|
759 |
-
|
760 |
-
|
761 |
-
gr.
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
gr.
|
769 |
-
gr.
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
gr.
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
781 |
with gr.Column():
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
-
|
|
|
|
|
793 |
|
794 |
-
demo = gr.TabbedInterface([
|
795 |
|
796 |
# Queue up the demo
|
797 |
if is_queue_mode:
|
@@ -807,8 +1039,7 @@ def create_ui(app_config: ApplicationConfig):
|
|
807 |
|
808 |
if __name__ == '__main__':
|
809 |
default_app_config = ApplicationConfig.create_default()
|
810 |
-
whisper_models = default_app_config.get_model_names()
|
811 |
-
nllb_models = default_app_config.get_nllb_model_names()
|
812 |
|
813 |
# Environment variable overrides
|
814 |
default_whisper_implementation = os.environ.get("WHISPER_IMPLEMENTATION", default_app_config.whisper_implementation)
|
@@ -846,9 +1077,10 @@ if __name__ == '__main__':
|
|
846 |
help="the compute type to use for inference")
|
847 |
parser.add_argument("--threads", type=optional_int, default=0,
|
848 |
help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
|
|
849 |
parser.add_argument("--vad_max_merge_size", type=int, default=default_app_config.vad_max_merge_size, \
|
850 |
help="The number of VAD - Max Merge Size (s).") # 30
|
851 |
-
parser.add_argument("--language", type=str, default=None, choices=sorted(
|
852 |
help="language spoken in the audio, specify None to perform language detection")
|
853 |
parser.add_argument("--save_downloaded_files", action='store_true', \
|
854 |
help="True to move downloaded files to outputs directory. This argument will take effect only after output_dir is set.")
|
@@ -858,6 +1090,7 @@ if __name__ == '__main__':
|
|
858 |
help="Maximum length of a file name.")
|
859 |
parser.add_argument("--autolaunch", action='store_true', \
|
860 |
help="open the webui URL in the system's default browser upon launch")
|
|
|
861 |
parser.add_argument('--auth_token', type=str, default=default_app_config.auth_token, help='HuggingFace API Token (optional)')
|
862 |
parser.add_argument("--diarization", type=str2bool, default=default_app_config.diarization, \
|
863 |
help="whether to perform speaker diarization")
|
|
|
1 |
from datetime import datetime
|
2 |
import json
|
3 |
import math
|
4 |
+
from typing import Iterator, Union, List
|
5 |
import argparse
|
6 |
|
7 |
from io import StringIO
|
|
|
20 |
from src.hooks.progressListener import ProgressListener
|
21 |
from src.hooks.subTaskProgressListener import SubTaskProgressListener
|
22 |
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
|
|
23 |
from src.modelCache import ModelCache
|
24 |
from src.prompts.jsonPromptStrategy import JsonPromptStrategy
|
25 |
from src.prompts.prependPromptStrategy import PrependPromptStrategy
|
|
|
33 |
import gradio as gr
|
34 |
|
35 |
from src.download import ExceededMaximumDuration, download_url
|
36 |
+
from src.utils import optional_int, slugify, str2bool, write_srt, write_srt_original, write_vtt
|
37 |
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
38 |
from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
|
39 |
from src.whisper.whisperFactory import create_whisper_container
|
40 |
+
from src.translation.translationModel import TranslationModel
|
41 |
+
from src.translation.translationLangs import (TranslationLang,
|
42 |
+
_TO_LANG_CODE_WHISPER, get_lang_whisper_names, get_lang_from_whisper_name, get_lang_from_whisper_code,
|
43 |
+
get_lang_nllb_names, get_lang_from_nllb_name, get_lang_m2m100_names, get_lang_from_m2m100_name)
|
|
|
44 |
import shutil
|
45 |
import zhconv
|
46 |
import tqdm
|
47 |
+
import traceback
|
48 |
|
49 |
# Configure more application defaults in config.json5
|
50 |
|
|
|
113 |
self.diarization.cleanup()
|
114 |
self.diarization_kwargs = None
|
115 |
|
116 |
+
# Entry function for the simple tab, Queue mode disabled: progress bars will not be shown
|
117 |
+
def transcribe_webui_simple(self, data: dict): return self.transcribe_webui_simple_progress(data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
+
# Entry function for the simple tab progress, Progress tracking requires queuing to be enabled
|
120 |
+
def transcribe_webui_simple_progress(self, data: dict, progress=gr.Progress()):
|
121 |
+
dataDict = {}
|
122 |
+
for key, value in data.items():
|
123 |
+
dataDict.update({key.elem_id: value})
|
124 |
+
|
125 |
+
return self.transcribe_webui(dataDict, progress=progress)
|
126 |
+
|
127 |
+
# Entry function for the full tab, Queue mode disabled: progress bars will not be shown
|
128 |
+
def transcribe_webui_full(self, data: dict): return self.transcribe_webui_full_progress(data)
|
129 |
+
|
130 |
+
# Entry function for the full tab with progress, Progress tracking requires queuing to be enabled
|
131 |
+
def transcribe_webui_full_progress(self, data: dict, progress=gr.Progress()):
|
132 |
+
dataDict = {}
|
133 |
+
for key, value in data.items():
|
134 |
+
dataDict.update({key.elem_id: value})
|
135 |
+
|
136 |
+
return self.transcribe_webui(dataDict, progress=progress)
|
137 |
+
|
138 |
+
def transcribe_webui(self, decodeOptions: dict, progress: gr.Progress = None):
|
139 |
+
"""
|
140 |
+
Transcribe an audio file using Whisper
|
141 |
+
https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L37
|
142 |
+
Parameters
|
143 |
+
----------
|
144 |
+
model: Whisper
|
145 |
+
The Whisper model instance
|
146 |
+
|
147 |
+
temperature: Union[float, Tuple[float, ...]]
|
148 |
+
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
|
149 |
+
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
150 |
+
|
151 |
+
compression_ratio_threshold: float
|
152 |
+
If the gzip compression ratio is above this value, treat as failed
|
153 |
+
|
154 |
+
logprob_threshold: float
|
155 |
+
If the average log probability over sampled tokens is below this value, treat as failed
|
156 |
+
|
157 |
+
no_speech_threshold: float
|
158 |
+
If the no_speech probability is higher than this value AND the average log probability
|
159 |
+
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
160 |
+
|
161 |
+
condition_on_previous_text: bool
|
162 |
+
if True, the previous output of the model is provided as a prompt for the next window;
|
163 |
+
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
164 |
+
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
165 |
+
|
166 |
+
word_timestamps: bool
|
167 |
+
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
168 |
+
and include the timestamps for each word in each segment.
|
169 |
+
|
170 |
+
prepend_punctuations: str
|
171 |
+
If word_timestamps is True, merge these punctuation symbols with the next word
|
172 |
+
|
173 |
+
append_punctuations: str
|
174 |
+
If word_timestamps is True, merge these punctuation symbols with the previous word
|
175 |
+
|
176 |
+
initial_prompt: Optional[str]
|
177 |
+
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
178 |
+
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
179 |
+
to make it more likely to predict those word correctly.
|
180 |
+
|
181 |
+
decode_options: dict
|
182 |
+
Keyword arguments to construct `DecodingOptions` instances
|
183 |
+
https://github.com/openai/whisper/blob/main/whisper/decoding.py#L81
|
184 |
+
|
185 |
+
task: str = "transcribe"
|
186 |
+
whether to perform X->X "transcribe" or X->English "translate"
|
187 |
+
|
188 |
+
language: Optional[str] = None
|
189 |
+
language that the audio is in; uses detected language if None
|
190 |
+
|
191 |
+
temperature: float = 0.0
|
192 |
+
sample_len: Optional[int] = None # maximum number of tokens to sample
|
193 |
+
best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
|
194 |
+
beam_size: Optional[int] = None # number of beams in beam search, if t == 0
|
195 |
+
patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
|
196 |
+
sampling-related options
|
197 |
+
|
198 |
+
length_penalty: Optional[float] = None
|
199 |
+
"alpha" in Google NMT, or None for length norm, when ranking generations
|
200 |
+
to select which to return among the beams or best-of-N samples
|
201 |
+
|
202 |
+
prompt: Optional[Union[str, List[int]]] = None # for the previous context
|
203 |
+
prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
|
204 |
+
text or tokens to feed as the prompt or the prefix; for more info:
|
205 |
+
https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
206 |
+
|
207 |
+
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
208 |
+
suppress_blank: bool = True # this will suppress blank outputs
|
209 |
+
list of tokens ids (or comma-separated token ids) to suppress
|
210 |
+
"-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
211 |
+
|
212 |
+
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
213 |
+
max_initial_timestamp: Optional[float] = 1.0
|
214 |
+
timestamp sampling options
|
215 |
+
|
216 |
+
fp16: bool = True # use fp16 for most of the calculation
|
217 |
+
implementation details
|
218 |
+
repetition_penalty: float
|
219 |
+
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
|
220 |
+
no_repeat_ngram_size: int
|
221 |
+
The model ensures that a sequence of words of no_repeat_ngram_size isn’t repeated in the output sequence. If specified, it must be a positive integer greater than 1.
|
222 |
+
"""
|
223 |
+
try:
|
224 |
+
whisperModelName: str = decodeOptions.pop("whisperModelName")
|
225 |
+
whisperLangName: str = decodeOptions.pop("whisperLangName")
|
226 |
+
|
227 |
+
translateInput: str = decodeOptions.pop("translateInput")
|
228 |
+
m2m100ModelName: str = decodeOptions.pop("m2m100ModelName")
|
229 |
+
m2m100LangName: str = decodeOptions.pop("m2m100LangName")
|
230 |
+
nllbModelName: str = decodeOptions.pop("nllbModelName")
|
231 |
+
nllbLangName: str = decodeOptions.pop("nllbLangName")
|
232 |
+
mt5ModelName: str = decodeOptions.pop("mt5ModelName")
|
233 |
+
mt5LangName: str = decodeOptions.pop("mt5LangName")
|
234 |
+
|
235 |
+
translationBatchSize: int = decodeOptions.pop("translationBatchSize")
|
236 |
+
translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
|
237 |
+
translationNumBeams: int = decodeOptions.pop("translationNumBeams")
|
238 |
+
|
239 |
+
sourceInput: str = decodeOptions.pop("sourceInput")
|
240 |
+
urlData: str = decodeOptions.pop("urlData")
|
241 |
+
multipleFiles: List = decodeOptions.pop("multipleFiles")
|
242 |
+
microphoneData: str = decodeOptions.pop("microphoneData")
|
243 |
+
task: str = decodeOptions.pop("task")
|
244 |
+
|
245 |
+
vad: str = decodeOptions.pop("vad")
|
246 |
+
vadMergeWindow: float = decodeOptions.pop("vadMergeWindow")
|
247 |
+
vadMaxMergeSize: float = decodeOptions.pop("vadMaxMergeSize")
|
248 |
+
vadPadding: float = decodeOptions.pop("vadPadding", self.app_config.vad_padding)
|
249 |
+
vadPromptWindow: float = decodeOptions.pop("vadPromptWindow", self.app_config.vad_prompt_window)
|
250 |
+
vadInitialPromptMode: str = decodeOptions.pop("vadInitialPromptMode", self.app_config.vad_initial_prompt_mode)
|
251 |
+
|
252 |
+
diarization: bool = decodeOptions.pop("diarization", False)
|
253 |
+
diarization_speakers: int = decodeOptions.pop("diarization_speakers", 2)
|
254 |
+
diarization_min_speakers: int = decodeOptions.pop("diarization_min_speakers", 1)
|
255 |
+
diarization_max_speakers: int = decodeOptions.pop("diarization_max_speakers", 8)
|
256 |
+
highlight_words: bool = decodeOptions.pop("highlight_words", False)
|
257 |
+
|
258 |
+
temperature: float = decodeOptions.pop("temperature", None)
|
259 |
+
temperature_increment_on_fallback: float = decodeOptions.pop("temperature_increment_on_fallback", None)
|
260 |
+
|
261 |
+
whisperRepetitionPenalty: float = decodeOptions.get("repetition_penalty", None)
|
262 |
+
whisperNoRepeatNgramSize: int = decodeOptions.get("no_repeat_ngram_size", None)
|
263 |
+
if whisperRepetitionPenalty is not None and whisperRepetitionPenalty <= 1.0:
|
264 |
+
decodeOptions.pop("repetition_penalty")
|
265 |
+
if whisperNoRepeatNgramSize is not None and whisperNoRepeatNgramSize <= 1:
|
266 |
+
decodeOptions.pop("no_repeat_ngram_size")
|
267 |
+
|
268 |
+
# word_timestamps = options.get("word_timestamps", False)
|
269 |
+
# condition_on_previous_text = options.get("condition_on_previous_text", False)
|
270 |
+
|
271 |
+
# prepend_punctuations = options.get("prepend_punctuations", None)
|
272 |
+
# append_punctuations = options.get("append_punctuations", None)
|
273 |
+
# initial_prompt = options.get("initial_prompt", None)
|
274 |
+
# best_of = options.get("best_of", None)
|
275 |
+
# beam_size = options.get("beam_size", None)
|
276 |
+
# patience = options.get("patience", None)
|
277 |
+
# length_penalty = options.get("length_penalty", None)
|
278 |
+
# suppress_tokens = options.get("suppress_tokens", None)
|
279 |
+
# compression_ratio_threshold = options.get("compression_ratio_threshold", None)
|
280 |
+
# logprob_threshold = options.get("logprob_threshold", None)
|
281 |
+
|
282 |
+
vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
|
283 |
+
|
284 |
+
if diarization:
|
285 |
+
if diarization_speakers is not None and diarization_speakers < 1:
|
286 |
+
self.set_diarization(auth_token=self.app_config.auth_token, min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
|
287 |
+
else:
|
288 |
+
self.set_diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers, min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
|
289 |
else:
|
290 |
+
self.unset_diarization()
|
291 |
+
|
292 |
+
# Handle temperature_increment_on_fallback
|
293 |
+
if temperature is not None:
|
294 |
+
if temperature_increment_on_fallback is not None:
|
295 |
+
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
|
296 |
+
else:
|
297 |
+
temperature = [temperature]
|
298 |
+
decodeOptions["temperature"] = temperature
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
progress(0, desc="init audio sources")
|
301 |
+
|
302 |
+
if sourceInput == "urlData":
|
303 |
+
sources = self.__get_source(urlData, None, None)
|
304 |
+
elif sourceInput == "multipleFiles":
|
305 |
+
sources = self.__get_source(None, multipleFiles, None)
|
306 |
+
elif sourceInput == "microphoneData":
|
307 |
+
sources = self.__get_source(None, None, microphoneData)
|
308 |
+
|
309 |
if (len(sources) == 0):
|
310 |
raise Exception("init audio sources failed...")
|
311 |
+
|
312 |
try:
|
313 |
progress(0, desc="init whisper model")
|
314 |
+
whisperLang: TranslationLang = get_lang_from_whisper_name(whisperLangName)
|
315 |
+
whisperLangCode = whisperLang.whisper.code if whisperLang is not None and whisperLang.whisper is not None else None
|
316 |
+
selectedModel = whisperModelName if whisperModelName is not None else "base"
|
317 |
|
318 |
model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation,
|
319 |
model_name=selectedModel, compute_type=self.app_config.compute_type,
|
320 |
+
cache=self.model_cache, models=self.app_config.models["whisper"])
|
321 |
|
322 |
progress(0, desc="init translate model")
|
323 |
+
translationLang = None
|
324 |
+
translationModel = None
|
325 |
+
if translateInput == "m2m100" and m2m100LangName is not None and len(m2m100LangName) > 0:
|
326 |
+
selectedModelName = m2m100ModelName if m2m100ModelName is not None and len(m2m100ModelName) > 0 else "m2m100_418M/facebook"
|
327 |
+
selectedModel = next((modelConfig for modelConfig in self.app_config.models["m2m100"] if modelConfig.name == selectedModelName), None)
|
328 |
+
translationLang = get_lang_from_m2m100_name(m2m100LangName)
|
329 |
+
elif translateInput == "nllb" and nllbLangName is not None and len(nllbLangName) > 0:
|
330 |
+
selectedModelName = nllbModelName if nllbModelName is not None and len(nllbModelName) > 0 else "nllb-200-distilled-600M/facebook"
|
331 |
+
selectedModel = next((modelConfig for modelConfig in self.app_config.models["nllb"] if modelConfig.name == selectedModelName), None)
|
332 |
+
translationLang = get_lang_from_nllb_name(nllbLangName)
|
333 |
+
elif translateInput == "mt5" and mt5LangName is not None and len(mt5LangName) > 0:
|
334 |
+
selectedModelName = mt5ModelName if mt5ModelName is not None and len(mt5ModelName) > 0 else "mt5-zh-ja-en-trimmed/K024"
|
335 |
+
selectedModel = next((modelConfig for modelConfig in self.app_config.models["mt5"] if modelConfig.name == selectedModelName), None)
|
336 |
+
translationLang = get_lang_from_m2m100_name(mt5LangName)
|
337 |
+
|
338 |
+
if translationLang is not None:
|
339 |
+
translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
|
340 |
+
|
341 |
progress(0, desc="init transcribe")
|
342 |
# Result
|
343 |
download = []
|
|
|
348 |
# Write result
|
349 |
downloadDirectory = tempfile.mkdtemp()
|
350 |
source_index = 0
|
351 |
+
extra_tasks_count = 1 if translationLang is not None else 0
|
352 |
|
353 |
outputDirectory = self.output_dir if self.output_dir is not None else downloadDirectory
|
354 |
|
|
|
377 |
sub_task_total=sub_task_total)
|
378 |
|
379 |
# Transcribe
|
380 |
+
result = self.transcribe_file(model, source.source_path, whisperLangCode, task, vadOptions, scaled_progress_listener, **decodeOptions)
|
381 |
+
if whisperLang is None and result["language"] is not None and len(result["language"]) > 0:
|
382 |
+
whisperLang = get_lang_from_whisper_code(result["language"])
|
383 |
+
translationModel.whisperLang = whisperLang
|
384 |
|
385 |
short_name, suffix = source.get_short_name_suffix(max_length=self.app_config.input_max_file_name_length)
|
386 |
filePrefix = slugify(source_prefix + short_name, allow_unicode=True)
|
|
|
388 |
# Update progress
|
389 |
current_progress += source_audio_duration
|
390 |
|
391 |
+
source_download, source_text, source_vtt = self.write_result(result, whisperLang, translationModel, filePrefix + suffix.replace(".", "_"), outputDirectory, highlight_words, scaled_progress_listener)
|
392 |
|
393 |
if self.app_config.merge_subtitle_with_sources and self.app_config.output_dir is not None:
|
394 |
print("\nmerge subtitle(srt) with source file [" + source.source_name + "]\n")
|
|
|
397 |
srt_path = source_download[0]
|
398 |
save_path = os.path.join(self.app_config.output_dir, filePrefix)
|
399 |
# save_without_ext, ext = os.path.splitext(save_path)
|
400 |
+
source_lang = "." + whisperLang.whisper.code if whisperLang is not None and whisperLang.whisper is not None else ""
|
401 |
+
translate_lang = "." + translationLang.nllb.code if translationLang is not None else ""
|
402 |
output_with_srt = save_path + source_lang + translate_lang + suffix
|
403 |
|
404 |
#ffmpeg -i "input.mp4" -i "input.srt" -c copy -c:s mov_text output.mp4
|
|
|
473 |
except ExceededMaximumDuration as e:
|
474 |
return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
|
475 |
except Exception as e:
|
|
|
476 |
print(traceback.format_exc())
|
477 |
+
return [], ("Error occurred during transcribe: " + str(e)), traceback.format_exc()
|
478 |
|
479 |
|
480 |
+
def transcribe_file(self, model: AbstractWhisperContainer, audio_path: str, languageCode: str, task: str = None,
|
481 |
vadOptions: VadOptions = VadOptions(),
|
482 |
progressListener: ProgressListener = None, **decodeOptions: dict):
|
483 |
|
|
|
507 |
raise ValueError("Invalid vadInitialPromptMode: " + initial_prompt_mode)
|
508 |
|
509 |
# Callable for processing an audio file
|
510 |
+
whisperCallable = model.create_callback(languageCode, task, prompt_strategy=prompt_strategy, **decodeOptions)
|
511 |
|
512 |
# The results
|
513 |
if (vadOptions.vad == 'silero-vad'):
|
|
|
622 |
|
623 |
return config
|
624 |
|
625 |
+
def write_result(self, result: dict, whisperLang: TranslationLang, translationModel: TranslationModel, source_name: str, output_dir: str, highlight_words: bool = False, progressListener: ProgressListener = None):
|
626 |
if not os.path.exists(output_dir):
|
627 |
os.makedirs(output_dir)
|
628 |
|
|
|
631 |
language = result["language"]
|
632 |
languageMaxLineWidth = self.__get_max_line_width(language)
|
633 |
|
634 |
+
if translationModel is not None and translationModel.translationLang is not None:
|
635 |
try:
|
636 |
segments_progress_listener = SubTaskProgressListener(progressListener,
|
637 |
base_task_total=progressListener.sub_task_total,
|
|
|
639 |
sub_task_total=1)
|
640 |
pbar = tqdm.tqdm(total=len(segments))
|
641 |
perf_start_time = time.perf_counter()
|
642 |
+
translationModel.load_model()
|
643 |
for idx, segment in enumerate(segments):
|
644 |
seg_text = segment["text"]
|
645 |
+
segment["original"] = seg_text
|
646 |
+
segment["text"] = translationModel.translation(seg_text)
|
|
|
|
|
647 |
pbar.update(1)
|
648 |
segments_progress_listener.on_progress(idx+1, len(segments), desc=f"Process segments: {idx}/{len(segments)}")
|
649 |
|
650 |
+
translationModel.release_vram()
|
651 |
perf_end_time = time.perf_counter()
|
652 |
# Call the finished callback
|
653 |
if segments_progress_listener is not None:
|
|
|
656 |
print("\n\nprocess segments took {} seconds.\n\n".format(perf_end_time - perf_start_time))
|
657 |
except Exception as e:
|
658 |
# Ignore error - it's just a cleanup
|
659 |
+
print(traceback.format_exc())
|
660 |
print("Error process segments: " + str(e))
|
661 |
|
662 |
print("Max line width " + str(languageMaxLineWidth) + " for language:" + language)
|
663 |
vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth, highlight_words=highlight_words)
|
664 |
srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth, highlight_words=highlight_words)
|
665 |
json_result = json.dumps(result, indent=4, ensure_ascii=False)
|
666 |
+
srt_original = None
|
667 |
+
srt_bilingual = None
|
668 |
+
if translationModel is not None and translationModel.translationLang is not None:
|
669 |
+
srt_original = self.__get_subs(result["segments"], "srt_original", languageMaxLineWidth, highlight_words=highlight_words)
|
670 |
+
srt_bilingual = self.__get_subs(result["segments"], "srt_bilingual", languageMaxLineWidth, highlight_words=highlight_words)
|
671 |
+
|
672 |
+
whisperLangZho: bool = whisperLang is not None and whisperLang.nllb is not None and whisperLang.nllb.code in ["zho_Hant", "zho_Hans", "yue_Hant"]
|
673 |
+
translationZho: bool = translationModel is not None and translationModel.translationLang is not None and translationModel.translationLang.nllb is not None and translationModel.translationLang.nllb.code in ["zho_Hant", "zho_Hans", "yue_Hant"]
|
674 |
+
if whisperLangZho or translationZho:
|
675 |
+
locale = None
|
676 |
+
if whisperLangZho:
|
677 |
+
if whisperLang.nllb.code == "zho_Hant":
|
678 |
+
locale = "zh-tw"
|
679 |
+
elif whisperLang.nllb.code == "zho_Hans":
|
680 |
+
locale = "zh-cn"
|
681 |
+
elif whisperLang.nllb.code == "yue_Hant":
|
682 |
+
locale = "zh-hk"
|
683 |
+
if translationZho:
|
684 |
+
if translationModel.translationLang.nllb.code == "zho_Hant":
|
685 |
+
locale = "zh-tw"
|
686 |
+
elif translationModel.translationLang.nllb.code == "zho_Hans":
|
687 |
+
locale = "zh-cn"
|
688 |
+
elif translationModel.translationLang.nllb.code == "yue_Hant":
|
689 |
+
locale = "zh-hk"
|
690 |
+
if locale is not None:
|
691 |
+
vtt = zhconv.convert(vtt, locale)
|
692 |
+
srt = zhconv.convert(srt, locale)
|
693 |
+
text = zhconv.convert(text, locale)
|
694 |
+
json_result = zhconv.convert(json_result, locale)
|
695 |
+
if translationModel is not None and translationModel.translationLang is not None:
|
696 |
+
if srt_original is not None and len(srt_original) > 0:
|
697 |
+
srt_original = zhconv.convert(srt_original, locale)
|
698 |
+
if srt_bilingual is not None and len(srt_bilingual) > 0:
|
699 |
+
srt_bilingual = zhconv.convert(srt_bilingual, locale)
|
700 |
|
701 |
output_files = []
|
702 |
output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
|
703 |
output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
|
704 |
output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
|
705 |
output_files.append(self.__create_file(json_result, output_dir, source_name + "-result.json"));
|
706 |
+
if srt_original is not None and len(srt_original) > 0:
|
707 |
+
output_files.append(self.__create_file(srt_original, output_dir, source_name + "-original.srt"));
|
708 |
+
if srt_bilingual is not None and len(srt_bilingual) > 0:
|
709 |
+
output_files.append(self.__create_file(srt_bilingual, output_dir, source_name + "-bilingual.srt"));
|
710 |
|
711 |
return output_files, text, vtt
|
712 |
|
|
|
733 |
write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
|
734 |
elif format == 'srt':
|
735 |
write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
|
736 |
+
elif format == 'srt_original':
|
737 |
+
write_srt_original(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
|
738 |
+
elif format == 'srt_bilingual':
|
739 |
+
write_srt_original(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words, bilingual=True)
|
740 |
else:
|
741 |
raise Exception("Unknown format " + format)
|
742 |
|
|
|
765 |
self.diarization = None
|
766 |
|
767 |
def create_ui(app_config: ApplicationConfig):
|
768 |
+
optionsMd: str = None
|
769 |
+
readmeMd: str = None
|
770 |
+
try:
|
771 |
+
with open("docs\options.md", "r", encoding="utf-8") as optionsFile:
|
772 |
+
optionsMd = optionsFile.read()
|
773 |
+
with open("README.md", "r", encoding="utf-8") as readmeFile:
|
774 |
+
readmeMd = readmeFile.read()
|
775 |
+
except Exception as e:
|
776 |
+
print("Error occurred during read options.md file: ", str(e))
|
777 |
+
|
778 |
ui = WhisperTranscriber(app_config.input_audio_max_duration, app_config.vad_process_timeout, app_config.vad_cpu_cores,
|
779 |
app_config.delete_uploaded_files, app_config.output_dir, app_config)
|
780 |
|
|
|
793 |
# Try to convert from camel-case to title-case
|
794 |
implementation_name = app_config.whisper_implementation.title().replace("_", " ").replace("-", " ")
|
795 |
|
796 |
+
uiDescription = implementation_name + " is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
|
797 |
+
uiDescription += " audio and is also a multi-task model that can perform multilingual speech recognition "
|
798 |
+
uiDescription += " as well as speech translation and language identification. "
|
799 |
|
800 |
+
uiDescription += "\n\n\n\nFor longer audio files (>10 minutes) not in English, it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
|
801 |
|
802 |
# Recommend faster-whisper
|
803 |
if is_whisper:
|
804 |
+
uiDescription += "\n\n\n\nFor faster inference on GPU, try [faster-whisper](https://huggingface.co/spaces/aadnk/faster-whisper-webui)."
|
805 |
|
806 |
if app_config.input_audio_max_duration > 0:
|
807 |
+
uiDescription += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
|
808 |
+
|
809 |
+
uiArticle = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
|
810 |
+
uiArticle += "\n\nWhisper's Task 'translate' only implements the functionality of translating other languages into English. "
|
811 |
+
uiArticle += "OpenAI does not guarantee translations between arbitrary languages. In such cases, you can choose to use the NLLB Model to implement the translation task. "
|
812 |
+
uiArticle += "However, it's important to note that the NLLB Model runs slowly, and the completion time may be twice as long as usual. "
|
813 |
+
uiArticle += "\n\nThe larger the parameters of the NLLB model, the better its performance is expected to be. "
|
814 |
+
uiArticle += "However, it also requires higher computational resources, making it slower to operate. "
|
815 |
+
uiArticle += "On the other hand, the version converted from ct2 (CTranslate2) requires lower resources and operates at a faster speed."
|
816 |
+
uiArticle += "\n\nCurrently, enabling word-level timestamps cannot be used in conjunction with NLLB Model translation "
|
817 |
+
uiArticle += "because Word Timestamps will split the source text, and after translation, it becomes a non-word-level string. "
|
818 |
+
uiArticle += "\n\nThe 'mt5-zh-ja-en-trimmed' model is finetuned from Google's 'mt5-base' model. "
|
819 |
+
uiArticle += "This model has a relatively good translation speed, but it only supports three languages: Chinese, Japanese, and English. "
|
820 |
+
|
821 |
+
whisper_models = app_config.get_model_names("whisper")
|
822 |
+
nllb_models = app_config.get_model_names("nllb")
|
823 |
+
m2m100_models = app_config.get_model_names("m2m100")
|
824 |
+
mt5_models = app_config.get_model_names("mt5")
|
825 |
|
826 |
+
common_whisper_inputs = lambda : {
|
827 |
+
gr.Dropdown(label="Whisper - Model (for audio)", choices=whisper_models, value=app_config.default_model_name, elem_id="whisperModelName"),
|
828 |
+
gr.Dropdown(label="Whisper - Language", choices=sorted(get_lang_whisper_names()), value=app_config.language, elem_id="whisperLangName"),
|
829 |
+
}
|
830 |
+
common_m2m100_inputs = lambda : {
|
831 |
+
gr.Dropdown(label="M2M100 - Model (for translate)", choices=m2m100_models, elem_id="m2m100ModelName"),
|
832 |
+
gr.Dropdown(label="M2M100 - Language", choices=sorted(get_lang_m2m100_names()), elem_id="m2m100LangName"),
|
833 |
+
}
|
834 |
+
common_nllb_inputs = lambda : {
|
835 |
+
gr.Dropdown(label="NLLB - Model (for translate)", choices=nllb_models, elem_id="nllbModelName"),
|
836 |
+
gr.Dropdown(label="NLLB - Language", choices=sorted(get_lang_nllb_names()), elem_id="nllbLangName"),
|
837 |
+
}
|
838 |
+
common_mt5_inputs = lambda : {
|
839 |
+
gr.Dropdown(label="MT5 - Model (for translate)", choices=mt5_models, elem_id="mt5ModelName"),
|
840 |
+
gr.Dropdown(label="MT5 - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="mt5LangName"),
|
841 |
+
}
|
|
|
|
|
|
|
|
|
842 |
|
843 |
+
common_translation_inputs = lambda : {
|
844 |
+
gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
|
845 |
+
gr.Number(label="Translation - No Repeat Ngram Size", precision=0, value=app_config.translation_no_repeat_ngram_size, elem_id="translationNoRepeatNgramSize"),
|
846 |
+
gr.Number(label="Translation - Num Beams", precision=0, value=app_config.translation_num_beams, elem_id="translationNumBeams")
|
847 |
+
}
|
848 |
+
|
849 |
+
common_vad_inputs = lambda : {
|
850 |
+
gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD", elem_id="vad"),
|
851 |
+
gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window, elem_id="vadMergeWindow"),
|
852 |
+
gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size, elem_id="vadMaxMergeSize"),
|
853 |
+
}
|
854 |
+
|
855 |
+
common_word_timestamps_inputs = lambda : {
|
856 |
+
gr.Checkbox(label="Word Timestamps", value=app_config.word_timestamps, elem_id="word_timestamps"),
|
857 |
+
gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words, elem_id="highlight_words"),
|
858 |
+
}
|
859 |
|
860 |
has_diarization_libs = Diarization.has_libraries()
|
861 |
|
|
|
863 |
print("Diarization libraries not found - disabling diarization")
|
864 |
app_config.diarization = False
|
865 |
|
866 |
+
common_diarization_inputs = lambda : {
|
867 |
+
gr.Checkbox(label="Diarization", value=app_config.diarization, interactive=has_diarization_libs, elem_id="diarization"),
|
868 |
+
gr.Number(label="Diarization - Speakers", precision=0, value=app_config.diarization_speakers, interactive=has_diarization_libs, elem_id="diarization_speakers"),
|
869 |
+
gr.Number(label="Diarization - Min Speakers", precision=0, value=app_config.diarization_min_speakers, interactive=has_diarization_libs, elem_id="diarization_min_speakers"),
|
870 |
+
gr.Number(label="Diarization - Max Speakers", precision=0, value=app_config.diarization_max_speakers, interactive=has_diarization_libs, elem_id="diarization_max_speakers")
|
871 |
+
}
|
872 |
|
873 |
common_output = lambda : [
|
874 |
gr.File(label="Download"),
|
|
|
878 |
|
879 |
is_queue_mode = app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0
|
880 |
|
881 |
+
simpleInputDict = {}
|
882 |
+
|
883 |
+
with gr.Blocks() as simpleTranscribe:
|
884 |
+
simpleTranslateInput = gr.State(value="m2m100", elem_id = "translateInput")
|
885 |
+
simpleSourceInput = gr.State(value="urlData", elem_id = "sourceInput")
|
886 |
+
gr.Markdown(uiDescription)
|
887 |
with gr.Row():
|
888 |
with gr.Column():
|
889 |
+
simpleSubmit = gr.Button("Submit", variant="primary")
|
890 |
with gr.Column():
|
891 |
with gr.Row():
|
892 |
+
simpleInputDict = common_whisper_inputs()
|
893 |
+
with gr.Tab(label="M2M100") as simpleM2M100Tab:
|
894 |
+
with gr.Row():
|
895 |
+
simpleInputDict.update(common_m2m100_inputs())
|
896 |
+
with gr.Tab(label="NLLB") as simpleNllbTab:
|
897 |
+
with gr.Row():
|
898 |
+
simpleInputDict.update(common_nllb_inputs())
|
899 |
+
with gr.Tab(label="MT5") as simpleMT5Tab:
|
900 |
+
with gr.Row():
|
901 |
+
simpleInputDict.update(common_mt5_inputs())
|
902 |
+
simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
|
903 |
+
simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
|
904 |
+
simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
|
905 |
with gr.Column():
|
906 |
+
with gr.Tab(label="URL") as simpleUrlTab:
|
907 |
+
simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
908 |
+
with gr.Tab(label="Upload") as simpleUploadTab:
|
909 |
+
simpleInputDict.update({gr.File(label="Upload Files", file_count="multiple", elem_id = "multipleFiles")})
|
910 |
+
with gr.Tab(label="Microphone") as simpleMicTab:
|
911 |
+
simpleInputDict.update({gr.Audio(source="microphone", type="filepath", label="Microphone Input", elem_id = "microphoneData")})
|
912 |
+
simpleUrlTab.select(fn=lambda: "urlData", inputs = [], outputs= [simpleSourceInput] )
|
913 |
+
simpleUploadTab.select(fn=lambda: "multipleFiles", inputs = [], outputs= [simpleSourceInput] )
|
914 |
+
simpleMicTab.select(fn=lambda: "microphoneData", inputs = [], outputs= [simpleSourceInput] )
|
915 |
+
simpleInputDict.update({gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task, elem_id = "task")})
|
916 |
+
with gr.Accordion("VAD options", open=False):
|
917 |
+
simpleInputDict.update(common_vad_inputs())
|
918 |
+
with gr.Accordion("Word Timestamps options", open=False):
|
919 |
+
simpleInputDict.update(common_word_timestamps_inputs())
|
920 |
+
with gr.Accordion("Diarization options", open=False):
|
921 |
+
simpleInputDict.update(common_diarization_inputs())
|
922 |
+
with gr.Accordion("Translation options", open=False):
|
923 |
+
simpleInputDict.update(common_translation_inputs())
|
924 |
with gr.Column():
|
925 |
+
simpleOutput = common_output()
|
926 |
+
with gr.Accordion("Article"):
|
927 |
+
gr.Markdown(uiArticle)
|
928 |
+
if optionsMd is not None:
|
929 |
+
with gr.Accordion("docs/options.md", open=False):
|
930 |
+
gr.Markdown(optionsMd)
|
931 |
+
if readmeMd is not None:
|
932 |
+
with gr.Accordion("README.md", open=False):
|
933 |
+
gr.Markdown(readmeMd)
|
934 |
+
|
935 |
+
simpleInputDict.update({simpleTranslateInput, simpleSourceInput})
|
936 |
+
simpleSubmit.click(fn=ui.transcribe_webui_simple_progress if is_queue_mode else ui.transcribe_webui_simple,
|
937 |
+
inputs=simpleInputDict, outputs=simpleOutput)
|
938 |
|
939 |
+
fullInputDict = {}
|
940 |
+
fullDescription = uiDescription + "\n\n\n\n" + "Be careful when changing some of the options in the full interface - this can cause the model to crash."
|
941 |
|
942 |
+
with gr.Blocks() as fullTranscribe:
|
943 |
+
fullTranslateInput = gr.State(value="m2m100", elem_id = "translateInput")
|
944 |
+
fullSourceInput = gr.State(value="urlData", elem_id = "sourceInput")
|
945 |
+
gr.Markdown(fullDescription)
|
946 |
with gr.Row():
|
947 |
with gr.Column():
|
948 |
+
fullSubmit = gr.Button("Submit", variant="primary")
|
949 |
with gr.Column():
|
950 |
with gr.Row():
|
951 |
+
fullInputDict = common_whisper_inputs()
|
952 |
+
with gr.Tab(label="M2M100") as fullM2M100Tab:
|
953 |
+
with gr.Row():
|
954 |
+
fullInputDict.update(common_m2m100_inputs())
|
955 |
+
with gr.Tab(label="NLLB") as fullNllbTab:
|
956 |
+
with gr.Row():
|
957 |
+
fullInputDict.update(common_nllb_inputs())
|
958 |
+
with gr.Tab(label="MT5") as fullMT5Tab:
|
959 |
+
with gr.Row():
|
960 |
+
fullInputDict.update(common_mt5_inputs())
|
961 |
+
fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
|
962 |
+
fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
|
963 |
+
fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
|
964 |
with gr.Column():
|
965 |
+
with gr.Tab(label="URL") as fullUrlTab:
|
966 |
+
fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
967 |
+
with gr.Tab(label="Upload") as fullUploadTab:
|
968 |
+
fullInputDict.update({gr.File(label="Upload Files", file_count="multiple", elem_id = "multipleFiles")})
|
969 |
+
with gr.Tab(label="Microphone") as fullMicTab:
|
970 |
+
fullInputDict.update({gr.Audio(source="microphone", type="filepath", label="Microphone Input", elem_id = "microphoneData")})
|
971 |
+
fullUrlTab.select(fn=lambda: "urlData", inputs = [], outputs= [fullSourceInput] )
|
972 |
+
fullUploadTab.select(fn=lambda: "multipleFiles", inputs = [], outputs= [fullSourceInput] )
|
973 |
+
fullMicTab.select(fn=lambda: "microphoneData", inputs = [], outputs= [fullSourceInput] )
|
974 |
+
fullInputDict.update({gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task, elem_id = "task")})
|
975 |
+
with gr.Accordion("VAD options", open=False):
|
976 |
+
fullInputDict.update(common_vad_inputs())
|
977 |
+
fullInputDict.update({
|
978 |
+
gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding, elem_id = "vadPadding"),
|
979 |
+
gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window, elem_id = "vadPromptWindow"),
|
980 |
+
gr.Dropdown(choices=VAD_INITIAL_PROMPT_MODE_VALUES, label="VAD - Initial Prompt Mode", value=app_config.vad_initial_prompt_mode, elem_id = "vadInitialPromptMode")})
|
981 |
+
with gr.Accordion("Word Timestamps options", open=False):
|
982 |
+
fullInputDict.update(common_word_timestamps_inputs())
|
983 |
+
fullInputDict.update({
|
984 |
+
gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations, elem_id = "prepend_punctuations"),
|
985 |
+
gr.Text(label="Word Timestamps - Append Punctuations", value=app_config.append_punctuations, elem_id = "append_punctuations")})
|
986 |
+
with gr.Accordion("Whisper Advanced options", open=False):
|
987 |
+
fullInputDict.update({
|
988 |
+
gr.TextArea(label="Initial Prompt", elem_id = "initial_prompt"),
|
989 |
+
gr.Number(label="Temperature", value=app_config.temperature, elem_id = "temperature"),
|
990 |
+
gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0, elem_id = "best_of"),
|
991 |
+
gr.Number(label="Beam Size - Zero temperature", value=app_config.beam_size, precision=0, elem_id = "beam_size"),
|
992 |
+
gr.Number(label="Patience - Zero temperature", value=app_config.patience, elem_id = "patience"),
|
993 |
+
gr.Number(label="Length Penalty - Any temperature", value=app_config.length_penalty, elem_id = "length_penalty"),
|
994 |
+
gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value=app_config.suppress_tokens, elem_id = "suppress_tokens"),
|
995 |
+
gr.Checkbox(label="Condition on previous text", value=app_config.condition_on_previous_text, elem_id = "condition_on_previous_text"),
|
996 |
+
gr.Checkbox(label="FP16", value=app_config.fp16, elem_id = "fp16"),
|
997 |
+
gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback, elem_id = "temperature_increment_on_fallback"),
|
998 |
+
gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold, elem_id = "compression_ratio_threshold"),
|
999 |
+
gr.Number(label="Logprob threshold", value=app_config.logprob_threshold, elem_id = "logprob_threshold"),
|
1000 |
+
gr.Number(label="No speech threshold", value=app_config.no_speech_threshold, elem_id = "no_speech_threshold"),
|
1001 |
+
})
|
1002 |
+
if app_config.whisper_implementation == "faster-whisper":
|
1003 |
+
fullInputDict.update({
|
1004 |
+
gr.Number(label="Repetition Penalty", value=app_config.repetition_penalty, elem_id = "repetition_penalty"),
|
1005 |
+
gr.Number(label="No Repeat Ngram Size", value=app_config.no_repeat_ngram_size, precision=0, elem_id = "no_repeat_ngram_size")
|
1006 |
+
})
|
1007 |
+
with gr.Accordion("Diarization options", open=False):
|
1008 |
+
fullInputDict.update(common_diarization_inputs())
|
1009 |
+
with gr.Accordion("Translation options", open=False):
|
1010 |
+
fullInputDict.update(common_translation_inputs())
|
1011 |
with gr.Column():
|
1012 |
+
fullOutput = common_output()
|
1013 |
+
with gr.Accordion("Article"):
|
1014 |
+
gr.Markdown(uiArticle)
|
1015 |
+
if optionsMd is not None:
|
1016 |
+
with gr.Accordion("docs/options.md", open=False):
|
1017 |
+
gr.Markdown(optionsMd)
|
1018 |
+
if readmeMd is not None:
|
1019 |
+
with gr.Accordion("README.md", open=False):
|
1020 |
+
gr.Markdown(readmeMd)
|
1021 |
+
|
1022 |
+
fullInputDict.update({fullTranslateInput, fullSourceInput})
|
1023 |
+
fullSubmit.click(fn=ui.transcribe_webui_full_progress if is_queue_mode else ui.transcribe_webui_full,
|
1024 |
+
inputs=fullInputDict, outputs=fullOutput)
|
1025 |
|
1026 |
+
demo = gr.TabbedInterface([simpleTranscribe, fullTranscribe], tab_names=["Simple", "Full"])
|
1027 |
|
1028 |
# Queue up the demo
|
1029 |
if is_queue_mode:
|
|
|
1039 |
|
1040 |
if __name__ == '__main__':
|
1041 |
default_app_config = ApplicationConfig.create_default()
|
1042 |
+
whisper_models = default_app_config.get_model_names("whisper")
|
|
|
1043 |
|
1044 |
# Environment variable overrides
|
1045 |
default_whisper_implementation = os.environ.get("WHISPER_IMPLEMENTATION", default_app_config.whisper_implementation)
|
|
|
1077 |
help="the compute type to use for inference")
|
1078 |
parser.add_argument("--threads", type=optional_int, default=0,
|
1079 |
help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
1080 |
+
|
1081 |
parser.add_argument("--vad_max_merge_size", type=int, default=default_app_config.vad_max_merge_size, \
|
1082 |
help="The number of VAD - Max Merge Size (s).") # 30
|
1083 |
+
parser.add_argument("--language", type=str, default=None, choices=sorted(get_lang_whisper_names()) + sorted([k.title() for k in _TO_LANG_CODE_WHISPER.keys()]),
|
1084 |
help="language spoken in the audio, specify None to perform language detection")
|
1085 |
parser.add_argument("--save_downloaded_files", action='store_true', \
|
1086 |
help="True to move downloaded files to outputs directory. This argument will take effect only after output_dir is set.")
|
|
|
1090 |
help="Maximum length of a file name.")
|
1091 |
parser.add_argument("--autolaunch", action='store_true', \
|
1092 |
help="open the webui URL in the system's default browser upon launch")
|
1093 |
+
|
1094 |
parser.add_argument('--auth_token', type=str, default=default_app_config.auth_token, help='HuggingFace API Token (optional)')
|
1095 |
parser.add_argument("--diarization", type=str2bool, default=default_app_config.diarization, \
|
1096 |
help="whether to perform speaker diarization")
|
@@ -10,7 +10,7 @@ from app import VadOptions, WhisperTranscriber
|
|
10 |
from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
|
11 |
from src.diarization.diarization import Diarization
|
12 |
from src.download import download_url
|
13 |
-
from src.languages import get_language_names
|
14 |
|
15 |
from src.utils import optional_float, optional_int, str2bool
|
16 |
from src.whisper.whisperFactory import create_whisper_container
|
@@ -43,7 +43,7 @@ def cli():
|
|
43 |
|
44 |
parser.add_argument("--task", type=str, default=app_config.task, choices=["transcribe", "translate"], \
|
45 |
help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
46 |
-
parser.add_argument("--language", type=str, default=app_config.language, choices=sorted(
|
47 |
help="language spoken in the audio, specify None to perform language detection")
|
48 |
|
49 |
parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
|
|
|
10 |
from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
|
11 |
from src.diarization.diarization import Diarization
|
12 |
from src.download import download_url
|
13 |
+
from src.translation.translationLangs import get_lang_whisper_names # from src.languages import get_language_names
|
14 |
|
15 |
from src.utils import optional_float, optional_int, str2bool
|
16 |
from src.whisper.whisperFactory import create_whisper_container
|
|
|
43 |
|
44 |
parser.add_argument("--task", type=str, default=app_config.task, choices=["transcribe", "translate"], \
|
45 |
help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
46 |
+
parser.add_argument("--language", type=str, default=app_config.language, choices=sorted(get_lang_whisper_names()), \
|
47 |
help="language spoken in the audio, specify None to perform language detection")
|
48 |
|
49 |
parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
|
@@ -1,254 +1,300 @@
|
|
1 |
{
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
49 |
],
|
50 |
-
"
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
"type": "huggingface"
|
80 |
-
},
|
81 |
-
{
|
82 |
-
"name": "mt5-zh-ja-en-trimmed/K024",
|
83 |
-
"url": "K024/mt5-zh-ja-en-trimmed",
|
84 |
-
"type": "huggingface"
|
85 |
-
},
|
86 |
-
{
|
87 |
-
"name": "mt5-zh-ja-en-trimmed-fine-tuned-v1/engmatic-earth",
|
88 |
-
"url": "engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1",
|
89 |
-
"type": "huggingface"
|
90 |
-
},
|
91 |
-
{
|
92 |
-
"name": "nllb-200-distilled-600M/facebook",
|
93 |
-
"url": "facebook/nllb-200-distilled-600M",
|
94 |
-
"type": "huggingface"
|
95 |
-
},
|
96 |
-
{
|
97 |
-
"name": "nllb-200-distilled-600M-ct2/JustFrederik",
|
98 |
-
"url": "JustFrederik/nllb-200-distilled-600M-ct2",
|
99 |
-
"type": "huggingface"
|
100 |
-
},
|
101 |
-
{
|
102 |
-
"name": "nllb-200-distilled-600M-ct2:float16/JustFrederik",
|
103 |
-
"url": "JustFrederik/nllb-200-distilled-600M-ct2-float16",
|
104 |
-
"type": "huggingface"
|
105 |
-
},
|
106 |
-
{
|
107 |
-
"name": "nllb-200-distilled-600M-ct2:int8/JustFrederik",
|
108 |
-
"url": "JustFrederik/nllb-200-distilled-600M-ct2-int8",
|
109 |
-
"type": "huggingface"
|
110 |
-
},
|
111 |
-
// Uncomment to add official Facebook 1.3B and 3.3B model
|
112 |
-
// The official Facebook 1.3B and 3.3B model files are too large,
|
113 |
-
// and to avoid occupying too much disk space on Hugging Face's free spaces,
|
114 |
-
// these models are not included in the config.
|
115 |
-
//{
|
116 |
-
// "name": "nllb-200-distilled-1.3B/facebook",
|
117 |
-
// "url": "facebook/nllb-200-distilled-1.3B",
|
118 |
-
// "type": "huggingface"
|
119 |
-
//},
|
120 |
-
//{
|
121 |
-
// "name": "nllb-200-1.3B/facebook",
|
122 |
-
// "url": "facebook/nllb-200-1.3B",
|
123 |
-
// "type": "huggingface"
|
124 |
-
//},
|
125 |
-
//{
|
126 |
-
// "name": "nllb-200-3.3B/facebook",
|
127 |
-
// "url": "facebook/nllb-200-3.3B",
|
128 |
-
// "type": "huggingface"
|
129 |
-
//},
|
130 |
-
//{
|
131 |
-
// "name": "nllb-200-distilled-1.3B-ct2/JustFrederik",
|
132 |
-
// "url": "JustFrederik/nllb-200-distilled-1.3B-ct2",
|
133 |
-
// "type": "huggingface"
|
134 |
-
//},
|
135 |
-
//{
|
136 |
-
// "name": "nllb-200-1.3B-ct2/JustFrederik",
|
137 |
-
// "url": "JustFrederik/nllb-200-1.3B-ct2",
|
138 |
-
// "type": "huggingface"
|
139 |
-
//},
|
140 |
-
//{
|
141 |
-
// "name": "nllb-200-3.3B-ct2:float16/JustFrederik",
|
142 |
-
// "url": "JustFrederik/nllb-200-3.3B-ct2-float16",
|
143 |
-
// "type": "huggingface"
|
144 |
-
//},
|
145 |
],
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
-
|
149 |
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
|
163 |
-
|
164 |
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
}
|
|
|
1 |
{
|
2 |
+
"models": {
|
3 |
+
"whisper": [
|
4 |
+
// Configuration for the built-in models. You can remove any of these
|
5 |
+
// if you don't want to use the default models.
|
6 |
+
{
|
7 |
+
"name": "tiny",
|
8 |
+
"url": "tiny"
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"name": "base",
|
12 |
+
"url": "base"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"name": "small",
|
16 |
+
"url": "small"
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"name": "medium",
|
20 |
+
"url": "medium"
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"name": "large",
|
24 |
+
"url": "large"
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"name": "large-v2",
|
28 |
+
"url": "large-v2"
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"name": "large-v3",
|
32 |
+
"url": "large-v3"
|
33 |
+
}
|
34 |
+
// Uncomment to add custom Japanese models
|
35 |
+
//{
|
36 |
+
// "name": "whisper-large-v2-mix-jp",
|
37 |
+
// "url": "vumichien/whisper-large-v2-mix-jp",
|
38 |
+
// // The type of the model. Can be "huggingface" or "whisper" - "whisper" is the default.
|
39 |
+
// // HuggingFace models are loaded using the HuggingFace transformers library and then converted to Whisper models.
|
40 |
+
// "type": "huggingface",
|
41 |
+
//},
|
42 |
+
//{
|
43 |
+
// "name": "local-model",
|
44 |
+
// "url": "path/to/local/model",
|
45 |
+
//},
|
46 |
+
//{
|
47 |
+
// "name": "remote-model",
|
48 |
+
// "url": "https://example.com/path/to/model",
|
49 |
+
//}
|
50 |
],
|
51 |
+
"m2m100": [
|
52 |
+
{
|
53 |
+
"name": "m2m100_1.2B-ct2fast/michaelfeil",
|
54 |
+
"url": "michaelfeil/ct2fast-m2m100_1.2B",
|
55 |
+
"type": "huggingface",
|
56 |
+
"tokenizer_url": "facebook/m2m100_1.2B"
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"name": "m2m100_418M-ct2fast/michaelfeil",
|
60 |
+
"url": "michaelfeil/ct2fast-m2m100_418M",
|
61 |
+
"type": "huggingface",
|
62 |
+
"tokenizer_url": "facebook/m2m100_418M"
|
63 |
+
},
|
64 |
+
//{
|
65 |
+
// "name": "m2m100-12B-ct2fast/michaelfeil",
|
66 |
+
// "url": "michaelfeil/ct2fast-m2m100-12B-last-ckpt",
|
67 |
+
// "type": "huggingface",
|
68 |
+
// "tokenizer_url": "facebook/m2m100-12B-last-ckpt"
|
69 |
+
//},
|
70 |
+
{
|
71 |
+
"name": "m2m100_1.2B/facebook",
|
72 |
+
"url": "facebook/m2m100_1.2B",
|
73 |
+
"type": "huggingface"
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"name": "m2m100_418M/facebook",
|
77 |
+
"url": "facebook/m2m100_418M",
|
78 |
+
"type": "huggingface"
|
79 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
],
|
81 |
+
"nllb": [
|
82 |
+
{
|
83 |
+
"name": "nllb-200-distilled-1.3B-ct2fast:int8_float16/michaelfeil",
|
84 |
+
"url": "michaelfeil/ct2fast-nllb-200-distilled-1.3B",
|
85 |
+
"type": "huggingface",
|
86 |
+
"tokenizer_url": "facebook/nllb-200-distilled-1.3B"
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"name": "nllb-200-3.3B-ct2fast:int8_float16/michaelfeil",
|
90 |
+
"url": "michaelfeil/ct2fast-nllb-200-3.3B",
|
91 |
+
"type": "huggingface",
|
92 |
+
"tokenizer_url": "facebook/nllb-200-3.3B"
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"name": "nllb-200-1.3B-ct2:float16/JustFrederik",
|
96 |
+
"url": "JustFrederik/nllb-200-1.3B-ct2-float16",
|
97 |
+
"type": "huggingface",
|
98 |
+
"tokenizer_url": "facebook/nllb-200-1.3B"
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"name": "nllb-200-distilled-1.3B-ct2:float16/JustFrederik",
|
102 |
+
"url": "JustFrederik/nllb-200-distilled-1.3B-ct2-float16",
|
103 |
+
"type": "huggingface",
|
104 |
+
"tokenizer_url": "facebook/nllb-200-distilled-1.3B"
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"name": "nllb-200-1.3B-ct2:int8/JustFrederik",
|
108 |
+
"url": "JustFrederik/nllb-200-1.3B-ct2-int8",
|
109 |
+
"type": "huggingface",
|
110 |
+
"tokenizer_url": "facebook/nllb-200-1.3B"
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"name": "nllb-200-distilled-1.3B-ct2:int8/JustFrederik",
|
114 |
+
"url": "JustFrederik/nllb-200-distilled-1.3B-ct2-int8",
|
115 |
+
"type": "huggingface",
|
116 |
+
"tokenizer_url": "facebook/nllb-200-distilled-1.3B"
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"name": "nllb-200-distilled-600M/facebook",
|
120 |
+
"url": "facebook/nllb-200-distilled-600M",
|
121 |
+
"type": "huggingface"
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"name": "nllb-200-distilled-600M-ct2/JustFrederik",
|
125 |
+
"url": "JustFrederik/nllb-200-distilled-600M-ct2",
|
126 |
+
"type": "huggingface",
|
127 |
+
"tokenizer_url": "facebook/nllb-200-distilled-600M"
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"name": "nllb-200-distilled-600M-ct2:float16/JustFrederik",
|
131 |
+
"url": "JustFrederik/nllb-200-distilled-600M-ct2-float16",
|
132 |
+
"type": "huggingface",
|
133 |
+
"tokenizer_url": "facebook/nllb-200-distilled-600M"
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"name": "nllb-200-distilled-600M-ct2:int8/JustFrederik",
|
137 |
+
"url": "JustFrederik/nllb-200-distilled-600M-ct2-int8",
|
138 |
+
"type": "huggingface",
|
139 |
+
"tokenizer_url": "facebook/nllb-200-distilled-600M"
|
140 |
+
}
|
141 |
+
// Uncomment to add official Facebook 1.3B and 3.3B model
|
142 |
+
// The official Facebook 1.3B and 3.3B model files are too large,
|
143 |
+
// and to avoid occupying too much disk space on Hugging Face's free spaces,
|
144 |
+
// these models are not included in the config.
|
145 |
+
//{
|
146 |
+
// "name": "nllb-200-distilled-1.3B/facebook",
|
147 |
+
// "url": "facebook/nllb-200-distilled-1.3B",
|
148 |
+
// "type": "huggingface"
|
149 |
+
//},
|
150 |
+
//{
|
151 |
+
// "name": "nllb-200-1.3B/facebook",
|
152 |
+
// "url": "facebook/nllb-200-1.3B",
|
153 |
+
// "type": "huggingface"
|
154 |
+
//},
|
155 |
+
//{
|
156 |
+
// "name": "nllb-200-3.3B/facebook",
|
157 |
+
// "url": "facebook/nllb-200-3.3B",
|
158 |
+
// "type": "huggingface"
|
159 |
+
//},
|
160 |
+
//{
|
161 |
+
// "name": "nllb-200-distilled-1.3B-ct2/JustFrederik",
|
162 |
+
// "url": "JustFrederik/nllb-200-distilled-1.3B-ct2",
|
163 |
+
// "type": "huggingface",
|
164 |
+
// "tokenizer_url": "facebook/nllb-200-distilled-1.3B"
|
165 |
+
//},
|
166 |
+
//{
|
167 |
+
// "name": "nllb-200-1.3B-ct2/JustFrederik",
|
168 |
+
// "url": "JustFrederik/nllb-200-1.3B-ct2",
|
169 |
+
// "type": "huggingface",
|
170 |
+
// "tokenizer_url": "facebook/nllb-200-1.3B"
|
171 |
+
//},
|
172 |
+
//{
|
173 |
+
// "name": "nllb-200-3.3B-ct2:float16/JustFrederik",
|
174 |
+
// "url": "JustFrederik/nllb-200-3.3B-ct2-float16",
|
175 |
+
// "type": "huggingface",
|
176 |
+
// "tokenizer_url": "facebook/nllb-200-3.3B"
|
177 |
+
//},
|
178 |
+
],
|
179 |
+
"mt5": [
|
180 |
+
{
|
181 |
+
"name": "mt5-zh-ja-en-trimmed/K024",
|
182 |
+
"url": "K024/mt5-zh-ja-en-trimmed",
|
183 |
+
"type": "huggingface"
|
184 |
+
},
|
185 |
+
{
|
186 |
+
"name": "mt5-zh-ja-en-trimmed-fine-tuned-v1/engmatic-earth",
|
187 |
+
"url": "engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1",
|
188 |
+
"type": "huggingface"
|
189 |
+
}
|
190 |
+
]
|
191 |
+
},
|
192 |
+
// Configuration options that will be used if they are not specified in the command line arguments.
|
193 |
|
194 |
+
// * WEBUI options *
|
195 |
|
196 |
+
// Maximum audio file length in seconds, or -1 for no limit. Ignored by CLI.
|
197 |
+
"input_audio_max_duration": 1800,
|
198 |
+
// True to share the app on HuggingFace.
|
199 |
+
"share": false,
|
200 |
+
// The host or IP to bind to. If None, bind to localhost.
|
201 |
+
"server_name": null,
|
202 |
+
// The port to bind to.
|
203 |
+
"server_port": 7860,
|
204 |
+
// The number of workers to use for the web server. Use -1 to disable queueing.
|
205 |
+
"queue_concurrency_count": 1,
|
206 |
+
// Whether or not to automatically delete all uploaded files, to save disk space
|
207 |
+
"delete_uploaded_files": true,
|
208 |
|
209 |
+
// * General options *
|
210 |
|
211 |
+
// The default implementation to use for Whisper. Can be "whisper" or "faster-whisper".
|
212 |
+
// Note that you must either install the requirements for faster-whisper (requirements-fasterWhisper.txt)
|
213 |
+
// or whisper (requirements.txt)
|
214 |
+
"whisper_implementation": "faster-whisper",
|
215 |
|
216 |
+
// The default model name.
|
217 |
+
"default_model_name": "large-v2",
|
218 |
+
// The default VAD.
|
219 |
+
"default_vad": "silero-vad",
|
220 |
+
// A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.
|
221 |
+
"vad_parallel_devices": "",
|
222 |
+
// The number of CPU cores to use for VAD pre-processing.
|
223 |
+
"vad_cpu_cores": 1,
|
224 |
+
// The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.
|
225 |
+
"vad_process_timeout": 1800,
|
226 |
+
// True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.
|
227 |
+
"auto_parallel": false,
|
228 |
+
// Directory to save the outputs (CLI will use the current directory if not specified)
|
229 |
+
"output_dir": null,
|
230 |
+
// The path to save model files; uses ~/.cache/whisper by default
|
231 |
+
"model_dir": null,
|
232 |
+
// Device to use for PyTorch inference, or Null to use the default device
|
233 |
+
"device": null,
|
234 |
+
// Whether to print out the progress and debug messages
|
235 |
+
"verbose": true,
|
236 |
+
// Whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')
|
237 |
+
"task": "transcribe",
|
238 |
+
// Language spoken in the audio, specify None to perform language detection
|
239 |
+
"language": null,
|
240 |
+
// The window size (in seconds) to merge voice segments
|
241 |
+
"vad_merge_window": 5,
|
242 |
+
// The maximum size (in seconds) of a voice segment
|
243 |
+
"vad_max_merge_size": 90,
|
244 |
+
// The padding (in seconds) to add to each voice segment
|
245 |
+
"vad_padding": 1,
|
246 |
+
// Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)
|
247 |
+
"vad_initial_prompt_mode": "prepend_first_segment",
|
248 |
+
// The window size of the prompt to pass to Whisper
|
249 |
+
"vad_prompt_window": 3,
|
250 |
+
// Temperature to use for sampling
|
251 |
+
"temperature": 0,
|
252 |
+
// Number of candidates when sampling with non-zero temperature
|
253 |
+
"best_of": 5,
|
254 |
+
// Number of beams in beam search, only applicable when temperature is zero
|
255 |
+
"beam_size": 5,
|
256 |
+
// Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search
|
257 |
+
"patience": 1,
|
258 |
+
// Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default
|
259 |
+
"length_penalty": null,
|
260 |
+
// Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations
|
261 |
+
"suppress_tokens": "-1",
|
262 |
+
// Optional text to provide as a prompt for the first window
|
263 |
+
"initial_prompt": null,
|
264 |
+
// If True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop
|
265 |
+
"condition_on_previous_text": true,
|
266 |
+
// Whether to perform inference in fp16; True by default
|
267 |
+
"fp16": true,
|
268 |
+
// The compute type used by faster-whisper. Can be "int8". "int16" or "float16".
|
269 |
+
"compute_type": "auto",
|
270 |
+
// Temperature to increase when falling back when the decoding fails to meet either of the thresholds below
|
271 |
+
"temperature_increment_on_fallback": 0.2,
|
272 |
+
// If the gzip compression ratio is higher than this value, treat the decoding as failed
|
273 |
+
"compression_ratio_threshold": 2.4,
|
274 |
+
// If the average log probability is lower than this value, treat the decoding as failed
|
275 |
+
"logprob_threshold": -1.0,
|
276 |
+
// If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
|
277 |
+
"no_speech_threshold": 0.6,
|
278 |
|
279 |
+
// (experimental) extract word-level timestamps and refine the results based on them
|
280 |
+
"word_timestamps": false,
|
281 |
+
// if word_timestamps is True, merge these punctuation symbols with the next word
|
282 |
+
"prepend_punctuations": "\"\'“¿([{-",
|
283 |
+
// if word_timestamps is True, merge these punctuation symbols with the previous word
|
284 |
+
"append_punctuations": "\"\'.。,,!!??::”)]}、",
|
285 |
+
// (requires --word_timestamps True) underline each word as it is spoken in srt and vtt
|
286 |
+
"highlight_words": false,
|
287 |
|
288 |
+
// Diarization settings
|
289 |
+
"auth_token": null,
|
290 |
+
// Whether to perform speaker diarization
|
291 |
+
"diarization": false,
|
292 |
+
// The number of speakers to detect
|
293 |
+
"diarization_speakers": 2,
|
294 |
+
// The minimum number of speakers to detect
|
295 |
+
"diarization_min_speakers": 1,
|
296 |
+
// The maximum number of speakers to detect
|
297 |
+
"diarization_max_speakers": 8,
|
298 |
+
// The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.
|
299 |
+
"diarization_process_timeout": 60,
|
300 |
}
|
@@ -1,6 +1,5 @@
|
|
1 |
git+https://github.com/huggingface/transformers
|
2 |
git+https://github.com/openai/whisper.git
|
3 |
-
transformers
|
4 |
ffmpeg-python==0.2.0
|
5 |
gradio==3.50.2
|
6 |
yt-dlp
|
|
|
1 |
git+https://github.com/huggingface/transformers
|
2 |
git+https://github.com/openai/whisper.git
|
|
|
3 |
ffmpeg-python==0.2.0
|
4 |
gradio==3.50.2
|
5 |
yt-dlp
|
@@ -1,16 +1,11 @@
|
|
1 |
from enum import Enum
|
2 |
-
import urllib
|
3 |
|
4 |
import os
|
5 |
-
from typing import List
|
6 |
-
from urllib.parse import urlparse
|
7 |
-
import json5
|
8 |
-
import torch
|
9 |
|
10 |
-
from tqdm import tqdm
|
11 |
|
12 |
class ModelConfig:
|
13 |
-
def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
|
14 |
"""
|
15 |
Initialize a model configuration.
|
16 |
|
@@ -23,6 +18,7 @@ class ModelConfig:
|
|
23 |
self.url = url
|
24 |
self.path = path
|
25 |
self.type = type
|
|
|
26 |
|
27 |
VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
|
28 |
|
@@ -33,7 +29,7 @@ class VadInitialPromptMode(Enum):
|
|
33 |
|
34 |
@staticmethod
|
35 |
def from_string(s: str):
|
36 |
-
normalized = s.lower() if s is not None else None
|
37 |
|
38 |
if normalized == "prepend_all_segments":
|
39 |
return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
|
@@ -47,11 +43,11 @@ class VadInitialPromptMode(Enum):
|
|
47 |
return None
|
48 |
|
49 |
class ApplicationConfig:
|
50 |
-
def __init__(self, models:
|
51 |
-
share: bool = False, server_name: str = None, server_port: int = 7860,
|
52 |
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
|
53 |
-
whisper_implementation: str = "whisper",
|
54 |
-
|
55 |
vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
|
56 |
auto_parallel: bool = False, output_dir: str = None,
|
57 |
model_dir: str = None, device: str = None,
|
@@ -66,6 +62,7 @@ class ApplicationConfig:
|
|
66 |
compute_type: str = "float16",
|
67 |
temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
|
68 |
logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
|
|
|
69 |
# Word timestamp settings
|
70 |
word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
|
71 |
append_punctuations: str = "\"\'.。,,!!??::”)]}、",
|
@@ -73,10 +70,14 @@ class ApplicationConfig:
|
|
73 |
# Diarization
|
74 |
auth_token: str = None, diarization: bool = False, diarization_speakers: int = 2,
|
75 |
diarization_min_speakers: int = 1, diarization_max_speakers: int = 5,
|
76 |
-
diarization_process_timeout: int = 60
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
self.models = models
|
79 |
-
self.nllb_models = nllb_models
|
80 |
|
81 |
# WebUI settings
|
82 |
self.input_audio_max_duration = input_audio_max_duration
|
@@ -120,6 +121,8 @@ class ApplicationConfig:
|
|
120 |
self.compression_ratio_threshold = compression_ratio_threshold
|
121 |
self.logprob_threshold = logprob_threshold
|
122 |
self.no_speech_threshold = no_speech_threshold
|
|
|
|
|
123 |
|
124 |
# Word timestamp settings
|
125 |
self.word_timestamps = word_timestamps
|
@@ -134,12 +137,13 @@ class ApplicationConfig:
|
|
134 |
self.diarization_min_speakers = diarization_min_speakers
|
135 |
self.diarization_max_speakers = diarization_max_speakers
|
136 |
self.diarization_process_timeout = diarization_process_timeout
|
|
|
|
|
|
|
|
|
137 |
|
138 |
-
def get_model_names(self):
|
139 |
-
return [ x.name for x in self.models ]
|
140 |
-
|
141 |
-
def get_nllb_model_names(self):
|
142 |
-
return [ x.name for x in self.nllb_models ]
|
143 |
|
144 |
def update(self, **new_values):
|
145 |
result = ApplicationConfig(**self.__dict__)
|
@@ -165,9 +169,9 @@ class ApplicationConfig:
|
|
165 |
# Load using json5
|
166 |
data = json5.load(f)
|
167 |
data_models = data.pop("models", [])
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
|
173 |
-
return ApplicationConfig(models,
|
|
|
1 |
from enum import Enum
|
|
|
2 |
|
3 |
import os
|
4 |
+
from typing import List, Dict, Literal
|
|
|
|
|
|
|
5 |
|
|
|
6 |
|
7 |
class ModelConfig:
|
8 |
+
def __init__(self, name: str, url: str, path: str = None, type: str = "whisper", tokenizer_url: str = None):
|
9 |
"""
|
10 |
Initialize a model configuration.
|
11 |
|
|
|
18 |
self.url = url
|
19 |
self.path = path
|
20 |
self.type = type
|
21 |
+
self.tokenizer_url = tokenizer_url
|
22 |
|
23 |
VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
|
24 |
|
|
|
29 |
|
30 |
@staticmethod
|
31 |
def from_string(s: str):
|
32 |
+
normalized = s.lower() if s is not None and len(s) > 0 else None
|
33 |
|
34 |
if normalized == "prepend_all_segments":
|
35 |
return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
|
|
|
43 |
return None
|
44 |
|
45 |
class ApplicationConfig:
|
46 |
+
def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5"], List[ModelConfig]],
|
47 |
+
input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
|
48 |
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
|
49 |
+
whisper_implementation: str = "whisper", default_model_name: str = "medium",
|
50 |
+
default_nllb_model_name: str = "distilled-600M", default_vad: str = "silero-vad",
|
51 |
vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
|
52 |
auto_parallel: bool = False, output_dir: str = None,
|
53 |
model_dir: str = None, device: str = None,
|
|
|
62 |
compute_type: str = "float16",
|
63 |
temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
|
64 |
logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
|
65 |
+
repetition_penalty: float = 1.0, no_repeat_ngram_size: int = 0,
|
66 |
# Word timestamp settings
|
67 |
word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
|
68 |
append_punctuations: str = "\"\'.。,,!!??::”)]}、",
|
|
|
70 |
# Diarization
|
71 |
auth_token: str = None, diarization: bool = False, diarization_speakers: int = 2,
|
72 |
diarization_min_speakers: int = 1, diarization_max_speakers: int = 5,
|
73 |
+
diarization_process_timeout: int = 60,
|
74 |
+
# Translation
|
75 |
+
translation_batch_size: int = 2,
|
76 |
+
translation_no_repeat_ngram_size: int = 3,
|
77 |
+
translation_num_beams: int = 2,
|
78 |
+
):
|
79 |
|
80 |
self.models = models
|
|
|
81 |
|
82 |
# WebUI settings
|
83 |
self.input_audio_max_duration = input_audio_max_duration
|
|
|
121 |
self.compression_ratio_threshold = compression_ratio_threshold
|
122 |
self.logprob_threshold = logprob_threshold
|
123 |
self.no_speech_threshold = no_speech_threshold
|
124 |
+
self.repetition_penalty = repetition_penalty
|
125 |
+
self.no_repeat_ngram_size = no_repeat_ngram_size
|
126 |
|
127 |
# Word timestamp settings
|
128 |
self.word_timestamps = word_timestamps
|
|
|
137 |
self.diarization_min_speakers = diarization_min_speakers
|
138 |
self.diarization_max_speakers = diarization_max_speakers
|
139 |
self.diarization_process_timeout = diarization_process_timeout
|
140 |
+
# Translation
|
141 |
+
self.translation_batch_size = translation_batch_size
|
142 |
+
self.translation_no_repeat_ngram_size = translation_no_repeat_ngram_size
|
143 |
+
self.translation_num_beams = translation_num_beams
|
144 |
|
145 |
+
def get_model_names(self, name: str):
|
146 |
+
return [ x.name for x in self.models[name] ]
|
|
|
|
|
|
|
147 |
|
148 |
def update(self, **new_values):
|
149 |
result = ApplicationConfig(**self.__dict__)
|
|
|
169 |
# Load using json5
|
170 |
data = json5.load(f)
|
171 |
data_models = data.pop("models", [])
|
172 |
+
models: Dict[Literal["whisper", "m2m100", "nllb", "mt5"], List[ModelConfig]] = {
|
173 |
+
key: [ModelConfig(**item) for item in value]
|
174 |
+
for key, value in data_models.items()
|
175 |
+
}
|
176 |
|
177 |
+
return ApplicationConfig(models, **data)
|
@@ -1,147 +0,0 @@
|
|
1 |
-
class Language():
|
2 |
-
def __init__(self, code, name):
|
3 |
-
self.code = code
|
4 |
-
self.name = name
|
5 |
-
|
6 |
-
def __str__(self):
|
7 |
-
return "Language(code={}, name={})".format(self.code, self.name)
|
8 |
-
|
9 |
-
LANGUAGES = [
|
10 |
-
Language('en', 'English'),
|
11 |
-
Language('zh', 'Chinese'),
|
12 |
-
Language('de', 'German'),
|
13 |
-
Language('es', 'Spanish'),
|
14 |
-
Language('ru', 'Russian'),
|
15 |
-
Language('ko', 'Korean'),
|
16 |
-
Language('fr', 'French'),
|
17 |
-
Language('ja', 'Japanese'),
|
18 |
-
Language('pt', 'Portuguese'),
|
19 |
-
Language('tr', 'Turkish'),
|
20 |
-
Language('pl', 'Polish'),
|
21 |
-
Language('ca', 'Catalan'),
|
22 |
-
Language('nl', 'Dutch'),
|
23 |
-
Language('ar', 'Arabic'),
|
24 |
-
Language('sv', 'Swedish'),
|
25 |
-
Language('it', 'Italian'),
|
26 |
-
Language('id', 'Indonesian'),
|
27 |
-
Language('hi', 'Hindi'),
|
28 |
-
Language('fi', 'Finnish'),
|
29 |
-
Language('vi', 'Vietnamese'),
|
30 |
-
Language('he', 'Hebrew'),
|
31 |
-
Language('uk', 'Ukrainian'),
|
32 |
-
Language('el', 'Greek'),
|
33 |
-
Language('ms', 'Malay'),
|
34 |
-
Language('cs', 'Czech'),
|
35 |
-
Language('ro', 'Romanian'),
|
36 |
-
Language('da', 'Danish'),
|
37 |
-
Language('hu', 'Hungarian'),
|
38 |
-
Language('ta', 'Tamil'),
|
39 |
-
Language('no', 'Norwegian'),
|
40 |
-
Language('th', 'Thai'),
|
41 |
-
Language('ur', 'Urdu'),
|
42 |
-
Language('hr', 'Croatian'),
|
43 |
-
Language('bg', 'Bulgarian'),
|
44 |
-
Language('lt', 'Lithuanian'),
|
45 |
-
Language('la', 'Latin'),
|
46 |
-
Language('mi', 'Maori'),
|
47 |
-
Language('ml', 'Malayalam'),
|
48 |
-
Language('cy', 'Welsh'),
|
49 |
-
Language('sk', 'Slovak'),
|
50 |
-
Language('te', 'Telugu'),
|
51 |
-
Language('fa', 'Persian'),
|
52 |
-
Language('lv', 'Latvian'),
|
53 |
-
Language('bn', 'Bengali'),
|
54 |
-
Language('sr', 'Serbian'),
|
55 |
-
Language('az', 'Azerbaijani'),
|
56 |
-
Language('sl', 'Slovenian'),
|
57 |
-
Language('kn', 'Kannada'),
|
58 |
-
Language('et', 'Estonian'),
|
59 |
-
Language('mk', 'Macedonian'),
|
60 |
-
Language('br', 'Breton'),
|
61 |
-
Language('eu', 'Basque'),
|
62 |
-
Language('is', 'Icelandic'),
|
63 |
-
Language('hy', 'Armenian'),
|
64 |
-
Language('ne', 'Nepali'),
|
65 |
-
Language('mn', 'Mongolian'),
|
66 |
-
Language('bs', 'Bosnian'),
|
67 |
-
Language('kk', 'Kazakh'),
|
68 |
-
Language('sq', 'Albanian'),
|
69 |
-
Language('sw', 'Swahili'),
|
70 |
-
Language('gl', 'Galician'),
|
71 |
-
Language('mr', 'Marathi'),
|
72 |
-
Language('pa', 'Punjabi'),
|
73 |
-
Language('si', 'Sinhala'),
|
74 |
-
Language('km', 'Khmer'),
|
75 |
-
Language('sn', 'Shona'),
|
76 |
-
Language('yo', 'Yoruba'),
|
77 |
-
Language('so', 'Somali'),
|
78 |
-
Language('af', 'Afrikaans'),
|
79 |
-
Language('oc', 'Occitan'),
|
80 |
-
Language('ka', 'Georgian'),
|
81 |
-
Language('be', 'Belarusian'),
|
82 |
-
Language('tg', 'Tajik'),
|
83 |
-
Language('sd', 'Sindhi'),
|
84 |
-
Language('gu', 'Gujarati'),
|
85 |
-
Language('am', 'Amharic'),
|
86 |
-
Language('yi', 'Yiddish'),
|
87 |
-
Language('lo', 'Lao'),
|
88 |
-
Language('uz', 'Uzbek'),
|
89 |
-
Language('fo', 'Faroese'),
|
90 |
-
Language('ht', 'Haitian creole'),
|
91 |
-
Language('ps', 'Pashto'),
|
92 |
-
Language('tk', 'Turkmen'),
|
93 |
-
Language('nn', 'Nynorsk'),
|
94 |
-
Language('mt', 'Maltese'),
|
95 |
-
Language('sa', 'Sanskrit'),
|
96 |
-
Language('lb', 'Luxembourgish'),
|
97 |
-
Language('my', 'Myanmar'),
|
98 |
-
Language('bo', 'Tibetan'),
|
99 |
-
Language('tl', 'Tagalog'),
|
100 |
-
Language('mg', 'Malagasy'),
|
101 |
-
Language('as', 'Assamese'),
|
102 |
-
Language('tt', 'Tatar'),
|
103 |
-
Language('haw', 'Hawaiian'),
|
104 |
-
Language('ln', 'Lingala'),
|
105 |
-
Language('ha', 'Hausa'),
|
106 |
-
Language('ba', 'Bashkir'),
|
107 |
-
Language('jw', 'Javanese'),
|
108 |
-
Language('su', 'Sundanese')
|
109 |
-
]
|
110 |
-
|
111 |
-
_TO_LANGUAGE_CODE = {
|
112 |
-
**{language.code: language for language in LANGUAGES},
|
113 |
-
"burmese": "my",
|
114 |
-
"valencian": "ca",
|
115 |
-
"flemish": "nl",
|
116 |
-
"haitian": "ht",
|
117 |
-
"letzeburgesch": "lb",
|
118 |
-
"pushto": "ps",
|
119 |
-
"panjabi": "pa",
|
120 |
-
"moldavian": "ro",
|
121 |
-
"moldovan": "ro",
|
122 |
-
"sinhalese": "si",
|
123 |
-
"castilian": "es",
|
124 |
-
}
|
125 |
-
|
126 |
-
_FROM_LANGUAGE_NAME = {
|
127 |
-
**{language.name.lower(): language for language in LANGUAGES}
|
128 |
-
}
|
129 |
-
|
130 |
-
def get_language_from_code(language_code, default=None) -> Language:
|
131 |
-
"""Return the language name from the language code."""
|
132 |
-
return _TO_LANGUAGE_CODE.get(language_code, default)
|
133 |
-
|
134 |
-
def get_language_from_name(language, default=None) -> Language:
|
135 |
-
"""Return the language code from the language name."""
|
136 |
-
return _FROM_LANGUAGE_NAME.get(language.lower() if language else None, default)
|
137 |
-
|
138 |
-
def get_language_names():
|
139 |
-
"""Return a list of language names."""
|
140 |
-
return [language.name for language in LANGUAGES]
|
141 |
-
|
142 |
-
if __name__ == "__main__":
|
143 |
-
# Test lookup
|
144 |
-
print(get_language_from_code('en'))
|
145 |
-
print(get_language_from_name('English'))
|
146 |
-
|
147 |
-
print(get_language_names())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,251 +0,0 @@
|
|
1 |
-
class NllbLang():
|
2 |
-
def __init__(self, code, name, code_whisper=None, name_whisper=None):
|
3 |
-
self.code = code
|
4 |
-
self.name = name
|
5 |
-
self.code_whisper = code_whisper
|
6 |
-
self.name_whisper = name_whisper
|
7 |
-
|
8 |
-
def __str__(self):
|
9 |
-
return "Language(code={}, name={})".format(self.code, self.name)
|
10 |
-
|
11 |
-
NLLB_LANGS = [
|
12 |
-
NllbLang('ace_Arab', 'Acehnese (Arabic script)'),
|
13 |
-
NllbLang('ace_Latn', 'Acehnese (Latin script)'),
|
14 |
-
NllbLang('acm_Arab', 'Mesopotamian Arabic', 'ar', 'Arabic'),
|
15 |
-
NllbLang('acq_Arab', 'Ta’izzi-Adeni Arabic', 'ar', 'Arabic'),
|
16 |
-
NllbLang('aeb_Arab', 'Tunisian Arabic'),
|
17 |
-
NllbLang('afr_Latn', 'Afrikaans', 'am', 'Amharic'),
|
18 |
-
NllbLang('ajp_Arab', 'South Levantine Arabic', 'ar', 'Arabic'),
|
19 |
-
NllbLang('aka_Latn', 'Akan'),
|
20 |
-
NllbLang('amh_Ethi', 'Amharic'),
|
21 |
-
NllbLang('apc_Arab', 'North Levantine Arabic', 'ar', 'Arabic'),
|
22 |
-
NllbLang('arb_Arab', 'Modern Standard Arabic', 'ar', 'Arabic'),
|
23 |
-
NllbLang('arb_Latn', 'Modern Standard Arabic (Romanized)'),
|
24 |
-
NllbLang('ars_Arab', 'Najdi Arabic', 'ar', 'Arabic'),
|
25 |
-
NllbLang('ary_Arab', 'Moroccan Arabic', 'ar', 'Arabic'),
|
26 |
-
NllbLang('arz_Arab', 'Egyptian Arabic', 'ar', 'Arabic'),
|
27 |
-
NllbLang('asm_Beng', 'Assamese', 'as', 'Assamese'),
|
28 |
-
NllbLang('ast_Latn', 'Asturian'),
|
29 |
-
NllbLang('awa_Deva', 'Awadhi'),
|
30 |
-
NllbLang('ayr_Latn', 'Central Aymara'),
|
31 |
-
NllbLang('azb_Arab', 'South Azerbaijani', 'az', 'Azerbaijani'),
|
32 |
-
NllbLang('azj_Latn', 'North Azerbaijani', 'az', 'Azerbaijani'),
|
33 |
-
NllbLang('bak_Cyrl', 'Bashkir', 'ba', 'Bashkir'),
|
34 |
-
NllbLang('bam_Latn', 'Bambara'),
|
35 |
-
NllbLang('ban_Latn', 'Balinese'),
|
36 |
-
NllbLang('bel_Cyrl', 'Belarusian', 'be', 'Belarusian'),
|
37 |
-
NllbLang('bem_Latn', 'Bemba'),
|
38 |
-
NllbLang('ben_Beng', 'Bengali', 'bn', 'Bengali'),
|
39 |
-
NllbLang('bho_Deva', 'Bhojpuri'),
|
40 |
-
NllbLang('bjn_Arab', 'Banjar (Arabic script)'),
|
41 |
-
NllbLang('bjn_Latn', 'Banjar (Latin script)'),
|
42 |
-
NllbLang('bod_Tibt', 'Standard Tibetan', 'bo', 'Tibetan'),
|
43 |
-
NllbLang('bos_Latn', 'Bosnian', 'bs', 'Bosnian'),
|
44 |
-
NllbLang('bug_Latn', 'Buginese'),
|
45 |
-
NllbLang('bul_Cyrl', 'Bulgarian', 'bg', 'Bulgarian'),
|
46 |
-
NllbLang('cat_Latn', 'Catalan', 'ca', 'Catalan'),
|
47 |
-
NllbLang('ceb_Latn', 'Cebuano'),
|
48 |
-
NllbLang('ces_Latn', 'Czech', 'cs', 'Czech'),
|
49 |
-
NllbLang('cjk_Latn', 'Chokwe'),
|
50 |
-
NllbLang('ckb_Arab', 'Central Kurdish'),
|
51 |
-
NllbLang('crh_Latn', 'Crimean Tatar'),
|
52 |
-
NllbLang('cym_Latn', 'Welsh', 'cy', 'Welsh'),
|
53 |
-
NllbLang('dan_Latn', 'Danish', 'da', 'Danish'),
|
54 |
-
NllbLang('deu_Latn', 'German', 'de', 'German'),
|
55 |
-
NllbLang('dik_Latn', 'Southwestern Dinka'),
|
56 |
-
NllbLang('dyu_Latn', 'Dyula'),
|
57 |
-
NllbLang('dzo_Tibt', 'Dzongkha'),
|
58 |
-
NllbLang('ell_Grek', 'Greek', 'el', 'Greek'),
|
59 |
-
NllbLang('eng_Latn', 'English', 'en', 'English'),
|
60 |
-
NllbLang('epo_Latn', 'Esperanto'),
|
61 |
-
NllbLang('est_Latn', 'Estonian', 'et', 'Estonian'),
|
62 |
-
NllbLang('eus_Latn', 'Basque', 'eu', 'Basque'),
|
63 |
-
NllbLang('ewe_Latn', 'Ewe'),
|
64 |
-
NllbLang('fao_Latn', 'Faroese', 'fo', 'Faroese'),
|
65 |
-
NllbLang('fij_Latn', 'Fijian'),
|
66 |
-
NllbLang('fin_Latn', 'Finnish', 'fi', 'Finnish'),
|
67 |
-
NllbLang('fon_Latn', 'Fon'),
|
68 |
-
NllbLang('fra_Latn', 'French', 'fr', 'French'),
|
69 |
-
NllbLang('fur_Latn', 'Friulian'),
|
70 |
-
NllbLang('fuv_Latn', 'Nigerian Fulfulde'),
|
71 |
-
NllbLang('gla_Latn', 'Scottish Gaelic'),
|
72 |
-
NllbLang('gle_Latn', 'Irish'),
|
73 |
-
NllbLang('glg_Latn', 'Galician', 'gl', 'Galician'),
|
74 |
-
NllbLang('grn_Latn', 'Guarani'),
|
75 |
-
NllbLang('guj_Gujr', 'Gujarati', 'gu', 'Gujarati'),
|
76 |
-
NllbLang('hat_Latn', 'Haitian Creole', 'ht', 'Haitian creole'),
|
77 |
-
NllbLang('hau_Latn', 'Hausa', 'ha', 'Hausa'),
|
78 |
-
NllbLang('heb_Hebr', 'Hebrew', 'he', 'Hebrew'),
|
79 |
-
NllbLang('hin_Deva', 'Hindi', 'hi', 'Hindi'),
|
80 |
-
NllbLang('hne_Deva', 'Chhattisgarhi'),
|
81 |
-
NllbLang('hrv_Latn', 'Croatian', 'hr', 'Croatian'),
|
82 |
-
NllbLang('hun_Latn', 'Hungarian', 'hu', 'Hungarian'),
|
83 |
-
NllbLang('hye_Armn', 'Armenian', 'hy', 'Armenian'),
|
84 |
-
NllbLang('ibo_Latn', 'Igbo'),
|
85 |
-
NllbLang('ilo_Latn', 'Ilocano'),
|
86 |
-
NllbLang('ind_Latn', 'Indonesian', 'id', 'Indonesian'),
|
87 |
-
NllbLang('isl_Latn', 'Icelandic', 'is', 'Icelandic'),
|
88 |
-
NllbLang('ita_Latn', 'Italian', 'it', 'Italian'),
|
89 |
-
NllbLang('jav_Latn', 'Javanese', 'jw', 'Javanese'),
|
90 |
-
NllbLang('jpn_Jpan', 'Japanese', 'ja', 'Japanese'),
|
91 |
-
NllbLang('kab_Latn', 'Kabyle'),
|
92 |
-
NllbLang('kac_Latn', 'Jingpho'),
|
93 |
-
NllbLang('kam_Latn', 'Kamba'),
|
94 |
-
NllbLang('kan_Knda', 'Kannada', 'kn', 'Kannada'),
|
95 |
-
NllbLang('kas_Arab', 'Kashmiri (Arabic script)'),
|
96 |
-
NllbLang('kas_Deva', 'Kashmiri (Devanagari script)'),
|
97 |
-
NllbLang('kat_Geor', 'Georgian', 'ka', 'Georgian'),
|
98 |
-
NllbLang('knc_Arab', 'Central Kanuri (Arabic script)'),
|
99 |
-
NllbLang('knc_Latn', 'Central Kanuri (Latin script)'),
|
100 |
-
NllbLang('kaz_Cyrl', 'Kazakh', 'kk', 'Kazakh'),
|
101 |
-
NllbLang('kbp_Latn', 'Kabiyè'),
|
102 |
-
NllbLang('kea_Latn', 'Kabuverdianu'),
|
103 |
-
NllbLang('khm_Khmr', 'Khmer', 'km', 'Khmer'),
|
104 |
-
NllbLang('kik_Latn', 'Kikuyu'),
|
105 |
-
NllbLang('kin_Latn', 'Kinyarwanda'),
|
106 |
-
NllbLang('kir_Cyrl', 'Kyrgyz'),
|
107 |
-
NllbLang('kmb_Latn', 'Kimbundu'),
|
108 |
-
NllbLang('kmr_Latn', 'Northern Kurdish'),
|
109 |
-
NllbLang('kon_Latn', 'Kikongo'),
|
110 |
-
NllbLang('kor_Hang', 'Korean', 'ko', 'Korean'),
|
111 |
-
NllbLang('lao_Laoo', 'Lao', 'lo', 'Lao'),
|
112 |
-
NllbLang('lij_Latn', 'Ligurian'),
|
113 |
-
NllbLang('lim_Latn', 'Limburgish'),
|
114 |
-
NllbLang('lin_Latn', 'Lingala', 'ln', 'Lingala'),
|
115 |
-
NllbLang('lit_Latn', 'Lithuanian', 'lt', 'Lithuanian'),
|
116 |
-
NllbLang('lmo_Latn', 'Lombard'),
|
117 |
-
NllbLang('ltg_Latn', 'Latgalian'),
|
118 |
-
NllbLang('ltz_Latn', 'Luxembourgish', 'lb', 'Luxembourgish'),
|
119 |
-
NllbLang('lua_Latn', 'Luba-Kasai'),
|
120 |
-
NllbLang('lug_Latn', 'Ganda'),
|
121 |
-
NllbLang('luo_Latn', 'Luo'),
|
122 |
-
NllbLang('lus_Latn', 'Mizo'),
|
123 |
-
NllbLang('lvs_Latn', 'Standard Latvian', 'lv', 'Latvian'),
|
124 |
-
NllbLang('mag_Deva', 'Magahi'),
|
125 |
-
NllbLang('mai_Deva', 'Maithili'),
|
126 |
-
NllbLang('mal_Mlym', 'Malayalam', 'ml', 'Malayalam'),
|
127 |
-
NllbLang('mar_Deva', 'Marathi', 'mr', 'Marathi'),
|
128 |
-
NllbLang('min_Arab', 'Minangkabau (Arabic script)'),
|
129 |
-
NllbLang('min_Latn', 'Minangkabau (Latin script)'),
|
130 |
-
NllbLang('mkd_Cyrl', 'Macedonian', 'mk', 'Macedonian'),
|
131 |
-
NllbLang('plt_Latn', 'Plateau Malagasy', 'mg', 'Malagasy'),
|
132 |
-
NllbLang('mlt_Latn', 'Maltese', 'mt', 'Maltese'),
|
133 |
-
NllbLang('mni_Beng', 'Meitei (Bengali script)'),
|
134 |
-
NllbLang('khk_Cyrl', 'Halh Mongolian', 'mn', 'Mongolian'),
|
135 |
-
NllbLang('mos_Latn', 'Mossi'),
|
136 |
-
NllbLang('mri_Latn', 'Maori', 'mi', 'Maori'),
|
137 |
-
NllbLang('mya_Mymr', 'Burmese', 'my', 'Myanmar'),
|
138 |
-
NllbLang('nld_Latn', 'Dutch', 'nl', 'Dutch'),
|
139 |
-
NllbLang('nno_Latn', 'Norwegian Nynorsk', 'nn', 'Nynorsk'),
|
140 |
-
NllbLang('nob_Latn', 'Norwegian Bokmål', 'no', 'Norwegian'),
|
141 |
-
NllbLang('npi_Deva', 'Nepali', 'ne', 'Nepali'),
|
142 |
-
NllbLang('nso_Latn', 'Northern Sotho'),
|
143 |
-
NllbLang('nus_Latn', 'Nuer'),
|
144 |
-
NllbLang('nya_Latn', 'Nyanja'),
|
145 |
-
NllbLang('oci_Latn', 'Occitan', 'oc', 'Occitan'),
|
146 |
-
NllbLang('gaz_Latn', 'West Central Oromo'),
|
147 |
-
NllbLang('ory_Orya', 'Odia'),
|
148 |
-
NllbLang('pag_Latn', 'Pangasinan'),
|
149 |
-
NllbLang('pan_Guru', 'Eastern Panjabi', 'pa', 'Punjabi'),
|
150 |
-
NllbLang('pap_Latn', 'Papiamento'),
|
151 |
-
NllbLang('pes_Arab', 'Western Persian', 'fa', 'Persian'),
|
152 |
-
NllbLang('pol_Latn', 'Polish', 'pl', 'Polish'),
|
153 |
-
NllbLang('por_Latn', 'Portuguese', 'pt', 'Portuguese'),
|
154 |
-
NllbLang('prs_Arab', 'Dari'),
|
155 |
-
NllbLang('pbt_Arab', 'Southern Pashto', 'ps', 'Pashto'),
|
156 |
-
NllbLang('quy_Latn', 'Ayacucho Quechua'),
|
157 |
-
NllbLang('ron_Latn', 'Romanian', 'ro', 'Romanian'),
|
158 |
-
NllbLang('run_Latn', 'Rundi'),
|
159 |
-
NllbLang('rus_Cyrl', 'Russian', 'ru', 'Russian'),
|
160 |
-
NllbLang('sag_Latn', 'Sango'),
|
161 |
-
NllbLang('san_Deva', 'Sanskrit', 'sa', 'Sanskrit'),
|
162 |
-
NllbLang('sat_Olck', 'Santali'),
|
163 |
-
NllbLang('scn_Latn', 'Sicilian'),
|
164 |
-
NllbLang('shn_Mymr', 'Shan'),
|
165 |
-
NllbLang('sin_Sinh', 'Sinhala', 'si', 'Sinhala'),
|
166 |
-
NllbLang('slk_Latn', 'Slovak', 'sk', 'Slovak'),
|
167 |
-
NllbLang('slv_Latn', 'Slovenian', 'sl', 'Slovenian'),
|
168 |
-
NllbLang('smo_Latn', 'Samoan'),
|
169 |
-
NllbLang('sna_Latn', 'Shona', 'sn', 'Shona'),
|
170 |
-
NllbLang('snd_Arab', 'Sindhi', 'sd', 'Sindhi'),
|
171 |
-
NllbLang('som_Latn', 'Somali', 'so', 'Somali'),
|
172 |
-
NllbLang('sot_Latn', 'Southern Sotho'),
|
173 |
-
NllbLang('spa_Latn', 'Spanish', 'es', 'Spanish'),
|
174 |
-
NllbLang('als_Latn', 'Tosk Albanian', 'sq', 'Albanian'),
|
175 |
-
NllbLang('srd_Latn', 'Sardinian'),
|
176 |
-
NllbLang('srp_Cyrl', 'Serbian', 'sr', 'Serbian'),
|
177 |
-
NllbLang('ssw_Latn', 'Swati'),
|
178 |
-
NllbLang('sun_Latn', 'Sundanese', 'su', 'Sundanese'),
|
179 |
-
NllbLang('swe_Latn', 'Swedish', 'sv', 'Swedish'),
|
180 |
-
NllbLang('swh_Latn', 'Swahili', 'sw', 'Swahili'),
|
181 |
-
NllbLang('szl_Latn', 'Silesian'),
|
182 |
-
NllbLang('tam_Taml', 'Tamil', 'ta', 'Tamil'),
|
183 |
-
NllbLang('tat_Cyrl', 'Tatar', 'tt', 'Tatar'),
|
184 |
-
NllbLang('tel_Telu', 'Telugu', 'te', 'Telugu'),
|
185 |
-
NllbLang('tgk_Cyrl', 'Tajik', 'tg', 'Tajik'),
|
186 |
-
NllbLang('tgl_Latn', 'Tagalog', 'tl', 'Tagalog'),
|
187 |
-
NllbLang('tha_Thai', 'Thai', 'th', 'Thai'),
|
188 |
-
NllbLang('tir_Ethi', 'Tigrinya'),
|
189 |
-
NllbLang('taq_Latn', 'Tamasheq (Latin script)'),
|
190 |
-
NllbLang('taq_Tfng', 'Tamasheq (Tifinagh script)'),
|
191 |
-
NllbLang('tpi_Latn', 'Tok Pisin'),
|
192 |
-
NllbLang('tsn_Latn', 'Tswana'),
|
193 |
-
NllbLang('tso_Latn', 'Tsonga'),
|
194 |
-
NllbLang('tuk_Latn', 'Turkmen', 'tk', 'Turkmen'),
|
195 |
-
NllbLang('tum_Latn', 'Tumbuka'),
|
196 |
-
NllbLang('tur_Latn', 'Turkish', 'tr', 'Turkish'),
|
197 |
-
NllbLang('twi_Latn', 'Twi'),
|
198 |
-
NllbLang('tzm_Tfng', 'Central Atlas Tamazight'),
|
199 |
-
NllbLang('uig_Arab', 'Uyghur'),
|
200 |
-
NllbLang('ukr_Cyrl', 'Ukrainian', 'uk', 'Ukrainian'),
|
201 |
-
NllbLang('umb_Latn', 'Umbundu'),
|
202 |
-
NllbLang('urd_Arab', 'Urdu', 'ur', 'Urdu'),
|
203 |
-
NllbLang('uzn_Latn', 'Northern Uzbek', 'uz', 'Uzbek'),
|
204 |
-
NllbLang('vec_Latn', 'Venetian'),
|
205 |
-
NllbLang('vie_Latn', 'Vietnamese', 'vi', 'Vietnamese'),
|
206 |
-
NllbLang('war_Latn', 'Waray'),
|
207 |
-
NllbLang('wol_Latn', 'Wolof'),
|
208 |
-
NllbLang('xho_Latn', 'Xhosa'),
|
209 |
-
NllbLang('ydd_Hebr', 'Eastern Yiddish', 'yi', 'Yiddish'),
|
210 |
-
NllbLang('yor_Latn', 'Yoruba', 'yo', 'Yoruba'),
|
211 |
-
NllbLang('yue_Hant', 'Yue Chinese', 'zh', 'Chinese'),
|
212 |
-
NllbLang('zho_Hans', 'Chinese (Simplified)', 'zh', 'Chinese'),
|
213 |
-
NllbLang('zho_Hant', 'Chinese (Traditional)', 'zh', 'Chinese'),
|
214 |
-
NllbLang('zsm_Latn', 'Standard Malay', 'ms', 'Malay'),
|
215 |
-
NllbLang('zul_Latn', 'Zulu'),
|
216 |
-
]
|
217 |
-
|
218 |
-
_TO_NLLB_LANG_CODE = {language.code.lower(): language for language in NLLB_LANGS if language.code is not None}
|
219 |
-
|
220 |
-
_TO_NLLB_LANG_NAME = {language.name.lower(): language for language in NLLB_LANGS if language.name is not None}
|
221 |
-
|
222 |
-
_TO_NLLB_LANG_WHISPER_CODE = {language.code_whisper.lower(): language for language in NLLB_LANGS if language.code_whisper is not None}
|
223 |
-
|
224 |
-
_TO_NLLB_LANG_WHISPER_NAME = {language.name_whisper.lower(): language for language in NLLB_LANGS if language.name_whisper is not None}
|
225 |
-
|
226 |
-
def get_nllb_lang_from_code(lang_code, default=None) -> NllbLang:
|
227 |
-
"""Return the language from the language code."""
|
228 |
-
return _TO_NLLB_LANG_CODE.get(lang_code, default)
|
229 |
-
|
230 |
-
def get_nllb_lang_from_name(lang_name, default=None) -> NllbLang:
|
231 |
-
"""Return the language from the language name."""
|
232 |
-
return _TO_NLLB_LANG_NAME.get(lang_name.lower() if lang_name else None, default)
|
233 |
-
|
234 |
-
def get_nllb_lang_from_code_whisper(lang_code_whisper, default=None) -> NllbLang:
|
235 |
-
"""Return the language from the language code."""
|
236 |
-
return _TO_NLLB_LANG_WHISPER_CODE.get(lang_code_whisper, default)
|
237 |
-
|
238 |
-
def get_nllb_lang_from_name_whisper(lang_name_whisper, default=None) -> NllbLang:
|
239 |
-
"""Return the language from the language name."""
|
240 |
-
return _TO_NLLB_LANG_WHISPER_NAME.get(lang_name_whisper.lower() if lang_name_whisper else None, default)
|
241 |
-
|
242 |
-
def get_nllb_lang_names():
|
243 |
-
"""Return a list of language names."""
|
244 |
-
return [language.name for language in NLLB_LANGS]
|
245 |
-
|
246 |
-
if __name__ == "__main__":
|
247 |
-
# Test lookup
|
248 |
-
print(get_nllb_lang_from_code('eng_Latn'))
|
249 |
-
print(get_nllb_lang_from_name('English'))
|
250 |
-
|
251 |
-
print(get_nllb_lang_names())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Lang():
|
2 |
+
def __init__(self, code: str, *names: str):
|
3 |
+
self.code = code
|
4 |
+
self.names = names
|
5 |
+
|
6 |
+
def __repr__(self):
|
7 |
+
return f"code:{self.code}, name:{self.names}"
|
8 |
+
|
9 |
+
class TranslationLang():
|
10 |
+
def __init__(self, nllb: Lang, whisper: Lang = None, m2m100: Lang = None):
|
11 |
+
self.nllb = nllb
|
12 |
+
self.whisper = whisper
|
13 |
+
self.m2m100 = None
|
14 |
+
|
15 |
+
if m2m100 is None: m2m100 = whisper
|
16 |
+
if m2m100 is not None and len(m2m100.names) > 0:
|
17 |
+
self.m2m100 = m2m100
|
18 |
+
|
19 |
+
def __repr__(self):
|
20 |
+
result = ""
|
21 |
+
if self.nllb is not None:
|
22 |
+
result += f"NLLB={self.nllb} "
|
23 |
+
if self.whisper is not None:
|
24 |
+
result += f"WHISPER={self.whisper} "
|
25 |
+
if self.m2m100 is not None:
|
26 |
+
result += f"M@M100={self.m2m100} "
|
27 |
+
return f"Language {result}"
|
28 |
+
|
29 |
+
"""
|
30 |
+
Model available Languages
|
31 |
+
|
32 |
+
[NLLB]
|
33 |
+
ace_Latn:Acehnese (Latin script), aka_Latn:Akan, als_Latn:Tosk Albanian, amh_Ethi:Amharic, asm_Beng:Assamese, awa_Deva:Awadhi, ayr_Latn:Central Aymara, azb_Arab:South Azerbaijani, azj_Latn:North Azerbaijani, bak_Cyrl:Bashkir, bam_Latn:Bambara, ban_Latn:Balinese, bel_Cyrl:Belarusian, bem_Latn:Bemba, ben_Beng:Bengali, bho_Deva:Bhojpuri, bjn_Latn:Banjar (Latin script), bod_Tibt:Standard Tibetan, bug_Latn:Buginese, ceb_Latn:Cebuano, cjk_Latn:Chokwe, ckb_Arab:Central Kurdish, crh_Latn:Crimean Tatar, cym_Latn:Welsh, dik_Latn:Southwestern Dinka, diq_Latn:Southern Zaza, dyu_Latn:Dyula, dzo_Tibt:Dzongkha, ewe_Latn:Ewe, fao_Latn:Faroese, fij_Latn:Fijian, fon_Latn:Fon, fur_Latn:Friulian, fuv_Latn:Nigerian Fulfulde, gaz_Latn:West Central Oromo, gla_Latn:Scottish Gaelic, gle_Latn:Irish, grn_Latn:Guarani, guj_Gujr:Gujarati, hat_Latn:Haitian Creole, hau_Latn:Hausa, hin_Deva:Hindi, hne_Deva:Chhattisgarhi, hye_Armn:Armenian, ibo_Latn:Igbo, ilo_Latn:Ilocano, ind_Latn:Indonesian, jav_Latn:Javanese, kab_Latn:Kabyle, kac_Latn:Jingpho, kam_Latn:Kamba, kan_Knda:Kannada, kas_Arab:Kashmiri (Arabic script), kas_Deva:Kashmiri (Devanagari script), kat_Geor:Georgian, kaz_Cyrl:Kazakh, kbp_Latn:Kabiyè, kea_Latn:Kabuverdianu, khk_Cyrl:Halh Mongolian, khm_Khmr:Khmer, kik_Latn:Kikuyu, kin_Latn:Kinyarwanda, kir_Cyrl:Kyrgyz, kmb_Latn:Kimbundu, kmr_Latn:Northern Kurdish, knc_Arab:Central Kanuri (Arabic script), knc_Latn:Central Kanuri (Latin script), kon_Latn:Kikongo, lao_Laoo:Lao, lij_Latn:Ligurian, lim_Latn:Limburgish, lin_Latn:Lingala, lmo_Latn:Lombard, ltg_Latn:Latgalian, ltz_Latn:Luxembourgish, lua_Latn:Luba-Kasai, lug_Latn:Ganda, luo_Latn:Luo, lus_Latn:Mizo, mag_Deva:Magahi, mai_Deva:Maithili, mal_Mlym:Malayalam, mar_Deva:Marathi, min_Latn:Minangkabau (Latin script), mlt_Latn:Maltese, mni_Beng:Meitei (Bengali script), mos_Latn:Mossi, mri_Latn:Maori, mya_Mymr:Burmese, npi_Deva:Nepali, nso_Latn:Northern Sotho, nus_Latn:Nuer, nya_Latn:Nyanja, ory_Orya:Odia, pag_Latn:Pangasinan, pan_Guru:Eastern Panjabi, pap_Latn:Papiamento, pbt_Arab:Southern Pashto, pes_Arab:Western Persian, plt_Latn:Plateau Malagasy, prs_Arab:Dari, quy_Latn:Ayacucho Quechua, run_Latn:Rundi, sag_Latn:Sango, san_Deva:Sanskrit, sat_Beng:Santali, scn_Latn:Sicilian, shn_Mymr:Shan, sin_Sinh:Sinhala, smo_Latn:Samoan, sna_Latn:Shona, snd_Arab:Sindhi, som_Latn:Somali, sot_Latn:Southern Sotho, srd_Latn:Sardinian, ssw_Latn:Swati, sun_Latn:Sundanese, swh_Latn:Swahili, szl_Latn:Silesian, tam_Taml:Tamil, taq_Latn:Tamasheq (Latin script), tat_Cyrl:Tatar, tel_Telu:Telugu, tgk_Cyrl:Tajik, tgl_Latn:Tagalog, tha_Thai:Thai, tir_Ethi:Tigrinya, tpi_Latn:Tok Pisin, tsn_Latn:Tswana, tso_Latn:Tsonga, tuk_Latn:Turkmen, tum_Latn:Tumbuka, tur_Latn:Turkish, twi_Latn:Twi, tzm_Tfng:Central Atlas Tamazight, uig_Arab:Uyghur, umb_Latn:Umbundu, urd_Arab:Urdu, uzn_Latn:Northern Uzbek, vec_Latn:Venetian, war_Latn:Waray, wol_Latn:Wolof, xho_Latn:Xhosa, ydd_Hebr:Eastern Yiddish, yor_Latn:Yoruba, zsm_Latn:Standard Malay, zul_Latn:Zulu
|
34 |
+
https://github.com/facebookresearch/LASER/blob/main/nllb/README.md
|
35 |
+
|
36 |
+
In the NLLB model, languages are identified by a FLORES-200 code of the form {language}_{script}, where the language is an ISO 639-3 code and the script is an ISO 15924 code.
|
37 |
+
https://github.com/sillsdev/serval/wiki/FLORES%E2%80%90200-Language-Code-Resolution-for-NMT-Engine
|
38 |
+
|
39 |
+
[whisper]
|
40 |
+
en:english, zh:chinese, de:german, es:spanish, ru:russian, ko:korean, fr:french, ja:japanese, pt:portuguese, tr:turkish, pl:polish, ca:catalan, nl:dutch, ar:arabic, sv:swedish, it:italian, id:indonesian, hi:hindi, fi:finnish, vi:vietnamese, he:hebrew, uk:ukrainian, el:greek, ms:malay, cs:czech, ro:romanian, da:danish, hu:hungarian, ta:tamil, no:norwegian, th:thai, ur:urdu, hr:croatian, bg:bulgarian, lt:lithuanian, la:latin, mi:maori, ml:malayalam, cy:welsh, sk:slovak, te:telugu, fa:persian, lv:latvian, bn:bengali, sr:serbian, az:azerbaijani, sl:slovenian, kn:kannada, et:estonian, mk:macedonian, br:breton, eu:basque, is:icelandic, hy:armenian, ne:nepali, mn:mongolian, bs:bosnian, kk:kazakh, sq:albanian, sw:swahili, gl:galician, mr:marathi, pa:punjabi, si:sinhala, km:khmer, sn:shona, yo:yoruba, so:somali, af:afrikaans, oc:occitan, ka:georgian, be:belarusian, tg:tajik, sd:sindhi, gu:gujarati, am:amharic, yi:yiddish, lo:lao, uz:uzbek, fo:faroese, ht:haitian creole, ps:pashto, tk:turkmen, nn:nynorsk, mt:maltese, sa:sanskrit, lb:luxembourgish, my:myanmar, bo:tibetan, tl:tagalog, mg:malagasy, as:assamese, tt:tatar, haw:hawaiian, ln:lingala, ha:hausa, ba:bashkir, jw:javanese, su:sundanese, yue:cantonese,
|
41 |
+
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py
|
42 |
+
|
43 |
+
[m2m100]
|
44 |
+
af:Afrikaans, am:Amharic, ar:Arabic, ast:Asturian, az:Azerbaijani, ba:Bashkir, be:Belarusian, bg:Bulgarian, bn:Bengali, br:Breton, bs:Bosnian, ca:Catalan; Valencian, ceb:Cebuano, cs:Czech, cy:Welsh, da:Danish, de:German, el:Greek, en:English, es:Spanish, et:Estonian, fa:Persian, ff:Fulah, fi:Finnish, fr:French, fy:Western Frisian, ga:Irish, gd:Gaelic; Scottish Gaelic, gl:Galician, gu:Gujarati, ha:Hausa, he:Hebrew, hi:Hindi, hr:Croatian, ht:Haitian; Haitian Creole, hu:Hungarian, hy:Armenian, id:Indonesian, ig:Igbo, ilo:Iloko, is:Icelandic, it:Italian, ja:Japanese, jv:Javanese, ka:Georgian, kk:Kazakh, km:Central Khmer, kn:Kannada, ko:Korean, lb:Luxembourgish; Letzeburgesch, lg:Ganda, ln:Lingala, lo:Lao, lt:Lithuanian, lv:Latvian, mg:Malagasy, mk:Macedonian, ml:Malayalam, mn:Mongolian, mr:Marathi, ms:Malay, my:Burmese, ne:Nepali, nl:Dutch; Flemish, no:Norwegian, ns:Northern Sotho, Occitan (oc:post 1500), or:Oriya, pa:Panjabi; Punjabi, pl:Polish, ps:Pushto; Pashto, pt:Portuguese, ro:Romanian; Moldavian; Moldovan, ru:Russian, sd:Sindhi, si:Sinhala; Sinhalese, sk:Slovak, sl:Slovenian, so:Somali, sq:Albanian, sr:Serbian, ss:Swati, su:Sundanese, sv:Swedish, sw:Swahili, ta:Tamil, th:Thai, tl:Tagalog, tn:Tswana, tr:Turkish, uk:Ukrainian, ur:Urdu, uz:Uzbek, vi:Vietnamese, wo:Wolof, xh:Xhosa, yi:Yiddish, yo:Yoruba, zh:Chinese, zu:Zulu
|
45 |
+
https://huggingface.co/facebook/m2m100_1.2B
|
46 |
+
|
47 |
+
The available languages for m2m100 and whisper are almost identical. Most of the codes correspond to the ISO 639-1 standard. For detailed information, please refer to the official documentation provided.
|
48 |
+
"""
|
49 |
+
TranslationLangs = [
|
50 |
+
TranslationLang(Lang("ace_Arab", "Acehnese (Arabic script)")),
|
51 |
+
TranslationLang(Lang("ace_Latn", "Acehnese (Latin script)")),
|
52 |
+
TranslationLang(Lang("acm_Arab", "Mesopotamian Arabic"), Lang("ar", "Arabic")),
|
53 |
+
TranslationLang(Lang("acq_Arab", "Ta’izzi-Adeni Arabic"), Lang("ar", "Arabic")),
|
54 |
+
TranslationLang(Lang("aeb_Arab", "Tunisian Arabic")),
|
55 |
+
TranslationLang(Lang("afr_Latn", "Afrikaans"), Lang("af", "Afrikaans")),
|
56 |
+
TranslationLang(Lang("ajp_Arab", "South Levantine Arabic"), Lang("ar", "Arabic")),
|
57 |
+
TranslationLang(Lang("aka_Latn", "Akan")),
|
58 |
+
TranslationLang(Lang("amh_Ethi", "Amharic"), Lang("am", "Amharic")),
|
59 |
+
TranslationLang(Lang("apc_Arab", "North Levantine Arabic"), Lang("ar", "Arabic")),
|
60 |
+
TranslationLang(Lang("arb_Arab", "Modern Standard Arabic"), Lang("ar", "Arabic")),
|
61 |
+
TranslationLang(Lang("arb_Latn", "Modern Standard Arabic (Romanized)")),
|
62 |
+
TranslationLang(Lang("ars_Arab", "Najdi Arabic"), Lang("ar", "Arabic")),
|
63 |
+
TranslationLang(Lang("ary_Arab", "Moroccan Arabic"), Lang("ar", "Arabic")),
|
64 |
+
TranslationLang(Lang("arz_Arab", "Egyptian Arabic"), Lang("ar", "Arabic")),
|
65 |
+
TranslationLang(Lang("asm_Beng", "Assamese"), Lang("as", "Assamese")),
|
66 |
+
TranslationLang(Lang("ast_Latn", "Asturian"), None, Lang("ast", "Asturian")),
|
67 |
+
TranslationLang(Lang("awa_Deva", "Awadhi")),
|
68 |
+
TranslationLang(Lang("ayr_Latn", "Central Aymara")),
|
69 |
+
TranslationLang(Lang("azb_Arab", "South Azerbaijani"), Lang("az", "Azerbaijani")),
|
70 |
+
TranslationLang(Lang("azj_Latn", "North Azerbaijani"), Lang("az", "Azerbaijani")),
|
71 |
+
TranslationLang(Lang("bak_Cyrl", "Bashkir"), Lang("ba", "Bashkir")),
|
72 |
+
TranslationLang(Lang("bam_Latn", "Bambara")),
|
73 |
+
TranslationLang(Lang("ban_Latn", "Balinese")),
|
74 |
+
TranslationLang(Lang("bel_Cyrl", "Belarusian"), Lang("be", "Belarusian")),
|
75 |
+
TranslationLang(Lang("bem_Latn", "Bemba")),
|
76 |
+
TranslationLang(Lang("ben_Beng", "Bengali"), Lang("bn", "Bengali")),
|
77 |
+
TranslationLang(Lang("bho_Deva", "Bhojpuri")),
|
78 |
+
TranslationLang(Lang("bjn_Arab", "Banjar (Arabic script)")),
|
79 |
+
TranslationLang(Lang("bjn_Latn", "Banjar (Latin script)")),
|
80 |
+
TranslationLang(Lang("bod_Tibt", "Standard Tibetan"), Lang("bo", "Tibetan")),
|
81 |
+
TranslationLang(Lang("bos_Latn", "Bosnian"), Lang("bs", "Bosnian")),
|
82 |
+
TranslationLang(Lang("bug_Latn", "Buginese")),
|
83 |
+
TranslationLang(Lang("bul_Cyrl", "Bulgarian"), Lang("bg", "Bulgarian")),
|
84 |
+
TranslationLang(Lang("cat_Latn", "Catalan"), Lang("ca", "Catalan", "valencian")),
|
85 |
+
TranslationLang(Lang("ceb_Latn", "Cebuano"), None, Lang("ceb", "Cebuano")),
|
86 |
+
TranslationLang(Lang("ces_Latn", "Czech"), Lang("cs", "Czech")),
|
87 |
+
TranslationLang(Lang("cjk_Latn", "Chokwe")),
|
88 |
+
TranslationLang(Lang("ckb_Arab", "Central Kurdish")),
|
89 |
+
TranslationLang(Lang("crh_Latn", "Crimean Tatar")),
|
90 |
+
TranslationLang(Lang("cym_Latn", "Welsh"), Lang("cy", "Welsh")),
|
91 |
+
TranslationLang(Lang("dan_Latn", "Danish"), Lang("da", "Danish")),
|
92 |
+
TranslationLang(Lang("deu_Latn", "German"), Lang("de", "German")),
|
93 |
+
TranslationLang(Lang("dik_Latn", "Southwestern Dinka")),
|
94 |
+
TranslationLang(Lang("dyu_Latn", "Dyula")),
|
95 |
+
TranslationLang(Lang("dzo_Tibt", "Dzongkha")),
|
96 |
+
TranslationLang(Lang("ell_Grek", "Greek"), Lang("el", "Greek")),
|
97 |
+
TranslationLang(Lang("eng_Latn", "English"), Lang("en", "English")),
|
98 |
+
TranslationLang(Lang("epo_Latn", "Esperanto")),
|
99 |
+
TranslationLang(Lang("est_Latn", "Estonian"), Lang("et", "Estonian")),
|
100 |
+
TranslationLang(Lang("eus_Latn", "Basque"), Lang("eu", "Basque")),
|
101 |
+
TranslationLang(Lang("ewe_Latn", "Ewe")),
|
102 |
+
TranslationLang(Lang("fao_Latn", "Faroese"), Lang("fo", "Faroese")),
|
103 |
+
TranslationLang(Lang("fij_Latn", "Fijian")),
|
104 |
+
TranslationLang(Lang("fin_Latn", "Finnish"), Lang("fi", "Finnish")),
|
105 |
+
TranslationLang(Lang("fon_Latn", "Fon")),
|
106 |
+
TranslationLang(Lang("fra_Latn", "French"), Lang("fr", "French")),
|
107 |
+
TranslationLang(Lang("fur_Latn", "Friulian")),
|
108 |
+
TranslationLang(Lang("fuv_Latn", "Nigerian Fulfulde"), None, Lang("ff", "Fulah")),
|
109 |
+
TranslationLang(Lang("gla_Latn", "Scottish Gaelic"), None, Lang("gd", "Scottish Gaelic")),
|
110 |
+
TranslationLang(Lang("gle_Latn", "Irish"), None, Lang("ga", "Irish")),
|
111 |
+
TranslationLang(Lang("glg_Latn", "Galician"), Lang("gl", "Galician")),
|
112 |
+
TranslationLang(Lang("grn_Latn", "Guarani")),
|
113 |
+
TranslationLang(Lang("guj_Gujr", "Gujarati"), Lang("gu", "Gujarati")),
|
114 |
+
TranslationLang(Lang("hat_Latn", "Haitian Creole"), Lang("ht", "Haitian creole", "haitian")),
|
115 |
+
TranslationLang(Lang("hau_Latn", "Hausa"), Lang("ha", "Hausa")),
|
116 |
+
TranslationLang(Lang("heb_Hebr", "Hebrew"), Lang("he", "Hebrew")),
|
117 |
+
TranslationLang(Lang("hin_Deva", "Hindi"), Lang("hi", "Hindi")),
|
118 |
+
TranslationLang(Lang("hne_Deva", "Chhattisgarhi")),
|
119 |
+
TranslationLang(Lang("hrv_Latn", "Croatian"), Lang("hr", "Croatian")),
|
120 |
+
TranslationLang(Lang("hun_Latn", "Hungarian"), Lang("hu", "Hungarian")),
|
121 |
+
TranslationLang(Lang("hye_Armn", "Armenian"), Lang("hy", "Armenian")),
|
122 |
+
TranslationLang(Lang("ibo_Latn", "Igbo"), None, Lang("ig", "Igbo")),
|
123 |
+
TranslationLang(Lang("ilo_Latn", "Ilocano"), None, Lang("ilo", "Iloko")),
|
124 |
+
TranslationLang(Lang("ind_Latn", "Indonesian"), Lang("id", "Indonesian")),
|
125 |
+
TranslationLang(Lang("isl_Latn", "Icelandic"), Lang("is", "Icelandic")),
|
126 |
+
TranslationLang(Lang("ita_Latn", "Italian"), Lang("it", "Italian")),
|
127 |
+
TranslationLang(Lang("jav_Latn", "Javanese"), Lang("jw", "Javanese"), Lang("jv", "Javanese")),
|
128 |
+
TranslationLang(Lang("jpn_Jpan", "Japanese"), Lang("ja", "Japanese")),
|
129 |
+
TranslationLang(Lang("kab_Latn", "Kabyle")),
|
130 |
+
TranslationLang(Lang("kac_Latn", "Jingpho")),
|
131 |
+
TranslationLang(Lang("kam_Latn", "Kamba")),
|
132 |
+
TranslationLang(Lang("kan_Knda", "Kannada"), Lang("kn", "Kannada")),
|
133 |
+
TranslationLang(Lang("kas_Arab", "Kashmiri (Arabic script)")),
|
134 |
+
TranslationLang(Lang("kas_Deva", "Kashmiri (Devanagari script)")),
|
135 |
+
TranslationLang(Lang("kat_Geor", "Georgian"), Lang("ka", "Georgian")),
|
136 |
+
TranslationLang(Lang("knc_Arab", "Central Kanuri (Arabic script)")),
|
137 |
+
TranslationLang(Lang("knc_Latn", "Central Kanuri (Latin script)")),
|
138 |
+
TranslationLang(Lang("kaz_Cyrl", "Kazakh"), Lang("kk", "Kazakh")),
|
139 |
+
TranslationLang(Lang("kbp_Latn", "Kabiyè")),
|
140 |
+
TranslationLang(Lang("kea_Latn", "Kabuverdianu")),
|
141 |
+
TranslationLang(Lang("khm_Khmr", "Khmer"), Lang("km", "Khmer")),
|
142 |
+
TranslationLang(Lang("kik_Latn", "Kikuyu")),
|
143 |
+
TranslationLang(Lang("kin_Latn", "Kinyarwanda")),
|
144 |
+
TranslationLang(Lang("kir_Cyrl", "Kyrgyz")),
|
145 |
+
TranslationLang(Lang("kmb_Latn", "Kimbundu")),
|
146 |
+
TranslationLang(Lang("kmr_Latn", "Northern Kurdish")),
|
147 |
+
TranslationLang(Lang("kon_Latn", "Kikongo")),
|
148 |
+
TranslationLang(Lang("kor_Hang", "Korean"), Lang("ko", "Korean")),
|
149 |
+
TranslationLang(Lang("lao_Laoo", "Lao"), Lang("lo", "Lao")),
|
150 |
+
TranslationLang(Lang("lij_Latn", "Ligurian")),
|
151 |
+
TranslationLang(Lang("lim_Latn", "Limburgish")),
|
152 |
+
TranslationLang(Lang("lin_Latn", "Lingala"), Lang("ln", "Lingala")),
|
153 |
+
TranslationLang(Lang("lit_Latn", "Lithuanian"), Lang("lt", "Lithuanian")),
|
154 |
+
TranslationLang(Lang("lmo_Latn", "Lombard")),
|
155 |
+
TranslationLang(Lang("ltg_Latn", "Latgalian")),
|
156 |
+
TranslationLang(Lang("ltz_Latn", "Luxembourgish"), Lang("lb", "Luxembourgish", "letzeburgesch")),
|
157 |
+
TranslationLang(Lang("lua_Latn", "Luba-Kasai")),
|
158 |
+
TranslationLang(Lang("lug_Latn", "Ganda"), None, Lang("lg", "Ganda")),
|
159 |
+
TranslationLang(Lang("luo_Latn", "Luo")),
|
160 |
+
TranslationLang(Lang("lus_Latn", "Mizo")),
|
161 |
+
TranslationLang(Lang("lvs_Latn", "Standard Latvian"), Lang("lv", "Latvian")),
|
162 |
+
TranslationLang(Lang("mag_Deva", "Magahi")),
|
163 |
+
TranslationLang(Lang("mai_Deva", "Maithili")),
|
164 |
+
TranslationLang(Lang("mal_Mlym", "Malayalam"), Lang("ml", "Malayalam")),
|
165 |
+
TranslationLang(Lang("mar_Deva", "Marathi"), Lang("mr", "Marathi")),
|
166 |
+
TranslationLang(Lang("min_Arab", "Minangkabau (Arabic script)")),
|
167 |
+
TranslationLang(Lang("min_Latn", "Minangkabau (Latin script)")),
|
168 |
+
TranslationLang(Lang("mkd_Cyrl", "Macedonian"), Lang("mk", "Macedonian")),
|
169 |
+
TranslationLang(Lang("plt_Latn", "Plateau Malagasy"), Lang("mg", "Malagasy")),
|
170 |
+
TranslationLang(Lang("mlt_Latn", "Maltese"), Lang("mt", "Maltese")),
|
171 |
+
TranslationLang(Lang("mni_Beng", "Meitei (Bengali script)")),
|
172 |
+
TranslationLang(Lang("khk_Cyrl", "Halh Mongolian"), Lang("mn", "Mongolian")),
|
173 |
+
TranslationLang(Lang("mos_Latn", "Mossi")),
|
174 |
+
TranslationLang(Lang("mri_Latn", "Maori"), Lang("mi", "Maori")),
|
175 |
+
TranslationLang(Lang("mya_Mymr", "Burmese"), Lang("my", "Myanmar", "burmese")),
|
176 |
+
TranslationLang(Lang("nld_Latn", "Dutch"), Lang("nl", "Dutch", "flemish")),
|
177 |
+
TranslationLang(Lang("nno_Latn", "Norwegian Nynorsk"), Lang("nn", "Nynorsk")),
|
178 |
+
TranslationLang(Lang("nob_Latn", "Norwegian Bokmål"), Lang("no", "Norwegian")),
|
179 |
+
TranslationLang(Lang("npi_Deva", "Nepali"), Lang("ne", "Nepali")),
|
180 |
+
TranslationLang(Lang("nso_Latn", "Northern Sotho"), None, Lang("ns", "Northern Sotho")),
|
181 |
+
TranslationLang(Lang("nus_Latn", "Nuer")),
|
182 |
+
TranslationLang(Lang("nya_Latn", "Nyanja")),
|
183 |
+
TranslationLang(Lang("oci_Latn", "Occitan"), Lang("oc", "Occitan")),
|
184 |
+
TranslationLang(Lang("gaz_Latn", "West Central Oromo")),
|
185 |
+
TranslationLang(Lang("ory_Orya", "Odia"), None, Lang("or", "Oriya")),
|
186 |
+
TranslationLang(Lang("pag_Latn", "Pangasinan")),
|
187 |
+
TranslationLang(Lang("pan_Guru", "Eastern Panjabi"), Lang("pa", "Punjabi", "panjabi")),
|
188 |
+
TranslationLang(Lang("pap_Latn", "Papiamento")),
|
189 |
+
TranslationLang(Lang("pes_Arab", "Western Persian"), Lang("fa", "Persian")),
|
190 |
+
TranslationLang(Lang("pol_Latn", "Polish"), Lang("pl", "Polish")),
|
191 |
+
TranslationLang(Lang("por_Latn", "Portuguese"), Lang("pt", "Portuguese")),
|
192 |
+
TranslationLang(Lang("prs_Arab", "Dari")),
|
193 |
+
TranslationLang(Lang("pbt_Arab", "Southern Pashto"), Lang("ps", "Pashto", "pushto")),
|
194 |
+
TranslationLang(Lang("quy_Latn", "Ayacucho Quechua")),
|
195 |
+
TranslationLang(Lang("ron_Latn", "Romanian"), Lang("ro", "Romanian", "moldavian", "moldovan")),
|
196 |
+
TranslationLang(Lang("run_Latn", "Rundi")),
|
197 |
+
TranslationLang(Lang("rus_Cyrl", "Russian"), Lang("ru", "Russian")),
|
198 |
+
TranslationLang(Lang("sag_Latn", "Sango")),
|
199 |
+
TranslationLang(Lang("san_Deva", "Sanskrit"), Lang("sa", "Sanskrit")),
|
200 |
+
TranslationLang(Lang("sat_Olck", "Santali")),
|
201 |
+
TranslationLang(Lang("scn_Latn", "Sicilian")),
|
202 |
+
TranslationLang(Lang("shn_Mymr", "Shan")),
|
203 |
+
TranslationLang(Lang("sin_Sinh", "Sinhala"), Lang("si", "Sinhala", "sinhalese")),
|
204 |
+
TranslationLang(Lang("slk_Latn", "Slovak"), Lang("sk", "Slovak")),
|
205 |
+
TranslationLang(Lang("slv_Latn", "Slovenian"), Lang("sl", "Slovenian")),
|
206 |
+
TranslationLang(Lang("smo_Latn", "Samoan")),
|
207 |
+
TranslationLang(Lang("sna_Latn", "Shona"), Lang("sn", "Shona")),
|
208 |
+
TranslationLang(Lang("snd_Arab", "Sindhi"), Lang("sd", "Sindhi")),
|
209 |
+
TranslationLang(Lang("som_Latn", "Somali"), Lang("so", "Somali")),
|
210 |
+
TranslationLang(Lang("sot_Latn", "Southern Sotho")),
|
211 |
+
TranslationLang(Lang("spa_Latn", "Spanish"), Lang("es", "Spanish", "castilian")),
|
212 |
+
TranslationLang(Lang("als_Latn", "Tosk Albanian"), Lang("sq", "Albanian")),
|
213 |
+
TranslationLang(Lang("srd_Latn", "Sardinian")),
|
214 |
+
TranslationLang(Lang("srp_Cyrl", "Serbian"), Lang("sr", "Serbian")),
|
215 |
+
TranslationLang(Lang("ssw_Latn", "Swati"), None, Lang("ss", "Swati")),
|
216 |
+
TranslationLang(Lang("sun_Latn", "Sundanese"), Lang("su", "Sundanese")),
|
217 |
+
TranslationLang(Lang("swe_Latn", "Swedish"), Lang("sv", "Swedish")),
|
218 |
+
TranslationLang(Lang("swh_Latn", "Swahili"), Lang("sw", "Swahili")),
|
219 |
+
TranslationLang(Lang("szl_Latn", "Silesian")),
|
220 |
+
TranslationLang(Lang("tam_Taml", "Tamil"), Lang("ta", "Tamil")),
|
221 |
+
TranslationLang(Lang("tat_Cyrl", "Tatar"), Lang("tt", "Tatar")),
|
222 |
+
TranslationLang(Lang("tel_Telu", "Telugu"), Lang("te", "Telugu")),
|
223 |
+
TranslationLang(Lang("tgk_Cyrl", "Tajik"), Lang("tg", "Tajik")),
|
224 |
+
TranslationLang(Lang("tgl_Latn", "Tagalog"), Lang("tl", "Tagalog")),
|
225 |
+
TranslationLang(Lang("tha_Thai", "Thai"), Lang("th", "Thai")),
|
226 |
+
TranslationLang(Lang("tir_Ethi", "Tigrinya")),
|
227 |
+
TranslationLang(Lang("taq_Latn", "Tamasheq (Latin script)")),
|
228 |
+
TranslationLang(Lang("taq_Tfng", "Tamasheq (Tifinagh script)")),
|
229 |
+
TranslationLang(Lang("tpi_Latn", "Tok Pisin")),
|
230 |
+
TranslationLang(Lang("tsn_Latn", "Tswana"), None, Lang("tn", "Tswana")),
|
231 |
+
TranslationLang(Lang("tso_Latn", "Tsonga")),
|
232 |
+
TranslationLang(Lang("tuk_Latn", "Turkmen"), Lang("tk", "Turkmen")),
|
233 |
+
TranslationLang(Lang("tum_Latn", "Tumbuka")),
|
234 |
+
TranslationLang(Lang("tur_Latn", "Turkish"), Lang("tr", "Turkish")),
|
235 |
+
TranslationLang(Lang("twi_Latn", "Twi")),
|
236 |
+
TranslationLang(Lang("tzm_Tfng", "Central Atlas Tamazight")),
|
237 |
+
TranslationLang(Lang("uig_Arab", "Uyghur")),
|
238 |
+
TranslationLang(Lang("ukr_Cyrl", "Ukrainian"), Lang("uk", "Ukrainian")),
|
239 |
+
TranslationLang(Lang("umb_Latn", "Umbundu")),
|
240 |
+
TranslationLang(Lang("urd_Arab", "Urdu"), Lang("ur", "Urdu")),
|
241 |
+
TranslationLang(Lang("uzn_Latn", "Northern Uzbek"), Lang("uz", "Uzbek")),
|
242 |
+
TranslationLang(Lang("vec_Latn", "Venetian")),
|
243 |
+
TranslationLang(Lang("vie_Latn", "Vietnamese"), Lang("vi", "Vietnamese")),
|
244 |
+
TranslationLang(Lang("war_Latn", "Waray")),
|
245 |
+
TranslationLang(Lang("wol_Latn", "Wolof"), None, Lang("wo", "Wolof")),
|
246 |
+
TranslationLang(Lang("xho_Latn", "Xhosa"), None, Lang("xh", "Xhosa")),
|
247 |
+
TranslationLang(Lang("ydd_Hebr", "Eastern Yiddish"), Lang("yi", "Yiddish")),
|
248 |
+
TranslationLang(Lang("yor_Latn", "Yoruba"), Lang("yo", "Yoruba")),
|
249 |
+
TranslationLang(Lang("yue_Hant", "Yue Chinese"), Lang("yue", "cantonese"), Lang("zh", "Chinese (zh-yue)")),
|
250 |
+
TranslationLang(Lang("zho_Hans", "Chinese (Simplified)"), Lang("zh", "Chinese (Simplified)", "Chinese", "mandarin")),
|
251 |
+
TranslationLang(Lang("zho_Hant", "Chinese (Traditional)"), Lang("zh", "Chinese (Traditional)")),
|
252 |
+
TranslationLang(Lang("zsm_Latn", "Standard Malay"), Lang("ms", "Malay")),
|
253 |
+
TranslationLang(Lang("zul_Latn", "Zulu"), None, Lang("zu", "Zulu")),
|
254 |
+
TranslationLang(None, Lang("br", "Breton")), # Both whisper and m2m100 support the Breton language, but nllb does not have this language.
|
255 |
+
]
|
256 |
+
|
257 |
+
|
258 |
+
_TO_LANG_NAME_NLLB = {name.lower(): language for language in TranslationLangs if language.nllb is not None for name in language.nllb.names}
|
259 |
+
|
260 |
+
_TO_LANG_NAME_M2M100 = {name.lower(): language for language in TranslationLangs if language.m2m100 is not None for name in language.m2m100.names}
|
261 |
+
|
262 |
+
_TO_LANG_NAME_WHISPER = {name.lower(): language for language in TranslationLangs if language.whisper is not None for name in language.whisper.names}
|
263 |
+
|
264 |
+
_TO_LANG_CODE_WHISPER = {language.whisper.code.lower(): language for language in TranslationLangs if language.whisper is not None and len(language.whisper.code) > 0}
|
265 |
+
|
266 |
+
|
267 |
+
def get_lang_from_nllb_name(nllbName, default=None) -> TranslationLang:
|
268 |
+
"""Return the TranslationLang from the lang_name_nllb."""
|
269 |
+
return _TO_LANG_NAME_NLLB.get(nllbName.lower() if nllbName else None, default)
|
270 |
+
|
271 |
+
def get_lang_from_m2m100_name(m2m100Name, default=None) -> TranslationLang:
|
272 |
+
"""Return the TranslationLang from the lang_name_m2m100 name."""
|
273 |
+
return _TO_LANG_NAME_M2M100.get(m2m100Name.lower() if m2m100Name else None, default)
|
274 |
+
|
275 |
+
def get_lang_from_whisper_name(whisperName, default=None) -> TranslationLang:
|
276 |
+
"""Return the TranslationLang from the lang_name_whisper name."""
|
277 |
+
return _TO_LANG_NAME_WHISPER.get(whisperName.lower() if whisperName else None, default)
|
278 |
+
|
279 |
+
def get_lang_from_whisper_code(whisperCode, default=None) -> TranslationLang:
|
280 |
+
"""Return the TranslationLang from the lang_code_whisper."""
|
281 |
+
return _TO_LANG_CODE_WHISPER.get(whisperCode, default)
|
282 |
+
|
283 |
+
def get_lang_nllb_names():
|
284 |
+
"""Return a list of nllb language names."""
|
285 |
+
return list(_TO_LANG_NAME_NLLB.keys())
|
286 |
+
|
287 |
+
def get_lang_m2m100_names(codes = []):
|
288 |
+
"""Return a list of m2m100 language names."""
|
289 |
+
return list({name.lower(): None for language in TranslationLangs if language.m2m100 is not None and (len(codes) == 0 or any(code in language.m2m100.code for code in codes)) for name in language.m2m100.names}.keys())
|
290 |
+
|
291 |
+
def get_lang_whisper_names():
|
292 |
+
"""Return a list of whisper language names."""
|
293 |
+
return list(_TO_LANG_NAME_WHISPER.keys())
|
294 |
+
|
295 |
+
if __name__ == "__main__":
|
296 |
+
# Test lookup
|
297 |
+
print("name:Chinese (Traditional)", get_lang_from_nllb_name("Chinese (Traditional)"))
|
298 |
+
print("name:moldavian", get_lang_from_m2m100_name("moldavian"))
|
299 |
+
print("code:ja", get_lang_from_whisper_code("ja"))
|
300 |
+
print("name:English", get_lang_from_nllb_name('English'))
|
301 |
+
|
302 |
+
print(get_lang_m2m100_names(["en", "ja", "zh"]))
|
303 |
+
print(get_lang_nllb_names())
|
@@ -9,24 +9,26 @@ import transformers
|
|
9 |
|
10 |
from typing import Optional
|
11 |
from src.config import ModelConfig
|
12 |
-
from src.
|
13 |
-
from src.nllb.nllbLangs import NllbLang, get_nllb_lang_from_code_whisper
|
14 |
|
15 |
-
class
|
16 |
def __init__(
|
17 |
self,
|
18 |
-
|
19 |
device: str = None,
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
25 |
):
|
26 |
-
"""Initializes the Nllb-200 model.
|
27 |
|
28 |
Args:
|
29 |
-
|
30 |
1.3B, 3.3B...) or a path to a converted
|
31 |
model directory. When a size is configured, the converted model is downloaded
|
32 |
from the Hugging Face Hub.
|
@@ -44,62 +46,72 @@ class NllbModel:
|
|
44 |
having multiple workers enables true parallelism when running the model
|
45 |
(concurrent calls to self.model.generate() will run in parallel).
|
46 |
This can improve the global throughput at the cost of increased memory usage.
|
47 |
-
|
48 |
are saved in the standard Hugging Face cache directory.
|
49 |
-
|
50 |
local cached file if it exists.
|
51 |
"""
|
52 |
-
self.
|
53 |
-
self.
|
54 |
-
self.
|
55 |
-
self.model_config = model_config
|
56 |
|
57 |
-
if
|
58 |
return
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
if os.path.isdir(
|
61 |
-
self.
|
62 |
else:
|
63 |
-
self.
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
)
|
68 |
|
69 |
if device is None:
|
70 |
if torch.cuda.is_available():
|
71 |
-
device = "cuda" if "ct2" in self.
|
72 |
else:
|
73 |
device = "cpu"
|
74 |
|
75 |
self.device = device
|
76 |
|
77 |
-
if
|
78 |
self.load_model()
|
79 |
|
80 |
def load_model(self):
|
81 |
-
print('\n\nLoading model: %s\n\n' % self.
|
82 |
-
if "ct2" in self.
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
self.
|
90 |
-
|
91 |
-
|
92 |
-
self.
|
93 |
-
self.
|
94 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
def release_vram(self):
|
97 |
try:
|
98 |
if torch.cuda.is_available():
|
99 |
-
if "ct2" not in self.
|
100 |
device = torch.device("cpu")
|
101 |
-
self.
|
102 |
-
del self.
|
103 |
torch.cuda.empty_cache()
|
104 |
print("release vram end.")
|
105 |
except Exception as e:
|
@@ -110,16 +122,16 @@ class NllbModel:
|
|
110 |
output = None
|
111 |
result = None
|
112 |
try:
|
113 |
-
if "ct2" in self.
|
114 |
-
source = self.
|
115 |
-
output = self.
|
116 |
target = output[0].hypotheses[0][1:]
|
117 |
-
result = self.
|
118 |
-
elif "mt5" in self.
|
119 |
-
output = self.
|
120 |
result = output[0]['generated_text']
|
121 |
-
else: #NLLB
|
122 |
-
output = self.
|
123 |
result = output[0]['translation_text']
|
124 |
except Exception as e:
|
125 |
print("Error translation text: " + str(e))
|
@@ -133,6 +145,8 @@ _MODELS = ["distilled-600M", "distilled-1.3B", "1.3B", "3.3B",
|
|
133 |
"nllb-200-3.3B-ct2-float16", "nllb-200-1.3B-ct2", "nllb-200-1.3B-ct2-int8", "nllb-200-1.3B-ct2-float16",
|
134 |
"nllb-200-distilled-1.3B-ct2", "nllb-200-distilled-1.3B-ct2-int8", "nllb-200-distilled-1.3B-ct2-float16",
|
135 |
"nllb-200-distilled-600M-ct2", "nllb-200-distilled-600M-ct2-int8", "nllb-200-distilled-600M-ct2-float16",
|
|
|
|
|
136 |
"mt5-zh-ja-en-trimmed",
|
137 |
"mt5-zh-ja-en-trimmed-fine-tuned-v1"]
|
138 |
|
@@ -140,10 +154,10 @@ def check_model_name(name):
|
|
140 |
return any(allowed_name in name for allowed_name in _MODELS)
|
141 |
|
142 |
def download_model(
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
):
|
148 |
""""download_model" is referenced from the "utils.py" script
|
149 |
of the "faster_whisper" project, authored by guillaumekln.
|
@@ -153,13 +167,13 @@ def download_model(
|
|
153 |
The model is downloaded from https://huggingface.co/facebook.
|
154 |
|
155 |
Args:
|
156 |
-
|
157 |
facebook/nllb-distilled-1.3B, facebook/nllb-1.3B, facebook/nllb-3.3B...).
|
158 |
-
|
159 |
the cache directory.
|
160 |
-
|
161 |
cached file if it exists.
|
162 |
-
|
163 |
|
164 |
Returns:
|
165 |
The path to the downloaded model.
|
@@ -167,19 +181,20 @@ def download_model(
|
|
167 |
Raises:
|
168 |
ValueError: if the model size is invalid.
|
169 |
"""
|
170 |
-
if not check_model_name(
|
171 |
raise ValueError(
|
172 |
-
"Invalid model name '%s', expected one of: %s" % (
|
173 |
)
|
174 |
|
175 |
-
|
176 |
|
177 |
-
|
178 |
"config.json",
|
179 |
"generation_config.json",
|
180 |
"model.bin",
|
181 |
"pytorch_model.bin",
|
182 |
"pytorch_model.bin.index.json",
|
|
|
183 |
"pytorch_model-00001-of-00003.bin",
|
184 |
"pytorch_model-00002-of-00003.bin",
|
185 |
"pytorch_model-00003-of-00003.bin",
|
@@ -190,30 +205,31 @@ def download_model(
|
|
190 |
"shared_vocabulary.json",
|
191 |
"special_tokens_map.json",
|
192 |
"spiece.model",
|
|
|
193 |
]
|
194 |
|
195 |
kwargs = {
|
196 |
-
"local_files_only":
|
197 |
-
"allow_patterns":
|
198 |
#"tqdm_class": disabled_tqdm,
|
199 |
}
|
200 |
|
201 |
-
if
|
202 |
-
kwargs["local_dir"] =
|
203 |
kwargs["local_dir_use_symlinks"] = False
|
204 |
|
205 |
-
if
|
206 |
-
kwargs["cache_dir"] =
|
207 |
|
208 |
try:
|
209 |
-
return huggingface_hub.snapshot_download(
|
210 |
except (
|
211 |
huggingface_hub.utils.HfHubHTTPError,
|
212 |
requests.exceptions.ConnectionError,
|
213 |
) as exception:
|
214 |
warnings.warn(
|
215 |
"An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
|
216 |
-
|
217 |
exception,
|
218 |
)
|
219 |
warnings.warn(
|
@@ -221,4 +237,4 @@ def download_model(
|
|
221 |
)
|
222 |
|
223 |
kwargs["local_files_only"] = True
|
224 |
-
return huggingface_hub.snapshot_download(
|
|
|
9 |
|
10 |
from typing import Optional
|
11 |
from src.config import ModelConfig
|
12 |
+
from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
|
|
|
13 |
|
14 |
+
class TranslationModel:
|
15 |
def __init__(
|
16 |
self,
|
17 |
+
modelConfig: ModelConfig,
|
18 |
device: str = None,
|
19 |
+
whisperLang: TranslationLang = None,
|
20 |
+
translationLang: TranslationLang = None,
|
21 |
+
batchSize: int = 2,
|
22 |
+
noRepeatNgramSize: int = 3,
|
23 |
+
numBeams: int = 2,
|
24 |
+
downloadRoot: Optional[str] = None,
|
25 |
+
localFilesOnly: bool = False,
|
26 |
+
loadModel: bool = False,
|
27 |
):
|
28 |
+
"""Initializes the M2M100 / Nllb-200 / mt5 model.
|
29 |
|
30 |
Args:
|
31 |
+
modelConfig: Config of the model to use (distilled-600M, distilled-1.3B,
|
32 |
1.3B, 3.3B...) or a path to a converted
|
33 |
model directory. When a size is configured, the converted model is downloaded
|
34 |
from the Hugging Face Hub.
|
|
|
46 |
having multiple workers enables true parallelism when running the model
|
47 |
(concurrent calls to self.model.generate() will run in parallel).
|
48 |
This can improve the global throughput at the cost of increased memory usage.
|
49 |
+
downloadRoot: Directory where the models should be saved. If not set, the models
|
50 |
are saved in the standard Hugging Face cache directory.
|
51 |
+
localFilesOnly: If True, avoid downloading the file and return the path to the
|
52 |
local cached file if it exists.
|
53 |
"""
|
54 |
+
self.modelConfig = modelConfig
|
55 |
+
self.whisperLang = whisperLang # self.translationLangWhisper = get_lang_from_whisper_code(whisperLang.code.lower() if whisperLang is not None else "en")
|
56 |
+
self.translationLang = translationLang
|
|
|
57 |
|
58 |
+
if translationLang is None:
|
59 |
return
|
60 |
+
|
61 |
+
self.batchSize = batchSize
|
62 |
+
self.noRepeatNgramSize = noRepeatNgramSize
|
63 |
+
self.numBeams = numBeams
|
64 |
|
65 |
+
if os.path.isdir(modelConfig.url):
|
66 |
+
self.modelPath = modelConfig.url
|
67 |
else:
|
68 |
+
self.modelPath = download_model(
|
69 |
+
modelConfig,
|
70 |
+
localFilesOnly=localFilesOnly,
|
71 |
+
cacheDir=downloadRoot,
|
72 |
)
|
73 |
|
74 |
if device is None:
|
75 |
if torch.cuda.is_available():
|
76 |
+
device = "cuda" if "ct2" in self.modelPath else "cuda:0"
|
77 |
else:
|
78 |
device = "cpu"
|
79 |
|
80 |
self.device = device
|
81 |
|
82 |
+
if loadModel:
|
83 |
self.load_model()
|
84 |
|
85 |
def load_model(self):
|
86 |
+
print('\n\nLoading model: %s\n\n' % self.modelPath)
|
87 |
+
if "ct2" in self.modelPath:
|
88 |
+
if "nllb" in self.modelPath:
|
89 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.nllb.code)
|
90 |
+
self.targetPrefix = [self.translationLang.nllb.code]
|
91 |
+
elif "m2m100" in self.modelPath:
|
92 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
|
93 |
+
self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
|
94 |
+
self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
|
95 |
+
elif "mt5" in self.modelPath:
|
96 |
+
self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
|
97 |
+
self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
|
98 |
+
self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath)
|
99 |
+
self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
|
100 |
+
else:
|
101 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
|
102 |
+
self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
|
103 |
+
if "m2m100" in self.modelPath:
|
104 |
+
self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.m2m100.code, tgt_lang=self.translationLang.m2m100.code)
|
105 |
+
else: #NLLB
|
106 |
+
self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.nllb.code, tgt_lang=self.translationLang.nllb.code)
|
107 |
|
108 |
def release_vram(self):
|
109 |
try:
|
110 |
if torch.cuda.is_available():
|
111 |
+
if "ct2" not in self.modelPath:
|
112 |
device = torch.device("cpu")
|
113 |
+
self.transModel.to(device)
|
114 |
+
del self.transModel
|
115 |
torch.cuda.empty_cache()
|
116 |
print("release vram end.")
|
117 |
except Exception as e:
|
|
|
122 |
output = None
|
123 |
result = None
|
124 |
try:
|
125 |
+
if "ct2" in self.modelPath:
|
126 |
+
source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(text))
|
127 |
+
output = self.transModel.translate_batch([source], target_prefix=[self.targetPrefix], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
|
128 |
target = output[0].hypotheses[0][1:]
|
129 |
+
result = self.transTokenizer.decode(self.transTokenizer.convert_tokens_to_ids(target))
|
130 |
+
elif "mt5" in self.modelPath:
|
131 |
+
output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
|
132 |
result = output[0]['generated_text']
|
133 |
+
else: #M2M100 & NLLB
|
134 |
+
output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
|
135 |
result = output[0]['translation_text']
|
136 |
except Exception as e:
|
137 |
print("Error translation text: " + str(e))
|
|
|
145 |
"nllb-200-3.3B-ct2-float16", "nllb-200-1.3B-ct2", "nllb-200-1.3B-ct2-int8", "nllb-200-1.3B-ct2-float16",
|
146 |
"nllb-200-distilled-1.3B-ct2", "nllb-200-distilled-1.3B-ct2-int8", "nllb-200-distilled-1.3B-ct2-float16",
|
147 |
"nllb-200-distilled-600M-ct2", "nllb-200-distilled-600M-ct2-int8", "nllb-200-distilled-600M-ct2-float16",
|
148 |
+
"m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
|
149 |
+
"m2m100_1.2B", "m2m100_418M",
|
150 |
"mt5-zh-ja-en-trimmed",
|
151 |
"mt5-zh-ja-en-trimmed-fine-tuned-v1"]
|
152 |
|
|
|
154 |
return any(allowed_name in name for allowed_name in _MODELS)
|
155 |
|
156 |
def download_model(
|
157 |
+
modelConfig: ModelConfig,
|
158 |
+
outputDir: Optional[str] = None,
|
159 |
+
localFilesOnly: bool = False,
|
160 |
+
cacheDir: Optional[str] = None,
|
161 |
):
|
162 |
""""download_model" is referenced from the "utils.py" script
|
163 |
of the "faster_whisper" project, authored by guillaumekln.
|
|
|
167 |
The model is downloaded from https://huggingface.co/facebook.
|
168 |
|
169 |
Args:
|
170 |
+
modelConfig: config of the model to download (facebook/nllb-distilled-600M,
|
171 |
facebook/nllb-distilled-1.3B, facebook/nllb-1.3B, facebook/nllb-3.3B...).
|
172 |
+
outputDir: Directory where the model should be saved. If not set, the model is saved in
|
173 |
the cache directory.
|
174 |
+
localFilesOnly: If True, avoid downloading the file and return the path to the local
|
175 |
cached file if it exists.
|
176 |
+
cacheDir: Path to the folder where cached files are stored.
|
177 |
|
178 |
Returns:
|
179 |
The path to the downloaded model.
|
|
|
181 |
Raises:
|
182 |
ValueError: if the model size is invalid.
|
183 |
"""
|
184 |
+
if not check_model_name(modelConfig.name):
|
185 |
raise ValueError(
|
186 |
+
"Invalid model name '%s', expected one of: %s" % (modelConfig.name, ", ".join(_MODELS))
|
187 |
)
|
188 |
|
189 |
+
repoId = modelConfig.url #"facebook/nllb-200-%s" %
|
190 |
|
191 |
+
allowPatterns = [
|
192 |
"config.json",
|
193 |
"generation_config.json",
|
194 |
"model.bin",
|
195 |
"pytorch_model.bin",
|
196 |
"pytorch_model.bin.index.json",
|
197 |
+
"pytorch_model-*.bin",
|
198 |
"pytorch_model-00001-of-00003.bin",
|
199 |
"pytorch_model-00002-of-00003.bin",
|
200 |
"pytorch_model-00003-of-00003.bin",
|
|
|
205 |
"shared_vocabulary.json",
|
206 |
"special_tokens_map.json",
|
207 |
"spiece.model",
|
208 |
+
"vocab.json", #m2m100
|
209 |
]
|
210 |
|
211 |
kwargs = {
|
212 |
+
"local_files_only": localFilesOnly,
|
213 |
+
"allow_patterns": allowPatterns,
|
214 |
#"tqdm_class": disabled_tqdm,
|
215 |
}
|
216 |
|
217 |
+
if outputDir is not None:
|
218 |
+
kwargs["local_dir"] = outputDir
|
219 |
kwargs["local_dir_use_symlinks"] = False
|
220 |
|
221 |
+
if cacheDir is not None:
|
222 |
+
kwargs["cache_dir"] = cacheDir
|
223 |
|
224 |
try:
|
225 |
+
return huggingface_hub.snapshot_download(repoId, **kwargs)
|
226 |
except (
|
227 |
huggingface_hub.utils.HfHubHTTPError,
|
228 |
requests.exceptions.ConnectionError,
|
229 |
) as exception:
|
230 |
warnings.warn(
|
231 |
"An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
|
232 |
+
repoId,
|
233 |
exception,
|
234 |
)
|
235 |
warnings.warn(
|
|
|
237 |
)
|
238 |
|
239 |
kwargs["local_files_only"] = True
|
240 |
+
return huggingface_hub.snapshot_download(repoId, **kwargs)
|
@@ -100,46 +100,91 @@ def write_srt(transcript: Iterator[dict], file: TextIO,
|
|
100 |
flush=True,
|
101 |
)
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
|
104 |
for segment in transcript:
|
105 |
words: list = segment.get('words', [])
|
106 |
|
107 |
# Append longest speaker ID if available
|
108 |
segment_longest_speaker = segment.get('longest_speaker', None)
|
|
|
|
|
|
|
|
|
|
|
109 |
if segment_longest_speaker is not None:
|
110 |
segment_longest_speaker = segment_longest_speaker.replace("SPEAKER", "S")
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
112 |
if len(words) == 0:
|
113 |
-
#
|
114 |
-
if
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
'text': process_text(text, maxLineWidth)
|
127 |
-
}
|
128 |
# We are done
|
129 |
continue
|
130 |
|
131 |
-
subtitle_start = segment['start']
|
132 |
-
subtitle_end = segment['end']
|
133 |
-
|
134 |
if segment_longest_speaker is not None:
|
135 |
# Add the beginning
|
136 |
words.insert(0, {
|
137 |
'start': subtitle_start,
|
138 |
-
'end': subtitle_start,
|
139 |
-
'word': f"({segment_longest_speaker})"
|
140 |
})
|
141 |
|
142 |
-
text_words = [ this_word["word"] for this_word in words ]
|
143 |
subtitle_text = __join_words(text_words, maxLineWidth)
|
144 |
|
145 |
# Iterate over the words in the segment
|
@@ -154,15 +199,15 @@ def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: i
|
|
154 |
# Display the text up to this point
|
155 |
yield {
|
156 |
'start': last,
|
157 |
-
'end': start,
|
158 |
-
'text': subtitle_text
|
159 |
}
|
160 |
|
161 |
# Display the text with the current word highlighted
|
162 |
yield {
|
163 |
'start': start,
|
164 |
-
'end': end,
|
165 |
-
'text': __join_words(
|
166 |
[
|
167 |
{
|
168 |
"word": re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
@@ -180,17 +225,20 @@ def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: i
|
|
180 |
# Display the last part of the text
|
181 |
yield {
|
182 |
'start': last,
|
183 |
-
'end': subtitle_end,
|
184 |
-
'text': subtitle_text
|
185 |
}
|
186 |
|
187 |
# Just return the subtitle text
|
188 |
else:
|
189 |
-
|
190 |
'start': subtitle_start,
|
191 |
-
'end': subtitle_end,
|
192 |
-
'text': subtitle_text
|
193 |
}
|
|
|
|
|
|
|
194 |
|
195 |
def __join_words(words: Iterator[Union[str, dict]], maxLineWidth: int = None):
|
196 |
if maxLineWidth is None or maxLineWidth < 0:
|
|
|
100 |
flush=True,
|
101 |
)
|
102 |
|
103 |
+
def write_srt_original(transcript: Iterator[dict], file: TextIO,
|
104 |
+
maxLineWidth=None, highlight_words: bool = False, bilingual: bool = False):
|
105 |
+
"""
|
106 |
+
Write a transcript to a file in SRT format.
|
107 |
+
Example usage:
|
108 |
+
from pathlib import Path
|
109 |
+
from whisper.utils import write_srt
|
110 |
+
result = transcribe(model, audio_path, temperature=temperature, **args)
|
111 |
+
# save SRT
|
112 |
+
audio_basename = Path(audio_path).stem
|
113 |
+
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
114 |
+
write_srt(result["segments"], file=srt)
|
115 |
+
"""
|
116 |
+
iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
|
117 |
+
|
118 |
+
for i, segment in enumerate(iterator, start=1):
|
119 |
+
if "original" not in segment:
|
120 |
+
continue
|
121 |
+
|
122 |
+
original = segment['original'].replace('-->', '->')
|
123 |
+
|
124 |
+
# write srt lines
|
125 |
+
print(
|
126 |
+
f"{i}\n"
|
127 |
+
f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
|
128 |
+
f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}",
|
129 |
+
file=file,
|
130 |
+
flush=True,
|
131 |
+
)
|
132 |
+
|
133 |
+
if original is not None: print(f"{original}",
|
134 |
+
file=file,
|
135 |
+
flush=True)
|
136 |
+
|
137 |
+
if bilingual:
|
138 |
+
text = segment['text'].replace('-->', '->')
|
139 |
+
print(f"{text}\n",
|
140 |
+
file=file,
|
141 |
+
flush=True)
|
142 |
+
|
143 |
def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
|
144 |
for segment in transcript:
|
145 |
words: list = segment.get('words', [])
|
146 |
|
147 |
# Append longest speaker ID if available
|
148 |
segment_longest_speaker = segment.get('longest_speaker', None)
|
149 |
+
|
150 |
+
# Yield the segment as-is or processed
|
151 |
+
if len(words) == 0 and (maxLineWidth is None or maxLineWidth < 0) and segment_longest_speaker is None:
|
152 |
+
yield segment
|
153 |
+
|
154 |
if segment_longest_speaker is not None:
|
155 |
segment_longest_speaker = segment_longest_speaker.replace("SPEAKER", "S")
|
156 |
+
|
157 |
+
subtitle_start = segment['start']
|
158 |
+
subtitle_end = segment['end']
|
159 |
+
text = segment['text'].strip()
|
160 |
+
original_text = segment['original'].strip() if 'original' in segment else None
|
161 |
+
|
162 |
if len(words) == 0:
|
163 |
+
# Prepend the longest speaker ID if available
|
164 |
+
if segment_longest_speaker is not None:
|
165 |
+
text = f"({segment_longest_speaker}) {text}"
|
166 |
+
|
167 |
+
result = {
|
168 |
+
'start': subtitle_start,
|
169 |
+
'end' : subtitle_end,
|
170 |
+
'text' : process_text(text, maxLineWidth)
|
171 |
+
}
|
172 |
+
if original_text is not None and len(original_text) > 0:
|
173 |
+
result.update({'original': process_text(original_text, maxLineWidth)})
|
174 |
+
yield result
|
175 |
+
|
|
|
|
|
176 |
# We are done
|
177 |
continue
|
178 |
|
|
|
|
|
|
|
179 |
if segment_longest_speaker is not None:
|
180 |
# Add the beginning
|
181 |
words.insert(0, {
|
182 |
'start': subtitle_start,
|
183 |
+
'end' : subtitle_start,
|
184 |
+
'word' : f"({segment_longest_speaker})"
|
185 |
})
|
186 |
|
187 |
+
text_words = [text] if not highlight_words and original_text is not None and len(original_text) > 0 else [ this_word["word"] for this_word in words ]
|
188 |
subtitle_text = __join_words(text_words, maxLineWidth)
|
189 |
|
190 |
# Iterate over the words in the segment
|
|
|
199 |
# Display the text up to this point
|
200 |
yield {
|
201 |
'start': last,
|
202 |
+
'end' : start,
|
203 |
+
'text' : subtitle_text
|
204 |
}
|
205 |
|
206 |
# Display the text with the current word highlighted
|
207 |
yield {
|
208 |
'start': start,
|
209 |
+
'end' : end,
|
210 |
+
'text' : __join_words(
|
211 |
[
|
212 |
{
|
213 |
"word": re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
|
|
225 |
# Display the last part of the text
|
226 |
yield {
|
227 |
'start': last,
|
228 |
+
'end' : subtitle_end,
|
229 |
+
'text' : subtitle_text
|
230 |
}
|
231 |
|
232 |
# Just return the subtitle text
|
233 |
else:
|
234 |
+
result = {
|
235 |
'start': subtitle_start,
|
236 |
+
'end' : subtitle_end,
|
237 |
+
'text' : subtitle_text
|
238 |
}
|
239 |
+
if original_text is not None and len(original_text) > 0:
|
240 |
+
result.update({'original': original_text})
|
241 |
+
yield result
|
242 |
|
243 |
def __join_words(words: Iterator[Union[str, dict]], maxLineWidth: int = None):
|
244 |
if maxLineWidth is None or maxLineWidth < 0:
|
@@ -242,9 +242,8 @@ class AbstractTranscription(ABC):
|
|
242 |
|
243 |
# Update prompt window
|
244 |
self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
|
245 |
-
|
246 |
-
if detected_language is not None
|
247 |
-
result['language'] = detected_language
|
248 |
finally:
|
249 |
# Notify progress listener that we are done
|
250 |
if progressListener is not None:
|
|
|
242 |
|
243 |
# Update prompt window
|
244 |
self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
|
245 |
+
|
246 |
+
result['language'] = detected_language if detected_language is not None else segment_result['language']
|
|
|
247 |
finally:
|
248 |
# Notify progress listener that we are done
|
249 |
if progressListener is not None:
|
@@ -71,7 +71,7 @@ class AbstractWhisperContainer:
|
|
71 |
pass
|
72 |
|
73 |
@abc.abstractmethod
|
74 |
-
def create_callback(self,
|
75 |
prompt_strategy: AbstractPromptStrategy = None,
|
76 |
**decodeOptions: dict) -> AbstractWhisperCallback:
|
77 |
"""
|
@@ -79,8 +79,8 @@ class AbstractWhisperContainer:
|
|
79 |
|
80 |
Parameters
|
81 |
----------
|
82 |
-
|
83 |
-
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
84 |
task: str
|
85 |
The task - either translate or transcribe.
|
86 |
prompt_strategy: AbstractPromptStrategy
|
|
|
71 |
pass
|
72 |
|
73 |
@abc.abstractmethod
|
74 |
+
def create_callback(self, languageCode: str = None, task: str = None,
|
75 |
prompt_strategy: AbstractPromptStrategy = None,
|
76 |
**decodeOptions: dict) -> AbstractWhisperCallback:
|
77 |
"""
|
|
|
79 |
|
80 |
Parameters
|
81 |
----------
|
82 |
+
languageCode: str
|
83 |
+
The target language code of the transcription. If not specified, the language will be inferred from the audio content.
|
84 |
task: str
|
85 |
The task - either translate or transcribe.
|
86 |
prompt_strategy: AbstractPromptStrategy
|
@@ -4,7 +4,6 @@ from typing import List, Union
|
|
4 |
from faster_whisper import WhisperModel, download_model
|
5 |
from src.config import ModelConfig, VadInitialPromptMode
|
6 |
from src.hooks.progressListener import ProgressListener
|
7 |
-
from src.languages import get_language_from_name
|
8 |
from src.modelCache import ModelCache
|
9 |
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
10 |
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
|
@@ -57,7 +56,7 @@ class FasterWhisperContainer(AbstractWhisperContainer):
|
|
57 |
model = WhisperModel(model_url, device=device, compute_type=self.compute_type)
|
58 |
return model
|
59 |
|
60 |
-
def create_callback(self,
|
61 |
prompt_strategy: AbstractPromptStrategy = None,
|
62 |
**decodeOptions: dict) -> AbstractWhisperCallback:
|
63 |
"""
|
@@ -65,8 +64,8 @@ class FasterWhisperContainer(AbstractWhisperContainer):
|
|
65 |
|
66 |
Parameters
|
67 |
----------
|
68 |
-
|
69 |
-
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
70 |
task: str
|
71 |
The task - either translate or transcribe.
|
72 |
prompt_strategy: AbstractPromptStrategy
|
@@ -78,14 +77,14 @@ class FasterWhisperContainer(AbstractWhisperContainer):
|
|
78 |
-------
|
79 |
A WhisperCallback object.
|
80 |
"""
|
81 |
-
return FasterWhisperCallback(self,
|
82 |
|
83 |
class FasterWhisperCallback(AbstractWhisperCallback):
|
84 |
-
def __init__(self, model_container: FasterWhisperContainer,
|
85 |
prompt_strategy: AbstractPromptStrategy = None,
|
86 |
**decodeOptions: dict):
|
87 |
self.model_container = model_container
|
88 |
-
self.
|
89 |
self.task = task
|
90 |
self.prompt_strategy = prompt_strategy
|
91 |
self.decodeOptions = decodeOptions
|
@@ -108,7 +107,6 @@ class FasterWhisperCallback(AbstractWhisperCallback):
|
|
108 |
A callback to receive progress updates.
|
109 |
"""
|
110 |
model: WhisperModel = self.model_container.get_model()
|
111 |
-
language_code = self._lookup_language_code(self.language) if self.language else None
|
112 |
|
113 |
# Copy decode options and remove options that are not supported by faster-whisper
|
114 |
decodeOptions = self.decodeOptions.copy()
|
@@ -139,7 +137,7 @@ class FasterWhisperCallback(AbstractWhisperCallback):
|
|
139 |
if self.prompt_strategy else prompt
|
140 |
|
141 |
segments_generator, info = model.transcribe(audio, \
|
142 |
-
language=
|
143 |
initial_prompt=initial_prompt, \
|
144 |
**decodeOptions
|
145 |
)
|
@@ -197,11 +195,3 @@ class FasterWhisperCallback(AbstractWhisperCallback):
|
|
197 |
return suppress_tokens
|
198 |
|
199 |
return [int(token) for token in suppress_tokens.split(",")]
|
200 |
-
|
201 |
-
def _lookup_language_code(self, language: str):
|
202 |
-
language = get_language_from_name(language)
|
203 |
-
|
204 |
-
if language is None:
|
205 |
-
raise ValueError("Invalid language: " + language)
|
206 |
-
|
207 |
-
return language.code
|
|
|
4 |
from faster_whisper import WhisperModel, download_model
|
5 |
from src.config import ModelConfig, VadInitialPromptMode
|
6 |
from src.hooks.progressListener import ProgressListener
|
|
|
7 |
from src.modelCache import ModelCache
|
8 |
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
9 |
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
|
|
|
56 |
model = WhisperModel(model_url, device=device, compute_type=self.compute_type)
|
57 |
return model
|
58 |
|
59 |
+
def create_callback(self, languageCode: str = None, task: str = None,
|
60 |
prompt_strategy: AbstractPromptStrategy = None,
|
61 |
**decodeOptions: dict) -> AbstractWhisperCallback:
|
62 |
"""
|
|
|
64 |
|
65 |
Parameters
|
66 |
----------
|
67 |
+
languageCode: str
|
68 |
+
The target language code of the transcription. If not specified, the language will be inferred from the audio content.
|
69 |
task: str
|
70 |
The task - either translate or transcribe.
|
71 |
prompt_strategy: AbstractPromptStrategy
|
|
|
77 |
-------
|
78 |
A WhisperCallback object.
|
79 |
"""
|
80 |
+
return FasterWhisperCallback(self, languageCode=languageCode, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
|
81 |
|
82 |
class FasterWhisperCallback(AbstractWhisperCallback):
|
83 |
+
def __init__(self, model_container: FasterWhisperContainer, languageCode: str = None, task: str = None,
|
84 |
prompt_strategy: AbstractPromptStrategy = None,
|
85 |
**decodeOptions: dict):
|
86 |
self.model_container = model_container
|
87 |
+
self.languageCode = languageCode
|
88 |
self.task = task
|
89 |
self.prompt_strategy = prompt_strategy
|
90 |
self.decodeOptions = decodeOptions
|
|
|
107 |
A callback to receive progress updates.
|
108 |
"""
|
109 |
model: WhisperModel = self.model_container.get_model()
|
|
|
110 |
|
111 |
# Copy decode options and remove options that are not supported by faster-whisper
|
112 |
decodeOptions = self.decodeOptions.copy()
|
|
|
137 |
if self.prompt_strategy else prompt
|
138 |
|
139 |
segments_generator, info = model.transcribe(audio, \
|
140 |
+
language=self.languageCode if self.languageCode else detected_language, task=self.task, \
|
141 |
initial_prompt=initial_prompt, \
|
142 |
**decodeOptions
|
143 |
)
|
|
|
195 |
return suppress_tokens
|
196 |
|
197 |
return [int(token) for token in suppress_tokens.split(",")]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -70,7 +70,7 @@ class WhisperContainer(AbstractWhisperContainer):
|
|
70 |
|
71 |
return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
|
72 |
|
73 |
-
def create_callback(self,
|
74 |
prompt_strategy: AbstractPromptStrategy = None,
|
75 |
**decodeOptions: dict) -> AbstractWhisperCallback:
|
76 |
"""
|
@@ -78,8 +78,8 @@ class WhisperContainer(AbstractWhisperContainer):
|
|
78 |
|
79 |
Parameters
|
80 |
----------
|
81 |
-
|
82 |
-
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
83 |
task: str
|
84 |
The task - either translate or transcribe.
|
85 |
prompt_strategy: AbstractPromptStrategy
|
@@ -91,7 +91,7 @@ class WhisperContainer(AbstractWhisperContainer):
|
|
91 |
-------
|
92 |
A WhisperCallback object.
|
93 |
"""
|
94 |
-
return WhisperCallback(self,
|
95 |
|
96 |
def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
|
97 |
from src.conversion.hf_converter import convert_hf_whisper
|
@@ -160,11 +160,11 @@ class WhisperContainer(AbstractWhisperContainer):
|
|
160 |
return model_config.path
|
161 |
|
162 |
class WhisperCallback(AbstractWhisperCallback):
|
163 |
-
def __init__(self, model_container: WhisperContainer,
|
164 |
prompt_strategy: AbstractPromptStrategy = None,
|
165 |
**decodeOptions: dict):
|
166 |
self.model_container = model_container
|
167 |
-
self.
|
168 |
self.task = task
|
169 |
self.prompt_strategy = prompt_strategy
|
170 |
|
@@ -204,7 +204,7 @@ class WhisperCallback(AbstractWhisperCallback):
|
|
204 |
if self.prompt_strategy else prompt
|
205 |
|
206 |
result = model.transcribe(audio, \
|
207 |
-
language=self.
|
208 |
initial_prompt=initial_prompt, \
|
209 |
**decodeOptions
|
210 |
)
|
|
|
70 |
|
71 |
return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
|
72 |
|
73 |
+
def create_callback(self, languageCode: str = None, task: str = None,
|
74 |
prompt_strategy: AbstractPromptStrategy = None,
|
75 |
**decodeOptions: dict) -> AbstractWhisperCallback:
|
76 |
"""
|
|
|
78 |
|
79 |
Parameters
|
80 |
----------
|
81 |
+
languageCode: str
|
82 |
+
The target language code of the transcription. If not specified, the language will be inferred from the audio content.
|
83 |
task: str
|
84 |
The task - either translate or transcribe.
|
85 |
prompt_strategy: AbstractPromptStrategy
|
|
|
91 |
-------
|
92 |
A WhisperCallback object.
|
93 |
"""
|
94 |
+
return WhisperCallback(self, languageCode=languageCode, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
|
95 |
|
96 |
def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
|
97 |
from src.conversion.hf_converter import convert_hf_whisper
|
|
|
160 |
return model_config.path
|
161 |
|
162 |
class WhisperCallback(AbstractWhisperCallback):
|
163 |
+
def __init__(self, model_container: WhisperContainer, languageCode: str = None, task: str = None,
|
164 |
prompt_strategy: AbstractPromptStrategy = None,
|
165 |
**decodeOptions: dict):
|
166 |
self.model_container = model_container
|
167 |
+
self.languageCode = languageCode
|
168 |
self.task = task
|
169 |
self.prompt_strategy = prompt_strategy
|
170 |
|
|
|
204 |
if self.prompt_strategy else prompt
|
205 |
|
206 |
result = model.transcribe(audio, \
|
207 |
+
language=self.languageCode if self.languageCode else detected_language, task=self.task, \
|
208 |
initial_prompt=initial_prompt, \
|
209 |
**decodeOptions
|
210 |
)
|