ydshieh
commited on
Commit
•
e30ab96
1
Parent(s):
5081c5d
Fix project_encoder
Browse files
vit_gpt2/modeling_flax_vit_gpt2_lm.py
CHANGED
@@ -553,8 +553,10 @@ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
|
|
553 |
|
554 |
# instantiate config with corresponding kwargs
|
555 |
dtype = kwargs.pop("dtype", jnp.float32)
|
|
|
|
|
556 |
config = ViTGPT2Config.from_vision_text_configs(
|
557 |
-
vision_model.config, text_model.config, **kwargs
|
558 |
)
|
559 |
|
560 |
# init model
|
|
|
553 |
|
554 |
# instantiate config with corresponding kwargs
|
555 |
dtype = kwargs.pop("dtype", jnp.float32)
|
556 |
+
project_encoder = kwargs.pop("project_encoder", None)
|
557 |
+
|
558 |
config = ViTGPT2Config.from_vision_text_configs(
|
559 |
+
vision_model.config, text_model.config, project_encoder=project_encoder, **kwargs
|
560 |
)
|
561 |
|
562 |
# init model
|