Hamza-cpp commited on
Commit
c6b2bc1
1 Parent(s): a156163

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -62
app.py CHANGED
@@ -1,85 +1,59 @@
1
  import os
2
- import torch
3
- import gradio as gr
4
  import time
 
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
  from flores200_codes import flores_codes
7
 
8
-
9
  def load_models():
10
- # build model and tokenizer
11
- model_name_dict = {'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
12
- #'nllb-1.3B': 'facebook/nllb-200-1.3B',
13
- #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
14
- #'nllb-3.3B': 'facebook/nllb-200-3.3B',
15
- }
16
-
17
  model_dict = {}
18
-
19
  for call_name, real_name in model_name_dict.items():
20
- print('\tLoading model: %s' % call_name)
21
  model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
22
  tokenizer = AutoTokenizer.from_pretrained(real_name)
23
- model_dict[call_name+'_model'] = model
24
- model_dict[call_name+'_tokenizer'] = tokenizer
25
-
26
  return model_dict
27
 
 
 
28
 
29
- def translation(source, target, text):
30
  if len(model_dict) == 2:
31
  model_name = 'nllb-distilled-600M'
32
-
33
  start_time = time.time()
34
- source = flores_codes[source]
35
- target = flores_codes[target]
 
 
 
36
 
37
  model = model_dict[model_name + '_model']
38
  tokenizer = model_dict[model_name + '_tokenizer']
39
-
40
  translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
41
- output = translator(text, max_length=400)
42
-
43
  end_time = time.time()
44
-
45
- output = output[0]['translation_text']
46
- result = {'inference_time': end_time - start_time,
47
- 'source': source,
48
- 'target': target,
49
- 'result': output}
 
50
  return result
51
 
52
-
53
- if __name__ == '__main__':
54
- print('\tinit models')
55
-
56
- global model_dict
57
-
58
- model_dict = load_models()
59
-
60
- # define gradio demo
61
- lang_codes = list(flores_codes.keys())
62
- #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
63
- inputs = [gr.inputs.Dropdown(lang_codes, default='English', label='Source'),
64
- gr.inputs.Dropdown(lang_codes, default='Korean', label='Target'),
65
- gr.inputs.Textbox(lines=5, label="Input text"),
66
- ]
67
-
68
- outputs = gr.outputs.JSON()
69
-
70
- title = "NLLB distilled 600M demo"
71
-
72
- demo_status = "Demo is running on CPU"
73
- description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
74
- examples = [
75
- ['English', 'Korean', 'Hi. nice to meet you']
76
- ]
77
-
78
- gr.Interface(translation,
79
- inputs,
80
- outputs,
81
- title=title,
82
- description=description,
83
- ).launch()
84
-
85
-
 
1
  import os
 
 
2
  import time
3
+ import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
  from flores200_codes import flores_codes
6
 
 
7
  def load_models():
8
+ model_name_dict = {'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M'}
 
 
 
 
 
 
9
  model_dict = {}
 
10
  for call_name, real_name in model_name_dict.items():
11
+ print(f'\tLoading model: {call_name}')
12
  model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
13
  tokenizer = AutoTokenizer.from_pretrained(real_name)
14
+ model_dict[call_name + '_model'] = model
15
+ model_dict[call_name + '_tokenizer'] = tokenizer
 
16
  return model_dict
17
 
18
+ global model_dict
19
+ model_dict = load_models()
20
 
21
+ def translate_text(source_lang, target_lang, input_text):
22
  if len(model_dict) == 2:
23
  model_name = 'nllb-distilled-600M'
 
24
  start_time = time.time()
25
+ source = flores_codes.get(source_lang)
26
+ target = flores_codes.get(target_lang)
27
+
28
+ if not source or not target:
29
+ return {"error": "Invalid source or target language code"}
30
 
31
  model = model_dict[model_name + '_model']
32
  tokenizer = model_dict[model_name + '_tokenizer']
 
33
  translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
34
+ output = translator(input_text, max_length=400)
 
35
  end_time = time.time()
36
+ output_text = output[0]['translation_text']
37
+ result = {
38
+ 'inference_time': end_time - start_time,
39
+ 'source': source_lang,
40
+ 'target': target_lang,
41
+ 'result': output_text
42
+ }
43
  return result
44
 
45
+ # Define Gradio Interface
46
+ iface = gr.Interface(
47
+ fn=translate_text,
48
+ inputs=[
49
+ gr.inputs.Textbox(lines=1, placeholder="Source language code", label="Source Language Code"),
50
+ gr.inputs.Textbox(lines=1, placeholder="Target language code", label="Target Language Code"),
51
+ gr.inputs.Textbox(lines=5, placeholder="Enter text to translate", label="Input Text"),
52
+ ],
53
+ outputs=gr.outputs.JSON(),
54
+ title="Translation API",
55
+ description="Translation API using NLLB model."
56
+ )
57
+
58
+ # Launch as API only
59
+ iface.launch(share=True, enable_queue=True, show_error=True, server_name="0.0.0.0", server_port=7860)