ydshieh commited on
Commit
87485e5
1 Parent(s): 3a17811
Files changed (1) hide show
  1. vit_gpt2/modeling_flax_gpt2.py +159 -37
vit_gpt2/modeling_flax_gpt2.py CHANGED
@@ -24,7 +24,7 @@ from flax.linen.attention import dot_product_attention_weights
24
  from jax import lax
25
 
26
  from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
27
- from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPast, FlaxCausalLMOutput
28
  from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
29
  from ...utils import logging
30
  from .configuration_gpt2 import GPT2Config
@@ -117,6 +117,8 @@ class FlaxConv1D(nn.Module):
117
  class FlaxGPT2Attention(nn.Module):
118
  config: GPT2Config
119
  dtype: jnp.dtype = jnp.float32
 
 
120
 
121
  def setup(self):
122
  config = self.config
@@ -124,10 +126,19 @@ class FlaxGPT2Attention(nn.Module):
124
  self.num_heads = config.num_attention_heads
125
  self.head_dim = self.embed_dim // self.num_heads
126
 
127
- self.c_attn = FlaxConv1D(features=3 * self.embed_dim, dtype=self.dtype)
 
 
 
 
128
  self.c_proj = FlaxConv1D(self.embed_dim, dtype=self.dtype)
 
129
  self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
130
- self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
 
 
 
 
131
 
132
  def _split_heads(self, hidden_states):
133
  return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
@@ -170,13 +181,26 @@ class FlaxGPT2Attention(nn.Module):
170
  def __call__(
171
  self,
172
  hidden_states,
 
173
  attention_mask=None,
174
  deterministic: bool = True,
175
  init_cache: bool = False,
176
  output_attentions: bool = False,
177
  ):
178
- qkv_out = self.c_attn(hidden_states)
179
- query, key, value = jnp.split(qkv_out, 3, axis=2)
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  query = self._split_heads(query)
182
  key = self._split_heads(key)
@@ -184,20 +208,25 @@ class FlaxGPT2Attention(nn.Module):
184
 
185
  query_length, key_length = query.shape[1], key.shape[1]
186
 
187
- if self.has_variable("cache", "cached_key"):
188
- mask_shift = self.variables["cache"]["cache_index"]
189
- max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
190
- causal_mask = lax.dynamic_slice(
191
- self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
192
- )
193
- else:
194
- causal_mask = self.causal_mask[:, :, :query_length, :key_length]
195
-
196
- batch_size = hidden_states.shape[0]
197
- causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
198
-
199
- attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
200
- attention_mask = combine_masks(attention_mask, causal_mask)
 
 
 
 
 
201
 
202
  dropout_rng = None
203
  if not deterministic and self.config.attn_pdrop > 0.0:
@@ -205,15 +234,18 @@ class FlaxGPT2Attention(nn.Module):
205
 
206
  # During fast autoregressive decoding, we feed one position at a time,
207
  # and cache the keys and values step by step.
208
- if self.has_variable("cache", "cached_key") or init_cache:
209
  key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
210
 
211
  # transform boolean mask into float mask
212
- attention_bias = lax.select(
213
- attention_mask > 0,
214
- jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
215
- jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
216
- )
 
 
 
217
 
218
  # usual dot product attention
219
  attn_weights = dot_product_attention_weights(
@@ -267,19 +299,31 @@ class FlaxGPT2Block(nn.Module):
267
  self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
268
  self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
269
  self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
 
 
 
 
 
 
 
 
 
 
270
  self.mlp = FlaxGPT2MLP(self.config, inner_dim, dtype=self.dtype)
271
 
272
  def __call__(
273
  self,
274
  hidden_states,
275
  attention_mask=None,
 
 
276
  deterministic: bool = True,
277
  init_cache: bool = False,
278
  output_attentions: bool = False,
279
  ):
280
  residual = hidden_states
281
  hidden_states = self.ln_1(hidden_states)
282
- outputs = self.attn(
283
  hidden_states,
284
  attention_mask=attention_mask,
285
  deterministic=deterministic,
@@ -287,16 +331,53 @@ class FlaxGPT2Block(nn.Module):
287
  output_attentions=output_attentions,
288
  )
289
  # residual connection
290
- attn_output = outputs[0]
 
 
291
  hidden_states = attn_output + residual
292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  residual = hidden_states
294
  hidden_states = self.ln_2(hidden_states)
295
  feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic)
296
  # residual connection
297
  hidden_states = residual + feed_forward_hidden_states
298
 
299
- return (hidden_states,) + outputs[1:]
 
 
300
 
301
 
302
  class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
@@ -328,7 +409,19 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
328
  params_rng, dropout_rng = jax.random.split(rng)
329
  rngs = {"params": params_rng, "dropout": dropout_rng}
330
 
331
- return self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
  def init_cache(self, batch_size, max_length):
334
  r"""
@@ -355,6 +448,8 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
355
  input_ids,
356
  attention_mask=None,
357
  position_ids=None,
 
 
358
  params: dict = None,
359
  past_key_values: dict = None,
360
  dropout_rng: jax.random.PRNGKey = None,
@@ -369,6 +464,10 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
369
  )
370
  return_dict = return_dict if return_dict is not None else self.config.return_dict
371
 
 
 
 
 
372
  batch_size, sequence_length = input_ids.shape
373
 
374
  if position_ids is None:
@@ -399,6 +498,8 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
399
  jnp.array(input_ids, dtype="i4"),
400
  jnp.array(attention_mask, dtype="i4"),
401
  jnp.array(position_ids, dtype="i4"),
 
 
402
  not train,
403
  False,
404
  output_attentions,
@@ -433,6 +534,8 @@ class FlaxGPT2BlockCollection(nn.Module):
433
  self,
434
  hidden_states,
435
  attention_mask=None,
 
 
436
  deterministic: bool = True,
437
  init_cache: bool = False,
438
  output_attentions: bool = False,
@@ -441,6 +544,7 @@ class FlaxGPT2BlockCollection(nn.Module):
441
  ):
442
  all_attentions = () if output_attentions else None
443
  all_hidden_states = () if output_hidden_states else None
 
444
 
445
  for block in self.blocks:
446
  if output_hidden_states:
@@ -449,6 +553,8 @@ class FlaxGPT2BlockCollection(nn.Module):
449
  layer_outputs = block(
450
  hidden_states,
451
  attention_mask,
 
 
452
  deterministic=deterministic,
453
  init_cache=init_cache,
454
  output_attentions=output_attentions,
@@ -458,19 +564,22 @@ class FlaxGPT2BlockCollection(nn.Module):
458
  if output_attentions:
459
  all_attentions += (layer_outputs[1],)
460
 
 
 
 
461
  if output_hidden_states:
462
  all_hidden_states += (hidden_states,)
463
 
464
- outputs = (hidden_states,)
465
 
466
  if not return_dict:
467
  return tuple(v for v in outputs if v is not None)
468
 
469
- return FlaxBaseModelOutputWithPast(
470
  last_hidden_state=hidden_states,
471
- past_key_values=None,
472
  hidden_states=all_hidden_states,
473
  attentions=all_attentions,
 
474
  )
475
 
476
 
@@ -502,6 +611,8 @@ class FlaxGPT2Module(nn.Module):
502
  input_ids,
503
  attention_mask,
504
  position_ids,
 
 
505
  deterministic=True,
506
  init_cache: bool = False,
507
  output_attentions: bool = False,
@@ -517,6 +628,8 @@ class FlaxGPT2Module(nn.Module):
517
  outputs = self.h(
518
  hidden_states,
519
  attention_mask,
 
 
520
  deterministic=deterministic,
521
  init_cache=init_cache,
522
  output_attentions=output_attentions,
@@ -530,10 +643,11 @@ class FlaxGPT2Module(nn.Module):
530
  if not return_dict:
531
  return (hidden_states,) + outputs[1:]
532
 
533
- return FlaxBaseModelOutput(
534
  last_hidden_state=hidden_states,
535
  hidden_states=outputs.hidden_states,
536
  attentions=outputs.attentions,
 
537
  )
538
 
539
 
@@ -546,7 +660,7 @@ class FlaxGPT2Model(FlaxGPT2PreTrainedModel):
546
 
547
 
548
  append_call_sample_docstring(
549
- FlaxGPT2Model, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC
550
  )
551
 
552
 
@@ -568,6 +682,8 @@ class FlaxGPT2LMHeadModule(nn.Module):
568
  input_ids,
569
  attention_mask,
570
  position_ids,
 
 
571
  deterministic: bool = True,
572
  init_cache: bool = False,
573
  output_attentions: bool = False,
@@ -578,6 +694,8 @@ class FlaxGPT2LMHeadModule(nn.Module):
578
  input_ids,
579
  attention_mask,
580
  position_ids,
 
 
581
  deterministic=deterministic,
582
  init_cache=init_cache,
583
  output_attentions=output_attentions,
@@ -596,8 +714,12 @@ class FlaxGPT2LMHeadModule(nn.Module):
596
  if not return_dict:
597
  return (lm_logits,) + outputs[1:]
598
 
599
- return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
600
-
 
 
 
 
601
 
602
  @add_start_docstrings(
603
  """
@@ -637,5 +759,5 @@ class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
637
 
638
 
639
  append_call_sample_docstring(
640
- FlaxGPT2LMHeadModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC
641
  )
 
24
  from jax import lax
25
 
26
  from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
27
+ from ...modeling_flax_outputs import FlaxBaseModelOutputWithPastAndCrossAttentions, FlaxCausalLMOutputWithCrossAttentions
28
  from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
29
  from ...utils import logging
30
  from .configuration_gpt2 import GPT2Config
 
117
  class FlaxGPT2Attention(nn.Module):
118
  config: GPT2Config
119
  dtype: jnp.dtype = jnp.float32
120
+ causal: bool = True
121
+ is_cross_attention: bool = False
122
 
123
  def setup(self):
124
  config = self.config
 
126
  self.num_heads = config.num_attention_heads
127
  self.head_dim = self.embed_dim // self.num_heads
128
 
129
+ if self.is_cross_attention:
130
+ self.c_attn = FlaxConv1D(2 * self.embed_dim, dtype=self.dtype)
131
+ self.q_attn = FlaxConv1D(self.embed_dim, dtype=self.dtype)
132
+ else:
133
+ self.c_attn = FlaxConv1D(3 * self.embed_dim, dtype=self.dtype)
134
  self.c_proj = FlaxConv1D(self.embed_dim, dtype=self.dtype)
135
+
136
  self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
137
+
138
+ if self.causal:
139
+ self.causal_mask = make_causal_mask(
140
+ jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool"
141
+ )
142
 
143
  def _split_heads(self, hidden_states):
144
  return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
 
181
  def __call__(
182
  self,
183
  hidden_states,
184
+ key_value_states: Optional[jnp.ndarray] = None,
185
  attention_mask=None,
186
  deterministic: bool = True,
187
  init_cache: bool = False,
188
  output_attentions: bool = False,
189
  ):
190
+
191
+ # if key_value_states are provided this layer is used as a cross-attention layer
192
+ # for the decoder
193
+ is_cross_attention = key_value_states is not None
194
+ batch_size = hidden_states.shape[0]
195
+
196
+ if not is_cross_attention:
197
+ qkv_out = self.c_attn(hidden_states)
198
+ query, key, value = jnp.split(qkv_out, 3, axis=2)
199
+ else:
200
+ q_out = self.q_attn(hidden_states)
201
+ (query,) = jnp.split(q_out, 1, axis=2)
202
+ kv_out = self.c_attn(key_value_states)
203
+ key, value = jnp.split(kv_out, 2, axis=2)
204
 
205
  query = self._split_heads(query)
206
  key = self._split_heads(key)
 
208
 
209
  query_length, key_length = query.shape[1], key.shape[1]
210
 
211
+ if self.causal:
212
+ if self.has_variable("cache", "cached_key"):
213
+ mask_shift = self.variables["cache"]["cache_index"]
214
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
215
+ causal_mask = lax.dynamic_slice(
216
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
217
+ )
218
+ else:
219
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
220
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
221
+
222
+ # combine masks if needed
223
+ if attention_mask is not None and self.causal:
224
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
225
+ attention_mask = combine_masks(attention_mask, causal_mask)
226
+ elif self.causal:
227
+ attention_mask = causal_mask
228
+ elif attention_mask is not None:
229
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
230
 
231
  dropout_rng = None
232
  if not deterministic and self.config.attn_pdrop > 0.0:
 
234
 
235
  # During fast autoregressive decoding, we feed one position at a time,
236
  # and cache the keys and values step by step.
237
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
238
  key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
239
 
240
  # transform boolean mask into float mask
241
+ if attention_mask is not None:
242
+ attention_bias = lax.select(
243
+ attention_mask > 0,
244
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
245
+ jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
246
+ )
247
+ else:
248
+ attention_bias = None
249
 
250
  # usual dot product attention
251
  attn_weights = dot_product_attention_weights(
 
299
  self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
300
  self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
301
  self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
302
+
303
+ if self.config.add_cross_attention:
304
+ self.crossattention = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, is_cross_attention=True)
305
+ self.ln_cross_attn = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
306
+
307
+ project_encoder = getattr(self.config, "project_encoder", None)
308
+ if project_encoder:
309
+ self.encoder_projection_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
310
+ self.encoder_projection_mlp = FlaxGPT2MLP(self.config, self.config.hidden_size, dtype=self.dtype)
311
+
312
  self.mlp = FlaxGPT2MLP(self.config, inner_dim, dtype=self.dtype)
313
 
314
  def __call__(
315
  self,
316
  hidden_states,
317
  attention_mask=None,
318
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
319
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
320
  deterministic: bool = True,
321
  init_cache: bool = False,
322
  output_attentions: bool = False,
323
  ):
324
  residual = hidden_states
325
  hidden_states = self.ln_1(hidden_states)
326
+ attn_outputs = self.attn(
327
  hidden_states,
328
  attention_mask=attention_mask,
329
  deterministic=deterministic,
 
331
  output_attentions=output_attentions,
332
  )
333
  # residual connection
334
+ attn_output = attn_outputs[0] # output_attn: a, (attentions)
335
+ outputs = attn_outputs[1:]
336
+ # residual connection
337
  hidden_states = attn_output + residual
338
 
339
+ # Cross-Attention Block
340
+ cross_attn_weights = None
341
+ if encoder_hidden_states is not None:
342
+ # add one self-attention block for cross-attention
343
+ if not hasattr(self, "crossattention"):
344
+ raise ValueError(
345
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
346
+ "cross-attention layers by setting `config.add_cross_attention=True`"
347
+ )
348
+
349
+ project_encoder = getattr(self.config, "project_encoder", None)
350
+ if project_encoder:
351
+ encoder_hidden_states = self.encoder_projection_ln(encoder_hidden_states)
352
+ feed_forward_hidden_states = self.encoder_projection_mlp(
353
+ encoder_hidden_states, deterministic=deterministic
354
+ )
355
+ # residual connection
356
+ encoder_hidden_states = feed_forward_hidden_states
357
+
358
+ residual = hidden_states
359
+ hidden_states = self.ln_cross_attn(hidden_states)
360
+ cross_attn_outputs = self.crossattention(
361
+ hidden_states,
362
+ key_value_states=encoder_hidden_states,
363
+ attention_mask=encoder_attention_mask,
364
+ deterministic=deterministic,
365
+ output_attentions=output_attentions,
366
+ )
367
+ attn_output = cross_attn_outputs[0]
368
+ # residual connection
369
+ hidden_states = residual + attn_output
370
+ outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights
371
+
372
  residual = hidden_states
373
  hidden_states = self.ln_2(hidden_states)
374
  feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic)
375
  # residual connection
376
  hidden_states = residual + feed_forward_hidden_states
377
 
378
+ outputs = (hidden_states,) + outputs
379
+
380
+ return outputs
381
 
382
 
383
  class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
 
409
  params_rng, dropout_rng = jax.random.split(rng)
410
  rngs = {"params": params_rng, "dropout": dropout_rng}
411
 
412
+ if self.config.add_cross_attention:
413
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
414
+ encoder_attention_mask = attention_mask
415
+ module_init_outputs = self.module.init(
416
+ rngs, input_ids, attention_mask, position_ids,
417
+ encoder_hidden_states, encoder_attention_mask, return_dict=False
418
+ )
419
+ else:
420
+ module_init_outputs = self.module.init(
421
+ rngs, input_ids, attention_mask, position_ids, return_dict=False
422
+ )
423
+
424
+ return module_init_outputs["params"]
425
 
426
  def init_cache(self, batch_size, max_length):
427
  r"""
 
448
  input_ids,
449
  attention_mask=None,
450
  position_ids=None,
451
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
452
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
453
  params: dict = None,
454
  past_key_values: dict = None,
455
  dropout_rng: jax.random.PRNGKey = None,
 
464
  )
465
  return_dict = return_dict if return_dict is not None else self.config.return_dict
466
 
467
+ if encoder_hidden_states is not None and encoder_attention_mask is None:
468
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
469
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
470
+
471
  batch_size, sequence_length = input_ids.shape
472
 
473
  if position_ids is None:
 
498
  jnp.array(input_ids, dtype="i4"),
499
  jnp.array(attention_mask, dtype="i4"),
500
  jnp.array(position_ids, dtype="i4"),
501
+ encoder_hidden_states,
502
+ encoder_attention_mask,
503
  not train,
504
  False,
505
  output_attentions,
 
534
  self,
535
  hidden_states,
536
  attention_mask=None,
537
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
538
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
539
  deterministic: bool = True,
540
  init_cache: bool = False,
541
  output_attentions: bool = False,
 
544
  ):
545
  all_attentions = () if output_attentions else None
546
  all_hidden_states = () if output_hidden_states else None
547
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
548
 
549
  for block in self.blocks:
550
  if output_hidden_states:
 
553
  layer_outputs = block(
554
  hidden_states,
555
  attention_mask,
556
+ encoder_hidden_states=encoder_hidden_states,
557
+ encoder_attention_mask=encoder_attention_mask,
558
  deterministic=deterministic,
559
  init_cache=init_cache,
560
  output_attentions=output_attentions,
 
564
  if output_attentions:
565
  all_attentions += (layer_outputs[1],)
566
 
567
+ if encoder_hidden_states is not None:
568
+ all_cross_attentions += (layer_outputs[2],)
569
+
570
  if output_hidden_states:
571
  all_hidden_states += (hidden_states,)
572
 
573
+ outputs = [hidden_states, all_hidden_states, all_attentions, all_cross_attentions]
574
 
575
  if not return_dict:
576
  return tuple(v for v in outputs if v is not None)
577
 
578
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
579
  last_hidden_state=hidden_states,
 
580
  hidden_states=all_hidden_states,
581
  attentions=all_attentions,
582
+ cross_attentions=all_cross_attentions,
583
  )
584
 
585
 
 
611
  input_ids,
612
  attention_mask,
613
  position_ids,
614
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
615
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
616
  deterministic=True,
617
  init_cache: bool = False,
618
  output_attentions: bool = False,
 
628
  outputs = self.h(
629
  hidden_states,
630
  attention_mask,
631
+ encoder_hidden_states,
632
+ encoder_attention_mask,
633
  deterministic=deterministic,
634
  init_cache=init_cache,
635
  output_attentions=output_attentions,
 
643
  if not return_dict:
644
  return (hidden_states,) + outputs[1:]
645
 
646
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
647
  last_hidden_state=hidden_states,
648
  hidden_states=outputs.hidden_states,
649
  attentions=outputs.attentions,
650
+ cross_attentions=outputs.cross_attentions,
651
  )
652
 
653
 
 
660
 
661
 
662
  append_call_sample_docstring(
663
+ FlaxGPT2Model, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPastAndCrossAttentions, _CONFIG_FOR_DOC
664
  )
665
 
666
 
 
682
  input_ids,
683
  attention_mask,
684
  position_ids,
685
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
686
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
687
  deterministic: bool = True,
688
  init_cache: bool = False,
689
  output_attentions: bool = False,
 
694
  input_ids,
695
  attention_mask,
696
  position_ids,
697
+ encoder_hidden_states,
698
+ encoder_attention_mask,
699
  deterministic=deterministic,
700
  init_cache=init_cache,
701
  output_attentions=output_attentions,
 
714
  if not return_dict:
715
  return (lm_logits,) + outputs[1:]
716
 
717
+ return FlaxCausalLMOutputWithCrossAttentions(
718
+ logits=lm_logits,
719
+ hidden_states=outputs.hidden_states,
720
+ attentions=outputs.attentions,
721
+ cross_attentions=outputs.cross_attentions
722
+ )
723
 
724
  @add_start_docstrings(
725
  """
 
759
 
760
 
761
  append_call_sample_docstring(
762
+ FlaxGPT2LMHeadModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutputWithCrossAttentions, _CONFIG_FOR_DOC
763
  )