File size: 2,590 Bytes
a244e91
 
ac161f7
a244e91
 
 
ac161f7
a244e91
 
 
 
 
 
 
b04e4c6
a244e91
ac161f7
b04e4c6
54ece9e
 
 
 
a244e91
ac161f7
 
 
a244e91
ac161f7
 
 
a244e91
ac161f7
 
a244e91
b04e4c6
 
9aceda3
 
 
 
 
 
 
b04e4c6
 
 
54ece9e
 
 
b8c22f0
54ece9e
b8c22f0
54ece9e
 
 
 
5081c5d
54ece9e
 
a244e91
b31314b
ac161f7
 
a244e91
 
 
ac161f7
 
a244e91
ac161f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import copy

from transformers import ViTConfig, GPT2Config
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging


logger = logging.get_logger(__name__)


class ViTGPT2Config(PretrainedConfig):

    model_type = "vit-gpt2"
    is_composition = True
    keys_to_ignore_at_inference = ["past_key_values"]

    def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs):
        super().__init__(
            vision_config_dict=vision_config_dict, text_config_dict=text_config_dict, **kwargs
        )

        project_encoder = kwargs.pop("project_encoder", None)

        if vision_config_dict is None:
            vision_config_dict = {}
            logger.info("vision_config_dict is None. initializing the ViTConfig with default values.")

        if text_config_dict is None:
            text_config_dict = {}
            logger.info("text_config_dict is None. Initializing the GPT2Config with default values.")

        self.vision_config = ViTConfig(**vision_config_dict)
        self.text_config = GPT2Config(**text_config_dict)

        self.is_encoder_decoder = True

        # Required in `generate()`.
        self.bos_token_id = self.text_config.bos_token_id
        self.eos_token_id = self.text_config.eos_token_id

        assert hasattr(self.text_config, 'pad_token_id')
        self.pad_token_id = self.text_config.pad_token_id

        self.decoder_start_token_id = self.text_config.bos_token_id
        self.forced_eos_token_id = self.text_config.eos_token_id

        _project_encoder = getattr(self.text_config, "project_encoder", None)
        if project_encoder is not None and _project_encoder is not None:
            assert project_encoder == _project_encoder
        elif project_encoder is not None:
            _project_encoder = project_encoder
        elif _project_encoder is not None:
            project_encoder = _project_encoder
        else:
            project_encoder = False

        self.project_encoder = project_encoder
        self.text_config.project_encoder = project_encoder

    @classmethod
    def from_vision_text_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs):

        return cls(vision_config_dict=vision_config.to_dict(), text_config_dict=text_config.to_dict(), **kwargs)

    def to_dict(self):
        output = copy.deepcopy(self.__dict__)
        output["vision_config"] = self.vision_config.to_dict()
        output["text_config"] = self.text_config.to_dict()
        output["model_type"] = self.__class__.model_type
        return output