# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This script is adapted from https://github.com/ildoonet/pytorch-gradual-warmup-lr/blob/master/warmup_scheduler/scheduler.py """ from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler class GradualWarmupScheduler(_LRScheduler): """Gradually warm-up(increasing) learning rate in optimizer. Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. Args: optimizer (Optimizer): Wrapped optimizer. multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. total_epoch: target learning rate is reached at total_epoch, gradually after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) """ def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): self.multiplier = multiplier if self.multiplier < 1.0: raise ValueError("multiplier should be greater thant or equal to 1.") self.total_epoch = total_epoch self.after_scheduler = after_scheduler self.finished = False super(GradualWarmupScheduler, self).__init__(optimizer) def get_lr(self): self.last_epoch = max(1, self.last_epoch) # to avoid epoch=0 thus lr=0 if self.last_epoch > self.total_epoch: if self.after_scheduler: if not self.finished: self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] self.finished = True return self.after_scheduler.get_last_lr() return [base_lr * self.multiplier for base_lr in self.base_lrs] if self.multiplier == 1.0: return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] else: return [ base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0) for base_lr in self.base_lrs ] def step_reduce_lr_on_plateau(self, metrics, epoch=None): if epoch is None: epoch = self.last_epoch + 1 self.last_epoch = ( epoch if epoch != 0 else 1 ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning if self.last_epoch <= self.total_epoch: warmup_lr = [ base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0) for base_lr in self.base_lrs ] for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): param_group["lr"] = lr else: if epoch is None: self.after_scheduler.step(metrics, None) else: self.after_scheduler.step(metrics, epoch - self.total_epoch) def step(self, epoch=None, metrics=None): if not isinstance(self.after_scheduler, ReduceLROnPlateau): if self.finished and self.after_scheduler: if epoch is None: self.after_scheduler.step(None) else: self.after_scheduler.step(epoch - self.total_epoch) self._last_lr = self.after_scheduler.get_last_lr() else: return super(GradualWarmupScheduler, self).step(epoch) else: self.step_reduce_lr_on_plateau(metrics, epoch)