ydshieh
commited on
Commit
•
03d8c80
1
Parent(s):
9aceda3
fix decoder_position_ids in decode()
Browse files
vit_gpt2/modeling_flax_vit_gpt2_lm.py
CHANGED
@@ -289,7 +289,7 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
|
|
289 |
"Make sure to provide `position_ids` when passing `past_key_values`."
|
290 |
)
|
291 |
|
292 |
-
|
293 |
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
294 |
)
|
295 |
|
|
|
289 |
"Make sure to provide `position_ids` when passing `past_key_values`."
|
290 |
)
|
291 |
|
292 |
+
decoder_position_ids = jnp.broadcast_to(
|
293 |
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
294 |
)
|
295 |
|