ydshieh
commited on
Commit
•
165ad1e
1
Parent(s):
a01b02a
Change Flax GPT2 with cross-attn outputs to be the same as PyTorch's version
Browse files- vit_gpt2/modeling_flax_gpt2.py +24 -40
vit_gpt2/modeling_flax_gpt2.py
CHANGED
@@ -593,28 +593,21 @@ class FlaxGPT2BlockCollection(nn.Module):
|
|
593 |
if output_hidden_states:
|
594 |
all_hidden_states += (hidden_states,)
|
595 |
|
|
|
596 |
outputs = [hidden_states, all_hidden_states, all_attentions, all_cross_attentions]
|
597 |
|
598 |
if not return_dict:
|
599 |
return tuple(v for v in outputs if v is not None)
|
600 |
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
# with cross_attn
|
611 |
-
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
612 |
-
last_hidden_state=hidden_states,
|
613 |
-
past_key_values=None,
|
614 |
-
hidden_states=all_hidden_states,
|
615 |
-
attentions=all_attentions,
|
616 |
-
cross_attentions=all_cross_attentions,
|
617 |
-
)
|
618 |
|
619 |
class FlaxGPT2Module(nn.Module):
|
620 |
config: GPT2Config
|
@@ -676,19 +669,13 @@ class FlaxGPT2Module(nn.Module):
|
|
676 |
if not return_dict:
|
677 |
return (hidden_states,) + outputs[1:]
|
678 |
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
687 |
-
last_hidden_state=hidden_states,
|
688 |
-
hidden_states=outputs.hidden_states,
|
689 |
-
attentions=outputs.attentions,
|
690 |
-
cross_attentions=outputs.cross_attentions,
|
691 |
-
)
|
692 |
|
693 |
@add_start_docstrings(
|
694 |
"The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
|
@@ -753,16 +740,13 @@ class FlaxGPT2LMHeadModule(nn.Module):
|
|
753 |
if not return_dict:
|
754 |
return (lm_logits,) + outputs[1:]
|
755 |
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
attentions=outputs.attentions,
|
764 |
-
cross_attentions=outputs.cross_attentions
|
765 |
-
)
|
766 |
|
767 |
@add_start_docstrings(
|
768 |
"""
|
|
|
593 |
if output_hidden_states:
|
594 |
all_hidden_states += (hidden_states,)
|
595 |
|
596 |
+
# In Flax, `past_key_values` is not contained in modules' outputs.
|
597 |
outputs = [hidden_states, all_hidden_states, all_attentions, all_cross_attentions]
|
598 |
|
599 |
if not return_dict:
|
600 |
return tuple(v for v in outputs if v is not None)
|
601 |
|
602 |
+
# with cross_attn
|
603 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
604 |
+
last_hidden_state=hidden_states,
|
605 |
+
past_key_values=None,
|
606 |
+
hidden_states=all_hidden_states,
|
607 |
+
attentions=all_attentions,
|
608 |
+
cross_attentions=all_cross_attentions,
|
609 |
+
)
|
610 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
611 |
|
612 |
class FlaxGPT2Module(nn.Module):
|
613 |
config: GPT2Config
|
|
|
669 |
if not return_dict:
|
670 |
return (hidden_states,) + outputs[1:]
|
671 |
|
672 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
673 |
+
last_hidden_state=hidden_states,
|
674 |
+
hidden_states=outputs.hidden_states,
|
675 |
+
attentions=outputs.attentions,
|
676 |
+
cross_attentions=outputs.cross_attentions,
|
677 |
+
)
|
678 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
679 |
|
680 |
@add_start_docstrings(
|
681 |
"The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
|
|
|
740 |
if not return_dict:
|
741 |
return (lm_logits,) + outputs[1:]
|
742 |
|
743 |
+
return FlaxCausalLMOutputWithCrossAttentions(
|
744 |
+
logits=lm_logits,
|
745 |
+
past_key_values=None,
|
746 |
+
hidden_states=outputs.hidden_states,
|
747 |
+
attentions=outputs.attentions,
|
748 |
+
cross_attentions=outputs.cross_attentions
|
749 |
+
)
|
|
|
|
|
|
|
750 |
|
751 |
@add_start_docstrings(
|
752 |
"""
|