|
import torch |
|
from transformers import set_seed, pipeline |
|
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
import time |
|
|
|
|
|
def translate_helsinki_nlp(s:str, src_iso:str, dest_iso:str)-> str: |
|
''' |
|
Translate the text using HelsinkiNLP's Opus models for Mossi language. |
|
|
|
Parameters |
|
---------- |
|
s: str |
|
The text |
|
src_iso: |
|
The ISO-3 code of the source language |
|
dest_iso: |
|
The ISO-3 code of the destination language |
|
|
|
Returns |
|
---------- |
|
translation:str |
|
The translated text |
|
''' |
|
|
|
set_seed(555) |
|
|
|
|
|
translator = pipeline("translation", model=f"Helsinki-NLP/opus-mt-{src_iso}-{dest_iso}") |
|
translation = translator(s)[0]['translation_text'] |
|
|
|
return translation |
|
|
|
|
|
def translate_masakhane(s:str, src_iso:str, dest_iso:str)-> str: |
|
''' |
|
Translate the text using Masakhane's M2M models for Mossi language. |
|
|
|
Parameters |
|
---------- |
|
s: str |
|
The text |
|
src_iso: |
|
The ISO-3 code of the source language |
|
dest_iso: |
|
The ISO-3 code of the destination language |
|
|
|
Returns |
|
---------- |
|
translation:str |
|
The translated text |
|
''' |
|
|
|
set_seed(555) |
|
|
|
|
|
model = M2M100ForConditionalGeneration.from_pretrained(f"masakhane/m2m100_418m_{src_iso}_{dest_iso}_news") |
|
tokenizer = M2M100Tokenizer.from_pretrained(f"masakhane/m2m100_418m_{src_iso}_{dest_iso}_news") |
|
|
|
|
|
encoded = tokenizer(s, return_tensors="pt") |
|
generated_tokens = model.generate(**encoded) |
|
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
|
|
return translation |
|
|
|
|
|
def translate_facebook(s:str, src_iso:str, dest_iso:str)-> str: |
|
''' |
|
Translate the text using Meta's NLLB model for Mossi language. |
|
|
|
Parameters |
|
---------- |
|
s: str |
|
The text |
|
src_iso: |
|
The ISO-3 code of the source language |
|
dest_iso: |
|
The ISO-3 code of the destination language |
|
|
|
Returns |
|
---------- |
|
translation:str |
|
The translated text |
|
''' |
|
|
|
set_seed(555) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M",src_lang=f"{src_iso}_Latn") |
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") |
|
|
|
|
|
encoded = tokenizer(s, return_tensors="pt") |
|
translated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.convert_tokens_to_ids(f"{dest_iso}_Latn"), max_length=30) |
|
translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] |
|
|
|
return translation |
|
|
|
|
|
|
|
def translate(s, src_iso, dest_iso): |
|
''' |
|
Translate the text using all available models (Meta, Masakhane, and Helsinki NLP where applicable). |
|
|
|
Parameters |
|
---------- |
|
s: str |
|
The text |
|
src_iso: |
|
The ISO-3 code of the source language |
|
dest_iso: |
|
The ISO-3 code of the destination language |
|
|
|
Returns |
|
---------- |
|
translation:str |
|
The translated text, concatenated over different models |
|
''' |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
translation= "Meta's NLLB translation is:\n\n" + translate_facebook(s, src_iso, dest_iso) |
|
|
|
|
|
iso_pair = f"{src_iso}-{dest_iso}" |
|
if iso_pair in ["mos-eng", 'eng-mos', 'fra-mos']: |
|
src_iso = src_iso.lower().replace("eng", "en").replace("fra", "fr") |
|
dest_iso = dest_iso.replace("eng", "en").replace("fra", "fr") |
|
translation+= f"\n\n\nHelsinkiNLP's Opus translation is:\n\n {translate_helsinki_nlp(s, src_iso, dest_iso)}" |
|
|
|
if iso_pair in ["mos-fra", "fra-mos"]: |
|
src_iso = src_iso.lower().replace("fra", "fr") |
|
dest_iso = dest_iso.replace("fra", "fr") |
|
translation+= "\n\n\nMasakhane's M2M translation is:\n\n" + translate_masakhane(s, src_iso, dest_iso) |
|
|
|
print("Time elapsed: ", int(time.time() - start_time), " seconds") |
|
|
|
return translation |
|
|