File size: 2,284 Bytes
3d1d83f
 
c6b2bc1
3d1d83f
 
 
 
0d4e984
 
 
3d1d83f
 
c6b2bc1
3d1d83f
 
c6b2bc1
 
3d1d83f
 
c6b2bc1
 
3d1d83f
c6b2bc1
c14e153
0d4e984
3d1d83f
c6b2bc1
 
 
 
 
3d1d83f
 
 
 
c6b2bc1
3d1d83f
c6b2bc1
 
 
 
 
 
 
3d1d83f
 
c6b2bc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import time
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from flores200_codes import flores_codes

def load_models():
    model_name_dict = {'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
                       #'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M'
                      }
    model_dict = {}
    for call_name, real_name in model_name_dict.items():
        print(f'\tLoading model: {call_name}')
        model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
        tokenizer = AutoTokenizer.from_pretrained(real_name)
        model_dict[call_name + '_model'] = model
        model_dict[call_name + '_tokenizer'] = tokenizer
    return model_dict

global model_dict
model_dict = load_models()

def translate_text(source_lang, target_lang, input_text):
    if len(model_dict) == 2:
        model_name = 'nllb-distilled-1.3B'
    start_time = time.time()
    source = flores_codes.get(source_lang)
    target = flores_codes.get(target_lang)
    
    if not source or not target:
        return {"error": "Invalid source or target language code"}

    model = model_dict[model_name + '_model']
    tokenizer = model_dict[model_name + '_tokenizer']
    translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
    output = translator(input_text, max_length=400)
    end_time = time.time()
    output_text = output[0]['translation_text']
    result = {
        'inference_time': end_time - start_time,
        'source': source_lang,
        'target': target_lang,
        'result': output_text
    }
    return result

# Define Gradio Interface
iface = gr.Interface(
    fn=translate_text,
    inputs=[
        gr.inputs.Textbox(lines=1, placeholder="Source language code", label="Source Language Code"),
        gr.inputs.Textbox(lines=1, placeholder="Target language code", label="Target Language Code"),
        gr.inputs.Textbox(lines=5, placeholder="Enter text to translate", label="Input Text"),
    ],
    outputs=gr.outputs.JSON(),
    title="Translation API",
    description="Translation API using NLLB model."
)

# Launch as API only
iface.launch(share=True, enable_queue=True, show_error=True, server_name="0.0.0.0", server_port=7860)