File size: 4,504 Bytes
2b6deb0 4463ade 2b6deb0 4463ade 2b6deb0 4463ade 2b6deb0 4463ade 2b6deb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import copy
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class HybridCLIPConfig(PretrainedConfig):
r"""
:class:`HybridCLIPConfig` is the configuration class to store the configuration of a
:class:`~HybridCLIPModel`. It is used to instantiate HybridCLIPModel model according to the specified arguments,
defining the text model and vision model configs.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args:
text_config_dict (:obj:`dict`):
Dictionary of configuration options that defines text model config.
vision_config_dict (:obj:`dict`):
Dictionary of configuration options that defines vison model config.
projection_dim (:obj:`int`, `optional`, defaults to 512):
Dimentionality of text and vision projection layers.
kwargs (`optional`):
Dictionary of keyword arguments.
Examples::
>>> from transformers import BertConfig, CLIPConfig, HybridCLIPConfig, FlaxHybridCLIP
>>> # Initializing a BERT and CLIP configuration
>>> config_text = BertConfig()
>>> config_vision = CLIPConfig()
>>> config = HybridCLIPConfig.from_text_vision_configs(config_text, config_vision, projection_dim=512)
>>> # Initializing a BERT and CLIPVision model
>>> model = EncoderDecoderModel(config=config)
>>> # Accessing the model configuration
>>> config_text = model.config.text_config
>>> config_vision = model.config.vision_config
>>> # Saving the model, including its configuration
>>> model.save_pretrained('my-model')
>>> # loading model and config from pretrained folder
>>> encoder_decoder_config = HybridCLIPConfig.from_pretrained('my-model')
>>> model = FlaxHybridCLIP.from_pretrained('my-model', config=encoder_decoder_config)
"""
model_type = "hybrid-clip"
is_composition = True
def __init__(self, projection_dim=512, **kwargs):
super().__init__(**kwargs)
if "text_config" not in kwargs:
raise ValueError("`text_config` can not be `None`.")
if "vision_config" not in kwargs:
raise ValueError("`vision_config` can not be `None`.")
text_config = kwargs.pop("text_config")
vision_config = kwargs.pop("vision_config")
text_model_type = text_config.pop("model_type")
vision_model_type = vision_config.pop("model_type")
from transformers import AutoConfig
self.text_config = AutoConfig.for_model(text_model_type, **text_config)
if vision_model_type == "clip":
self.vision_config = AutoConfig.for_model(
vision_model_type, **vision_config
).vision_config
elif vision_model_type == "clip_vision_model":
from transformers import CLIPVisionConfig
self.vision_config = CLIPVisionConfig(**vision_config)
else:
self.vision_config = AutoConfig.for_model(
vision_model_type, **vision_config
)
self.projection_dim = projection_dim
self.initializer_factor = 1.0
@classmethod
def from_text_vision_configs(
cls, text_config: PretrainedConfig, vision_config: PretrainedConfig, **kwargs
):
r"""
Instantiate a :class:`HybridCLIPConfig` (or a derived class) from text model configuration and
vision model configuration.
Returns:
:class:`HybridCLIPConfig`: An instance of a configuration object
"""
return cls(
text_config=text_config.to_dict(),
vision_config=vision_config.to_dict(),
**kwargs
)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default
:meth:`~transformers.PretrainedConfig.to_dict`.
Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
|