Spaces:
Runtime error
Runtime error
import copy | |
from transformers import GPT2Config, ViTConfig | |
from transformers.configuration_utils import PretrainedConfig | |
from transformers.utils import logging | |
logger = logging.get_logger(__name__) | |
class ViTGPT2Config(PretrainedConfig): | |
model_type = "vit-gpt2" | |
is_composition = True | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
if "vit_config" not in kwargs: | |
raise ValueError("`vit_config` can not be `None`.") | |
if "gpt2_config" not in kwargs: | |
raise ValueError("`gpt2_config` can not be `None`.") | |
vit_config = kwargs.pop("vit_config") | |
gpt2_config = kwargs.pop("gpt2_config") | |
self.vit_config = ViTConfig(**vit_config) | |
self.gpt2_config = GPT2Config(**gpt2_config) | |
def from_vit_gpt2_configs( | |
cls, vit_config: PretrainedConfig, gpt2_config: PretrainedConfig, **kwargs | |
): | |
return cls( | |
vit_config=vit_config.to_dict(), | |
gpt2_config=gpt2_config.to_dict(), | |
**kwargs | |
) | |
def to_dict(self): | |
output = copy.deepcopy(self.__dict__) | |
output["vit_config"] = self.vit_config.to_dict() | |
output["gpt2_config"] = self.gpt2_config.to_dict() | |
output["model_type"] = self.__class__.model_type | |
return output |