File size: 3,796 Bytes
a362e35
f1118e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361679a
f1118e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d85f539
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""TaiVisionLM configuration"""

from transformers import PretrainedConfig
from transformers import logging, CONFIG_MAPPING
import warnings
import transformers

logger = logging.get_logger(__name__)

class TaiVisionLMConfig(PretrainedConfig):
    model_type = "taivisionlm"
    is_composition = False

    def __init__(
            self,
            vision_config=None,
            text_config=None,
            ignore_index=-100,
            image_token_idx=32000,
            vocab_size=32001,
            projection_dim=768,
            hidden_size=2048,
            **kwargs, 
    ):
        self.ignore_index = ignore_index
        self.image_token_index = image_token_idx
        self._vocab_size = vocab_size
        self.projection_dim = projection_dim
        self.hidden_size = hidden_size
        self.vision_config = vision_config
        self.is_encoder_decoder = False
        
        if isinstance(self.vision_config, dict):
            vision_config["model_type"] = (
                vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model"
            )
            self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
        elif vision_config is None:
            self.vision_config = CONFIG_MAPPING["siglip_vision_model"](
                attention_dropout=0.0,
                hidden_act="gelu_pytorch_tanh",
                hidden_size=768,
                image_size=224,
                intermediate_size=3072,
                layer_norm_eps=1e-06,
                num_attention_heads=12,
                num_channels=3,
                num_hidden_layers=12,
                patch_size=16,
            )
        
        self.vocab_size = vocab_size
        self.text_config = text_config

        if isinstance(self.text_config, dict):
            text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gpt2"
            self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
        elif text_config is None:
            self.text_config = CONFIG_MAPPING["llama"](
                architecture= ["LlamaForCausalLM"],
                hidden_act = "silu",
                attention_bias =  False,
                attention_dropout =  0.0,
                bos_token_id =  1,
                eos_token_id =  2,
                hidden_size =  2048,
                initializer_range =  0.02,
                intermediate_size =  5632,
                max_position_embeddings =  2048,
                model_type =  "llama",
                num_attention_heads =  32,
                num_hidden_layers =  22,
                num_key_value_heads =  4,
                pretraining_tp =  1,
                rms_norm_eps =  1e-05,
                rope_scaling =  None,
                rope_theta =  10000.0,
                tie_word_embeddings =  False,
                torch_dtype =  "bfloat16",
                transformers_version =  "4.40.2",
                use_cache =  True,
                vocab_size =  32001
            )
        self.num_image_tokens = (self.vision_config.image_size // self.vision_config.patch_size) ** 2
        self.pad_token_id = self.text_config.pad_token_id
        self.vision_config.projection_dim = projection_dim
        super().__init__(**kwargs)

    @property
    def vocab_size(self):
        warnings.warn(
            "The `vocab_size` attribute is deprecated and will be removed in v4.44, Please use `text_config.vocab_size` instead.",
            FutureWarning,
        )
        return self._vocab_size

    @vocab_size.setter
    def vocab_size(self, value):
        self._vocab_size = value

    def to_dict(self):
        output = super().to_dict()
        output.pop("_vocab_size", None)
        return output