ydshieh commited on
Commit
b00cdfe
1 Parent(s): b31314b

update model script

Browse files
Files changed (1) hide show
  1. vit_gpt2/modeling_flax_vit_gpt2_lm.py +59 -51
vit_gpt2/modeling_flax_vit_gpt2_lm.py CHANGED
@@ -1,4 +1,5 @@
1
- from typing import Callable, Optional, Tuple
 
2
 
3
  import flax.linen as nn
4
  import jax
@@ -32,8 +33,8 @@ class FlaxViTGPT2LMModule(nn.Module):
32
 
33
  def setup(self):
34
 
35
- self.encoder = FlaxViTModule(self.config.vit_config, dtype=self.dtype)
36
- self.decoder = FlaxGPT2LMHeadModule(self.config.gpt2_config, dtype=self.dtype)
37
 
38
  def _get_encoder_module(self):
39
  return self.encoder
@@ -147,7 +148,7 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
147
  ):
148
  if input_shape is None:
149
  input_shape = (
150
- (1, config.vit_config.image_size, config.vit_config.image_size, 3),
151
  (1, 1),
152
  )
153
 
@@ -164,7 +165,7 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
164
  attention_mask = None
165
  decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
166
  # make sure initialization pass will work for FlaxBartForSequenceClassificationModule
167
- decoder_input_ids = jax.ops.index_update(decoder_input_ids, (..., -1), self.config.eos_token_id)
168
  decoder_attention_mask = jnp.ones_like(decoder_input_ids)
169
 
170
  batch_size, sequence_length = decoder_input_ids.shape
@@ -221,11 +222,11 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
221
  params: dict = None,
222
  dropout_rng: PRNGKey = None,
223
  ):
224
- output_attentions = (output_attentions if output_attentions is not None else self.config.vit_config.output_attentions)
225
  output_hidden_states = (
226
- output_hidden_states if output_hidden_states is not None else self.config.vit_config.output_hidden_states
227
  )
228
- return_dict = return_dict if return_dict is not None else self.config.vit_config.return_dict
229
 
230
  # (`transpose` is done in `FlaxViTPreTrainedModel.__call__()`, so we do the same here.)
231
  pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
@@ -267,12 +268,12 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
267
  ):
268
 
269
  output_attentions = (
270
- output_attentions if output_attentions is not None else self.config.gpt2_config.output_attentions
271
  )
272
  output_hidden_states = (
273
- output_hidden_states if output_hidden_states is not None else self.config.gpt2_config.output_hidden_states
274
  )
275
- return_dict = return_dict if return_dict is not None else self.config.gpt2_config.return_dict
276
 
277
  encoder_hidden_states = encoder_outputs[0]
278
  if encoder_attention_mask is None:
@@ -486,71 +487,78 @@ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
486
  return model_kwargs
487
 
488
  @classmethod
489
- def from_vit_gpt2_pretrained(
490
  cls,
491
- vit_model_name_or_path: str = None,
492
- gpt2_model_name_or_path: str = None,
493
  *model_args,
494
  **kwargs,
495
  ) -> FlaxViTGPT2LMPreTrainedModel:
496
 
497
- kwargs_gpt2 = {
498
- argument[len("gpt2_") :]: value
499
- for argument, value in kwargs.items()
500
- if argument.startswith("gpt2_")
501
  }
502
 
503
- kwargs_vit = {
504
- argument[len("vit_") :]: value
505
- for argument, value in kwargs.items()
506
- if argument.startswith("vit_")
507
  }
508
 
509
- # remove gpt2, vit kwargs from kwargs
510
- for key in kwargs_gpt2.keys():
511
- del kwargs["gpt2_" + key]
512
- for key in kwargs_vit.keys():
513
- del kwargs["vit_" + key]
514
 
515
- # Load and initialize the gpt2 and vit model
516
- gpt2_model = kwargs_gpt2.pop("model", None)
517
- if gpt2_model is None:
 
 
 
 
 
518
  assert (
519
- gpt2_model_name_or_path is not None
520
- ), "If `model` is not defined as an argument, a `gpt2_model_name_or_path` has to be defined"
521
 
522
- if "config" not in kwargs_gpt2:
523
- gpt2_config = GPT2Config.from_pretrained(gpt2_model_name_or_path)
524
- kwargs_gpt2["config"] = gpt2_config
525
 
526
- kwargs_gpt2["config"].add_cross_attention = True
527
- gpt2_model = FlaxGPT2LMHeadModel.from_pretrained(
528
- gpt2_model_name_or_path, *model_args, **kwargs_gpt2
529
  )
530
 
531
- vit_model = kwargs_vit.pop("model", None)
532
- if vit_model is None:
533
  assert (
534
- vit_model_name_or_path is not None
535
- ), "If `model` is not defined as an argument, a `vit_model_name_or_path` has to be defined"
 
 
 
 
536
 
537
- if "config" not in kwargs_vit:
538
- vit_config = ViTConfig.from_pretrained(vit_model_name_or_path)
539
- kwargs_vit["config"] = vit_config
540
 
541
- vit_model = FlaxViTModel.from_pretrained(
542
- vit_model_name_or_path, *model_args, **kwargs_vit
 
543
  )
544
 
545
  # instantiate config with corresponding kwargs
546
  dtype = kwargs.pop("dtype", jnp.float32)
547
- config = ViTGPT2Config.from_vit_gpt2_configs(
548
- vit_model.config, gpt2_model.config, **kwargs
549
  )
550
 
551
  # init model
552
  model = cls(config, *model_args, dtype=dtype, **kwargs)
553
- model.params["model"]["encoder"] = vit_model.params
554
- model.params["model"]["decoder"] = gpt2_model.params
555
 
556
  return model
 
1
+ import os
2
+ from typing import Callable, Optional, Tuple, Union
3
 
4
  import flax.linen as nn
5
  import jax
 
33
 
34
  def setup(self):
35
 
36
+ self.encoder = FlaxViTModule(self.config.vision_config, dtype=self.dtype)
37
+ self.decoder = FlaxGPT2LMHeadModule(self.config.text_config, dtype=self.dtype)
38
 
39
  def _get_encoder_module(self):
40
  return self.encoder
 
148
  ):
149
  if input_shape is None:
150
  input_shape = (
151
+ (1, config.vision_config.image_size, config.vision_config.image_size, 3),
152
  (1, 1),
153
  )
154
 
 
165
  attention_mask = None
166
  decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
167
  # make sure initialization pass will work for FlaxBartForSequenceClassificationModule
168
+ decoder_input_ids = jax.ops.index_update(decoder_input_ids, (..., -1), self.config.text_config.eos_token_id)
169
  decoder_attention_mask = jnp.ones_like(decoder_input_ids)
170
 
171
  batch_size, sequence_length = decoder_input_ids.shape
 
222
  params: dict = None,
223
  dropout_rng: PRNGKey = None,
224
  ):
225
+ output_attentions = (output_attentions if output_attentions is not None else self.config.vision_config.output_attentions)
226
  output_hidden_states = (
227
+ output_hidden_states if output_hidden_states is not None else self.config.vision_config.output_hidden_states
228
  )
229
+ return_dict = return_dict if return_dict is not None else self.config.vision_config.return_dict
230
 
231
  # (`transpose` is done in `FlaxViTPreTrainedModel.__call__()`, so we do the same here.)
232
  pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
 
268
  ):
269
 
270
  output_attentions = (
271
+ output_attentions if output_attentions is not None else self.config.text_config.output_attentions
272
  )
273
  output_hidden_states = (
274
+ output_hidden_states if output_hidden_states is not None else self.config.text_config.output_hidden_states
275
  )
276
+ return_dict = return_dict if return_dict is not None else self.config.text_config.return_dict
277
 
278
  encoder_hidden_states = encoder_outputs[0]
279
  if encoder_attention_mask is None:
 
487
  return model_kwargs
488
 
489
  @classmethod
490
+ def from_vision_text_pretrained(
491
  cls,
492
+ vision_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
493
+ text_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
494
  *model_args,
495
  **kwargs,
496
  ) -> FlaxViTGPT2LMPreTrainedModel:
497
 
498
+ vision_kwargs = {
499
+ kwarg[len("vision_"):]: value
500
+ for kwarg, value in kwargs.items()
501
+ if kwarg.startswith("vision_")
502
  }
503
 
504
+ text_kwargs = {
505
+ kwarg[len("text_"):]: value
506
+ for kwarg, value in kwargs.items()
507
+ if kwarg.startswith("text_")
508
  }
509
 
510
+ # remove vit & gpt2 kwargs from kwargs
511
+ for key in vision_kwargs.keys():
512
+ del kwargs["vision_" + key]
513
+ for key in text_kwargs.keys():
514
+ del kwargs["text_" + key]
515
 
516
+ vision_model_args = vision_kwargs.pop('model_args', None)
517
+ text_model_args = text_kwargs.pop('model_args', None)
518
+
519
+ # Load and initialize the vit & gpt2 model
520
+ vision_model = vision_kwargs.pop("model", None)
521
+ text_model = text_kwargs.pop("model", None)
522
+
523
+ if vision_model is None:
524
  assert (
525
+ vision_pretrained_model_name_or_path is not None
526
+ ), "If `model` is not defined as an argument, a `vision_pretrained_model_name_or_path` has to be defined"
527
 
528
+ if "config" not in vision_kwargs:
529
+ vision_config = ViTConfig.from_pretrained(vision_pretrained_model_name_or_path)
530
+ vision_kwargs["config"] = vision_config
531
 
532
+ # TODO: How to deal with model_args?
533
+ vision_model = FlaxViTModel.from_pretrained(
534
+ vision_pretrained_model_name_or_path, *vision_model_args, **vision_kwargs
535
  )
536
 
537
+ if text_model is None:
 
538
  assert (
539
+ text_pretrained_model_name_or_path is not None
540
+ ), "If `model` is not defined as an argument, a `text_pretrained_model_name_or_path` has to be defined"
541
+
542
+ if "config" not in text_kwargs:
543
+ text_config = GPT2Config.from_pretrained(text_pretrained_model_name_or_path)
544
+ text_kwargs["config"] = text_config
545
 
546
+ text_kwargs["config"].add_cross_attention = True
 
 
547
 
548
+ # TODO: How to deal with model_args?
549
+ text_model = FlaxGPT2LMHeadModel.from_pretrained(
550
+ text_pretrained_model_name_or_path, *text_model_args, **text_kwargs
551
  )
552
 
553
  # instantiate config with corresponding kwargs
554
  dtype = kwargs.pop("dtype", jnp.float32)
555
+ config = ViTGPT2Config.from_vision_text_configs(
556
+ vision_model.config, text_model.config, **kwargs
557
  )
558
 
559
  # init model
560
  model = cls(config, *model_args, dtype=dtype, **kwargs)
561
+ model.params["model"]["encoder"] = vision_model.params
562
+ model.params["model"]["decoder"] = text_model.params
563
 
564
  return model