Spaces:
Running
on
Zero
Running
on
Zero
Optimize the preprocessing and generation
#11
by
cointegrated
- opened
- app.py +33 -9
- flores.py +3 -3
- requirements.txt +3 -1
app.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
|
|
|
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
from flores import code_mapping
|
5 |
import platform
|
@@ -28,28 +30,47 @@ def load_model():
|
|
28 |
model = load_model()
|
29 |
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
|
38 |
# cache function
|
39 |
@lru_cache(maxsize=100)
|
40 |
def translate(text: str, src_lang: str, tgt_lang: str):
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
# Only assign GPU if cache not used
|
44 |
@spaces.GPU
|
45 |
def _translate(text: str, src_lang: str, tgt_lang: str):
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
paragraphs = text.split("\n")
|
49 |
translated_paragraphs = []
|
50 |
|
51 |
for paragraph in paragraphs:
|
52 |
-
|
|
|
53 |
translated_sentences = []
|
54 |
|
55 |
for sentence in sentences:
|
@@ -62,9 +83,12 @@ def _translate(text: str, src_lang: str, tgt_lang: str):
|
|
62 |
)
|
63 |
translated_chunk = model.generate(
|
64 |
input_ids=torch.tensor([input_tokens]).to(device),
|
65 |
-
forced_bos_token_id=tokenizer.convert_tokens_to_ids(
|
66 |
max_length=len(input_tokens) + 50,
|
67 |
num_return_sequences=1,
|
|
|
|
|
|
|
68 |
)
|
69 |
translated_chunk = tokenizer.decode(
|
70 |
translated_chunk[0], skip_special_tokens=True
|
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
3 |
+
from sacremoses import MosesPunctNormalizer
|
4 |
+
from stopes.pipelines.monolingual.utils.sentence_split import get_split_algo
|
5 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
6 |
from flores import code_mapping
|
7 |
import platform
|
|
|
30 |
model = load_model()
|
31 |
|
32 |
|
33 |
+
# Loading the tokenizer once, because re-loading it takes about 1.5 seconds each time
|
34 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
35 |
+
|
36 |
+
|
37 |
+
punct_normalizer = MosesPunctNormalizer(lang="en")
|
38 |
+
|
39 |
+
|
40 |
+
@lru_cache(maxsize=202)
|
41 |
+
def get_language_specific_sentence_splitter(language_code):
|
42 |
+
short_code = language_code[:3]
|
43 |
+
splitter = get_split_algo(short_code, "default")
|
44 |
+
return splitter
|
45 |
|
46 |
|
47 |
# cache function
|
48 |
@lru_cache(maxsize=100)
|
49 |
def translate(text: str, src_lang: str, tgt_lang: str):
|
50 |
+
if not src_lang:
|
51 |
+
raise gr.Error("The source language is empty! Please choose it in the dropdown list.")
|
52 |
+
if not tgt_lang:
|
53 |
+
raise gr.Error("The target language is empty! Please choose it in the dropdown list.")
|
54 |
+
return _translate(text, src_lang, tgt_lang)
|
55 |
+
|
56 |
|
57 |
# Only assign GPU if cache not used
|
58 |
@spaces.GPU
|
59 |
def _translate(text: str, src_lang: str, tgt_lang: str):
|
60 |
+
src_code = code_mapping[src_lang]
|
61 |
+
tgt_code = code_mapping[tgt_lang]
|
62 |
+
tokenizer.src_lang = src_code
|
63 |
+
tokenizer.tgt_lang = tgt_code
|
64 |
+
|
65 |
+
# normalizing the punctuation first
|
66 |
+
text = punct_normalizer.normalize(text)
|
67 |
|
68 |
paragraphs = text.split("\n")
|
69 |
translated_paragraphs = []
|
70 |
|
71 |
for paragraph in paragraphs:
|
72 |
+
splitter = get_language_specific_sentence_splitter(src_code)
|
73 |
+
sentences = list(splitter(paragraph))
|
74 |
translated_sentences = []
|
75 |
|
76 |
for sentence in sentences:
|
|
|
83 |
)
|
84 |
translated_chunk = model.generate(
|
85 |
input_ids=torch.tensor([input_tokens]).to(device),
|
86 |
+
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code),
|
87 |
max_length=len(input_tokens) + 50,
|
88 |
num_return_sequences=1,
|
89 |
+
num_beams=5,
|
90 |
+
no_repeat_ngram_size=4, # repetition blocking works better if this number is below num_beams
|
91 |
+
renormalize_logits=True, # recompute token probabilities after banning the repetitions
|
92 |
)
|
93 |
translated_chunk = tokenizer.decode(
|
94 |
translated_chunk[0], skip_special_tokens=True
|
flores.py
CHANGED
@@ -10,7 +10,7 @@ code_mapping = {
|
|
10 |
"Amharic": "amh_Ethi",
|
11 |
"North Levantine Arabic": "apc_Arab",
|
12 |
"Modern Standard Arabic": "arb_Arab",
|
13 |
-
"Modern Standard Arabic (Romanized)": "arb_Latn",
|
14 |
"Najdi Arabic": "ars_Arab",
|
15 |
"Moroccan Arabic": "ary_Arab",
|
16 |
"Egyptian Arabic": "arz_Arab",
|
@@ -115,7 +115,7 @@ code_mapping = {
|
|
115 |
"Maithili": "mai_Deva",
|
116 |
"Malayalam": "mal_Mlym",
|
117 |
"Marathi": "mar_Deva",
|
118 |
-
"Minangkabau (Arabic script)": "min_Arab",
|
119 |
"Minangkabau (Latin script)": "min_Latn",
|
120 |
"Macedonian": "mkd_Cyrl",
|
121 |
"Plateau Malagasy": "plt_Latn",
|
@@ -149,7 +149,7 @@ code_mapping = {
|
|
149 |
"Russian": "rus_Cyrl",
|
150 |
"Sango": "sag_Latn",
|
151 |
"Sanskrit": "san_Deva",
|
152 |
-
"Santali": "
|
153 |
"Sicilian": "scn_Latn",
|
154 |
"Shan": "shn_Mymr",
|
155 |
"Sinhala": "sin_Sinh",
|
|
|
10 |
"Amharic": "amh_Ethi",
|
11 |
"North Levantine Arabic": "apc_Arab",
|
12 |
"Modern Standard Arabic": "arb_Arab",
|
13 |
+
# "Modern Standard Arabic (Romanized)": "arb_Latn", # it is in FLORES, but not in NLLB
|
14 |
"Najdi Arabic": "ars_Arab",
|
15 |
"Moroccan Arabic": "ary_Arab",
|
16 |
"Egyptian Arabic": "arz_Arab",
|
|
|
115 |
"Maithili": "mai_Deva",
|
116 |
"Malayalam": "mal_Mlym",
|
117 |
"Marathi": "mar_Deva",
|
118 |
+
# "Minangkabau (Arabic script)": "min_Arab", # it is in FLORES, but not in NLLB
|
119 |
"Minangkabau (Latin script)": "min_Latn",
|
120 |
"Macedonian": "mkd_Cyrl",
|
121 |
"Plateau Malagasy": "plt_Latn",
|
|
|
149 |
"Russian": "rus_Cyrl",
|
150 |
"Sango": "sag_Latn",
|
151 |
"Sanskrit": "san_Deva",
|
152 |
+
"Santali": "sat_Beng", # It is called sat_Olck in FLORES, but (less correctly sat_Beng in NLLB)
|
153 |
"Sicilian": "scn_Latn",
|
154 |
"Shan": "shn_Mymr",
|
155 |
"Sinhala": "sin_Sinh",
|
requirements.txt
CHANGED
@@ -3,4 +3,6 @@ transformers
|
|
3 |
torch
|
4 |
gradio==4.32.2
|
5 |
spaces
|
6 |
-
nltk
|
|
|
|
|
|
3 |
torch
|
4 |
gradio==4.32.2
|
5 |
spaces
|
6 |
+
nltk
|
7 |
+
sacremoses
|
8 |
+
stopes[mono] @ git+https://github.com/facebookresearch/stopes@better-sentence-splitters
|