mimbres commited on
Commit
5f9fe7f
1 Parent(s): d06af5f

Update amt/src/model/init_train.py

Browse files
Files changed (1) hide show
  1. amt/src/model/init_train.py +10 -9
amt/src/model/init_train.py CHANGED
@@ -46,11 +46,12 @@ def initialize_trainer(args: argparse.Namespace,
46
  if shared_cfg["WANDB"].get("cache_dir", None) is not None:
47
  os.environ["WANDB_CACHE_DIR"] = shared_cfg["WANDB"].get("cache_dir")
48
  del shared_cfg["WANDB"]["cache_dir"] # remove cache_dir from shared_cfg
49
- wandb_logger = WandbLogger(log_model="all",
50
- project=args.project,
51
- id=args.exp_id,
52
- allow_val_change=True,
53
- **shared_cfg['WANDB'])
 
54
 
55
  # check if any checkpoint exists
56
  last_ckpt_path = os.path.join(lightning_dir, "checkpoints", checkpoint_name)
@@ -109,14 +110,14 @@ def initialize_trainer(args: argparse.Namespace,
109
  precision=args.precision,
110
  max_epochs=args.max_epochs if stage == 'train' else None,
111
  max_steps=args.max_steps if stage == 'train' else -1,
112
- logger=wandb_logger,
113
  callbacks=[checkpoint_callback, lr_monitor],
114
  sync_batchnorm=sync_batchnorm)
115
  trainer = pl.trainer.trainer.Trainer(**train_params)
116
 
117
- # Update wandb logger (for DDP)
118
- if trainer.global_rank == 0:
119
- wandb_logger.experiment.config.update(args, allow_val_change=True)
120
 
121
  return trainer, wandb_logger, dir_info, shared_cfg
122
 
 
46
  if shared_cfg["WANDB"].get("cache_dir", None) is not None:
47
  os.environ["WANDB_CACHE_DIR"] = shared_cfg["WANDB"].get("cache_dir")
48
  del shared_cfg["WANDB"]["cache_dir"] # remove cache_dir from shared_cfg
49
+ # wandb_logger = WandbLogger(log_model="all",
50
+ # project=args.project,
51
+ # id=args.exp_id,
52
+ # allow_val_change=True,
53
+ # **shared_cfg['WANDB'])
54
+ wandb_logger = None
55
 
56
  # check if any checkpoint exists
57
  last_ckpt_path = os.path.join(lightning_dir, "checkpoints", checkpoint_name)
 
110
  precision=args.precision,
111
  max_epochs=args.max_epochs if stage == 'train' else None,
112
  max_steps=args.max_steps if stage == 'train' else -1,
113
+ # logger=wandb_logger,
114
  callbacks=[checkpoint_callback, lr_monitor],
115
  sync_batchnorm=sync_batchnorm)
116
  trainer = pl.trainer.trainer.Trainer(**train_params)
117
 
118
+ # # Update wandb logger (for DDP)
119
+ # if trainer.global_rank == 0:
120
+ # wandb_logger.experiment.config.update(args, allow_val_change=True)
121
 
122
  return trainer, wandb_logger, dir_info, shared_cfg
123