ydshieh
commited on
Commit
•
64afcd5
1
Parent(s):
ec3ceb6
Fix style
Browse files- vit_gpt2/modeling_flax_gpt2.py +27 -11
vit_gpt2/modeling_flax_gpt2.py
CHANGED
@@ -24,7 +24,10 @@ from flax.linen.attention import dot_product_attention_weights
|
|
24 |
from jax import lax
|
25 |
|
26 |
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
27 |
-
from ...modeling_flax_outputs import
|
|
|
|
|
|
|
28 |
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
|
29 |
from ...utils import logging
|
30 |
from .configuration_gpt2 import GPT2Config
|
@@ -301,7 +304,9 @@ class FlaxGPT2Block(nn.Module):
|
|
301 |
self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
302 |
|
303 |
if self.config.add_cross_attention:
|
304 |
-
self.crossattention = FlaxGPT2Attention(
|
|
|
|
|
305 |
self.ln_cross_attn = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
306 |
|
307 |
project_encoder = getattr(self.config, "project_encoder", None)
|
@@ -337,7 +342,6 @@ class FlaxGPT2Block(nn.Module):
|
|
337 |
hidden_states = attn_output + residual
|
338 |
|
339 |
# Cross-Attention Block
|
340 |
-
cross_attn_weights = None
|
341 |
if encoder_hidden_states is not None:
|
342 |
# add one self-attention block for cross-attention
|
343 |
if not hasattr(self, "crossattention"):
|
@@ -413,13 +417,16 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
|
|
413 |
encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
|
414 |
encoder_attention_mask = attention_mask
|
415 |
module_init_outputs = self.module.init(
|
416 |
-
rngs,
|
417 |
-
|
|
|
|
|
|
|
|
|
|
|
418 |
)
|
419 |
else:
|
420 |
-
module_init_outputs = self.module.init(
|
421 |
-
rngs, input_ids, attention_mask, position_ids, return_dict=False
|
422 |
-
)
|
423 |
|
424 |
return module_init_outputs["params"]
|
425 |
|
@@ -660,7 +667,11 @@ class FlaxGPT2Model(FlaxGPT2PreTrainedModel):
|
|
660 |
|
661 |
|
662 |
append_call_sample_docstring(
|
663 |
-
FlaxGPT2Model,
|
|
|
|
|
|
|
|
|
664 |
)
|
665 |
|
666 |
|
@@ -718,9 +729,10 @@ class FlaxGPT2LMHeadModule(nn.Module):
|
|
718 |
logits=lm_logits,
|
719 |
hidden_states=outputs.hidden_states,
|
720 |
attentions=outputs.attentions,
|
721 |
-
cross_attentions=outputs.cross_attentions
|
722 |
)
|
723 |
|
|
|
724 |
@add_start_docstrings(
|
725 |
"""
|
726 |
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
@@ -759,5 +771,9 @@ class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
|
|
759 |
|
760 |
|
761 |
append_call_sample_docstring(
|
762 |
-
FlaxGPT2LMHeadModel,
|
|
|
|
|
|
|
|
|
763 |
)
|
|
|
24 |
from jax import lax
|
25 |
|
26 |
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
27 |
+
from ...modeling_flax_outputs import (
|
28 |
+
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
29 |
+
FlaxCausalLMOutputWithCrossAttentions,
|
30 |
+
)
|
31 |
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
|
32 |
from ...utils import logging
|
33 |
from .configuration_gpt2 import GPT2Config
|
|
|
304 |
self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
305 |
|
306 |
if self.config.add_cross_attention:
|
307 |
+
self.crossattention = FlaxGPT2Attention(
|
308 |
+
config=self.config, dtype=self.dtype, causal=False, is_cross_attention=True
|
309 |
+
)
|
310 |
self.ln_cross_attn = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
311 |
|
312 |
project_encoder = getattr(self.config, "project_encoder", None)
|
|
|
342 |
hidden_states = attn_output + residual
|
343 |
|
344 |
# Cross-Attention Block
|
|
|
345 |
if encoder_hidden_states is not None:
|
346 |
# add one self-attention block for cross-attention
|
347 |
if not hasattr(self, "crossattention"):
|
|
|
417 |
encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
|
418 |
encoder_attention_mask = attention_mask
|
419 |
module_init_outputs = self.module.init(
|
420 |
+
rngs,
|
421 |
+
input_ids,
|
422 |
+
attention_mask,
|
423 |
+
position_ids,
|
424 |
+
encoder_hidden_states,
|
425 |
+
encoder_attention_mask,
|
426 |
+
return_dict=False,
|
427 |
)
|
428 |
else:
|
429 |
+
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
|
|
|
|
|
430 |
|
431 |
return module_init_outputs["params"]
|
432 |
|
|
|
667 |
|
668 |
|
669 |
append_call_sample_docstring(
|
670 |
+
FlaxGPT2Model,
|
671 |
+
_TOKENIZER_FOR_DOC,
|
672 |
+
_CHECKPOINT_FOR_DOC,
|
673 |
+
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
674 |
+
_CONFIG_FOR_DOC,
|
675 |
)
|
676 |
|
677 |
|
|
|
729 |
logits=lm_logits,
|
730 |
hidden_states=outputs.hidden_states,
|
731 |
attentions=outputs.attentions,
|
732 |
+
cross_attentions=outputs.cross_attentions,
|
733 |
)
|
734 |
|
735 |
+
|
736 |
@add_start_docstrings(
|
737 |
"""
|
738 |
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
|
|
771 |
|
772 |
|
773 |
append_call_sample_docstring(
|
774 |
+
FlaxGPT2LMHeadModel,
|
775 |
+
_TOKENIZER_FOR_DOC,
|
776 |
+
_CHECKPOINT_FOR_DOC,
|
777 |
+
FlaxCausalLMOutputWithCrossAttentions,
|
778 |
+
_CONFIG_FOR_DOC,
|
779 |
)
|