File size: 2,590 Bytes
a244e91 ac161f7 a244e91 ac161f7 a244e91 b04e4c6 a244e91 ac161f7 b04e4c6 54ece9e a244e91 ac161f7 a244e91 ac161f7 a244e91 ac161f7 a244e91 b04e4c6 9aceda3 b04e4c6 54ece9e b8c22f0 54ece9e b8c22f0 54ece9e 5081c5d 54ece9e a244e91 b31314b ac161f7 a244e91 ac161f7 a244e91 ac161f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import copy
from transformers import ViTConfig, GPT2Config
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class ViTGPT2Config(PretrainedConfig):
model_type = "vit-gpt2"
is_composition = True
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs):
super().__init__(
vision_config_dict=vision_config_dict, text_config_dict=text_config_dict, **kwargs
)
project_encoder = kwargs.pop("project_encoder", None)
if vision_config_dict is None:
vision_config_dict = {}
logger.info("vision_config_dict is None. initializing the ViTConfig with default values.")
if text_config_dict is None:
text_config_dict = {}
logger.info("text_config_dict is None. Initializing the GPT2Config with default values.")
self.vision_config = ViTConfig(**vision_config_dict)
self.text_config = GPT2Config(**text_config_dict)
self.is_encoder_decoder = True
# Required in `generate()`.
self.bos_token_id = self.text_config.bos_token_id
self.eos_token_id = self.text_config.eos_token_id
assert hasattr(self.text_config, 'pad_token_id')
self.pad_token_id = self.text_config.pad_token_id
self.decoder_start_token_id = self.text_config.bos_token_id
self.forced_eos_token_id = self.text_config.eos_token_id
_project_encoder = getattr(self.text_config, "project_encoder", None)
if project_encoder is not None and _project_encoder is not None:
assert project_encoder == _project_encoder
elif project_encoder is not None:
_project_encoder = project_encoder
elif _project_encoder is not None:
project_encoder = _project_encoder
else:
project_encoder = False
self.project_encoder = project_encoder
self.text_config.project_encoder = project_encoder
@classmethod
def from_vision_text_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs):
return cls(vision_config_dict=vision_config.to_dict(), text_config_dict=text_config.to_dict(), **kwargs)
def to_dict(self):
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
|