File size: 2,757 Bytes
00bfa33 f50408f 00bfa33 c3e5a3b 463444e 93d168d 463444e ccf30c7 463444e 06d2814 93d168d 06d2814 463444e 00bfa33 06d2814 00bfa33 5cac2b1 463444e 00bfa33 48ff56c 93d168d 00bfa33 93d168d 48ff56c 00bfa33 48ff56c 9003587 48ff56c 00bfa33 9003587 48ff56c 00bfa33 48ff56c 9003587 48ff56c 463444e 93d168d 06d2814 463444e 9003587 463444e 00bfa33 463444e 93d168d 00bfa33 9003587 00bfa33 463444e 1714a96 463444e 9003587 463444e 9003587 463444e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import os
import torch
import gradio as gr
import time
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from flores200_codes import flores_codes
def load_models():
model_name_dict = {
"nllb-distilled-1.3B": "facebook/nllb-200-distilled-1.3B",
}
model_dict = {}
for call_name, real_name in model_name_dict.items():
print("\tLoading model:", call_name)
model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
tokenizer = AutoTokenizer.from_pretrained(real_name)
model_dict[call_name] = {
"model": model,
"tokenizer": tokenizer,
}
return model_dict
# Load models and tokenizers once during initialization
model_dict = load_models()
# Translate text using preloaded models and tokenizers
def translate_text(source, target, text):
model_name = "nllb-distilled-1.3B"
if model_name in model_dict and model_dict[model_name]["model"] is not None:
model = model_dict[model_name]["model"]
tokenizer = model_dict[model_name]["tokenizer"]
start_time = time.time()
source = flores_codes[source]
target = flores_codes[target]
translator = pipeline(
"translation",
model=model,
tokenizer=tokenizer,
src_lang=source,
tgt_lang=target,
)
output = translator(text, max_length=400)
end_time = time.time()
output = output[0]["translation_text"]
result = {
"inference_time": end_time - start_time,
"source": source,
"target": target,
"result": output,
}
return result
else:
raise KeyError(f"Model '{model_name}' not found in model_dict")
if __name__ == "__main__":
print("\tInitializing models")
lang_codes = list(flores_codes.keys())
inputs = [
gr.inputs.Dropdown(lang_codes, default="English", label="Source"),
gr.inputs.Dropdown(lang_codes, default="Nepali", label="Target"),
gr.inputs.Textbox(lines=5, label="Input text"),
]
outputs = gr.outputs.JSON()
title = "The Master Betters Translator"
desc = "This is a beta version of The Master Betters Translator that utilizes pre-trained language models for translation. To use this app you need to have chosen the source and target language with your input text to get the output."
description = (
f"{desc}"
)
examples = [["English", "Nepali", "The Master Betters Translator Welcomes You."]]
gr.Interface(
translate_text,
inputs,
outputs,
title=title,
description=description,
examples=examples,
examples_per_page=50,
).launch()
|