nobrowning commited on
Commit
7936150
1 Parent(s): dcfd438

add language detection

Browse files
Files changed (2) hide show
  1. app.py +41 -16
  2. languages.py +47 -0
app.py CHANGED
@@ -2,6 +2,8 @@ import streamlit as st
2
  import os
3
  import io
4
  from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
 
 
5
  import time
6
  import json
7
  from typing import List
@@ -135,6 +137,17 @@ def load_model(
135
  return tokenizer, model
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
138
  st.title("M2M100 Translator")
139
  st.write("M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. It was introduced in this paper https://arxiv.org/abs/2010.11125 and first released in https://github.com/pytorch/fairseq/tree/master/examples/m2m_100 repository. The model that can directly translate between the 9,900 directions of 100 languages.\n")
140
 
@@ -147,26 +160,38 @@ user_input: str = st.text_area(
147
  max_chars=5120,
148
  )
149
 
150
- source_lang = st.selectbox(label="Source language", options=list(lang_id.keys()))
151
  target_lang = st.selectbox(label="Target language", options=list(lang_id.keys()))
152
 
153
  if st.button("Run"):
154
  time_start = time.time()
155
  tokenizer, model = load_model()
 
156
 
157
- src_lang = lang_id[source_lang]
158
- trg_lang = lang_id[target_lang]
159
- tokenizer.src_lang = src_lang
160
  with torch.no_grad():
161
- encoded_input = tokenizer(user_input, return_tensors="pt").to(device)
162
- generated_tokens = model.generate(
163
- **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)
164
- )
165
- translated_text = tokenizer.batch_decode(
166
- generated_tokens, skip_special_tokens=True
167
- )[0]
168
-
169
- time_end = time.time()
170
- st.success(translated_text)
171
-
172
- st.write(f"Computation time: {round((time_end-time_start),3)} segs")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  import io
4
  from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+ from languages import LANGUANGE_MAP
7
  import time
8
  import json
9
  from typing import List
 
137
  return tokenizer, model
138
 
139
 
140
+ @st.cache(suppress_st_warning=True, allow_output_mutation=True)
141
+ def load_detection_model(
142
+ pretrained_model: str = "ivanlau/language-detection-fine-tuned-on-xlm-roberta-base",
143
+ cache_dir: str = "models/",
144
+ ):
145
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
146
+ model = AutoModelForSequenceClassification.from_pretrained(pretrained_model, cache_dir=cache_dir).to(device)
147
+ model.eval()
148
+ return tokenizer, model
149
+
150
+
151
  st.title("M2M100 Translator")
152
  st.write("M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. It was introduced in this paper https://arxiv.org/abs/2010.11125 and first released in https://github.com/pytorch/fairseq/tree/master/examples/m2m_100 repository. The model that can directly translate between the 9,900 directions of 100 languages.\n")
153
 
 
160
  max_chars=5120,
161
  )
162
 
 
163
  target_lang = st.selectbox(label="Target language", options=list(lang_id.keys()))
164
 
165
  if st.button("Run"):
166
  time_start = time.time()
167
  tokenizer, model = load_model()
168
+ de_tokenizer, de_model = load_detection_model()
169
 
 
 
 
170
  with torch.no_grad():
171
+
172
+ tokenized_sentence = de_tokenizer(user_input, return_tensors='pt')
173
+ output = de_model(**tokenized_sentence)
174
+ de_predictions = torch.nn.functional.softmax(output.logits, dim=-1)
175
+ _, preds = torch.max(de_predictions, dim=-1)
176
+
177
+ lang_type = LANGUANGE_MAP[preds.item()]
178
+
179
+ if lang_type not in lang_id:
180
+ st.success('Unsupported Language')
181
+ st.write(f"Computation time: {round((time_end-time_start),3)} segs")
182
+ else:
183
+ src_lang = lang_id[]
184
+ trg_lang = lang_id[target_lang]
185
+ tokenizer.src_lang = src_lang
186
+ encoded_input = tokenizer(user_input, return_tensors="pt").to(device)
187
+ generated_tokens = model.generate(
188
+ **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)
189
+ )
190
+ translated_text = tokenizer.batch_decode(
191
+ generated_tokens, skip_special_tokens=True
192
+ )[0]
193
+
194
+ time_end = time.time()
195
+ st.success(translated_text)
196
+
197
+ st.write(f"Computation time: {round((time_end-time_start),3)} segs")
languages.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LANGUANGE_MAP = {
2
+ 0: 'Arabic',
3
+ 1: 'Basque',
4
+ 2: 'Breton',
5
+ 3: 'Catalan',
6
+ 4: 'Chinese',
7
+ 5: 'Chinese',
8
+ 6: 'Chinese',
9
+ 7: 'Chuvash',
10
+ 8: 'Czech',
11
+ 9: 'Dhivehi',
12
+ 10: 'Dutch',
13
+ 11: 'English',
14
+ 12: 'Esperanto',
15
+ 13: 'Estonian',
16
+ 14: 'French',
17
+ 15: 'Frisian',
18
+ 16: 'Georgian',
19
+ 17: 'German',
20
+ 18: 'Greek',
21
+ 19: 'Hakha_Chin',
22
+ 20: 'Indonesian',
23
+ 21: 'Interlingua',
24
+ 22: 'Italian',
25
+ 23: 'Japanese',
26
+ 24: 'Kabyle',
27
+ 25: 'Kinyarwanda',
28
+ 26: 'Kyrgyz',
29
+ 27: 'Latvian',
30
+ 28: 'Maltese',
31
+ 29: 'Mongolian',
32
+ 30: 'Persian',
33
+ 31: 'Polish',
34
+ 32: 'Portuguese',
35
+ 33: 'Romanian',
36
+ 34: 'Romansh_Sursilvan',
37
+ 35: 'Russian',
38
+ 36: 'Sakha',
39
+ 37: 'Slovenian',
40
+ 38: 'Spanish',
41
+ 39: 'Swedish',
42
+ 40: 'Tamil',
43
+ 41: 'Tatar',
44
+ 42: 'Turkish',
45
+ 43: 'Ukranian',
46
+ 44: 'Welsh'
47
+ }