cointegrated commited on
Commit
2a62da0
1 Parent(s): d0ffdbf

add punctuation normalization and load the tokenizer only once

Browse files
Files changed (2) hide show
  1. app.py +14 -7
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import spaces
2
  import gradio as gr
 
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  from flores import code_mapping
5
  import platform
@@ -28,12 +29,11 @@ def load_model():
28
  model = load_model()
29
 
30
 
31
- def load_tokenizer(src_lang, tgt_lang):
32
- tokenizer = AutoTokenizer.from_pretrained(
33
- MODEL_NAME, src_lang=code_mapping[src_lang], tgt_lang=code_mapping[tgt_lang]
34
- )
35
- return tokenizer
36
 
 
37
 
38
  # cache function
39
  @lru_cache(maxsize=100)
@@ -44,10 +44,17 @@ def translate(text: str, src_lang: str, tgt_lang: str):
44
  raise gr.Error("The target language is empty! Please choose it in the dropdown list.")
45
  return _translate(text, src_lang, tgt_lang)
46
 
 
47
  # Only assign GPU if cache not used
48
  @spaces.GPU
49
  def _translate(text: str, src_lang: str, tgt_lang: str):
50
- tokenizer = load_tokenizer(src_lang, tgt_lang)
 
 
 
 
 
 
51
 
52
  paragraphs = text.split("\n")
53
  translated_paragraphs = []
@@ -66,7 +73,7 @@ def _translate(text: str, src_lang: str, tgt_lang: str):
66
  )
67
  translated_chunk = model.generate(
68
  input_ids=torch.tensor([input_tokens]).to(device),
69
- forced_bos_token_id=tokenizer.convert_tokens_to_ids(code_mapping[tgt_lang]),
70
  max_length=len(input_tokens) + 50,
71
  num_return_sequences=1,
72
  num_beams=5,
 
1
  import spaces
2
  import gradio as gr
3
+ from sacremoses import MosesPunctNormalizer
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
  from flores import code_mapping
6
  import platform
 
29
  model = load_model()
30
 
31
 
32
+ # Loading the tokenizer once, because re-loading it takes about 1.5 seconds each time
33
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
34
+
 
 
35
 
36
+ punct_normalizer = MosesPunctNormalizer(lang="en")
37
 
38
  # cache function
39
  @lru_cache(maxsize=100)
 
44
  raise gr.Error("The target language is empty! Please choose it in the dropdown list.")
45
  return _translate(text, src_lang, tgt_lang)
46
 
47
+
48
  # Only assign GPU if cache not used
49
  @spaces.GPU
50
  def _translate(text: str, src_lang: str, tgt_lang: str):
51
+ src_code = code_mapping[src_lang]
52
+ tgt_code = code_mapping[tgt_lang]
53
+ tokenizer.src_lang = src_code
54
+ tokenizer.tgt_lang = tgt_code
55
+
56
+ # normalizing the punctuation first
57
+ text = punct_normalizer.normalize(text)
58
 
59
  paragraphs = text.split("\n")
60
  translated_paragraphs = []
 
73
  )
74
  translated_chunk = model.generate(
75
  input_ids=torch.tensor([input_tokens]).to(device),
76
+ forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code),
77
  max_length=len(input_tokens) + 50,
78
  num_return_sequences=1,
79
  num_beams=5,
requirements.txt CHANGED
@@ -3,4 +3,5 @@ transformers
3
  torch
4
  gradio==4.32.2
5
  spaces
6
- nltk
 
 
3
  torch
4
  gradio==4.32.2
5
  spaces
6
+ nltk
7
+ sacremoses