Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
•
a841a4c
1
Parent(s):
a104edb
Decoder: set eos to an unreachable value, set min_length=max_length to
Browse files
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -258,6 +258,8 @@ class CustomFlaxBartModule(FlaxBartModule):
|
|
258 |
# the decoder has a different config
|
259 |
decoder_config = BartConfig(self.config.to_dict())
|
260 |
decoder_config.max_position_embeddings = OUTPUT_LENGTH
|
|
|
|
|
261 |
decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
|
262 |
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
|
263 |
|
@@ -407,7 +409,9 @@ def main():
|
|
407 |
config.decoder_start_token_id = BOS_TOKEN_ID
|
408 |
config.bos_token_id = BOS_TOKEN_ID # should not be used
|
409 |
config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
|
410 |
-
config.eos_token_id =
|
|
|
|
|
411 |
|
412 |
|
413 |
# Create a custom model and initialize it randomly
|
|
|
258 |
# the decoder has a different config
|
259 |
decoder_config = BartConfig(self.config.to_dict())
|
260 |
decoder_config.max_position_embeddings = OUTPUT_LENGTH
|
261 |
+
decoder_config.min_length = OUTPUT_LENGTH
|
262 |
+
decoder_config.max_length = OUTPUT_LENGTH
|
263 |
decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
|
264 |
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
|
265 |
|
|
|
409 |
config.decoder_start_token_id = BOS_TOKEN_ID
|
410 |
config.bos_token_id = BOS_TOKEN_ID # should not be used
|
411 |
config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
|
412 |
+
config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
|
413 |
+
#config.min_length = data_args.max_target_length # Set only in decoder?
|
414 |
+
#config.max_length = data_args.max_target_length # Set only in decoder?
|
415 |
|
416 |
|
417 |
# Create a custom model and initialize it randomly
|