Spaces:
Running
Running
feat: log epoch + check params
Browse files- dev/seq2seq/do_big_run.sh +4 -4
- dev/seq2seq/do_small_run.sh +3 -3
- dev/seq2seq/run_seq2seq_flax.py +21 -30
dev/seq2seq/do_big_run.sh
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
python run_seq2seq_flax.py \
|
2 |
-
--max_source_length 128 \
|
3 |
--dataset_repo_or_path dalle-mini/encoded \
|
4 |
--train_file **/train/*/*.jsonl \
|
5 |
--validation_file **/valid/*/*.jsonl \
|
|
|
|
|
6 |
--streaming \
|
7 |
-
--
|
8 |
-
--len_eval 100 \
|
9 |
--output_dir output \
|
10 |
--per_device_train_batch_size 56 \
|
11 |
--per_device_eval_batch_size 56 \
|
12 |
--preprocessing_num_workers 80 \
|
13 |
-
--warmup_steps
|
14 |
--gradient_accumulation_steps 8 \
|
15 |
--do_train \
|
16 |
--do_eval \
|
|
|
1 |
python run_seq2seq_flax.py \
|
|
|
2 |
--dataset_repo_or_path dalle-mini/encoded \
|
3 |
--train_file **/train/*/*.jsonl \
|
4 |
--validation_file **/valid/*/*.jsonl \
|
5 |
+
--len_train 42684248 \
|
6 |
+
--len_eval 34328 \
|
7 |
--streaming \
|
8 |
+
--normalize_text \
|
|
|
9 |
--output_dir output \
|
10 |
--per_device_train_batch_size 56 \
|
11 |
--per_device_eval_batch_size 56 \
|
12 |
--preprocessing_num_workers 80 \
|
13 |
+
--warmup_steps 500 \
|
14 |
--gradient_accumulation_steps 8 \
|
15 |
--do_train \
|
16 |
--do_eval \
|
dev/seq2seq/do_small_run.sh
CHANGED
@@ -2,9 +2,9 @@ python run_seq2seq_flax.py \
|
|
2 |
--dataset_repo_or_path dalle-mini/encoded \
|
3 |
--train_file **/train/*/*.jsonl \
|
4 |
--validation_file **/valid/*/*.jsonl \
|
|
|
|
|
5 |
--streaming \
|
6 |
-
--len_train 1000000 \
|
7 |
-
--len_eval 1000 \
|
8 |
--output_dir output \
|
9 |
--per_device_train_batch_size 56 \
|
10 |
--per_device_eval_batch_size 56 \
|
@@ -15,5 +15,5 @@ python run_seq2seq_flax.py \
|
|
15 |
--do_eval \
|
16 |
--adafactor \
|
17 |
--num_train_epochs 1 \
|
18 |
-
--max_train_samples
|
19 |
--learning_rate 0.005
|
|
|
2 |
--dataset_repo_or_path dalle-mini/encoded \
|
3 |
--train_file **/train/*/*.jsonl \
|
4 |
--validation_file **/valid/*/*.jsonl \
|
5 |
+
--len_train 42684248 \
|
6 |
+
--len_eval 34328 \
|
7 |
--streaming \
|
|
|
|
|
8 |
--output_dir output \
|
9 |
--per_device_train_batch_size 56 \
|
10 |
--per_device_eval_batch_size 56 \
|
|
|
15 |
--do_eval \
|
16 |
--adafactor \
|
17 |
--num_train_epochs 1 \
|
18 |
+
--max_train_samples 10000 \
|
19 |
--learning_rate 0.005
|
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -138,16 +138,6 @@ class DataTrainingArguments:
|
|
138 |
Arguments pertaining to what data we are going to input our model for training and eval.
|
139 |
"""
|
140 |
|
141 |
-
dataset_name: Optional[str] = field(
|
142 |
-
default=None,
|
143 |
-
metadata={"help": "The name of the dataset to use (via the datasets library)."},
|
144 |
-
)
|
145 |
-
dataset_config_name: Optional[str] = field(
|
146 |
-
default=None,
|
147 |
-
metadata={
|
148 |
-
"help": "The configuration name of the dataset to use (via the datasets library)."
|
149 |
-
},
|
150 |
-
)
|
151 |
text_column: Optional[str] = field(
|
152 |
default="caption",
|
153 |
metadata={
|
@@ -260,14 +250,10 @@ class DataTrainingArguments:
|
|
260 |
)
|
261 |
|
262 |
def __post_init__(self):
|
263 |
-
if
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
):
|
268 |
-
raise ValueError(
|
269 |
-
"Need either a dataset name or a training/validation file."
|
270 |
-
)
|
271 |
else:
|
272 |
if self.train_file is not None:
|
273 |
extension = self.train_file.split(".")[-1]
|
@@ -287,6 +273,10 @@ class DataTrainingArguments:
|
|
287 |
], "`validation_file` should be a tsv, csv or json file."
|
288 |
if self.val_max_target_length is None:
|
289 |
self.val_max_target_length = self.max_target_length
|
|
|
|
|
|
|
|
|
290 |
|
291 |
|
292 |
class TrainState(train_state.TrainState):
|
@@ -467,18 +457,6 @@ def main():
|
|
467 |
"Use --overwrite_output_dir to overcome."
|
468 |
)
|
469 |
|
470 |
-
# Set up wandb run
|
471 |
-
wandb.init(
|
472 |
-
entity="dalle-mini",
|
473 |
-
project="dalle-mini",
|
474 |
-
job_type="Seq2Seq",
|
475 |
-
config=parser.parse_args(),
|
476 |
-
)
|
477 |
-
|
478 |
-
# set default x-axis as 'train/step'
|
479 |
-
wandb.define_metric("train/step")
|
480 |
-
wandb.define_metric("*", step_metric="train/step")
|
481 |
-
|
482 |
# Make one log on every process with the configuration for debugging.
|
483 |
pylogging.basicConfig(
|
484 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
@@ -528,6 +506,18 @@ def main():
|
|
528 |
|
529 |
return step, optimizer_step, opt_state
|
530 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
531 |
if model_args.from_checkpoint is not None:
|
532 |
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
533 |
artifact_dir = artifact.download()
|
@@ -1006,6 +996,7 @@ def main():
|
|
1006 |
|
1007 |
for epoch in epochs:
|
1008 |
# ======================== Training ================================
|
|
|
1009 |
|
1010 |
# Create sampling rng
|
1011 |
rng, input_rng = jax.random.split(rng)
|
|
|
138 |
Arguments pertaining to what data we are going to input our model for training and eval.
|
139 |
"""
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
text_column: Optional[str] = field(
|
142 |
default="caption",
|
143 |
metadata={
|
|
|
250 |
)
|
251 |
|
252 |
def __post_init__(self):
|
253 |
+
if self.dataset_repo_or_path is None:
|
254 |
+
raise ValueError("Need a dataset repository or path.")
|
255 |
+
if self.train_file is None or self.validation_file is None:
|
256 |
+
raise ValueError("Need training/validation file.")
|
|
|
|
|
|
|
|
|
257 |
else:
|
258 |
if self.train_file is not None:
|
259 |
extension = self.train_file.split(".")[-1]
|
|
|
273 |
], "`validation_file` should be a tsv, csv or json file."
|
274 |
if self.val_max_target_length is None:
|
275 |
self.val_max_target_length = self.max_target_length
|
276 |
+
if self.streaming and (self.len_train is None or self.len_eval is None):
|
277 |
+
raise ValueError(
|
278 |
+
"Streaming requires providing length of training and validation datasets"
|
279 |
+
)
|
280 |
|
281 |
|
282 |
class TrainState(train_state.TrainState):
|
|
|
457 |
"Use --overwrite_output_dir to overcome."
|
458 |
)
|
459 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
# Make one log on every process with the configuration for debugging.
|
461 |
pylogging.basicConfig(
|
462 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
|
506 |
|
507 |
return step, optimizer_step, opt_state
|
508 |
|
509 |
+
# Set up wandb run
|
510 |
+
wandb.init(
|
511 |
+
entity="dalle-mini",
|
512 |
+
project="dalle-mini",
|
513 |
+
job_type="Seq2Seq",
|
514 |
+
config=parser.parse_args(),
|
515 |
+
)
|
516 |
+
|
517 |
+
# set default x-axis as 'train/step'
|
518 |
+
wandb.define_metric("train/step")
|
519 |
+
wandb.define_metric("*", step_metric="train/step")
|
520 |
+
|
521 |
if model_args.from_checkpoint is not None:
|
522 |
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
523 |
artifact_dir = artifact.download()
|
|
|
996 |
|
997 |
for epoch in epochs:
|
998 |
# ======================== Training ================================
|
999 |
+
wandb_log({"train/epoch": epoch}, step=global_step)
|
1000 |
|
1001 |
# Create sampling rng
|
1002 |
rng, input_rng = jax.random.split(rng)
|