Spaces:
Running
Running
fix: labels array
Browse files
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -479,7 +479,7 @@ def main():
|
|
479 |
# set up targets
|
480 |
# Note: labels correspond to our target indices
|
481 |
# decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
|
482 |
-
labels = [
|
483 |
labels = np.asarray(labels)
|
484 |
|
485 |
# We need the labels, in addition to the decoder_input_ids, for the compute_loss function
|
|
|
479 |
# set up targets
|
480 |
# Note: labels correspond to our target indices
|
481 |
# decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
|
482 |
+
labels = [eval(indices) for indices in examples['encoding']]
|
483 |
labels = np.asarray(labels)
|
484 |
|
485 |
# We need the labels, in addition to the decoder_input_ids, for the compute_loss function
|