|
import os |
|
import time |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
from flores200_codes import flores_codes |
|
|
|
def load_models(): |
|
model_name_dict = {'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B', |
|
|
|
} |
|
model_dict = {} |
|
for call_name, real_name in model_name_dict.items(): |
|
print(f'\tLoading model: {call_name}') |
|
model = AutoModelForSeq2SeqLM.from_pretrained(real_name) |
|
tokenizer = AutoTokenizer.from_pretrained(real_name) |
|
model_dict[call_name + '_model'] = model |
|
model_dict[call_name + '_tokenizer'] = tokenizer |
|
return model_dict |
|
|
|
global model_dict |
|
model_dict = load_models() |
|
|
|
def translate_text(source_lang, target_lang, input_text): |
|
if len(model_dict) == 2: |
|
model_name = 'nllb-distilled-1.3B' |
|
start_time = time.time() |
|
source = flores_codes.get(source_lang) |
|
target = flores_codes.get(target_lang) |
|
|
|
if not source or not target: |
|
return {"error": "Invalid source or target language code"} |
|
|
|
model = model_dict[model_name + '_model'] |
|
tokenizer = model_dict[model_name + '_tokenizer'] |
|
translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target) |
|
output = translator(input_text, max_length=400) |
|
end_time = time.time() |
|
output_text = output[0]['translation_text'] |
|
result = { |
|
'inference_time': end_time - start_time, |
|
'source': source_lang, |
|
'target': target_lang, |
|
'result': output_text |
|
} |
|
return result |
|
|
|
|
|
iface = gr.Interface( |
|
fn=translate_text, |
|
inputs=[ |
|
gr.inputs.Textbox(lines=1, placeholder="Source language code", label="Source Language Code"), |
|
gr.inputs.Textbox(lines=1, placeholder="Target language code", label="Target Language Code"), |
|
gr.inputs.Textbox(lines=5, placeholder="Enter text to translate", label="Input Text"), |
|
], |
|
outputs=gr.outputs.JSON(), |
|
title="Translation API", |
|
description="Translation API using NLLB model." |
|
) |
|
|
|
|
|
iface.launch(share=True, enable_queue=True, show_error=True, server_name="0.0.0.0", server_port=7860) |
|
|