Ayush Chaurasia glenn-jocher commited on
Commit
e88e8f7
1 Parent(s): 2683b18

W&B: Restructure code to support the new dataset_check() feature (#4197)

Browse files

* Improve docstrings and run names

* default wandb login prompt with timeout

* return key

* Update api_key check logic

* Properly support zipped dataset feature

* update docstring

* Revert tuorial change

* extend changes to log_dataset

* add run name

* bug fix

* bug fix

* Update comment

* fix import check

* remove unused import

* Hardcore .yaml file extension

* reduce code

* Reformat using pycharm

Co-authored-by: Glenn Jocher <[email protected]>

README.md CHANGED
File without changes
train.py CHANGED
@@ -73,24 +73,29 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
73
  yaml.safe_dump(hyp, f, sort_keys=False)
74
  with open(save_dir / 'opt.yaml', 'w') as f:
75
  yaml.safe_dump(vars(opt), f, sort_keys=False)
 
 
 
 
 
 
 
 
 
76
 
 
77
  # Config
78
  plots = not evolve # create plots
79
  cuda = device.type != 'cpu'
80
  init_seeds(1 + RANK)
81
  with torch_distributed_zero_first(RANK):
82
- data_dict = check_dataset(data) # check
83
  train_path, val_path = data_dict['train'], data_dict['val']
84
  nc = 1 if single_cls else int(data_dict['nc']) # number of classes
85
  names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
86
  assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
87
  is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
88
 
89
- # Loggers
90
- if RANK in [-1, 0]:
91
- loggers = Loggers(save_dir, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict
92
- if loggers.wandb and resume:
93
- weights, epochs, hyp, data_dict = opt.weights, opt.epochs, opt.hyp, loggers.wandb.data_dict
94
 
95
  # Model
96
  pretrained = weights.endswith('.pt')
 
73
  yaml.safe_dump(hyp, f, sort_keys=False)
74
  with open(save_dir / 'opt.yaml', 'w') as f:
75
  yaml.safe_dump(vars(opt), f, sort_keys=False)
76
+ data_dict = None
77
+
78
+ # Loggers
79
+ if RANK in [-1, 0]:
80
+ loggers = Loggers(save_dir, weights, opt, hyp, LOGGER).start() # loggers dict
81
+ if loggers.wandb:
82
+ data_dict = loggers.wandb.data_dict
83
+ if resume:
84
+ weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp
85
 
86
+
87
  # Config
88
  plots = not evolve # create plots
89
  cuda = device.type != 'cpu'
90
  init_seeds(1 + RANK)
91
  with torch_distributed_zero_first(RANK):
92
+ data_dict = data_dict or check_dataset(data) # check if None
93
  train_path, val_path = data_dict['train'], data_dict['val']
94
  nc = 1 if single_cls else int(data_dict['nc']) # number of classes
95
  names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
96
  assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
97
  is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
98
 
 
 
 
 
 
99
 
100
  # Model
101
  pretrained = weights.endswith('.pt')
utils/loggers/__init__.py CHANGED
@@ -1,9 +1,7 @@
1
  # YOLOv5 experiment logging utils
2
-
3
  import warnings
4
  from threading import Thread
5
-
6
- import torch
7
  from torch.utils.tensorboard import SummaryWriter
8
 
9
  from utils.general import colorstr, emojis
@@ -23,12 +21,11 @@ except (ImportError, AssertionError):
23
 
24
  class Loggers():
25
  # YOLOv5 Loggers class
26
- def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, data_dict=None, logger=None, include=LOGGERS):
27
  self.save_dir = save_dir
28
  self.weights = weights
29
  self.opt = opt
30
  self.hyp = hyp
31
- self.data_dict = data_dict
32
  self.logger = logger # for printing results to console
33
  self.include = include
34
  for k in LOGGERS:
@@ -38,9 +35,7 @@ class Loggers():
38
  self.csv = True # always log to csv
39
 
40
  # Message
41
- try:
42
- import wandb
43
- except ImportError:
44
  prefix = colorstr('Weights & Biases: ')
45
  s = f"{prefix}run 'pip install wandb' to automatically track and visualize YOLOv5 🚀 runs (RECOMMENDED)"
46
  print(emojis(s))
@@ -57,7 +52,7 @@ class Loggers():
57
  assert 'wandb' in self.include and wandb
58
  run_id = torch.load(self.weights).get('wandb_id') if self.opt.resume else None
59
  self.opt.hyp = self.hyp # add hyperparameters
60
- self.wandb = WandbLogger(self.opt, run_id, self.data_dict)
61
  except:
62
  self.wandb = None
63
 
 
1
  # YOLOv5 experiment logging utils
2
+ import torch
3
  import warnings
4
  from threading import Thread
 
 
5
  from torch.utils.tensorboard import SummaryWriter
6
 
7
  from utils.general import colorstr, emojis
 
21
 
22
  class Loggers():
23
  # YOLOv5 Loggers class
24
+ def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None, include=LOGGERS):
25
  self.save_dir = save_dir
26
  self.weights = weights
27
  self.opt = opt
28
  self.hyp = hyp
 
29
  self.logger = logger # for printing results to console
30
  self.include = include
31
  for k in LOGGERS:
 
35
  self.csv = True # always log to csv
36
 
37
  # Message
38
+ if not wandb:
 
 
39
  prefix = colorstr('Weights & Biases: ')
40
  s = f"{prefix}run 'pip install wandb' to automatically track and visualize YOLOv5 🚀 runs (RECOMMENDED)"
41
  print(emojis(s))
 
52
  assert 'wandb' in self.include and wandb
53
  run_id = torch.load(self.weights).get('wandb_id') if self.opt.resume else None
54
  self.opt.hyp = self.hyp # add hyperparameters
55
+ self.wandb = WandbLogger(self.opt, run_id)
56
  except:
57
  self.wandb = None
58
 
utils/loggers/wandb/log_dataset.py CHANGED
@@ -1,5 +1,4 @@
1
  import argparse
2
-
3
  import yaml
4
 
5
  from wandb_utils import WandbLogger
@@ -8,9 +7,7 @@ WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
8
 
9
 
10
  def create_dataset_artifact(opt):
11
- with open(opt.data, encoding='ascii', errors='ignore') as f:
12
- data = yaml.safe_load(f) # data dict
13
- logger = WandbLogger(opt, '', None, data, job_type='Dataset Creation') # TODO: return value unused
14
 
15
 
16
  if __name__ == '__main__':
@@ -19,6 +16,7 @@ if __name__ == '__main__':
19
  parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
20
  parser.add_argument('--project', type=str, default='YOLOv5', help='name of W&B Project')
21
  parser.add_argument('--entity', default=None, help='W&B entity')
 
22
 
23
  opt = parser.parse_args()
24
  opt.resume = False # Explicitly disallow resume check for dataset upload job
 
1
  import argparse
 
2
  import yaml
3
 
4
  from wandb_utils import WandbLogger
 
7
 
8
 
9
  def create_dataset_artifact(opt):
10
+ logger = WandbLogger(opt, None, job_type='Dataset Creation') # TODO: return value unused
 
 
11
 
12
 
13
  if __name__ == '__main__':
 
16
  parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
17
  parser.add_argument('--project', type=str, default='YOLOv5', help='name of W&B Project')
18
  parser.add_argument('--entity', default=None, help='W&B entity')
19
+ parser.add_argument('--name', type=str, default='log dataset', help='name of W&B run')
20
 
21
  opt = parser.parse_args()
22
  opt.resume = False # Explicitly disallow resume check for dataset upload job
utils/loggers/wandb/sweep.py CHANGED
@@ -1,7 +1,6 @@
1
  import sys
2
- from pathlib import Path
3
-
4
  import wandb
 
5
 
6
  FILE = Path(__file__).absolute()
7
  sys.path.append(FILE.parents[2].as_posix()) # add utils/ to path
 
1
  import sys
 
 
2
  import wandb
3
+ from pathlib import Path
4
 
5
  FILE = Path(__file__).absolute()
6
  sys.path.append(FILE.parents[2].as_posix()) # add utils/ to path
utils/loggers/wandb/wandb_utils.py CHANGED
@@ -3,10 +3,9 @@
3
  import logging
4
  import os
5
  import sys
 
6
  from contextlib import contextmanager
7
  from pathlib import Path
8
-
9
- import yaml
10
  from tqdm import tqdm
11
 
12
  FILE = Path(__file__).absolute()
@@ -99,7 +98,7 @@ class WandbLogger():
99
  https://docs.wandb.com/guides/integrations/yolov5
100
  """
101
 
102
- def __init__(self, opt, run_id, data_dict, job_type='Training'):
103
  """
104
  - Initialize WandbLogger instance
105
  - Upload dataset if opt.upload_dataset is True
@@ -108,7 +107,6 @@ class WandbLogger():
108
  arguments:
109
  opt (namespace) -- Commandline arguments for this run
110
  run_id (str) -- Run ID of W&B run to be resumed
111
- data_dict (Dict) -- Dictionary conataining info about the dataset to be used
112
  job_type (str) -- To set the job_type for this run
113
 
114
  """
@@ -119,10 +117,11 @@ class WandbLogger():
119
  self.train_artifact_path, self.val_artifact_path = None, None
120
  self.result_artifact = None
121
  self.val_table, self.result_table = None, None
122
- self.data_dict = data_dict
123
  self.bbox_media_panel_images = []
124
  self.val_table_path_map = None
125
  self.max_imgs_to_log = 16
 
 
126
  # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
127
  if isinstance(opt.resume, str): # checks resume from artifact
128
  if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
@@ -148,11 +147,23 @@ class WandbLogger():
148
  if self.wandb_run:
149
  if self.job_type == 'Training':
150
  if not opt.resume:
151
- wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
152
- # Info useful for resuming from artifacts
153
- self.wandb_run.config.update({'opt': vars(opt), 'data_dict': wandb_data_dict},
154
- allow_val_change=True)
155
- self.data_dict = self.setup_training(opt, data_dict)
 
 
 
 
 
 
 
 
 
 
 
 
156
  if self.job_type == 'Dataset Creation':
157
  self.data_dict = self.check_and_upload_dataset(opt)
158
 
@@ -167,7 +178,7 @@ class WandbLogger():
167
  Updated dataset info dictionary where local dataset paths are replaced by WAND_ARFACT_PREFIX links.
168
  """
169
  assert wandb, 'Install wandb to upload dataset'
170
- config_path = self.log_dataset_artifact(check_file(opt.data),
171
  opt.single_cls,
172
  'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
173
  print("Created dataset config file ", config_path)
@@ -175,7 +186,7 @@ class WandbLogger():
175
  wandb_data_dict = yaml.safe_load(f)
176
  return wandb_data_dict
177
 
178
- def setup_training(self, opt, data_dict):
179
  """
180
  Setup the necessary processes for training YOLO models:
181
  - Attempt to download model checkpoint and dataset artifacts if opt.resume stats with WANDB_ARTIFACT_PREFIX
@@ -184,10 +195,7 @@ class WandbLogger():
184
 
185
  arguments:
186
  opt (namespace) -- commandline arguments for this run
187
- data_dict (Dict) -- Dataset dictionary for this run
188
 
189
- returns:
190
- data_dict (Dict) -- contains the updated info about the dataset to be used for training
191
  """
192
  self.log_dict, self.current_epoch = {}, 0
193
  self.bbox_interval = opt.bbox_interval
@@ -198,8 +206,10 @@ class WandbLogger():
198
  config = self.wandb_run.config
199
  opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str(
200
  self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs, \
201
- config.opt['hyp']
202
  data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume
 
 
203
  if self.val_artifact is None: # If --upload_dataset is set, use the existing artifact, don't download
204
  self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),
205
  opt.artifact_alias)
@@ -221,7 +231,10 @@ class WandbLogger():
221
  self.map_val_table_path()
222
  if opt.bbox_interval == -1:
223
  self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
224
- return data_dict
 
 
 
225
 
226
  def download_dataset_artifact(self, path, alias):
227
  """
@@ -299,7 +312,8 @@ class WandbLogger():
299
  returns:
300
  the new .yaml file with artifact links. it can be used to start training directly from artifacts
301
  """
302
- data = check_dataset(data_file) # parse and check
 
303
  nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
304
  names = {k: v for k, v in enumerate(names)} # to index dictionary
305
  self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
@@ -310,7 +324,8 @@ class WandbLogger():
310
  data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train')
311
  if data.get('val'):
312
  data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')
313
- path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path
 
314
  data.pop('download', None)
315
  data.pop('path', None)
316
  with open(path, 'w') as f:
 
3
  import logging
4
  import os
5
  import sys
6
+ import yaml
7
  from contextlib import contextmanager
8
  from pathlib import Path
 
 
9
  from tqdm import tqdm
10
 
11
  FILE = Path(__file__).absolute()
 
98
  https://docs.wandb.com/guides/integrations/yolov5
99
  """
100
 
101
+ def __init__(self, opt, run_id, job_type='Training'):
102
  """
103
  - Initialize WandbLogger instance
104
  - Upload dataset if opt.upload_dataset is True
 
107
  arguments:
108
  opt (namespace) -- Commandline arguments for this run
109
  run_id (str) -- Run ID of W&B run to be resumed
 
110
  job_type (str) -- To set the job_type for this run
111
 
112
  """
 
117
  self.train_artifact_path, self.val_artifact_path = None, None
118
  self.result_artifact = None
119
  self.val_table, self.result_table = None, None
 
120
  self.bbox_media_panel_images = []
121
  self.val_table_path_map = None
122
  self.max_imgs_to_log = 16
123
+ self.wandb_artifact_data_dict = None
124
+ self.data_dict = None
125
  # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
126
  if isinstance(opt.resume, str): # checks resume from artifact
127
  if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
 
147
  if self.wandb_run:
148
  if self.job_type == 'Training':
149
  if not opt.resume:
150
+ if opt.upload_dataset:
151
+ self.wandb_artifact_data_dict = self.check_and_upload_dataset(opt)
152
+
153
+ elif opt.data.endswith('_wandb.yaml'): # When dataset is W&B artifact
154
+ with open(opt.data, encoding='ascii', errors='ignore') as f:
155
+ data_dict = yaml.safe_load(f)
156
+ self.data_dict = data_dict
157
+ else: # Local .yaml dataset file or .zip file
158
+ self.data_dict = check_dataset(opt.data)
159
+
160
+ self.setup_training(opt)
161
+ # write data_dict to config. useful for resuming from artifacts
162
+ if not self.wandb_artifact_data_dict:
163
+ self.wandb_artifact_data_dict = self.data_dict
164
+ self.wandb_run.config.update({'data_dict': self.wandb_artifact_data_dict},
165
+ allow_val_change=True)
166
+
167
  if self.job_type == 'Dataset Creation':
168
  self.data_dict = self.check_and_upload_dataset(opt)
169
 
 
178
  Updated dataset info dictionary where local dataset paths are replaced by WAND_ARFACT_PREFIX links.
179
  """
180
  assert wandb, 'Install wandb to upload dataset'
181
+ config_path = self.log_dataset_artifact(opt.data,
182
  opt.single_cls,
183
  'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
184
  print("Created dataset config file ", config_path)
 
186
  wandb_data_dict = yaml.safe_load(f)
187
  return wandb_data_dict
188
 
189
+ def setup_training(self, opt):
190
  """
191
  Setup the necessary processes for training YOLO models:
192
  - Attempt to download model checkpoint and dataset artifacts if opt.resume stats with WANDB_ARTIFACT_PREFIX
 
195
 
196
  arguments:
197
  opt (namespace) -- commandline arguments for this run
 
198
 
 
 
199
  """
200
  self.log_dict, self.current_epoch = {}, 0
201
  self.bbox_interval = opt.bbox_interval
 
206
  config = self.wandb_run.config
207
  opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str(
208
  self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs, \
209
+ config.hyp
210
  data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume
211
+ else:
212
+ data_dict = self.data_dict
213
  if self.val_artifact is None: # If --upload_dataset is set, use the existing artifact, don't download
214
  self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),
215
  opt.artifact_alias)
 
231
  self.map_val_table_path()
232
  if opt.bbox_interval == -1:
233
  self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
234
+ train_from_artifact = self.train_artifact_path is not None and self.val_artifact_path is not None
235
+ # Update the the data_dict to point to local artifacts dir
236
+ if train_from_artifact:
237
+ self.data_dict = data_dict
238
 
239
  def download_dataset_artifact(self, path, alias):
240
  """
 
312
  returns:
313
  the new .yaml file with artifact links. it can be used to start training directly from artifacts
314
  """
315
+ self.data_dict = check_dataset(data_file) # parse and check
316
+ data = dict(self.data_dict)
317
  nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
318
  names = {k: v for k, v in enumerate(names)} # to index dictionary
319
  self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
 
324
  data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train')
325
  if data.get('val'):
326
  data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')
327
+ path = Path(data_file).stem
328
+ path = (path if overwrite_config else path + '_wandb') + '.yaml' # updated data.yaml path
329
  data.pop('download', None)
330
  data.pop('path', None)
331
  with open(path, 'w') as f: