File size: 1,226 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from fairseq.tasks import register_task
from fairseq.tasks.speech_to_text import SpeechToTextTask
from fairseq.tasks.translation import TranslationTask, TranslationConfig

try:
    import examples.simultaneous_translation  # noqa

    import_successful = True
except BaseException:
    import_successful = False


logger = logging.getLogger(__name__)


def check_import(flag):
    if not flag:
        raise ImportError(
            "'examples.simultaneous_translation' is not correctly imported. "
            "Please considering `pip install -e $FAIRSEQ_DIR`."
        )


@register_task("simul_speech_to_text")
class SimulSpeechToTextTask(SpeechToTextTask):
    def __init__(self, args, tgt_dict):
        check_import(import_successful)
        super().__init__(args, tgt_dict)


@register_task("simul_text_to_text", dataclass=TranslationConfig)
class SimulTextToTextTask(TranslationTask):
    def __init__(self, cfg, src_dict, tgt_dict):
        check_import(import_successful)
        super().__init__(cfg, src_dict, tgt_dict)