File size: 9,419 Bytes
8077be2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import os
import warnings
import huggingface_hub
import requests
import torch

import ctranslate2
import transformers

from typing import Optional
from src.config import ModelConfig
from src.languages import Language
from src.nllb.nllbLangs import NllbLang, get_nllb_lang_from_code_whisper

class NllbModel:
    def __init__(
        self,
        model_config: ModelConfig,
        device: str = None,
        whisper_lang: Language = None,
        nllb_lang: NllbLang = None,
        download_root: Optional[str] = None,
        local_files_only: bool = False,
        load_model: bool = False,
    ):
        """Initializes the Nllb-200 model.

        Args:
          model_config: Config of the model to use (distilled-600M, distilled-1.3B, 
            1.3B, 3.3B...) or a path to a converted
            model directory. When a size is configured, the converted model is downloaded
            from the Hugging Face Hub.
          device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, 
            ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia).
          device_index: Device ID to use.
            The model can also be loaded on multiple GPUs by passing a list of IDs
            (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel
            when transcribe() is called from multiple Python threads (see also num_workers).
          compute_type: Type to use for computation.
            See https://opennmt.net/CTranslate2/quantization.html.
          cpu_threads: Number of threads to use when running on CPU (4 by default).
            A non zero value overrides the OMP_NUM_THREADS environment variable.
          num_workers: When transcribe() is called from multiple Python threads,
            having multiple workers enables true parallelism when running the model
            (concurrent calls to self.model.generate() will run in parallel).
            This can improve the global throughput at the cost of increased memory usage.
          download_root: Directory where the models should be saved. If not set, the models
            are saved in the standard Hugging Face cache directory.
          local_files_only:  If True, avoid downloading the file and return the path to the
            local cached file if it exists.
        """
        self.whisper_lang = whisper_lang
        self.nllb_whisper_lang = get_nllb_lang_from_code_whisper(whisper_lang.code.lower() if whisper_lang is not None else "en")
        self.nllb_lang = nllb_lang
        self.model_config = model_config

        if os.path.isdir(model_config.url):
            self.model_path = model_config.url
        else:
            self.model_path = download_model(
                model_config,
                local_files_only=local_files_only,
                cache_dir=download_root,
            )

        if device is None:
            if torch.cuda.is_available():
                device = "cuda" if "ct2" in self.model_path else "cuda:0"
            else:
                device = "cpu"

        self.device = device

        if load_model:
            self.load_model()

    def load_model(self):
        print('\n\nLoading model: %s\n\n' % self.model_path)
        if "ct2" in self.model_path:
            self.target_prefix = [self.nllb_lang.code]
            self.trans_tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_path, src_lang=self.nllb_whisper_lang.code)
            self.trans_model = ctranslate2.Translator(self.model_path, compute_type="auto", device=self.device)
        elif "mt5" in self.model_path:
            self.mt5_prefix = self.whisper_lang.code + "2" + self.nllb_lang.code_whisper + ": "
            self.trans_tokenizer = transformers.T5Tokenizer.from_pretrained(self.model_path) #requires spiece.model
            self.trans_model = transformers.MT5ForConditionalGeneration.from_pretrained(self.model_path)
            self.trans_translator = transformers.pipeline('text2text-generation', model=self.trans_model, device=self.device, tokenizer=self.trans_tokenizer)
        else: #NLLB
            self.trans_tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_path)
            self.trans_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.model_path)
            self.trans_translator = transformers.pipeline('translation', model=self.trans_model, device=self.device, tokenizer=self.trans_tokenizer, src_lang=self.nllb_whisper_lang.code, tgt_lang=self.nllb_lang.code)

    def release_vram(self):
        try:
            if torch.cuda.is_available():
                if "ct2" not in self.model_path:
                    device = torch.device("cpu")
                    self.trans_model.to(device)
                del self.trans_model
                torch.cuda.empty_cache()
                print("release vram end.")
        except Exception as e:
            print("Error release vram: " + str(e))


    def translation(self, text: str, max_length: int = 400):
        output = None
        result = None
        try:
            if "ct2" in self.model_path:
                source = self.trans_tokenizer.convert_ids_to_tokens(self.trans_tokenizer.encode(text))
                output = self.trans_model.translate_batch([source], target_prefix=[self.target_prefix])
                target = output[0].hypotheses[0][1:]
                result = self.trans_tokenizer.decode(self.trans_tokenizer.convert_tokens_to_ids(target))
            elif "mt5" in self.model_path:
                output = self.trans_translator(self.mt5_prefix + text, max_length=max_length, num_beams=4)
                result = output[0]['generated_text']
            else: #NLLB
                output = self.trans_translator(text, max_length=max_length)
                result = output[0]['translation_text']
        except Exception as e:
            print("Error translation text: " + str(e))

        return result


_MODELS = ["distilled-600M", "distilled-1.3B", "1.3B", "3.3B", 
           "ct2fast-nllb-200-distilled-1.3B-int8_float16", 
           "ct2fast-nllb-200-3.3B-int8_float16", 
           "nllb-200-3.3B-ct2-float16", "nllb-200-1.3B-ct2", "nllb-200-1.3B-ct2-int8", "nllb-200-1.3B-ct2-float16", 
           "nllb-200-distilled-1.3B-ct2", "nllb-200-distilled-1.3B-ct2-int8", "nllb-200-distilled-1.3B-ct2-float16", 
           "nllb-200-distilled-600M-ct2", "nllb-200-distilled-600M-ct2-int8", "nllb-200-distilled-600M-ct2-float16",
           "mt5-zh-ja-en-trimmed",
           "mt5-zh-ja-en-trimmed-fine-tuned-v1"]

def check_model_name(name):
    return any(allowed_name in name for allowed_name in _MODELS)

def download_model(
    model_config: ModelConfig,
    output_dir: Optional[str] = None,
    local_files_only: bool = False,
    cache_dir: Optional[str] = None,
):
    """"download_model" is referenced from the "utils.py" script 
      of the "faster_whisper" project, authored by guillaumekln.
    
    Downloads a nllb-200 model from the Hugging Face Hub.

    The model is downloaded from https://huggingface.co/facebook.

    Args:
      model_config: config of the model to download (facebook/nllb-distilled-600M, 
        facebook/nllb-distilled-1.3B, facebook/nllb-1.3B, facebook/nllb-3.3B...).
      output_dir: Directory where the model should be saved. If not set, the model is saved in
        the cache directory.
      local_files_only:  If True, avoid downloading the file and return the path to the local
        cached file if it exists.
      cache_dir: Path to the folder where cached files are stored.

    Returns:
      The path to the downloaded model.

    Raises:
      ValueError: if the model size is invalid.
    """
    if not check_model_name(model_config.name):
        raise ValueError(
            "Invalid model name '%s', expected one of: %s" % (model_config.name, ", ".join(_MODELS))
        )

    repo_id = model_config.url #"facebook/nllb-200-%s" % 

    allow_patterns = [
        "config.json",
        "generation_config.json",
        "model.bin",
        "pytorch_model.bin",
        "pytorch_model.bin.index.json",
        "pytorch_model-00001-of-00003.bin",
        "pytorch_model-00002-of-00003.bin",
        "pytorch_model-00003-of-00003.bin",
        "sentencepiece.bpe.model",
        "tokenizer.json",
        "tokenizer_config.json",
        "shared_vocabulary.txt",
        "shared_vocabulary.json",
        "special_tokens_map.json",
        "spiece.model",
    ]

    kwargs = {
        "local_files_only": local_files_only,
        "allow_patterns": allow_patterns,
        #"tqdm_class": disabled_tqdm,
    }

    if output_dir is not None:
        kwargs["local_dir"] = output_dir
        kwargs["local_dir_use_symlinks"] = False

    if cache_dir is not None:
        kwargs["cache_dir"] = cache_dir

    try:
        return huggingface_hub.snapshot_download(repo_id, **kwargs)
    except (
        huggingface_hub.utils.HfHubHTTPError,
        requests.exceptions.ConnectionError,
    ) as exception:
        warnings.warn(
            "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
            repo_id,
            exception,
        )
        warnings.warn(
            "Trying to load the model directly from the local cache, if it exists."
        )

        kwargs["local_files_only"] = True
        return huggingface_hub.snapshot_download(repo_id, **kwargs)