|
import copy |
|
|
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.utils import logging |
|
from transformers import AutoConfig, CLIPVisionConfig |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class MedCLIPConfig(PretrainedConfig): |
|
r""" |
|
:class:`MedCLIPConfig` is the configuration class to store the configuration of a |
|
:class:`~MedCLIPModel`. It is used to instantiate HybridCLIPModel model according to the specified arguments, |
|
defining the text model and vision model configs. |
|
|
|
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model |
|
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. |
|
|
|
Args: |
|
text_config_dict (:obj:`dict`): |
|
Dictionary of configuration options that defines text model config. |
|
vision_config_dict (:obj:`dict`): |
|
Dictionary of configuration options that defines vison model config. |
|
projection_dim (:obj:`int`, `optional`, defaults to 512): |
|
Dimentionality of text and vision projection layers. |
|
kwargs (`optional`): |
|
Dictionary of keyword arguments. |
|
|
|
Examples:: |
|
|
|
>>> from transformers import BertConfig, CLIPConfig, MedCLIPConfig, FlaxMedCLIP |
|
|
|
>>> # Initializing a BERT and CLIP configuration |
|
>>> config_text = BertConfig() |
|
>>> config_vision = CLIPConfig() |
|
|
|
>>> config = MedCLIPConfig.from_text_vision_configs(config_text, config_vision, projection_dim=512) |
|
|
|
>>> # Initializing a BERT and CLIPVision model |
|
>>> model = EncoderDecoderModel(config=config) |
|
|
|
>>> # Accessing the model configuration |
|
>>> config_text = model.config.text_config |
|
>>> config_vision = model.config.vision_config |
|
|
|
>>> # Saving the model, including its configuration |
|
>>> model.save_pretrained('my-model') |
|
|
|
>>> # loading model and config from pretrained folder |
|
>>> encoder_decoder_config = MedCLIPConfig.from_pretrained('my-model') |
|
>>> model = FlaxMedCLIP.from_pretrained('my-model', config=encoder_decoder_config) |
|
""" |
|
|
|
model_type = "hybrid-clip" |
|
is_composition = True |
|
|
|
def __init__(self, projection_dim=512, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
if "text_config" not in kwargs: |
|
raise ValueError("`text_config` can not be `None`.") |
|
|
|
if "vision_config" not in kwargs: |
|
raise ValueError("`vision_config` can not be `None`.") |
|
|
|
text_config = kwargs.pop("text_config") |
|
vision_config = kwargs.pop("vision_config") |
|
|
|
text_model_type = text_config.pop("model_type") |
|
vision_model_type = vision_config.pop("model_type") |
|
|
|
self.text_config = AutoConfig.for_model(text_model_type, **text_config) |
|
|
|
if vision_model_type == "clip": |
|
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config |
|
elif vision_model_type == "clip_vision_model": |
|
self.vision_config = CLIPVisionConfig(**vision_config) |
|
else: |
|
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config) |
|
|
|
self.projection_dim = projection_dim |
|
self.initializer_factor = 1.0 |
|
|
|
@classmethod |
|
def from_text_vision_configs(cls, text_config: PretrainedConfig, vision_config: PretrainedConfig, **kwargs): |
|
r""" |
|
Instantiate a :class:`MedCLIPConfig` (or a derived class) from text model configuration and |
|
vision model configuration. |
|
|
|
Returns: |
|
:class:`MedCLIPConfig`: An instance of a configuration object |
|
""" |
|
|
|
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) |
|
|
|
def to_dict(self): |
|
""" |
|
Serializes this instance to a Python dictionary. Override the default |
|
:meth:`~transformers.PretrainedConfig.to_dict`. |
|
|
|
Returns: |
|
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, |
|
""" |
|
output = copy.deepcopy(self.__dict__) |
|
output["text_config"] = self.text_config.to_dict() |
|
output["vision_config"] = self.vision_config.to_dict() |
|
output["model_type"] = self.__class__.model_type |
|
return output |
|
|