dipesh1701
commited on
Commit
•
93d168d
1
Parent(s):
48ff56c
fix
Browse files
app.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
import time
|
4 |
-
import asyncio
|
5 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
6 |
from flores200_codes import flores_codes
|
7 |
|
8 |
# Load models and tokenizers once during initialization
|
9 |
-
|
10 |
model_name_dict = {
|
11 |
"nllb-distilled-600M": "facebook/nllb-200-distilled-600M",
|
12 |
}
|
@@ -15,8 +14,8 @@ async def load_models():
|
|
15 |
|
16 |
for call_name, real_name in model_name_dict.items():
|
17 |
print("\tLoading model:", call_name)
|
18 |
-
model =
|
19 |
-
tokenizer =
|
20 |
model_dict[call_name] = {
|
21 |
"model": model,
|
22 |
"tokenizer": tokenizer,
|
@@ -28,14 +27,14 @@ async def load_models():
|
|
28 |
def translate_text(source_lang, target_lang, input_text, model_dict):
|
29 |
model_name = "nllb-distilled-600M"
|
30 |
|
31 |
-
start_time = time.time()
|
32 |
-
source_code = flores_codes[source_lang]
|
33 |
-
target_code = flores_codes[target_lang]
|
34 |
-
|
35 |
if model_name in model_dict:
|
36 |
model = model_dict[model_name]["model"]
|
37 |
tokenizer = model_dict[model_name]["tokenizer"]
|
38 |
|
|
|
|
|
|
|
|
|
39 |
translator = pipeline(
|
40 |
"translation",
|
41 |
model=model,
|
@@ -57,11 +56,11 @@ def translate_text(source_lang, target_lang, input_text, model_dict):
|
|
57 |
else:
|
58 |
raise KeyError(f"Model '{model_name}' not found in model_dict")
|
59 |
|
60 |
-
|
61 |
print("\tInitializing models")
|
62 |
|
63 |
# Load models and tokenizers
|
64 |
-
model_dict =
|
65 |
|
66 |
lang_codes = list(flores_codes.keys())
|
67 |
inputs = [
|
@@ -72,10 +71,10 @@ async def main():
|
|
72 |
|
73 |
outputs = gr.outputs.JSON()
|
74 |
|
75 |
-
title = "
|
76 |
|
77 |
app_description = (
|
78 |
-
"This is a beta version of
|
79 |
)
|
80 |
examples = [["English", "Nepali", "Hello, how are you?"]]
|
81 |
|
@@ -88,6 +87,3 @@ async def main():
|
|
88 |
examples=examples,
|
89 |
examples_per_page=50,
|
90 |
).launch()
|
91 |
-
|
92 |
-
if __name__ == "__main__":
|
93 |
-
asyncio.run(main())
|
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
import time
|
|
|
4 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
5 |
from flores200_codes import flores_codes
|
6 |
|
7 |
# Load models and tokenizers once during initialization
|
8 |
+
def load_models():
|
9 |
model_name_dict = {
|
10 |
"nllb-distilled-600M": "facebook/nllb-200-distilled-600M",
|
11 |
}
|
|
|
14 |
|
15 |
for call_name, real_name in model_name_dict.items():
|
16 |
print("\tLoading model:", call_name)
|
17 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
|
18 |
+
tokenizer = AutoTokenizer.from_pretrained(real_name)
|
19 |
model_dict[call_name] = {
|
20 |
"model": model,
|
21 |
"tokenizer": tokenizer,
|
|
|
27 |
def translate_text(source_lang, target_lang, input_text, model_dict):
|
28 |
model_name = "nllb-distilled-600M"
|
29 |
|
|
|
|
|
|
|
|
|
30 |
if model_name in model_dict:
|
31 |
model = model_dict[model_name]["model"]
|
32 |
tokenizer = model_dict[model_name]["tokenizer"]
|
33 |
|
34 |
+
start_time = time.time()
|
35 |
+
source_code = flores_codes[source_lang]
|
36 |
+
target_code = flores_codes[target_lang]
|
37 |
+
|
38 |
translator = pipeline(
|
39 |
"translation",
|
40 |
model=model,
|
|
|
56 |
else:
|
57 |
raise KeyError(f"Model '{model_name}' not found in model_dict")
|
58 |
|
59 |
+
if __name__ == "__main__":
|
60 |
print("\tInitializing models")
|
61 |
|
62 |
# Load models and tokenizers
|
63 |
+
model_dict = load_models()
|
64 |
|
65 |
lang_codes = list(flores_codes.keys())
|
66 |
inputs = [
|
|
|
71 |
|
72 |
outputs = gr.outputs.JSON()
|
73 |
|
74 |
+
title = "The Master Betters Translator"
|
75 |
|
76 |
app_description = (
|
77 |
+
"This is a beta version of The Master Betters Translator that utilizes pre-trained language models for translation."
|
78 |
)
|
79 |
examples = [["English", "Nepali", "Hello, how are you?"]]
|
80 |
|
|
|
87 |
examples=examples,
|
88 |
examples_per_page=50,
|
89 |
).launch()
|
|
|
|
|
|