File size: 4,342 Bytes
e5e9b34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
944775a
 
 
e5e9b34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
from transformers import set_seed, pipeline
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

######### HELSINKI NLP ##################
def translate_helsinki_nlp(s:str, src_iso:str, dest_iso:str)-> str:
    ''' 
    Translate the text using HelsinkiNLP's Opus models for Mossi language.
    
    Parameters
    ---------- 
    s: str
        The text
    src_iso:
        The ISO-3 code of the source language
    dest_iso:
        The ISO-3 code of the destination language

    Returns
    ---------- 
    translation:str
        The translated text
    '''
    # Ensure replicability
    set_seed(555) 

    # Inference
    translator = pipeline("translation", model=f"Helsinki-NLP/opus-mt-{src_iso}-{dest_iso}")
    translation = translator(s)[0]['translation_text']

    return translation

######### MASAKHANE ##################
def translate_masakhane(s:str, src_iso:str, dest_iso:str)-> str:
    ''' 
    Translate the text using Masakhane's M2M models for Mossi language.
    
    Parameters
    ---------- 
    s: str
        The text
    src_iso:
        The ISO-3 code of the source language
    dest_iso:
        The ISO-3 code of the destination language

    Returns
    ---------- 
    translation:str
        The translated text
    '''
    # Ensure replicability
    set_seed(555) 

    # Load model
    model = M2M100ForConditionalGeneration.from_pretrained(f"masakhane/m2m100_418m_{src_iso}_{dest_iso}_news")
    tokenizer = M2M100Tokenizer.from_pretrained(f"masakhane/m2m100_418m_{src_iso}_{dest_iso}_news")

    # Inference
    encoded = tokenizer(s, return_tensors="pt")
    generated_tokens = model.generate(**encoded) 
    translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]

    return translation

######### META ##################
def translate_facebook(s:str, src_iso:str, dest_iso:str)-> str:
    ''' 
    Translate the text using Meta's NLLB model for Mossi language.
    
    Parameters
    ---------- 
    s: str
        The text
    src_iso:
        The ISO-3 code of the source language
    dest_iso:
        The ISO-3 code of the destination language

    Returns
    ---------- 
    translation:str
        The translated text
    '''
    # Ensure replicability
    set_seed(555) 

    # Load model
    tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M",src_lang=f"{src_iso}_Latn") #use_auth_token=True, 
                                              
    model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") #, use_auth_token=True)

    # Inference
    encoded = tokenizer(s, return_tensors="pt")
    translated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.lang_code_to_id[f"{dest_iso}_Latn"], max_length=30)
    translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]

    return translation


######### ALL OF THE ABOVE ##################
def translate(s, src_iso, dest_iso):
    ''' 
    Translate the text using all available models (Meta, Masakhane, and Helsinki NLP where applicable).
    
    Parameters
    ---------- 
    s: str
        The text
    src_iso:
        The ISO-3 code of the source language
    dest_iso:
        The ISO-3 code of the destination language

    Returns
    ---------- 
    translation:str
        The translated text, concatenated over different models
    '''
    # Translate with Meta NLLB
    translation= "Meta's NLLB translation is:\n\n" + translate_facebook(s, src_iso, dest_iso)
    
    # Check if the ISO pair is supported by another model and if so, add to translation
    iso_pair = f"{src_iso}-{dest_iso}"
    if iso_pair in ["mos-eng", 'eng-mos', 'fra-mos']:
        src_iso = src_iso.lower().replace("eng", "en").replace("fra", "fr")
        dest_iso = dest_iso.replace("eng", "en").replace("fra", "fr")
        translation+= f"\n\n\nHelsinkiNLP's Opus translation is:\n\n {translate_helsinki_nlp(s, src_iso, dest_iso)}"
    
    if iso_pair in ["mos-fra", "fra-mos"]:
        src_iso = src_iso.lower().replace("fra", "fr")
        dest_iso = dest_iso.replace("fra", "fr")
        translation+= "\n\n\nMasakhane's M2M translation is:\n\n" + translate_masakhane(s, src_iso, dest_iso)

    return translation