ohashi56225 commited on
Commit
f188f75
1 Parent(s): 71c9844

Upload LlavaForConditionalGeneration

Browse files
config.json ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "output/jp-llava-small-sfcoco-bs8-lr5e5/checkpoints",
3
+ "architectures": [
4
+ "LlavaForConditionalGeneration"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_llava.LlavaConfig",
8
+ "AutoModelForVision2Seq": "modeling_llava.LlavaForConditionalGeneration"
9
+ },
10
+ "initializer_factor": 1.0,
11
+ "initializer_range": 0.02,
12
+ "mlp_config": {
13
+ "_name_or_path": "",
14
+ "add_cross_attention": false,
15
+ "architectures": null,
16
+ "bad_words_ids": null,
17
+ "begin_suppress_tokens": null,
18
+ "bos_token_id": null,
19
+ "chunk_size_feed_forward": 0,
20
+ "cross_attention_hidden_size": null,
21
+ "decoder_start_token_id": null,
22
+ "diversity_penalty": 0.0,
23
+ "do_sample": false,
24
+ "early_stopping": false,
25
+ "encoder_no_repeat_ngram_size": 0,
26
+ "eos_token_id": null,
27
+ "exponential_decay_length_penalty": null,
28
+ "finetuning_task": null,
29
+ "forced_bos_token_id": null,
30
+ "forced_eos_token_id": null,
31
+ "id2label": {
32
+ "0": "LABEL_0",
33
+ "1": "LABEL_1"
34
+ },
35
+ "is_decoder": false,
36
+ "is_encoder_decoder": false,
37
+ "label2id": {
38
+ "LABEL_0": 0,
39
+ "LABEL_1": 1
40
+ },
41
+ "length_penalty": 1.0,
42
+ "max_length": 20,
43
+ "min_length": 0,
44
+ "model_type": "llava_mlp",
45
+ "no_repeat_ngram_size": 0,
46
+ "num_beam_groups": 1,
47
+ "num_beams": 1,
48
+ "num_hidden_layers": 2,
49
+ "num_return_sequences": 1,
50
+ "output_attentions": false,
51
+ "output_hidden_states": false,
52
+ "output_scores": false,
53
+ "pad_token_id": null,
54
+ "prefix": null,
55
+ "problem_type": null,
56
+ "pruned_heads": {},
57
+ "remove_invalid_values": false,
58
+ "repetition_penalty": 1.0,
59
+ "return_dict": true,
60
+ "return_dict_in_generate": false,
61
+ "sep_token_id": null,
62
+ "suppress_tokens": null,
63
+ "task_specific_params": null,
64
+ "temperature": 1.0,
65
+ "tf_legacy_loss": false,
66
+ "tie_encoder_decoder": false,
67
+ "tie_word_embeddings": true,
68
+ "tokenizer_class": null,
69
+ "top_k": 50,
70
+ "top_p": 1.0,
71
+ "torch_dtype": null,
72
+ "torchscript": false,
73
+ "typical_p": 1.0,
74
+ "use_bfloat16": false
75
+ },
76
+ "model_type": "llava",
77
+ "text_config": {
78
+ "_name_or_path": "rinna/japanese-gpt-neox-small",
79
+ "add_cross_attention": false,
80
+ "architectures": [
81
+ "GPTNeoXForCausalLM"
82
+ ],
83
+ "attention_dropout": 0.0,
84
+ "bad_words_ids": null,
85
+ "begin_suppress_tokens": null,
86
+ "bos_token_id": 2,
87
+ "chunk_size_feed_forward": 0,
88
+ "classifier_dropout": 0.1,
89
+ "cross_attention_hidden_size": null,
90
+ "decoder_start_token_id": null,
91
+ "diversity_penalty": 0.0,
92
+ "do_sample": false,
93
+ "early_stopping": false,
94
+ "encoder_no_repeat_ngram_size": 0,
95
+ "eos_token_id": 3,
96
+ "exponential_decay_length_penalty": null,
97
+ "finetuning_task": null,
98
+ "forced_bos_token_id": null,
99
+ "forced_eos_token_id": null,
100
+ "hidden_act": "gelu",
101
+ "hidden_dropout": 0.0,
102
+ "hidden_size": 768,
103
+ "id2label": {
104
+ "0": "LABEL_0",
105
+ "1": "LABEL_1"
106
+ },
107
+ "initializer_range": 0.02,
108
+ "intermediate_size": 3072,
109
+ "is_decoder": false,
110
+ "is_encoder_decoder": false,
111
+ "label2id": {
112
+ "LABEL_0": 0,
113
+ "LABEL_1": 1
114
+ },
115
+ "layer_norm_eps": 1e-05,
116
+ "length_penalty": 1.0,
117
+ "max_length": 20,
118
+ "max_position_embeddings": 2048,
119
+ "min_length": 0,
120
+ "model_type": "gpt_neox",
121
+ "no_repeat_ngram_size": 0,
122
+ "num_attention_heads": 12,
123
+ "num_beam_groups": 1,
124
+ "num_beams": 1,
125
+ "num_hidden_layers": 12,
126
+ "num_return_sequences": 1,
127
+ "output_attentions": false,
128
+ "output_hidden_states": false,
129
+ "output_scores": false,
130
+ "pad_token_id": null,
131
+ "prefix": null,
132
+ "problem_type": null,
133
+ "pruned_heads": {},
134
+ "remove_invalid_values": false,
135
+ "repetition_penalty": 1.0,
136
+ "return_dict": true,
137
+ "return_dict_in_generate": false,
138
+ "rope_scaling": null,
139
+ "rotary_emb_base": 10000,
140
+ "rotary_pct": 1.0,
141
+ "sep_token_id": null,
142
+ "suppress_tokens": null,
143
+ "task_specific_params": null,
144
+ "temperature": 1.0,
145
+ "tf_legacy_loss": false,
146
+ "tie_encoder_decoder": false,
147
+ "tie_word_embeddings": false,
148
+ "tokenizer_class": "T5Tokenizer",
149
+ "top_k": 50,
150
+ "top_p": 1.0,
151
+ "torch_dtype": "float32",
152
+ "torchscript": false,
153
+ "typical_p": 1.0,
154
+ "use_bfloat16": false,
155
+ "use_cache": true,
156
+ "use_parallel_residual": false,
157
+ "vocab_size": 44416
158
+ },
159
+ "tie_word_embeddings": false,
160
+ "torch_dtype": "float32",
161
+ "transformers_version": "4.35.2",
162
+ "use_decoder_only_language_model": true,
163
+ "vision_config": {
164
+ "_name_or_path": "openai/clip-vit-large-patch14",
165
+ "add_cross_attention": false,
166
+ "architectures": null,
167
+ "attention_dropout": 0.0,
168
+ "bad_words_ids": null,
169
+ "begin_suppress_tokens": null,
170
+ "bos_token_id": null,
171
+ "chunk_size_feed_forward": 0,
172
+ "cross_attention_hidden_size": null,
173
+ "decoder_start_token_id": null,
174
+ "diversity_penalty": 0.0,
175
+ "do_sample": false,
176
+ "dropout": 0.0,
177
+ "early_stopping": false,
178
+ "encoder_no_repeat_ngram_size": 0,
179
+ "eos_token_id": null,
180
+ "exponential_decay_length_penalty": null,
181
+ "finetuning_task": null,
182
+ "forced_bos_token_id": null,
183
+ "forced_eos_token_id": null,
184
+ "hidden_act": "quick_gelu",
185
+ "hidden_size": 1024,
186
+ "id2label": {
187
+ "0": "LABEL_0",
188
+ "1": "LABEL_1"
189
+ },
190
+ "image_size": 224,
191
+ "initializer_factor": 1.0,
192
+ "initializer_range": 0.02,
193
+ "intermediate_size": 4096,
194
+ "is_decoder": false,
195
+ "is_encoder_decoder": false,
196
+ "label2id": {
197
+ "LABEL_0": 0,
198
+ "LABEL_1": 1
199
+ },
200
+ "layer_norm_eps": 1e-05,
201
+ "length_penalty": 1.0,
202
+ "max_length": 20,
203
+ "min_length": 0,
204
+ "model_type": "clip_vision_model",
205
+ "no_repeat_ngram_size": 0,
206
+ "num_attention_heads": 16,
207
+ "num_beam_groups": 1,
208
+ "num_beams": 1,
209
+ "num_channels": 3,
210
+ "num_hidden_layers": 24,
211
+ "num_return_sequences": 1,
212
+ "output_attentions": false,
213
+ "output_hidden_states": false,
214
+ "output_scores": false,
215
+ "pad_token_id": null,
216
+ "patch_size": 14,
217
+ "prefix": null,
218
+ "problem_type": null,
219
+ "projection_dim": 768,
220
+ "pruned_heads": {},
221
+ "remove_invalid_values": false,
222
+ "repetition_penalty": 1.0,
223
+ "return_dict": true,
224
+ "return_dict_in_generate": false,
225
+ "sep_token_id": null,
226
+ "suppress_tokens": null,
227
+ "task_specific_params": null,
228
+ "temperature": 1.0,
229
+ "tf_legacy_loss": false,
230
+ "tie_encoder_decoder": false,
231
+ "tie_word_embeddings": true,
232
+ "tokenizer_class": null,
233
+ "top_k": 50,
234
+ "top_p": 1.0,
235
+ "torch_dtype": null,
236
+ "torchscript": false,
237
+ "typical_p": 1.0,
238
+ "use_bfloat16": false
239
+ },
240
+ "vision_select_feature": "patch",
241
+ "vision_select_layer": -2
242
+ }
configuration_llava.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stability AI team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from typing import Union
16
+
17
+ from transformers import PretrainedConfig, CLIPVisionConfig
18
+ from transformers.models.auto import CONFIG_MAPPING
19
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
20
+ from transformers.utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class LlavaMlpConfig(PretrainedConfig):
27
+ model_type = "llava_mlp"
28
+
29
+ def __init__(
30
+ self,
31
+ num_hidden_layers=2,
32
+ **kwargs,
33
+ ):
34
+ super().__init__(**kwargs)
35
+
36
+ self.num_hidden_layers = num_hidden_layers
37
+
38
+ @classmethod
39
+ def from_pretrained(
40
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
41
+ ) -> "PretrainedConfig":
42
+ cls._set_token_in_kwargs(kwargs)
43
+
44
+ config_dict, kwargs = cls.get_config_dict(
45
+ pretrained_model_name_or_path, **kwargs
46
+ )
47
+
48
+ # get the qformer config dict if we are loading from InstructBlipConfig
49
+ if config_dict.get("model_type") == "llava":
50
+ config_dict = config_dict["mlp_config"]
51
+
52
+ if (
53
+ "model_type" in config_dict
54
+ and hasattr(cls, "model_type")
55
+ and config_dict["model_type"] != cls.model_type
56
+ ):
57
+ logger.warning(
58
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
59
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
60
+ )
61
+
62
+ return cls.from_dict(config_dict, **kwargs)
63
+
64
+
65
+ class LlavaConfig(PretrainedConfig):
66
+ model_type = "llava"
67
+ is_composition = True
68
+
69
+ def __init__(
70
+ self,
71
+ vision_config=None,
72
+ mlp_config=None,
73
+ text_config=None,
74
+ vision_select_layer=-2,
75
+ vision_select_feature="patch",
76
+ **kwargs,
77
+ ):
78
+ super().__init__(**kwargs)
79
+
80
+ if vision_config is None:
81
+ vision_config = {}
82
+ logger.info(
83
+ "vision_config is None. initializing the CLIPVisionConfig with default values."
84
+ )
85
+
86
+ if mlp_config is None:
87
+ mlp_config = {}
88
+ logger.info(
89
+ "mlp_config is None. Initializing the LlavaMlpConfig with default values."
90
+ )
91
+
92
+ if text_config is None:
93
+ text_config = {}
94
+ logger.info(
95
+ "text_config is None. Initializing the text config with default values (`OPTConfig`)."
96
+ )
97
+
98
+ self.vision_config = CLIPVisionConfig(**vision_config)
99
+ self.mlp_config = LlavaMlpConfig(**mlp_config)
100
+ text_model_type = text_config["model_type"]
101
+ self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
102
+
103
+ self.tie_word_embeddings = self.text_config.tie_word_embeddings
104
+ self.is_encoder_decoder = self.text_config.is_encoder_decoder
105
+
106
+ self.use_decoder_only_language_model = (
107
+ self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
108
+ )
109
+ self.vision_select_layer = vision_select_layer
110
+ assert vision_select_feature in [
111
+ "cls_patch",
112
+ "patch",
113
+ ], f"Unexpected select feature: {vision_select_feature}"
114
+ self.vision_select_feature = vision_select_feature
115
+ self.initializer_factor = 1.0
116
+ self.initializer_range = 0.02
117
+
118
+ @classmethod
119
+ def from_vision_mlp_text_configs(
120
+ cls,
121
+ vision_config: CLIPVisionConfig,
122
+ mlp_config: LlavaMlpConfig,
123
+ text_config: PretrainedConfig,
124
+ **kwargs,
125
+ ):
126
+ return cls(
127
+ vision_config=vision_config.to_dict(),
128
+ mlp_config=mlp_config.to_dict(),
129
+ text_config=text_config.to_dict(),
130
+ **kwargs,
131
+ )
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 2,
4
+ "eos_token_id": 3,
5
+ "transformers_version": "4.35.2"
6
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dca99e6a876047698b4bde214395e27b6babfa133584f341efe7c03006134a5f
3
+ size 1831419000
modeling_llava.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stability AI team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional, Tuple, Union, Any
15
+ from dataclasses import dataclass
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch.nn import CrossEntropyLoss
19
+
20
+ from transformers import (
21
+ AutoModelForCausalLM,
22
+ AutoModelForSeq2SeqLM,
23
+ PreTrainedModel,
24
+ CLIPVisionModel,
25
+ )
26
+
27
+ from transformers.utils import logging, ModelOutput
28
+ from .configuration_llava import LlavaConfig
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ @dataclass
35
+ class LlavaForConditionalGenerationModelOutput(ModelOutput):
36
+ loss: Optional[Tuple[torch.FloatTensor]] = None
37
+ logits: Optional[Tuple[torch.FloatTensor]] = None
38
+ vision_outputs: Optional[torch.FloatTensor] = None
39
+ language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None
40
+
41
+ def to_tuple(self) -> Tuple[Any]:
42
+ return tuple(
43
+ self[k]
44
+ if k not in ["vision_outputs", "language_model_outputs"]
45
+ else getattr(self, k).to_tuple()
46
+ for k in self.keys()
47
+ )
48
+
49
+
50
+ class LlavaPreTrainedModel(PreTrainedModel):
51
+ """
52
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
53
+ models.
54
+ """
55
+
56
+ config_class = LlavaConfig
57
+ base_model_prefix = "llava"
58
+
59
+ # Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlip
60
+ def _init_weights(self, module):
61
+ """Initialize the weights"""
62
+ factor = self.config.initializer_range
63
+ if (
64
+ isinstance(module, nn.Conv2d)
65
+ or isinstance(module, nn.Embedding)
66
+ or isinstance(module, nn.Linear)
67
+ ):
68
+ module.weight.data.normal_(mean=0.0, std=factor)
69
+ if hasattr(module, "bias") and module.bias is not None:
70
+ module.bias.data.zero_()
71
+
72
+ elif isinstance(module, nn.LayerNorm):
73
+ module.bias.data.zero_()
74
+ module.weight.data.fill_(1.0)
75
+ elif isinstance(module, nn.Linear) and module.bias is not None:
76
+ module.bias.data.zero_()
77
+
78
+
79
+ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
80
+ config_class = LlavaConfig
81
+ main_input_name = "pixel_values"
82
+ _no_split_modules = []
83
+
84
+ def __init__(self, config: LlavaConfig):
85
+ super().__init__(config)
86
+
87
+ self.vision_model = CLIPVisionModel(config.vision_config)
88
+ if config.use_decoder_only_language_model:
89
+ language_model = AutoModelForCausalLM.from_config(config.text_config)
90
+ else:
91
+ language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
92
+
93
+ if language_model._no_split_modules is not None:
94
+ self._no_split_modules.extend(language_model._no_split_modules)
95
+
96
+ if language_model._keep_in_fp32_modules is not None:
97
+ self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules)
98
+
99
+ self.language_model = language_model
100
+
101
+ modules = [
102
+ nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size)
103
+ ]
104
+ for _ in range(1, config.mlp_config.num_hidden_layers):
105
+ modules.append(nn.GELU())
106
+ modules.append(
107
+ nn.Linear(
108
+ config.text_config.hidden_size, config.text_config.hidden_size
109
+ )
110
+ )
111
+ self.mlp = nn.Sequential(*modules)
112
+
113
+ # Initialize weights and apply final processing
114
+ self.post_init()
115
+
116
+ def get_input_embeddings(self):
117
+ return self.language_model.get_input_embeddings()
118
+
119
+ def set_input_embeddings(self, value):
120
+ self.language_model.set_input_embeddings(value)
121
+
122
+ def set_output_embeddings(self, new_embeddings):
123
+ self.language_model.set_output_embeddings(new_embeddings)
124
+
125
+ def get_output_embeddings(self) -> nn.Module:
126
+ return self.language_model.get_output_embeddings()
127
+
128
+ def get_encoder(self):
129
+ return self.language_model.get_encoder()
130
+
131
+ def get_decoder(self):
132
+ return self.language_model.get_decoder()
133
+
134
+ def _tie_weights(self):
135
+ if not self.config.use_decoder_only_language_model:
136
+ self.language_model.encoder.embed_tokens = self.language_model.shared
137
+ self.language_model.decoder.embed_tokens = self.language_model.shared
138
+
139
+ def _preprocess_accelerate(self):
140
+ r"""
141
+ Some pre-processing hacks to make the model `accelerate` compatible. Check
142
+ https://github.com/huggingface/transformers/pull/21707 for more details.
143
+ """
144
+ hf_device_map = self.hf_device_map
145
+
146
+ if (
147
+ len(hf_device_map) > 1
148
+ and "language_model" not in hf_device_map
149
+ and torch.cuda.device_count() > 1
150
+ ):
151
+ # warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`.
152
+ logger.warning(
153
+ "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
154
+ " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
155
+ " Please pass a `device_map` that contains `language_model` to remove this warning."
156
+ " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
157
+ " more details on creating a `device_map` for large models.",
158
+ )
159
+
160
+ if hasattr(self.language_model, "_hf_hook"):
161
+ self.language_model._hf_hook.io_same_device = (
162
+ True # For `generate` compatibility
163
+ )
164
+
165
+ def forward(
166
+ self,
167
+ pixel_values: torch.FloatTensor,
168
+ input_ids: Optional[torch.FloatTensor] = None,
169
+ attention_mask: Optional[torch.LongTensor] = None,
170
+ decoder_input_ids: Optional[torch.LongTensor] = None,
171
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
172
+ output_attentions: Optional[bool] = None,
173
+ output_hidden_states: Optional[bool] = None,
174
+ labels: Optional[torch.LongTensor] = None,
175
+ return_dict: Optional[bool] = None,
176
+ ) -> Union[Tuple, LlavaForConditionalGenerationModelOutput]:
177
+ return_dict = (
178
+ return_dict if return_dict is not None else self.config.use_return_dict
179
+ )
180
+
181
+ # step 1: forward the images through the vision encoder,
182
+ vision_outputs = self.vision_model(
183
+ pixel_values=pixel_values,
184
+ output_attentions=output_attentions,
185
+ return_dict=return_dict,
186
+ output_hidden_states=True,
187
+ )
188
+ # (bsz, seq len, hidden_size)
189
+ image_embeds = vision_outputs.hidden_states[self.config.vision_select_layer]
190
+ if self.config.vision_select_feature == "patch":
191
+ image_embeds = image_embeds[:, 1:]
192
+ elif self.config.vision_select_feature == "cls_patch":
193
+ image_embeds = image_embeds
194
+ else:
195
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
196
+
197
+ # step 2: forward the image embeddings through the mlp
198
+ image_embeds = self.mlp(image_embeds)
199
+ image_attention_mask = torch.ones(
200
+ image_embeds.size()[:-1], device=image_embeds.device
201
+ )
202
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
203
+
204
+ # step 3: concatenate
205
+ inputs_embeds = torch.cat(
206
+ [image_embeds, inputs_embeds.to(image_embeds.device)],
207
+ dim=1,
208
+ )
209
+
210
+ if attention_mask is None:
211
+ attention_mask = torch.ones_like(input_ids, device=input_ids.device)
212
+
213
+ attention_mask = torch.cat(
214
+ [image_attention_mask.to(attention_mask.device), attention_mask],
215
+ dim=1,
216
+ )
217
+
218
+ if self.config.use_decoder_only_language_model:
219
+ outputs = self.language_model(
220
+ inputs_embeds=inputs_embeds,
221
+ attention_mask=attention_mask,
222
+ output_attentions=output_attentions,
223
+ output_hidden_states=output_hidden_states,
224
+ return_dict=return_dict,
225
+ )
226
+ logits = outputs.logits if return_dict else outputs[0]
227
+ loss = None
228
+ # we compute the loss here since we need to take into account the sequence length of the query embeds
229
+ if labels is not None:
230
+ labels = labels.to(logits.device)
231
+ logits = logits[:, -labels.size(1) :, :]
232
+ # Shift so that tokens < n predict n
233
+ shift_logits = logits[..., :-1, :].contiguous()
234
+ shift_labels = labels[..., 1:].contiguous().to(logits.device)
235
+
236
+ # Flatten the tokens
237
+ loss_fct = CrossEntropyLoss(reduction="mean")
238
+
239
+ loss = loss_fct(
240
+ shift_logits.view(-1, self.config.text_config.vocab_size),
241
+ shift_labels.view(-1),
242
+ )
243
+ else:
244
+ outputs = self.language_model(
245
+ inputs_embeds=inputs_embeds,
246
+ attention_mask=attention_mask,
247
+ decoder_input_ids=decoder_input_ids,
248
+ decoder_attention_mask=decoder_attention_mask,
249
+ output_attentions=output_attentions,
250
+ output_hidden_states=output_hidden_states,
251
+ return_dict=return_dict,
252
+ labels=labels,
253
+ )
254
+ loss = outputs.loss if return_dict else outputs[0]
255
+ logits = outputs.logits if return_dict else outputs[1]
256
+
257
+ if not return_dict:
258
+ output = (logits, vision_outputs, outputs)
259
+ return ((loss,) + output) if loss is not None else output
260
+
261
+ return LlavaForConditionalGenerationModelOutput(
262
+ loss=loss,
263
+ logits=logits,
264
+ vision_outputs=vision_outputs,
265
+ language_model_outputs=outputs,
266
+ )
267
+
268
+ def get_image_embeds(self, pixel_values: torch.FloatTensor):
269
+ vision_outputs = self.vision_model(
270
+ pixel_values=pixel_values,
271
+ output_hidden_states=True,
272
+ )
273
+ image_embeds = vision_outputs.hidden_states[self.config.vision_select_layer]
274
+ if self.config.vision_select_feature == "patch":
275
+ image_embeds = image_embeds[:, 1:]
276
+ elif self.config.vision_select_feature == "cls_patch":
277
+ image_embeds = image_embeds
278
+ else:
279
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
280
+
281
+ image_embeds = self.mlp(image_embeds)
282
+ image_attention_mask = torch.ones(
283
+ image_embeds.size()[:-1], device=image_embeds.device
284
+ )
285
+ return dict(
286
+ image_embeds=image_embeds,
287
+ image_attention_mask=image_attention_mask,
288
+ )
289
+
290
+ def prepare_for_lm_generation(
291
+ self,
292
+ pixel_values: torch.FloatTensor,
293
+ input_ids: Optional[torch.LongTensor] = None,
294
+ attention_mask: Optional[torch.LongTensor] = None,
295
+ ):
296
+ batch_size = pixel_values.shape[0]
297
+ vision_outputs = self.get_image_embeds(pixel_values)
298
+ image_embeds = vision_outputs["image_embeds"]
299
+ image_attention_mask = vision_outputs["image_attention_mask"]
300
+
301
+ if input_ids is None:
302
+ input_ids = (
303
+ torch.LongTensor([[self.config.text_config.bos_token_id]])
304
+ .repeat(batch_size, 1)
305
+ .to(image_embeds.device)
306
+ )
307
+ if attention_mask is None:
308
+ attention_mask = torch.ones_like(input_ids)
309
+ attention_mask = torch.cat(
310
+ [
311
+ image_attention_mask,
312
+ attention_mask.to(image_attention_mask.device),
313
+ ],
314
+ dim=1,
315
+ )
316
+
317
+ # concatenate query embeddings with prompt embeddings
318
+ inputs_embeds = self.get_input_embeddings()(input_ids)
319
+ inputs_embeds = torch.cat(
320
+ [image_embeds, inputs_embeds.to(image_embeds.device)],
321
+ dim=1,
322
+ )
323
+ return dict(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
324
+
325
+ @torch.no_grad()
326
+ def generate(
327
+ self,
328
+ pixel_values: torch.FloatTensor,
329
+ input_ids: Optional[torch.LongTensor] = None,
330
+ attention_mask: Optional[torch.LongTensor] = None,
331
+ **generate_kwargs,
332
+ ) -> torch.LongTensor:
333
+ if hasattr(self, "hf_device_map"):
334
+ # preprocess for `accelerate`
335
+ self._preprocess_accelerate()
336
+ encodings = self.prepare_for_lm_generation(
337
+ pixel_values=pixel_values,
338
+ input_ids=input_ids,
339
+ attention_mask=attention_mask,
340
+ )
341
+ outputs = self.language_model.generate(
342
+ **encodings,
343
+ **generate_kwargs,
344
+ )
345
+ return outputs