|
import gradio as gr |
|
from simpletransformers.seq2seq import Seq2SeqModel |
|
|
|
|
|
BM_MODEL_PATH = "Enutrof/marian-mt-en-pcm" |
|
BBGM_MODEL_PATH = "NITHUB-AI/marian-mt-bbc-en-pcm" |
|
|
|
|
|
bm_model = Seq2SeqModel(encoder_decoder_type="marian", encoder_decoder_name=BM_MODEL_PATH, use_cuda=False) |
|
bbgm_model = Seq2SeqModel(encoder_decoder_type="marian", encoder_decoder_name=BBGM_MODEL_PATH, use_cuda=False) |
|
|
|
|
|
models = { |
|
"BM Model": bm_model, |
|
"BBGM Model": bbgm_model |
|
} |
|
|
|
def translate(model_name, source_sentence, num_beams): |
|
selected_model = models[model_name] |
|
predictions = selected_model.predict([source_sentence] * 3, num_beams=int(num_beams), num_return_sequences=3) |
|
return tuple(predictions) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=translate, |
|
inputs=[ |
|
gr.Dropdown(choices=["BM Model", "BBGM Model"], label="Model Selection"), |
|
gr.Textbox(placeholder="Enter English sentence here...", label="Source Sentence"), |
|
gr.Slider(minimum=1, maximum=10, default=5, step=1, label="Number of Beams"), |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Prediction 1"), |
|
gr.Textbox(label="Prediction 2"), |
|
gr.Textbox(label="Prediction 3"), |
|
], |
|
live=True |
|
) |
|
|
|
interface.launch() |
|
|