Atin Sakkeer Hussain commited on
Commit
b5e6f78
1 Parent(s): 795ce43
util/.ipynb_checkpoints/misc-checkpoint.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+
12
+ import builtins
13
+ import datetime
14
+ import os
15
+ import time
16
+ from collections import defaultdict, deque
17
+ from pathlib import Path
18
+ import urllib
19
+ from tqdm import tqdm
20
+
21
+ import torch
22
+ import torch.utils.data
23
+ import torch.distributed as dist
24
+ from torch import inf
25
+
26
+
27
+ class SmoothedValue(object):
28
+ """Track a series of values and provide access to smoothed values over a
29
+ window or the global series average.
30
+ """
31
+
32
+ def __init__(self, window_size=20, fmt=None):
33
+ if fmt is None:
34
+ fmt = "{median:.4f} ({global_avg:.4f})"
35
+ self.deque = deque(maxlen=window_size)
36
+ self.total = 0.0
37
+ self.count = 0
38
+ self.fmt = fmt
39
+
40
+ def update(self, value, n=1):
41
+ self.deque.append(value)
42
+ self.count += n
43
+ self.total += value * n
44
+
45
+ def synchronize_between_processes(self):
46
+ """
47
+ Warning: does not synchronize the deque!
48
+ """
49
+ if not is_dist_avail_and_initialized():
50
+ return
51
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
52
+ dist.barrier()
53
+ dist.all_reduce(t)
54
+ t = t.tolist()
55
+ self.count = int(t[0])
56
+ self.total = t[1]
57
+
58
+ @property
59
+ def median(self):
60
+ d = torch.tensor(list(self.deque))
61
+ return d.median().item()
62
+
63
+ @property
64
+ def avg(self):
65
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
66
+ return d.mean().item()
67
+
68
+ @property
69
+ def global_avg(self):
70
+ return self.total / self.count
71
+
72
+ @property
73
+ def max(self):
74
+ return max(self.deque)
75
+
76
+ @property
77
+ def value(self):
78
+ return self.deque[-1]
79
+
80
+ def __str__(self):
81
+ return self.fmt.format(
82
+ median=self.median,
83
+ avg=self.avg,
84
+ global_avg=self.global_avg,
85
+ max=self.max,
86
+ value=self.value)
87
+
88
+
89
+ class MetricLogger(object):
90
+ def __init__(self, delimiter="\t"):
91
+ self.meters = defaultdict(SmoothedValue)
92
+ self.delimiter = delimiter
93
+
94
+ def update(self, **kwargs):
95
+ for k, v in kwargs.items():
96
+ if v is None:
97
+ continue
98
+ if isinstance(v, torch.Tensor):
99
+ v = v.item()
100
+ assert isinstance(v, (float, int))
101
+ self.meters[k].update(v)
102
+
103
+ def __getattr__(self, attr):
104
+ if attr in self.meters:
105
+ return self.meters[attr]
106
+ if attr in self.__dict__:
107
+ return self.__dict__[attr]
108
+ raise AttributeError("'{}' object has no attribute '{}'".format(
109
+ type(self).__name__, attr))
110
+
111
+ def __str__(self):
112
+ loss_str = []
113
+ for name, meter in self.meters.items():
114
+ loss_str.append(
115
+ "{}: {}".format(name, str(meter))
116
+ )
117
+ return self.delimiter.join(loss_str)
118
+
119
+ def synchronize_between_processes(self):
120
+ for meter in self.meters.values():
121
+ meter.synchronize_between_processes()
122
+
123
+ def add_meter(self, name, meter):
124
+ self.meters[name] = meter
125
+
126
+ def log_every(self, iterable, print_freq, header=None):
127
+ i = 0
128
+ if not header:
129
+ header = ''
130
+ start_time = time.time()
131
+ end = time.time()
132
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
133
+ data_time = SmoothedValue(fmt='{avg:.4f}')
134
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
135
+ log_msg = [
136
+ header,
137
+ '[{0' + space_fmt + '}/{1}]',
138
+ 'eta: {eta}',
139
+ '{meters}',
140
+ 'time: {time}',
141
+ 'data: {data}'
142
+ ]
143
+ if torch.cuda.is_available():
144
+ log_msg.append('max mem: {memory:.0f}')
145
+ log_msg = self.delimiter.join(log_msg)
146
+ MB = 1024.0 * 1024.0
147
+ for obj in iterable:
148
+ data_time.update(time.time() - end)
149
+ yield obj
150
+ iter_time.update(time.time() - end)
151
+ if i % print_freq == 0 or i == len(iterable) - 1:
152
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
153
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
154
+ if torch.cuda.is_available():
155
+ print(log_msg.format(
156
+ i, len(iterable), eta=eta_string,
157
+ meters=str(self),
158
+ time=str(iter_time), data=str(data_time),
159
+ memory=torch.cuda.max_memory_allocated() / MB))
160
+ else:
161
+ print(log_msg.format(
162
+ i, len(iterable), eta=eta_string,
163
+ meters=str(self),
164
+ time=str(iter_time), data=str(data_time)))
165
+ i += 1
166
+ end = time.time()
167
+ total_time = time.time() - start_time
168
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
169
+ print('{} Total time: {} ({:.4f} s / it)'.format(
170
+ header, total_time_str, total_time / len(iterable)))
171
+
172
+
173
+ def setup_for_distributed(is_master):
174
+ """
175
+ This function disables printing when not in master process
176
+ """
177
+ builtin_print = builtins.print
178
+
179
+ def print(*args, **kwargs):
180
+ force = kwargs.pop('force', False)
181
+ force = force or (get_world_size() > 8)
182
+ if is_master or force:
183
+ now = datetime.datetime.now().time()
184
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
185
+ builtin_print(*args, **kwargs)
186
+
187
+ builtins.print = print
188
+
189
+
190
+ def is_dist_avail_and_initialized():
191
+ if not dist.is_available():
192
+ return False
193
+ if not dist.is_initialized():
194
+ return False
195
+ return True
196
+
197
+
198
+ def get_world_size():
199
+ if not is_dist_avail_and_initialized():
200
+ return 1
201
+ return dist.get_world_size()
202
+
203
+
204
+ def get_rank():
205
+ if not is_dist_avail_and_initialized():
206
+ return 0
207
+ return dist.get_rank()
208
+
209
+
210
+ def is_main_process():
211
+ return get_rank() == 0
212
+
213
+
214
+ def save_on_master(*args, **kwargs):
215
+ if is_main_process():
216
+ torch.save(*args, **kwargs)
217
+
218
+
219
+ def init_distributed_mode(args):
220
+ if args.dist_on_itp:
221
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
222
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
223
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
224
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
225
+ os.environ['LOCAL_RANK'] = str(args.gpu)
226
+ os.environ['RANK'] = str(args.rank)
227
+ os.environ['WORLD_SIZE'] = str(args.world_size)
228
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
229
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
230
+ args.rank = int(os.environ["RANK"])
231
+ args.world_size = int(os.environ['WORLD_SIZE'])
232
+ args.gpu = int(os.environ['LOCAL_RANK'])
233
+ elif 'SLURM_PROCID' in os.environ:
234
+ args.rank = int(os.environ['SLURM_PROCID'])
235
+ args.gpu = args.rank % torch.cuda.device_count()
236
+ else:
237
+ print('Not using distributed mode')
238
+ setup_for_distributed(is_master=True) # hack
239
+ args.distributed = False
240
+ return
241
+
242
+ args.distributed = True
243
+
244
+ print("GPU::", args.gpu)
245
+ torch.cuda.set_device(args.gpu)
246
+ args.dist_backend = 'nccl'
247
+ print('| distributed init (rank {}): {}, gpu {}'.format(
248
+ args.rank, args.dist_url, args.gpu), flush=True)
249
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
250
+ world_size=args.world_size, rank=args.rank)
251
+ torch.distributed.barrier()
252
+ setup_for_distributed(args.rank == 0)
253
+
254
+
255
+ class NativeScalerWithGradNormCount:
256
+ state_dict_key = "amp_scaler"
257
+
258
+ def __init__(self):
259
+ self._scaler = torch.cuda.amp.GradScaler()
260
+
261
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
262
+ self._scaler.scale(loss).backward(create_graph=create_graph)
263
+ if update_grad:
264
+ if clip_grad is not None:
265
+ assert parameters is not None
266
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
267
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
268
+ else:
269
+ self._scaler.unscale_(optimizer)
270
+ norm = get_grad_norm_(parameters)
271
+ self._scaler.step(optimizer)
272
+ self._scaler.update()
273
+ else:
274
+ norm = None
275
+ return norm
276
+
277
+ def state_dict(self):
278
+ return self._scaler.state_dict()
279
+
280
+ def load_state_dict(self, state_dict):
281
+ self._scaler.load_state_dict(state_dict)
282
+
283
+
284
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
285
+ if isinstance(parameters, torch.Tensor):
286
+ parameters = [parameters]
287
+ parameters = [p for p in parameters if p.grad is not None]
288
+ norm_type = float(norm_type)
289
+ if len(parameters) == 0:
290
+ return torch.tensor(0.)
291
+ device = parameters[0].grad.device
292
+ if norm_type == inf:
293
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
294
+ else:
295
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
296
+ return total_norm
297
+
298
+
299
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
300
+ output_dir = Path(args.output_dir)
301
+ epoch_name = str(epoch)
302
+ if loss_scaler is not None:
303
+ checkpoint_paths = [output_dir / ('checkpoint.pth')]
304
+ for checkpoint_path in checkpoint_paths:
305
+ to_save = {
306
+ 'model': model_without_ddp.state_dict(),
307
+ 'optimizer': optimizer.state_dict(),
308
+ 'epoch': epoch,
309
+ 'scaler': loss_scaler.state_dict(),
310
+ 'args': args,
311
+ }
312
+
313
+ save_on_master(to_save, checkpoint_path)
314
+ else:
315
+ client_state = {'epoch': epoch}
316
+ model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint", client_state=client_state)
317
+
318
+
319
+ def load_model(model_without_ddp, optimizer, loss_scaler, path):
320
+ if path.startswith('https'):
321
+ checkpoint = torch.hub.load_state_dict_from_url(
322
+ path, map_location='cpu', check_hash=True)
323
+ else:
324
+ checkpoint = torch.load(path, map_location='cpu')
325
+ new_checkpoint = {}
326
+ if optimizer is not None:
327
+ optimizer.load_state_dict(checkpoint['optimizer'])
328
+ if loss_scaler is not None:
329
+ loss_scaler.load_state_dict(checkpoint['scaler'])
330
+ print(checkpoint.keys())
331
+ new_ckpt = {}
332
+ for key, value in checkpoint['model'].items():
333
+ key = key.replace("module.", "")
334
+ new_ckpt[key] = value
335
+
336
+ load_result = model_without_ddp.load_state_dict(new_ckpt, strict=True)
337
+ assert len(load_result.unexpected_keys) == 0, f"Unexpected keys: {load_result.unexpected_keys}"
338
+ print("Load checkpoint %s" % path)
339
+ return checkpoint['epoch']
340
+
341
+
342
+ def all_reduce_mean(x):
343
+ world_size = get_world_size()
344
+ if world_size > 1:
345
+ x_reduce = torch.tensor(x).cuda()
346
+ dist.all_reduce(x_reduce)
347
+ x_reduce /= world_size
348
+ return x_reduce.item()
349
+ else:
350
+ return x
351
+
352
+
353
+ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
354
+ decay = []
355
+ no_decay = []
356
+ for name, param in model.named_parameters():
357
+ if not param.requires_grad:
358
+ continue # frozen weights
359
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
360
+ no_decay.append(param)
361
+ else:
362
+ decay.append(param)
363
+ return [
364
+ {'params': no_decay, 'weight_decay': 0.},
365
+ {'params': decay, 'weight_decay': weight_decay}]
366
+
367
+
368
+ class DistributedSubEpochSampler(torch.utils.data.Sampler):
369
+
370
+ def __init__(self, dataset, num_replicas, rank, shuffle, split_epoch=1, seed=42):
371
+ self.dataset = dataset
372
+ self.num_replicas = num_replicas
373
+ self.rank = rank
374
+ self.shuffle = shuffle
375
+ self.split_epoch = split_epoch
376
+ self.seed = seed
377
+
378
+ self.num_samples = len(dataset) // (num_replicas * split_epoch)
379
+
380
+ def __len__(self):
381
+ return self.num_samples
382
+
383
+ def __iter__(self):
384
+ if self.shuffle:
385
+ # deterministically shuffle based on epoch and seed
386
+ g = torch.Generator()
387
+ g.manual_seed(self.seed + self.epoch // self.split_epoch)
388
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
389
+ else:
390
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
391
+
392
+ indices = indices[self.rank * self.split_epoch + self.epoch % self.split_epoch::self.num_replicas * self.split_epoch]
393
+ assert len(indices) >= self.num_samples
394
+ indices = indices[:self.num_samples]
395
+
396
+ return iter(indices)
397
+
398
+ def set_epoch(self, epoch):
399
+ self.epoch = epoch
400
+
401
+ def download(url: str, root: str):
402
+ os.makedirs(root, exist_ok=True)
403
+ filename = os.path.basename(url)
404
+ download_target = os.path.join(root, filename)
405
+
406
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
407
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
408
+
409
+ if os.path.isfile(download_target):
410
+ return download_target
411
+
412
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
413
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
414
+ while True:
415
+ buffer = source.read(8192)
416
+ if not buffer:
417
+ break
418
+ output.write(buffer)
419
+ loop.update(len(buffer))
420
+
421
+
422
+ return download_target
util/lr_sched.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ def adjust_learning_rate(optimizer, epoch, args):
10
+ """Decay the learning rate with half-cycle cosine after warmup"""
11
+ if epoch < args.warmup_epochs:
12
+ lr = args.lr * epoch / args.warmup_epochs
13
+ else:
14
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
15
+ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
16
+ for param_group in optimizer.param_groups:
17
+ if "lr_scale" in param_group:
18
+ param_group["lr"] = lr * param_group["lr_scale"]
19
+ else:
20
+ param_group["lr"] = lr
21
+ return lr
util/misc.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+
12
+ import builtins
13
+ import datetime
14
+ import os
15
+ import time
16
+ from collections import defaultdict, deque
17
+ from pathlib import Path
18
+ import urllib
19
+ from tqdm import tqdm
20
+
21
+ import torch
22
+ import torch.utils.data
23
+ import torch.distributed as dist
24
+ from torch import inf
25
+
26
+
27
+ class SmoothedValue(object):
28
+ """Track a series of values and provide access to smoothed values over a
29
+ window or the global series average.
30
+ """
31
+
32
+ def __init__(self, window_size=20, fmt=None):
33
+ if fmt is None:
34
+ fmt = "{median:.4f} ({global_avg:.4f})"
35
+ self.deque = deque(maxlen=window_size)
36
+ self.total = 0.0
37
+ self.count = 0
38
+ self.fmt = fmt
39
+
40
+ def update(self, value, n=1):
41
+ self.deque.append(value)
42
+ self.count += n
43
+ self.total += value * n
44
+
45
+ def synchronize_between_processes(self):
46
+ """
47
+ Warning: does not synchronize the deque!
48
+ """
49
+ if not is_dist_avail_and_initialized():
50
+ return
51
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
52
+ dist.barrier()
53
+ dist.all_reduce(t)
54
+ t = t.tolist()
55
+ self.count = int(t[0])
56
+ self.total = t[1]
57
+
58
+ @property
59
+ def median(self):
60
+ d = torch.tensor(list(self.deque))
61
+ return d.median().item()
62
+
63
+ @property
64
+ def avg(self):
65
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
66
+ return d.mean().item()
67
+
68
+ @property
69
+ def global_avg(self):
70
+ return self.total / self.count
71
+
72
+ @property
73
+ def max(self):
74
+ return max(self.deque)
75
+
76
+ @property
77
+ def value(self):
78
+ return self.deque[-1]
79
+
80
+ def __str__(self):
81
+ return self.fmt.format(
82
+ median=self.median,
83
+ avg=self.avg,
84
+ global_avg=self.global_avg,
85
+ max=self.max,
86
+ value=self.value)
87
+
88
+
89
+ class MetricLogger(object):
90
+ def __init__(self, delimiter="\t"):
91
+ self.meters = defaultdict(SmoothedValue)
92
+ self.delimiter = delimiter
93
+
94
+ def update(self, **kwargs):
95
+ for k, v in kwargs.items():
96
+ if v is None:
97
+ continue
98
+ if isinstance(v, torch.Tensor):
99
+ v = v.item()
100
+ assert isinstance(v, (float, int))
101
+ self.meters[k].update(v)
102
+
103
+ def __getattr__(self, attr):
104
+ if attr in self.meters:
105
+ return self.meters[attr]
106
+ if attr in self.__dict__:
107
+ return self.__dict__[attr]
108
+ raise AttributeError("'{}' object has no attribute '{}'".format(
109
+ type(self).__name__, attr))
110
+
111
+ def __str__(self):
112
+ loss_str = []
113
+ for name, meter in self.meters.items():
114
+ loss_str.append(
115
+ "{}: {}".format(name, str(meter))
116
+ )
117
+ return self.delimiter.join(loss_str)
118
+
119
+ def synchronize_between_processes(self):
120
+ for meter in self.meters.values():
121
+ meter.synchronize_between_processes()
122
+
123
+ def add_meter(self, name, meter):
124
+ self.meters[name] = meter
125
+
126
+ def log_every(self, iterable, print_freq, header=None):
127
+ i = 0
128
+ if not header:
129
+ header = ''
130
+ start_time = time.time()
131
+ end = time.time()
132
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
133
+ data_time = SmoothedValue(fmt='{avg:.4f}')
134
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
135
+ log_msg = [
136
+ header,
137
+ '[{0' + space_fmt + '}/{1}]',
138
+ 'eta: {eta}',
139
+ '{meters}',
140
+ 'time: {time}',
141
+ 'data: {data}'
142
+ ]
143
+ if torch.cuda.is_available():
144
+ log_msg.append('max mem: {memory:.0f}')
145
+ log_msg = self.delimiter.join(log_msg)
146
+ MB = 1024.0 * 1024.0
147
+ for obj in iterable:
148
+ data_time.update(time.time() - end)
149
+ yield obj
150
+ iter_time.update(time.time() - end)
151
+ if i % print_freq == 0 or i == len(iterable) - 1:
152
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
153
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
154
+ if torch.cuda.is_available():
155
+ print(log_msg.format(
156
+ i, len(iterable), eta=eta_string,
157
+ meters=str(self),
158
+ time=str(iter_time), data=str(data_time),
159
+ memory=torch.cuda.max_memory_allocated() / MB))
160
+ else:
161
+ print(log_msg.format(
162
+ i, len(iterable), eta=eta_string,
163
+ meters=str(self),
164
+ time=str(iter_time), data=str(data_time)))
165
+ i += 1
166
+ end = time.time()
167
+ total_time = time.time() - start_time
168
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
169
+ print('{} Total time: {} ({:.4f} s / it)'.format(
170
+ header, total_time_str, total_time / len(iterable)))
171
+
172
+
173
+ def setup_for_distributed(is_master):
174
+ """
175
+ This function disables printing when not in master process
176
+ """
177
+ builtin_print = builtins.print
178
+
179
+ def print(*args, **kwargs):
180
+ force = kwargs.pop('force', False)
181
+ force = force or (get_world_size() > 8)
182
+ if is_master or force:
183
+ now = datetime.datetime.now().time()
184
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
185
+ builtin_print(*args, **kwargs)
186
+
187
+ builtins.print = print
188
+
189
+
190
+ def is_dist_avail_and_initialized():
191
+ if not dist.is_available():
192
+ return False
193
+ if not dist.is_initialized():
194
+ return False
195
+ return True
196
+
197
+
198
+ def get_world_size():
199
+ if not is_dist_avail_and_initialized():
200
+ return 1
201
+ return dist.get_world_size()
202
+
203
+
204
+ def get_rank():
205
+ if not is_dist_avail_and_initialized():
206
+ return 0
207
+ return dist.get_rank()
208
+
209
+
210
+ def is_main_process():
211
+ return get_rank() == 0
212
+
213
+
214
+ def save_on_master(*args, **kwargs):
215
+ if is_main_process():
216
+ torch.save(*args, **kwargs)
217
+
218
+
219
+ def init_distributed_mode(args):
220
+ if args.dist_on_itp:
221
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
222
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
223
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
224
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
225
+ os.environ['LOCAL_RANK'] = str(args.gpu)
226
+ os.environ['RANK'] = str(args.rank)
227
+ os.environ['WORLD_SIZE'] = str(args.world_size)
228
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
229
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
230
+ args.rank = int(os.environ["RANK"])
231
+ args.world_size = int(os.environ['WORLD_SIZE'])
232
+ args.gpu = int(os.environ['LOCAL_RANK'])
233
+ elif 'SLURM_PROCID' in os.environ:
234
+ args.rank = int(os.environ['SLURM_PROCID'])
235
+ args.gpu = args.rank % torch.cuda.device_count()
236
+ else:
237
+ print('Not using distributed mode')
238
+ setup_for_distributed(is_master=True) # hack
239
+ args.distributed = False
240
+ return
241
+
242
+ args.distributed = True
243
+
244
+ print("GPU::", args.gpu)
245
+ torch.cuda.set_device(args.gpu)
246
+ args.dist_backend = 'nccl'
247
+ print('| distributed init (rank {}): {}, gpu {}'.format(
248
+ args.rank, args.dist_url, args.gpu), flush=True)
249
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
250
+ world_size=args.world_size, rank=args.rank)
251
+ torch.distributed.barrier()
252
+ setup_for_distributed(args.rank == 0)
253
+
254
+
255
+ class NativeScalerWithGradNormCount:
256
+ state_dict_key = "amp_scaler"
257
+
258
+ def __init__(self):
259
+ self._scaler = torch.cuda.amp.GradScaler()
260
+
261
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
262
+ self._scaler.scale(loss).backward(create_graph=create_graph)
263
+ if update_grad:
264
+ if clip_grad is not None:
265
+ assert parameters is not None
266
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
267
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
268
+ else:
269
+ self._scaler.unscale_(optimizer)
270
+ norm = get_grad_norm_(parameters)
271
+ self._scaler.step(optimizer)
272
+ self._scaler.update()
273
+ else:
274
+ norm = None
275
+ return norm
276
+
277
+ def state_dict(self):
278
+ return self._scaler.state_dict()
279
+
280
+ def load_state_dict(self, state_dict):
281
+ self._scaler.load_state_dict(state_dict)
282
+
283
+
284
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
285
+ if isinstance(parameters, torch.Tensor):
286
+ parameters = [parameters]
287
+ parameters = [p for p in parameters if p.grad is not None]
288
+ norm_type = float(norm_type)
289
+ if len(parameters) == 0:
290
+ return torch.tensor(0.)
291
+ device = parameters[0].grad.device
292
+ if norm_type == inf:
293
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
294
+ else:
295
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
296
+ return total_norm
297
+
298
+
299
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
300
+ output_dir = Path(args.output_dir)
301
+ epoch_name = str(epoch)
302
+ if loss_scaler is not None:
303
+ checkpoint_paths = [output_dir / ('checkpoint.pth')]
304
+ for checkpoint_path in checkpoint_paths:
305
+ to_save = {
306
+ 'model': model_without_ddp.state_dict(),
307
+ 'optimizer': optimizer.state_dict(),
308
+ 'epoch': epoch,
309
+ 'scaler': loss_scaler.state_dict(),
310
+ 'args': args,
311
+ }
312
+
313
+ save_on_master(to_save, checkpoint_path)
314
+ else:
315
+ client_state = {'epoch': epoch}
316
+ model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint", client_state=client_state)
317
+
318
+
319
+ def load_model(model_without_ddp, optimizer, loss_scaler, path):
320
+ if path.startswith('https'):
321
+ checkpoint = torch.hub.load_state_dict_from_url(
322
+ path, map_location='cpu', check_hash=True)
323
+ else:
324
+ checkpoint = torch.load(path, map_location='cpu')
325
+ new_checkpoint = {}
326
+ if optimizer is not None:
327
+ optimizer.load_state_dict(checkpoint['optimizer'])
328
+ if loss_scaler is not None:
329
+ loss_scaler.load_state_dict(checkpoint['scaler'])
330
+ print(checkpoint.keys())
331
+ new_ckpt = {}
332
+ for key, value in checkpoint['model'].items():
333
+ key = key.replace("module.", "")
334
+ new_ckpt[key] = value
335
+
336
+ load_result = model_without_ddp.load_state_dict(new_ckpt, strict=True)
337
+ assert len(load_result.unexpected_keys) == 0, f"Unexpected keys: {load_result.unexpected_keys}"
338
+ print("Load checkpoint %s" % path)
339
+ return checkpoint['epoch']
340
+
341
+
342
+ def all_reduce_mean(x):
343
+ world_size = get_world_size()
344
+ if world_size > 1:
345
+ x_reduce = torch.tensor(x).cuda()
346
+ dist.all_reduce(x_reduce)
347
+ x_reduce /= world_size
348
+ return x_reduce.item()
349
+ else:
350
+ return x
351
+
352
+
353
+ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
354
+ decay = []
355
+ no_decay = []
356
+ for name, param in model.named_parameters():
357
+ if not param.requires_grad:
358
+ continue # frozen weights
359
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
360
+ no_decay.append(param)
361
+ else:
362
+ decay.append(param)
363
+ return [
364
+ {'params': no_decay, 'weight_decay': 0.},
365
+ {'params': decay, 'weight_decay': weight_decay}]
366
+
367
+
368
+ class DistributedSubEpochSampler(torch.utils.data.Sampler):
369
+
370
+ def __init__(self, dataset, num_replicas, rank, shuffle, split_epoch=1, seed=42):
371
+ self.dataset = dataset
372
+ self.num_replicas = num_replicas
373
+ self.rank = rank
374
+ self.shuffle = shuffle
375
+ self.split_epoch = split_epoch
376
+ self.seed = seed
377
+
378
+ self.num_samples = len(dataset) // (num_replicas * split_epoch)
379
+
380
+ def __len__(self):
381
+ return self.num_samples
382
+
383
+ def __iter__(self):
384
+ if self.shuffle:
385
+ # deterministically shuffle based on epoch and seed
386
+ g = torch.Generator()
387
+ g.manual_seed(self.seed + self.epoch // self.split_epoch)
388
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
389
+ else:
390
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
391
+
392
+ indices = indices[self.rank * self.split_epoch + self.epoch % self.split_epoch::self.num_replicas * self.split_epoch]
393
+ assert len(indices) >= self.num_samples
394
+ indices = indices[:self.num_samples]
395
+
396
+ return iter(indices)
397
+
398
+ def set_epoch(self, epoch):
399
+ self.epoch = epoch
400
+
401
+ def download(url: str, root: str):
402
+ os.makedirs(root, exist_ok=True)
403
+ filename = os.path.basename(url)
404
+ download_target = os.path.join(root, filename)
405
+
406
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
407
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
408
+
409
+ if os.path.isfile(download_target):
410
+ return download_target
411
+
412
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
413
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
414
+ while True:
415
+ buffer = source.read(8192)
416
+ if not buffer:
417
+ break
418
+ output.write(buffer)
419
+ loop.update(len(buffer))
420
+
421
+
422
+ return download_target