ViTamin-XL-336px / configuration_vitamin.py
bbexx's picture
update
82f517a
raw
history blame
5.47 kB
""" ViTamin
Paper: Designing Scalable Vison Models in the Vision-Language Era
@misc{chen2023designing,
title={Designing Scalable Vison Models in the Vision-Language Era},
author={Jieneng Chen and Qihang Yu and Xiaohui Shen and Alan Yuille and Liang-Cheih Chen},
year={2023},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
Based on Apache 2.0 licensed code at https://github.com/Beckschen/ViTamin
by Jieneng Chen 2024
"""
import copy
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
if TYPE_CHECKING:
from transformers.processing_utils import ProcessorMixin
from transformers.utils import TensorType
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class ViTaminTextConfig(PretrainedConfig):
model_type = "vitamin_text_model"
def __init__(
self,
context_length = 77,
vocab_size = 49408,
width = 1024,
heads = 16,
layers = 24,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.context_length = context_length
self.width = width
self.heads = heads
self.layers = layers
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if 'text_config' in config_dict:
config_dict = config_dict['text_config']
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class ViTaminVisionConfig(PretrainedConfig):
model_type = "vitamin_vision_model"
def __init__(
self,
timm_model_name = "vitamin_large",
timm_model_pretrained = False,
timm_pool = "",
timm_proj = "linear",
timm_drop = 0.0,
timm_drop_path = 0.1,
image_size = 256,
timm_proj_bias = False,
patch_dropout = 0.0,
drop_path = None,
**kwargs,
):
super().__init__(**kwargs)
self.timm_model_name = timm_model_name
self.timm_model_pretrained = timm_model_pretrained
self.timm_pool = timm_pool
self.timm_proj = timm_proj
self.timm_drop = timm_drop
self.timm_drop_path = timm_drop_path
self.timm_proj_bias = timm_proj_bias
self.patch_dropout = patch_dropout
self.image_size = image_size
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if 'vision_config' in config_dict:
config_dict = config_dict['vision_config']
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class ViTaminConfig(PretrainedConfig):
model_type = "vitamin"
is_composition = True
def __init__(
self, text_config=None, vision_config=None, embed_dim=512, **kwargs
):
super().__init__(**kwargs)
if text_config is None:
text_config = {}
logger.info("`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.")
if vision_config is None:
vision_config = {}
logger.info("`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.")
self.embed_dim = embed_dim
self.text_config = ViTaminTextConfig(**text_config)
self.vision_config = ViTaminVisionConfig(**vision_config)
@classmethod
def from_text_vision_configs(cls, text_config: ViTaminTextConfig, vision_config: ViTaminVisionConfig, **kwargs):
r"""
Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model
configuration.
Returns:
[`CLIPConfig`]: An instance of a configuration object
"""
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["text_config"] = self.text_config.to_dict()
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output