Natthaphon commited on
Commit
8a7aa7a
1 Parent(s): 9a80df1

Added config files

Browse files
Files changed (3) hide show
  1. configuration_cap.py +30 -0
  2. modeling_cap.py +268 -0
  3. readme.md +5 -0
configuration_cap.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, AutoConfig
2
+
3
+
4
+ class CLIPEncoderDecoderConfig(PretrainedConfig):
5
+ model_type = "clip-encoder-decoder"
6
+
7
+ def __init__(
8
+ self,
9
+ **kwargs):
10
+ super().__init__(**kwargs)
11
+
12
+ self.encoder = AutoConfig.from_pretrained('microsoft/swinv2-large-patch4-window12to16-192to256-22kto1k-ft')
13
+ self.decoder = AutoConfig.from_pretrained('airesearch/wangchanberta-base-att-spm-uncased')
14
+ self.is_encoder_decoder = True
15
+
16
+ @classmethod
17
+ def from_encoder_decoder_configs(
18
+ cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
19
+ ) -> PretrainedConfig:
20
+ r"""
21
+ Instantiate a [`VisionEncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model
22
+ configuration and decoder model configuration.
23
+
24
+ Returns:
25
+ [`VisionEncoderDecoderConfig`]: An instance of a configuration object
26
+ """
27
+ decoder_config.is_decoder = True
28
+ decoder_config.add_cross_attention = True
29
+
30
+ return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
modeling_cap.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ PreTrainedModel,
3
+ VisionEncoderDecoderModel,
4
+ VisionEncoderDecoderConfig,
5
+ AutoModel,
6
+ AutoModelForCausalLM,
7
+ AutoConfig
8
+ )
9
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
10
+ from torch import nn
11
+ from .configuration_cap import CLIPEncoderDecoderConfig
12
+ from typing import Optional, Tuple, Union
13
+ import torch
14
+ import gc
15
+ import os
16
+ import tempfile
17
+
18
+
19
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
20
+ """
21
+ Shift input ids one token to the right.
22
+ """
23
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
24
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
25
+ if decoder_start_token_id is None:
26
+ raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
27
+ shifted_input_ids[:, 0] = decoder_start_token_id
28
+
29
+ if pad_token_id is None:
30
+ raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
31
+ # replace possible -100 values in labels by `pad_token_id`
32
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
33
+
34
+ return shifted_input_ids
35
+
36
+
37
+ class CLIPEncoderDecoderModel(PreTrainedModel):
38
+ config_class = CLIPEncoderDecoderConfig
39
+ base_model_prefix = "clip_encoder_decoder"
40
+ main_input_name = "pixel_values"
41
+ supports_gradient_checkpointing = True
42
+ def __init__(
43
+ self,
44
+ config = None,
45
+ encoder = None,
46
+ decoder = None,
47
+ ):
48
+ config.tie_word_embeddings = False
49
+ super().__init__(config)
50
+
51
+ encoder = AutoModel.from_config(config.encoder)
52
+ encoder_hidden_size = encoder.config.hidden_size
53
+
54
+ if decoder is None:
55
+ config.decoder.is_decoder = True
56
+ config.decoder.add_cross_attention = True
57
+ decoder = AutoModelForCausalLM.from_config(config.decoder)
58
+
59
+ self.encoder = encoder
60
+ self.decoder = decoder
61
+
62
+ self.encoder.config = self.config.encoder
63
+ self.decoder.config = self.config.decoder
64
+
65
+ self.enc_to_dec_proj = nn.Linear(encoder_hidden_size, self.decoder.config.hidden_size)
66
+
67
+ def get_encoder(self):
68
+ return self.encoder
69
+
70
+ def get_decoder(self):
71
+ return self.decoder
72
+
73
+ def get_output_embeddings(self):
74
+ return self.decoder.get_output_embeddings()
75
+
76
+ def set_output_embeddings(self, new_embeddings):
77
+ return self.decoder.set_output_embeddings(new_embeddings)
78
+
79
+ @classmethod
80
+ def from_encoder_decoder_pretrained(
81
+ cls,
82
+ encoder_pretrained_model_name_or_path: str = None,
83
+ decoder_pretrained_model_name_or_path: str = None,
84
+ *model_args,
85
+ **kwargs,
86
+ ) -> PreTrainedModel:
87
+ kwargs_encoder = {
88
+ argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
89
+ }
90
+
91
+ kwargs_decoder = {
92
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
93
+ }
94
+
95
+ # remove encoder, decoder kwargs from kwargs
96
+ for key in kwargs_encoder.keys():
97
+ del kwargs["encoder_" + key]
98
+ for key in kwargs_decoder.keys():
99
+ del kwargs["decoder_" + key]
100
+
101
+ # Load and initialize the encoder and decoder
102
+ # The distinction between encoder and decoder at the model level is made
103
+ # by the value of the flag `is_decoder` that we need to set correctly.
104
+ encoder = kwargs_encoder.pop("model", None)
105
+ if encoder is None:
106
+ if encoder_pretrained_model_name_or_path is None:
107
+ raise ValueError(
108
+ "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
109
+ "to be defined."
110
+ )
111
+
112
+ if "config" not in kwargs_encoder:
113
+ encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
114
+ encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
115
+ )
116
+
117
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
118
+ encoder_config.is_decoder = False
119
+ encoder_config.add_cross_attention = False
120
+
121
+ kwargs_encoder["config"] = encoder_config
122
+
123
+ encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
124
+
125
+ decoder = kwargs_decoder.pop("model", None)
126
+ if decoder is None:
127
+ if decoder_pretrained_model_name_or_path is None:
128
+ raise ValueError(
129
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
130
+ "to be defined."
131
+ )
132
+
133
+ if "config" not in kwargs_decoder:
134
+ decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
135
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
136
+ )
137
+
138
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
139
+ decoder_config.is_decoder = True
140
+ decoder_config.add_cross_attention = True
141
+
142
+ kwargs_decoder["config"] = decoder_config
143
+
144
+ decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
145
+
146
+ # instantiate config with corresponding kwargs
147
+ config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
148
+
149
+ # make sure input & output embeddings is not tied
150
+ config.tie_word_embeddings = False
151
+ return cls(encoder=encoder, decoder=decoder, config=config)
152
+
153
+ def forward(
154
+ self,
155
+ pixel_values: Optional[torch.FloatTensor] = None,
156
+ decoder_input_ids: Optional[torch.LongTensor] = None,
157
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
158
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
159
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
160
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
161
+ labels: Optional[torch.LongTensor] = None,
162
+ use_cache: Optional[bool] = None,
163
+ output_attentions: Optional[bool] = None,
164
+ output_hidden_states: Optional[bool] = None,
165
+ return_dict: Optional[bool] = None,
166
+ **kwargs,
167
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
168
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
169
+
170
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
171
+
172
+ kwargs_decoder = {
173
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
174
+ }
175
+
176
+ if encoder_outputs is None:
177
+ if pixel_values is None:
178
+ raise ValueError("You have to specify pixel_values")
179
+
180
+ encoder_outputs = self.encoder(
181
+ pixel_values,
182
+ output_attentions=output_attentions,
183
+ output_hidden_states=output_hidden_states,
184
+ return_dict=return_dict,
185
+ **kwargs_encoder,
186
+ )
187
+ elif isinstance(encoder_outputs, tuple):
188
+ encoder_outputs = BaseModelOutput(*encoder_outputs)
189
+
190
+ encoder_hidden_states = encoder_outputs[0]
191
+
192
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
193
+
194
+ # else:
195
+ encoder_attention_mask = None
196
+
197
+ if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
198
+ decoder_input_ids = shift_tokens_right(
199
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
200
+ )
201
+
202
+ # Decode
203
+ decoder_outputs = self.decoder(
204
+ input_ids=decoder_input_ids,
205
+ attention_mask=decoder_attention_mask,
206
+ encoder_hidden_states=encoder_hidden_states,
207
+ encoder_attention_mask=encoder_attention_mask,
208
+ inputs_embeds=decoder_inputs_embeds,
209
+ output_attentions=output_attentions,
210
+ output_hidden_states=output_hidden_states,
211
+ use_cache=use_cache,
212
+ past_key_values=past_key_values,
213
+ return_dict=return_dict,
214
+ **kwargs_decoder,
215
+ )
216
+
217
+ # Compute loss independent from decoder (as some shift the logits inside them)
218
+ loss = None
219
+ if labels is not None:
220
+ logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
221
+ loss_fct = nn.CrossEntropyLoss()
222
+ loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
223
+
224
+ if not return_dict:
225
+ if loss is not None:
226
+ return (loss,) + decoder_outputs + encoder_outputs
227
+ else:
228
+ return decoder_outputs + encoder_outputs
229
+
230
+ return Seq2SeqLMOutput(
231
+ loss=loss,
232
+ logits=decoder_outputs.logits,
233
+ past_key_values=decoder_outputs.past_key_values,
234
+ decoder_hidden_states=decoder_outputs.hidden_states,
235
+ decoder_attentions=decoder_outputs.attentions,
236
+ cross_attentions=decoder_outputs.cross_attentions,
237
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
238
+ encoder_hidden_states=encoder_outputs.hidden_states,
239
+ encoder_attentions=encoder_outputs.attentions,
240
+ )
241
+
242
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
243
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
244
+
245
+ def prepare_inputs_for_generation(
246
+ self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
247
+ ):
248
+ decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
249
+ decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
250
+ input_dict = {
251
+ "attention_mask": attention_mask,
252
+ "decoder_attention_mask": decoder_attention_mask,
253
+ "decoder_input_ids": decoder_inputs["input_ids"],
254
+ "encoder_outputs": encoder_outputs,
255
+ "past_key_values": decoder_inputs["past_key_values"],
256
+ "use_cache": use_cache,
257
+ }
258
+ return input_dict
259
+
260
+ def resize_token_embeddings(self, *args, **kwargs):
261
+ raise NotImplementedError(
262
+ "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the"
263
+ " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
264
+ )
265
+
266
+ def _reorder_cache(self, past_key_values, beam_idx):
267
+ # apply decoder cache reordering here
268
+ return self.decoder._reorder_cache(past_key_values, beam_idx)
readme.md CHANGED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Thai Image Captioning
2
+ A
3
+
4
+ # Acknowledgement
5
+ This work is partially supported by the Program Management Unit for Human Resources & Institutional Development, Research and Innovation (PMU-B) [Grant number B04G640107]