Files changed (1) hide show
  1. translation.py +22 -103
translation.py CHANGED
@@ -3,11 +3,12 @@ import sys
3
  import typing as tp
4
  import unicodedata
5
 
6
- import torch
7
  from sacremoses import MosesPunctNormalizer
8
  from sentence_splitter import SentenceSplitter
9
  from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
10
 
 
 
11
  MODEL_URL = "slone/nllb-210-v1"
12
  LANGUAGES = {
13
  "Русский | Russian": "rus_Cyrl",
@@ -23,49 +24,22 @@ LANGUAGES = {
23
  "Татар | Tatar | Татарский": "tat_Cyrl",
24
  "Тыва | Тувинский | Tuvan ": "tyv_Cyrl",
25
  }
26
- L1 = "rus_Cyrl"
27
- L2 = "eng_Latn"
28
-
29
 
30
  def get_non_printing_char_replacer(replace_by: str = " ") -> tp.Callable[[str], str]:
31
- non_printable_map = {
32
- ord(c): replace_by
33
- for c in (chr(i) for i in range(sys.maxunicode + 1))
34
- # same as \p{C} in perl
35
- # see https://www.unicode.org/reports/tr44/#General_Category_Values
36
- if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}
37
- }
38
-
39
- def replace_non_printing_char(line) -> str:
40
- return line.translate(non_printable_map)
41
-
42
- return replace_non_printing_char
43
-
44
 
45
  class TextPreprocessor:
46
- """
47
- Mimic the text preprocessing made for the NLLB model.
48
- This code is adapted from the Stopes repo of the NLLB team:
49
- https://github.com/facebookresearch/stopes/blob/main/stopes/pipelines/monolingual/monolingual_line_processor.py#L214
50
- """
51
-
52
  def __init__(self, lang="en"):
53
  self.mpn = MosesPunctNormalizer(lang=lang)
54
- self.mpn.substitutions = [
55
- (re.compile(r), sub) for r, sub in self.mpn.substitutions
56
- ]
57
  self.replace_nonprint = get_non_printing_char_replacer(" ")
58
 
59
  def __call__(self, text: str) -> str:
60
- clean = self.mpn.normalize(text)
61
- clean = self.replace_nonprint(clean)
62
- # replace 𝓕𝔯𝔞𝔫𝔠𝔢𝔰𝔠𝔞 by Francesca
63
- clean = unicodedata.normalize("NFKC", clean)
64
- return clean
65
-
66
 
67
  def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=False):
68
- """Apply a sentence splitter and return the sentences and all separators before and after them"""
69
  if fix_double_space:
70
  text = re.sub(" +", " ", text)
71
  sentences = splitter.split(text)
@@ -74,7 +48,6 @@ def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=
74
  for sentence in sentences:
75
  start_idx = text.find(sentence, i)
76
  if ignore_errors and start_idx == -1:
77
- # print(f"sent not found after {i}: `{sentence}`")
78
  start_idx = i + 1
79
  assert start_idx != -1, f"sent not found after {i}: `{sentence}`"
80
  fillers.append(text[i:start_idx])
@@ -82,87 +55,33 @@ def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=
82
  fillers.append(text[i:])
83
  return sentences, fillers
84
 
85
-
86
  class Translator:
87
  def __init__(self):
88
  self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_URL, low_cpu_mem_usage=True)
89
- if torch.cuda.is_available():
90
- self.model.cuda()
91
  self.tokenizer = NllbTokenizer.from_pretrained(MODEL_URL)
92
-
93
  self.splitter = SentenceSplitter("ru")
94
  self.preprocessor = TextPreprocessor()
95
-
96
  self.languages = LANGUAGES
97
 
98
- def translate(
99
- self,
100
- text,
101
- src_lang=L1,
102
- tgt_lang=L2,
103
- max_length="auto",
104
- num_beams=4,
105
- by_sentence=True,
106
- preprocess=True,
107
- **kwargs,
108
- ):
109
- """Translate a text sentence by sentence, preserving the fillers around the sentences."""
110
- if by_sentence:
111
- sents, fillers = sentenize_with_fillers(
112
- text, splitter=self.splitter, ignore_errors=True
113
- )
114
- else:
115
- sents = [text]
116
- fillers = ["", ""]
117
- if preprocess:
118
- sents = [self.preprocessor(sent) for sent in sents]
119
  results = []
120
- for sent, sep in zip(sents, fillers):
 
 
 
 
 
121
  results.append(sep)
122
- results.append(
123
- self.translate_single(
124
- sent,
125
- src_lang=src_lang,
126
- tgt_lang=tgt_lang,
127
- max_length=max_length,
128
- num_beams=num_beams,
129
- **kwargs,
130
- )
131
- )
132
  results.append(fillers[-1])
133
  return "".join(results)
134
 
135
- def translate_single(
136
- self,
137
- text,
138
- src_lang=L1,
139
- tgt_lang=L2,
140
- max_length="auto",
141
- num_beams=4,
142
- n_out=None,
143
- **kwargs,
144
- ):
145
  self.tokenizer.src_lang = src_lang
146
- encoded = self.tokenizer(
147
- text, return_tensors="pt", truncation=True, max_length=512
148
- )
149
- if max_length == "auto":
150
- max_length = int(32 + 2.0 * encoded.input_ids.shape[1])
151
- generated_tokens = self.model.generate(
152
- **encoded.to(self.model.device),
153
- forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang],
154
- max_length=max_length,
155
- num_beams=num_beams,
156
- num_return_sequences=n_out or 1,
157
- **kwargs,
158
- )
159
  out = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
160
- if isinstance(text, str) and n_out is None:
161
- return out[0]
162
- return out
163
-
164
-
165
- if __name__ == "__main__":
166
- print("Initializing a translator to pre-download models...")
167
- translator = Translator()
168
- print("Initialization successful!")
 
3
  import typing as tp
4
  import unicodedata
5
 
 
6
  from sacremoses import MosesPunctNormalizer
7
  from sentence_splitter import SentenceSplitter
8
  from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
9
 
10
+ import torch
11
+
12
  MODEL_URL = "slone/nllb-210-v1"
13
  LANGUAGES = {
14
  "Русский | Russian": "rus_Cyrl",
 
24
  "Татар | Tatar | Татарский": "tat_Cyrl",
25
  "Тыва | Тувинский | Tuvan ": "tyv_Cyrl",
26
  }
27
+ L1, L2 = "rus_Cyrl", "eng_Latn"
 
 
28
 
29
  def get_non_printing_char_replacer(replace_by: str = " ") -> tp.Callable[[str], str]:
30
+ non_printable_map = {ord(c): replace_by for c in (chr(i) for i in range(sys.maxunicode + 1)) if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}}
31
+ return lambda line: line.translate(non_printable_map)
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  class TextPreprocessor:
 
 
 
 
 
 
34
  def __init__(self, lang="en"):
35
  self.mpn = MosesPunctNormalizer(lang=lang)
36
+ self.mpn.substitutions = [(re.compile(r), sub) for r, sub in self.mpn.substitutions]
 
 
37
  self.replace_nonprint = get_non_printing_char_replacer(" ")
38
 
39
  def __call__(self, text: str) -> str:
40
+ return unicodedata.normalize("NFKC", self.replace_nonprint(self.mpn.normalize(text)))
 
 
 
 
 
41
 
42
  def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=False):
 
43
  if fix_double_space:
44
  text = re.sub(" +", " ", text)
45
  sentences = splitter.split(text)
 
48
  for sentence in sentences:
49
  start_idx = text.find(sentence, i)
50
  if ignore_errors and start_idx == -1:
 
51
  start_idx = i + 1
52
  assert start_idx != -1, f"sent not found after {i}: `{sentence}`"
53
  fillers.append(text[i:start_idx])
 
55
  fillers.append(text[i:])
56
  return sentences, fillers
57
 
 
58
  class Translator:
59
  def __init__(self):
60
  self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_URL, low_cpu_mem_usage=True)
61
+ self.model.cuda() if torch.cuda.is_available() else None
 
62
  self.tokenizer = NllbTokenizer.from_pretrained(MODEL_URL)
 
63
  self.splitter = SentenceSplitter("ru")
64
  self.preprocessor = TextPreprocessor()
 
65
  self.languages = LANGUAGES
66
 
67
+ def translate(self, text, src_lang=L1, tgt_lang=L2, max_length="auto", num_beams=4, by_sentence=True, preprocess=True, **kwargs):
68
+ sents, fillers = (sentenize_with_fillers(text, self.splitter, ignore_errors=True) if by_sentence else ([text], ["", ""]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  results = []
70
+ if preprocess:
71
+ for sent in sents:
72
+ results.append(self.preprocessor(sent))
73
+ else:
74
+ results = sents
75
+ for sent, sep in zip(results, fillers):
76
  results.append(sep)
77
+ results.append(self.translate_single(sent, src_lang, tgt_lang, max_length, num_beams, **kwargs))
 
 
 
 
 
 
 
 
 
78
  results.append(fillers[-1])
79
  return "".join(results)
80
 
81
+ def translate_single(self, text, src_lang=L1, tgt_lang=L2, max_length="auto", num_beams=4, n_out=None, **kwargs):
 
 
 
 
 
 
 
 
 
82
  self.tokenizer.src_lang = src_lang
83
+ encoded = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
84
+ max_length = int(32 + 2.0 * encoded.input_ids.shape[1]) if max_length == "auto" else max_length
85
+ generated_tokens = self.model.generate(**encoded.to(self.model.device), forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang], max_length=max_length, num_beams=num_beams, num_return_sequences=n_out or 1, **kwargs)
 
 
 
 
 
 
 
 
 
 
86
  out = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
87
+ return out[0] if isinstance(text, str) and n_out is None else out