Spaces:
Build error
Build error
Update amt/src/model/init_train.py
Browse files- amt/src/model/init_train.py +10 -9
amt/src/model/init_train.py
CHANGED
@@ -46,11 +46,12 @@ def initialize_trainer(args: argparse.Namespace,
|
|
46 |
if shared_cfg["WANDB"].get("cache_dir", None) is not None:
|
47 |
os.environ["WANDB_CACHE_DIR"] = shared_cfg["WANDB"].get("cache_dir")
|
48 |
del shared_cfg["WANDB"]["cache_dir"] # remove cache_dir from shared_cfg
|
49 |
-
wandb_logger = WandbLogger(log_model="all",
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
54 |
|
55 |
# check if any checkpoint exists
|
56 |
last_ckpt_path = os.path.join(lightning_dir, "checkpoints", checkpoint_name)
|
@@ -109,14 +110,14 @@ def initialize_trainer(args: argparse.Namespace,
|
|
109 |
precision=args.precision,
|
110 |
max_epochs=args.max_epochs if stage == 'train' else None,
|
111 |
max_steps=args.max_steps if stage == 'train' else -1,
|
112 |
-
logger=wandb_logger,
|
113 |
callbacks=[checkpoint_callback, lr_monitor],
|
114 |
sync_batchnorm=sync_batchnorm)
|
115 |
trainer = pl.trainer.trainer.Trainer(**train_params)
|
116 |
|
117 |
-
# Update wandb logger (for DDP)
|
118 |
-
if trainer.global_rank == 0:
|
119 |
-
|
120 |
|
121 |
return trainer, wandb_logger, dir_info, shared_cfg
|
122 |
|
|
|
46 |
if shared_cfg["WANDB"].get("cache_dir", None) is not None:
|
47 |
os.environ["WANDB_CACHE_DIR"] = shared_cfg["WANDB"].get("cache_dir")
|
48 |
del shared_cfg["WANDB"]["cache_dir"] # remove cache_dir from shared_cfg
|
49 |
+
# wandb_logger = WandbLogger(log_model="all",
|
50 |
+
# project=args.project,
|
51 |
+
# id=args.exp_id,
|
52 |
+
# allow_val_change=True,
|
53 |
+
# **shared_cfg['WANDB'])
|
54 |
+
wandb_logger = None
|
55 |
|
56 |
# check if any checkpoint exists
|
57 |
last_ckpt_path = os.path.join(lightning_dir, "checkpoints", checkpoint_name)
|
|
|
110 |
precision=args.precision,
|
111 |
max_epochs=args.max_epochs if stage == 'train' else None,
|
112 |
max_steps=args.max_steps if stage == 'train' else -1,
|
113 |
+
# logger=wandb_logger,
|
114 |
callbacks=[checkpoint_callback, lr_monitor],
|
115 |
sync_batchnorm=sync_batchnorm)
|
116 |
trainer = pl.trainer.trainer.Trainer(**train_params)
|
117 |
|
118 |
+
# # Update wandb logger (for DDP)
|
119 |
+
# if trainer.global_rank == 0:
|
120 |
+
# wandb_logger.experiment.config.update(args, allow_val_change=True)
|
121 |
|
122 |
return trainer, wandb_logger, dir_info, shared_cfg
|
123 |
|