ydshieh
commited on
Commit
•
b00cdfe
1
Parent(s):
b31314b
update model script
Browse files
vit_gpt2/modeling_flax_vit_gpt2_lm.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
|
|
|
2 |
|
3 |
import flax.linen as nn
|
4 |
import jax
|
@@ -32,8 +33,8 @@ class FlaxViTGPT2LMModule(nn.Module):
|
|
32 |
|
33 |
def setup(self):
|
34 |
|
35 |
-
self.encoder = FlaxViTModule(self.config.
|
36 |
-
self.decoder = FlaxGPT2LMHeadModule(self.config.
|
37 |
|
38 |
def _get_encoder_module(self):
|
39 |
return self.encoder
|
@@ -147,7 +148,7 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
|
|
147 |
):
|
148 |
if input_shape is None:
|
149 |
input_shape = (
|
150 |
-
(1, config.
|
151 |
(1, 1),
|
152 |
)
|
153 |
|
@@ -164,7 +165,7 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
|
|
164 |
attention_mask = None
|
165 |
decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
|
166 |
# make sure initialization pass will work for FlaxBartForSequenceClassificationModule
|
167 |
-
decoder_input_ids = jax.ops.index_update(decoder_input_ids, (..., -1), self.config.eos_token_id)
|
168 |
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
169 |
|
170 |
batch_size, sequence_length = decoder_input_ids.shape
|
@@ -221,11 +222,11 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
|
|
221 |
params: dict = None,
|
222 |
dropout_rng: PRNGKey = None,
|
223 |
):
|
224 |
-
output_attentions = (output_attentions if output_attentions is not None else self.config.
|
225 |
output_hidden_states = (
|
226 |
-
output_hidden_states if output_hidden_states is not None else self.config.
|
227 |
)
|
228 |
-
return_dict = return_dict if return_dict is not None else self.config.
|
229 |
|
230 |
# (`transpose` is done in `FlaxViTPreTrainedModel.__call__()`, so we do the same here.)
|
231 |
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
|
@@ -267,12 +268,12 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
|
|
267 |
):
|
268 |
|
269 |
output_attentions = (
|
270 |
-
output_attentions if output_attentions is not None else self.config.
|
271 |
)
|
272 |
output_hidden_states = (
|
273 |
-
output_hidden_states if output_hidden_states is not None else self.config.
|
274 |
)
|
275 |
-
return_dict = return_dict if return_dict is not None else self.config.
|
276 |
|
277 |
encoder_hidden_states = encoder_outputs[0]
|
278 |
if encoder_attention_mask is None:
|
@@ -486,71 +487,78 @@ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
|
|
486 |
return model_kwargs
|
487 |
|
488 |
@classmethod
|
489 |
-
def
|
490 |
cls,
|
491 |
-
|
492 |
-
|
493 |
*model_args,
|
494 |
**kwargs,
|
495 |
) -> FlaxViTGPT2LMPreTrainedModel:
|
496 |
|
497 |
-
|
498 |
-
|
499 |
-
for
|
500 |
-
if
|
501 |
}
|
502 |
|
503 |
-
|
504 |
-
|
505 |
-
for
|
506 |
-
if
|
507 |
}
|
508 |
|
509 |
-
# remove gpt2
|
510 |
-
for key in
|
511 |
-
del kwargs["
|
512 |
-
for key in
|
513 |
-
del kwargs["
|
514 |
|
515 |
-
|
516 |
-
|
517 |
-
|
|
|
|
|
|
|
|
|
|
|
518 |
assert (
|
519 |
-
|
520 |
-
), "If `model` is not defined as an argument, a `
|
521 |
|
522 |
-
if "config" not in
|
523 |
-
|
524 |
-
|
525 |
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
)
|
530 |
|
531 |
-
|
532 |
-
if vit_model is None:
|
533 |
assert (
|
534 |
-
|
535 |
-
), "If `model` is not defined as an argument, a `
|
|
|
|
|
|
|
|
|
536 |
|
537 |
-
|
538 |
-
vit_config = ViTConfig.from_pretrained(vit_model_name_or_path)
|
539 |
-
kwargs_vit["config"] = vit_config
|
540 |
|
541 |
-
|
542 |
-
|
|
|
543 |
)
|
544 |
|
545 |
# instantiate config with corresponding kwargs
|
546 |
dtype = kwargs.pop("dtype", jnp.float32)
|
547 |
-
config = ViTGPT2Config.
|
548 |
-
|
549 |
)
|
550 |
|
551 |
# init model
|
552 |
model = cls(config, *model_args, dtype=dtype, **kwargs)
|
553 |
-
model.params["model"]["encoder"] =
|
554 |
-
model.params["model"]["decoder"] =
|
555 |
|
556 |
return model
|
|
|
1 |
+
import os
|
2 |
+
from typing import Callable, Optional, Tuple, Union
|
3 |
|
4 |
import flax.linen as nn
|
5 |
import jax
|
|
|
33 |
|
34 |
def setup(self):
|
35 |
|
36 |
+
self.encoder = FlaxViTModule(self.config.vision_config, dtype=self.dtype)
|
37 |
+
self.decoder = FlaxGPT2LMHeadModule(self.config.text_config, dtype=self.dtype)
|
38 |
|
39 |
def _get_encoder_module(self):
|
40 |
return self.encoder
|
|
|
148 |
):
|
149 |
if input_shape is None:
|
150 |
input_shape = (
|
151 |
+
(1, config.vision_config.image_size, config.vision_config.image_size, 3),
|
152 |
(1, 1),
|
153 |
)
|
154 |
|
|
|
165 |
attention_mask = None
|
166 |
decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
|
167 |
# make sure initialization pass will work for FlaxBartForSequenceClassificationModule
|
168 |
+
decoder_input_ids = jax.ops.index_update(decoder_input_ids, (..., -1), self.config.text_config.eos_token_id)
|
169 |
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
170 |
|
171 |
batch_size, sequence_length = decoder_input_ids.shape
|
|
|
222 |
params: dict = None,
|
223 |
dropout_rng: PRNGKey = None,
|
224 |
):
|
225 |
+
output_attentions = (output_attentions if output_attentions is not None else self.config.vision_config.output_attentions)
|
226 |
output_hidden_states = (
|
227 |
+
output_hidden_states if output_hidden_states is not None else self.config.vision_config.output_hidden_states
|
228 |
)
|
229 |
+
return_dict = return_dict if return_dict is not None else self.config.vision_config.return_dict
|
230 |
|
231 |
# (`transpose` is done in `FlaxViTPreTrainedModel.__call__()`, so we do the same here.)
|
232 |
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
|
|
|
268 |
):
|
269 |
|
270 |
output_attentions = (
|
271 |
+
output_attentions if output_attentions is not None else self.config.text_config.output_attentions
|
272 |
)
|
273 |
output_hidden_states = (
|
274 |
+
output_hidden_states if output_hidden_states is not None else self.config.text_config.output_hidden_states
|
275 |
)
|
276 |
+
return_dict = return_dict if return_dict is not None else self.config.text_config.return_dict
|
277 |
|
278 |
encoder_hidden_states = encoder_outputs[0]
|
279 |
if encoder_attention_mask is None:
|
|
|
487 |
return model_kwargs
|
488 |
|
489 |
@classmethod
|
490 |
+
def from_vision_text_pretrained(
|
491 |
cls,
|
492 |
+
vision_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
493 |
+
text_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
494 |
*model_args,
|
495 |
**kwargs,
|
496 |
) -> FlaxViTGPT2LMPreTrainedModel:
|
497 |
|
498 |
+
vision_kwargs = {
|
499 |
+
kwarg[len("vision_"):]: value
|
500 |
+
for kwarg, value in kwargs.items()
|
501 |
+
if kwarg.startswith("vision_")
|
502 |
}
|
503 |
|
504 |
+
text_kwargs = {
|
505 |
+
kwarg[len("text_"):]: value
|
506 |
+
for kwarg, value in kwargs.items()
|
507 |
+
if kwarg.startswith("text_")
|
508 |
}
|
509 |
|
510 |
+
# remove vit & gpt2 kwargs from kwargs
|
511 |
+
for key in vision_kwargs.keys():
|
512 |
+
del kwargs["vision_" + key]
|
513 |
+
for key in text_kwargs.keys():
|
514 |
+
del kwargs["text_" + key]
|
515 |
|
516 |
+
vision_model_args = vision_kwargs.pop('model_args', None)
|
517 |
+
text_model_args = text_kwargs.pop('model_args', None)
|
518 |
+
|
519 |
+
# Load and initialize the vit & gpt2 model
|
520 |
+
vision_model = vision_kwargs.pop("model", None)
|
521 |
+
text_model = text_kwargs.pop("model", None)
|
522 |
+
|
523 |
+
if vision_model is None:
|
524 |
assert (
|
525 |
+
vision_pretrained_model_name_or_path is not None
|
526 |
+
), "If `model` is not defined as an argument, a `vision_pretrained_model_name_or_path` has to be defined"
|
527 |
|
528 |
+
if "config" not in vision_kwargs:
|
529 |
+
vision_config = ViTConfig.from_pretrained(vision_pretrained_model_name_or_path)
|
530 |
+
vision_kwargs["config"] = vision_config
|
531 |
|
532 |
+
# TODO: How to deal with model_args?
|
533 |
+
vision_model = FlaxViTModel.from_pretrained(
|
534 |
+
vision_pretrained_model_name_or_path, *vision_model_args, **vision_kwargs
|
535 |
)
|
536 |
|
537 |
+
if text_model is None:
|
|
|
538 |
assert (
|
539 |
+
text_pretrained_model_name_or_path is not None
|
540 |
+
), "If `model` is not defined as an argument, a `text_pretrained_model_name_or_path` has to be defined"
|
541 |
+
|
542 |
+
if "config" not in text_kwargs:
|
543 |
+
text_config = GPT2Config.from_pretrained(text_pretrained_model_name_or_path)
|
544 |
+
text_kwargs["config"] = text_config
|
545 |
|
546 |
+
text_kwargs["config"].add_cross_attention = True
|
|
|
|
|
547 |
|
548 |
+
# TODO: How to deal with model_args?
|
549 |
+
text_model = FlaxGPT2LMHeadModel.from_pretrained(
|
550 |
+
text_pretrained_model_name_or_path, *text_model_args, **text_kwargs
|
551 |
)
|
552 |
|
553 |
# instantiate config with corresponding kwargs
|
554 |
dtype = kwargs.pop("dtype", jnp.float32)
|
555 |
+
config = ViTGPT2Config.from_vision_text_configs(
|
556 |
+
vision_model.config, text_model.config, **kwargs
|
557 |
)
|
558 |
|
559 |
# init model
|
560 |
model = cls(config, *model_args, dtype=dtype, **kwargs)
|
561 |
+
model.params["model"]["encoder"] = vision_model.params
|
562 |
+
model.params["model"]["decoder"] = text_model.params
|
563 |
|
564 |
return model
|