Spaces:
Sleeping
Sleeping
SoybeanMilk
commited on
Commit
•
e8762f9
1
Parent(s):
50167d4
Add support for the ALMA model.
Browse files- app.py +19 -0
- config.json5 +7 -0
- src/config.py +2 -2
- src/translation/translationModel.py +18 -1
- src/utils.py +1 -1
app.py
CHANGED
@@ -231,6 +231,8 @@ class WhisperTranscriber:
|
|
231 |
nllbLangName: str = decodeOptions.pop("nllbLangName")
|
232 |
mt5ModelName: str = decodeOptions.pop("mt5ModelName")
|
233 |
mt5LangName: str = decodeOptions.pop("mt5LangName")
|
|
|
|
|
234 |
|
235 |
translationBatchSize: int = decodeOptions.pop("translationBatchSize")
|
236 |
translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
|
@@ -337,6 +339,10 @@ class WhisperTranscriber:
|
|
337 |
selectedModelName = mt5ModelName if mt5ModelName is not None and len(mt5ModelName) > 0 else "mt5-zh-ja-en-trimmed/K024"
|
338 |
selectedModel = next((modelConfig for modelConfig in self.app_config.models["mt5"] if modelConfig.name == selectedModelName), None)
|
339 |
translationLang = get_lang_from_m2m100_name(mt5LangName)
|
|
|
|
|
|
|
|
|
340 |
|
341 |
if translationLang is not None:
|
342 |
translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
|
@@ -828,6 +834,7 @@ def create_ui(app_config: ApplicationConfig):
|
|
828 |
nllb_models = app_config.get_model_names("nllb")
|
829 |
m2m100_models = app_config.get_model_names("m2m100")
|
830 |
mt5_models = app_config.get_model_names("mt5")
|
|
|
831 |
|
832 |
common_whisper_inputs = lambda : {
|
833 |
gr.Dropdown(label="Whisper - Model (for audio)", choices=whisper_models, value=app_config.default_model_name, elem_id="whisperModelName"),
|
@@ -845,6 +852,10 @@ def create_ui(app_config: ApplicationConfig):
|
|
845 |
gr.Dropdown(label="MT5 - Model (for translate)", choices=mt5_models, elem_id="mt5ModelName"),
|
846 |
gr.Dropdown(label="MT5 - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="mt5LangName"),
|
847 |
}
|
|
|
|
|
|
|
|
|
848 |
|
849 |
common_translation_inputs = lambda : {
|
850 |
gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
|
@@ -905,9 +916,13 @@ def create_ui(app_config: ApplicationConfig):
|
|
905 |
with gr.Tab(label="MT5") as simpleMT5Tab:
|
906 |
with gr.Row():
|
907 |
simpleInputDict.update(common_mt5_inputs())
|
|
|
|
|
|
|
908 |
simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
|
909 |
simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
|
910 |
simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
|
|
|
911 |
with gr.Column():
|
912 |
with gr.Tab(label="URL") as simpleUrlTab:
|
913 |
simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
@@ -964,9 +979,13 @@ def create_ui(app_config: ApplicationConfig):
|
|
964 |
with gr.Tab(label="MT5") as fullMT5Tab:
|
965 |
with gr.Row():
|
966 |
fullInputDict.update(common_mt5_inputs())
|
|
|
|
|
|
|
967 |
fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
|
968 |
fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
|
969 |
fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
|
|
|
970 |
with gr.Column():
|
971 |
with gr.Tab(label="URL") as fullUrlTab:
|
972 |
fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
|
|
231 |
nllbLangName: str = decodeOptions.pop("nllbLangName")
|
232 |
mt5ModelName: str = decodeOptions.pop("mt5ModelName")
|
233 |
mt5LangName: str = decodeOptions.pop("mt5LangName")
|
234 |
+
ALMAModelName: str = decodeOptions.pop("ALMAModelName")
|
235 |
+
ALMALangName: str = decodeOptions.pop("ALMALangName")
|
236 |
|
237 |
translationBatchSize: int = decodeOptions.pop("translationBatchSize")
|
238 |
translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
|
|
|
339 |
selectedModelName = mt5ModelName if mt5ModelName is not None and len(mt5ModelName) > 0 else "mt5-zh-ja-en-trimmed/K024"
|
340 |
selectedModel = next((modelConfig for modelConfig in self.app_config.models["mt5"] if modelConfig.name == selectedModelName), None)
|
341 |
translationLang = get_lang_from_m2m100_name(mt5LangName)
|
342 |
+
elif translateInput == "ALMA" and ALMALangName is not None and len(ALMALangName) > 0:
|
343 |
+
selectedModelName = ALMAModelName if ALMAModelName is not None and len(ALMAModelName) > 0 else "ALMA-13B-GPTQ/TheBloke"
|
344 |
+
selectedModel = next((modelConfig for modelConfig in self.app_config.models["ALMA"] if modelConfig.name == selectedModelName), None)
|
345 |
+
translationLang = get_lang_from_m2m100_name(ALMALangName)
|
346 |
|
347 |
if translationLang is not None:
|
348 |
translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
|
|
|
834 |
nllb_models = app_config.get_model_names("nllb")
|
835 |
m2m100_models = app_config.get_model_names("m2m100")
|
836 |
mt5_models = app_config.get_model_names("mt5")
|
837 |
+
ALMA_models = app_config.get_model_names("ALMA")
|
838 |
|
839 |
common_whisper_inputs = lambda : {
|
840 |
gr.Dropdown(label="Whisper - Model (for audio)", choices=whisper_models, value=app_config.default_model_name, elem_id="whisperModelName"),
|
|
|
852 |
gr.Dropdown(label="MT5 - Model (for translate)", choices=mt5_models, elem_id="mt5ModelName"),
|
853 |
gr.Dropdown(label="MT5 - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="mt5LangName"),
|
854 |
}
|
855 |
+
common_ALMA_inputs = lambda : {
|
856 |
+
gr.Dropdown(label="ALMA - Model (for translate)", choices=ALMA_models, elem_id="ALMAModelName"),
|
857 |
+
gr.Dropdown(label="ALMA - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="ALMALangName"),
|
858 |
+
}
|
859 |
|
860 |
common_translation_inputs = lambda : {
|
861 |
gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
|
|
|
916 |
with gr.Tab(label="MT5") as simpleMT5Tab:
|
917 |
with gr.Row():
|
918 |
simpleInputDict.update(common_mt5_inputs())
|
919 |
+
with gr.Tab(label="ALMA") as simpleALMATab:
|
920 |
+
with gr.Row():
|
921 |
+
simpleInputDict.update(common_ALMA_inputs())
|
922 |
simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
|
923 |
simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
|
924 |
simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
|
925 |
+
simpleALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [simpleTranslateInput] )
|
926 |
with gr.Column():
|
927 |
with gr.Tab(label="URL") as simpleUrlTab:
|
928 |
simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
|
|
979 |
with gr.Tab(label="MT5") as fullMT5Tab:
|
980 |
with gr.Row():
|
981 |
fullInputDict.update(common_mt5_inputs())
|
982 |
+
with gr.Tab(label="ALMA") as fullALMATab:
|
983 |
+
with gr.Row():
|
984 |
+
fullInputDict.update(common_ALMA_inputs())
|
985 |
fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
|
986 |
fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
|
987 |
fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
|
988 |
+
fullALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [fullTranslateInput] )
|
989 |
with gr.Column():
|
990 |
with gr.Tab(label="URL") as fullUrlTab:
|
991 |
fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
config.json5
CHANGED
@@ -191,6 +191,13 @@
|
|
191 |
"url": "engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1",
|
192 |
"type": "huggingface"
|
193 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
]
|
195 |
},
|
196 |
// Configuration options that will be used if they are not specified in the command line arguments.
|
|
|
191 |
"url": "engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1",
|
192 |
"type": "huggingface"
|
193 |
}
|
194 |
+
],
|
195 |
+
"ALMA": [
|
196 |
+
{
|
197 |
+
"name": "ALMA-13B-GPTQ/TheBloke",
|
198 |
+
"url": "TheBloke/ALMA-13B-GPTQ",
|
199 |
+
"type": "huggingface",
|
200 |
+
},
|
201 |
]
|
202 |
},
|
203 |
// Configuration options that will be used if they are not specified in the command line arguments.
|
src/config.py
CHANGED
@@ -43,7 +43,7 @@ class VadInitialPromptMode(Enum):
|
|
43 |
return None
|
44 |
|
45 |
class ApplicationConfig:
|
46 |
-
def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5"], List[ModelConfig]],
|
47 |
input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
|
48 |
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
|
49 |
whisper_implementation: str = "whisper", default_model_name: str = "medium",
|
@@ -169,7 +169,7 @@ class ApplicationConfig:
|
|
169 |
# Load using json5
|
170 |
data = json5.load(f)
|
171 |
data_models = data.pop("models", [])
|
172 |
-
models: Dict[Literal["whisper", "m2m100", "nllb", "mt5"], List[ModelConfig]] = {
|
173 |
key: [ModelConfig(**item) for item in value]
|
174 |
for key, value in data_models.items()
|
175 |
}
|
|
|
43 |
return None
|
44 |
|
45 |
class ApplicationConfig:
|
46 |
+
def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA"], List[ModelConfig]],
|
47 |
input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
|
48 |
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
|
49 |
whisper_implementation: str = "whisper", default_model_name: str = "medium",
|
|
|
169 |
# Load using json5
|
170 |
data = json5.load(f)
|
171 |
data_models = data.pop("models", [])
|
172 |
+
models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA"], List[ModelConfig]] = {
|
173 |
key: [ModelConfig(**item) for item in value]
|
174 |
for key, value in data_models.items()
|
175 |
}
|
src/translation/translationModel.py
CHANGED
@@ -7,6 +7,8 @@ import torch
|
|
7 |
import ctranslate2
|
8 |
import transformers
|
9 |
|
|
|
|
|
10 |
from typing import Optional
|
11 |
from src.config import ModelConfig
|
12 |
from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
|
@@ -97,6 +99,11 @@ class TranslationModel:
|
|
97 |
self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
|
98 |
self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath)
|
99 |
self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
|
|
|
|
|
|
|
|
|
|
|
100 |
else:
|
101 |
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
|
102 |
self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
|
@@ -130,6 +137,12 @@ class TranslationModel:
|
|
130 |
elif "mt5" in self.modelPath:
|
131 |
output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
|
132 |
result = output[0]['generated_text']
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
else: #M2M100 & NLLB
|
134 |
output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
|
135 |
result = output[0]['translation_text']
|
@@ -148,7 +161,8 @@ _MODELS = ["distilled-600M", "distilled-1.3B", "1.3B", "3.3B",
|
|
148 |
"m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
|
149 |
"m2m100_1.2B", "m2m100_418M",
|
150 |
"mt5-zh-ja-en-trimmed",
|
151 |
-
"mt5-zh-ja-en-trimmed-fine-tuned-v1"
|
|
|
152 |
|
153 |
def check_model_name(name):
|
154 |
return any(allowed_name in name for allowed_name in _MODELS)
|
@@ -206,6 +220,9 @@ def download_model(
|
|
206 |
"special_tokens_map.json",
|
207 |
"spiece.model",
|
208 |
"vocab.json", #m2m100
|
|
|
|
|
|
|
209 |
]
|
210 |
|
211 |
kwargs = {
|
|
|
7 |
import ctranslate2
|
8 |
import transformers
|
9 |
|
10 |
+
import re
|
11 |
+
|
12 |
from typing import Optional
|
13 |
from src.config import ModelConfig
|
14 |
from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
|
|
|
99 |
self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
|
100 |
self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath)
|
101 |
self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
|
102 |
+
elif "ALMA" in self.modelPath:
|
103 |
+
self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.code + " to " + self.translationLang.whisper.code + ":" + self.whisperLang.whisper.code + ":"
|
104 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath, use_fast=True)
|
105 |
+
self.transModel = transformers.AutoModelForCausalLM.from_pretrained(self.modelPath, device_map="auto", trust_remote_code=False, revision="main")
|
106 |
+
self.transTranslator = transformers.pipeline("text-generation", model=self.transModel, tokenizer=self.transTokenizer, batch_size=2, do_sample=True, temperature=0.7, top_p=0.95, top_k=40, repetition_penalty=1.1)
|
107 |
else:
|
108 |
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
|
109 |
self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
|
|
|
137 |
elif "mt5" in self.modelPath:
|
138 |
output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
|
139 |
result = output[0]['generated_text']
|
140 |
+
elif "ALMA" in self.modelPath:
|
141 |
+
output = self.transTranslator(self.ALMAPrefix + text + self.translationLang.whisper.code + ":", max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
|
142 |
+
result = output[0]['generated_text']
|
143 |
+
result = re.sub(rf'^(.*{self.translationLang.whisper.code}: )', '', result) # Remove the prompt from the result
|
144 |
+
result = re.sub(rf'^(Translate this from .* to .*:)', '', result) # Remove the translation instruction
|
145 |
+
return result.strip()
|
146 |
else: #M2M100 & NLLB
|
147 |
output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
|
148 |
result = output[0]['translation_text']
|
|
|
161 |
"m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
|
162 |
"m2m100_1.2B", "m2m100_418M",
|
163 |
"mt5-zh-ja-en-trimmed",
|
164 |
+
"mt5-zh-ja-en-trimmed-fine-tuned-v1",
|
165 |
+
"ALMA-13B-GPTQ"]
|
166 |
|
167 |
def check_model_name(name):
|
168 |
return any(allowed_name in name for allowed_name in _MODELS)
|
|
|
220 |
"special_tokens_map.json",
|
221 |
"spiece.model",
|
222 |
"vocab.json", #m2m100
|
223 |
+
"model.safetensors",
|
224 |
+
"quantize_config.json",
|
225 |
+
"tokenizer.model"
|
226 |
]
|
227 |
|
228 |
kwargs = {
|
src/utils.py
CHANGED
@@ -130,7 +130,7 @@ def write_srt_original(transcript: Iterator[dict], file: TextIO,
|
|
130 |
flush=True,
|
131 |
)
|
132 |
|
133 |
-
if original is not None: print(f"{original}",
|
134 |
file=file,
|
135 |
flush=True)
|
136 |
|
|
|
130 |
flush=True,
|
131 |
)
|
132 |
|
133 |
+
if original is not None: print(f"{original}\n",
|
134 |
file=file,
|
135 |
flush=True)
|
136 |
|