|
import copy |
|
|
|
from transformers import ViTConfig, GPT2Config |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class ViTGPT2Config(PretrainedConfig): |
|
|
|
model_type = "vit-gpt2" |
|
is_composition = True |
|
keys_to_ignore_at_inference = ["past_key_values"] |
|
|
|
def __init__(self, vision_config_dict=None, text_config_dict=None, **kwargs): |
|
super().__init__( |
|
vision_config_dict=vision_config_dict, text_config_dict=text_config_dict, **kwargs |
|
) |
|
|
|
project_encoder = kwargs.pop("project_encoder", None) |
|
|
|
if vision_config_dict is None: |
|
vision_config_dict = {} |
|
logger.info("vision_config_dict is None. initializing the ViTConfig with default values.") |
|
|
|
if text_config_dict is None: |
|
text_config_dict = {} |
|
logger.info("text_config_dict is None. Initializing the GPT2Config with default values.") |
|
|
|
self.vision_config = ViTConfig(**vision_config_dict) |
|
self.text_config = GPT2Config(**text_config_dict) |
|
|
|
self.is_encoder_decoder = True |
|
|
|
|
|
self.bos_token_id = self.text_config.bos_token_id |
|
self.eos_token_id = self.text_config.eos_token_id |
|
|
|
assert hasattr(self.text_config, 'pad_token_id') |
|
self.pad_token_id = self.text_config.pad_token_id |
|
|
|
self.decoder_start_token_id = self.text_config.bos_token_id |
|
self.forced_eos_token_id = self.text_config.eos_token_id |
|
|
|
_project_encoder = getattr(self.text_config, "project_encoder", None) |
|
if project_encoder is not None and _project_encoder is not None: |
|
assert project_encoder == _project_encoder |
|
elif project_encoder is not None: |
|
_project_encoder = project_encoder |
|
elif _project_encoder is not None: |
|
project_encoder = _project_encoder |
|
else: |
|
project_encoder = False |
|
|
|
self.project_encoder = project_encoder |
|
self.text_config.project_encoder = project_encoder |
|
|
|
@classmethod |
|
def from_vision_text_configs(cls, vision_config: ViTConfig, text_config: GPT2Config, **kwargs): |
|
|
|
return cls(vision_config_dict=vision_config.to_dict(), text_config_dict=text_config.to_dict(), **kwargs) |
|
|
|
def to_dict(self): |
|
output = copy.deepcopy(self.__dict__) |
|
output["vision_config"] = self.vision_config.to_dict() |
|
output["text_config"] = self.text_config.to_dict() |
|
output["model_type"] = self.__class__.model_type |
|
return output |
|
|