import torch from torch.optim import Optimizer from timm.scheduler import CosineLRScheduler # We need to subclass torch.optim.lr_scheduler._LRScheduler, or Pytorch-lightning will complain class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler): """ Wrap timm.scheduler.CosineLRScheduler so we can call scheduler.step() without passing in epoch. It supports resuming as well. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._last_epoch = -1 self.step(epoch=0) def step(self, epoch=None): if epoch is None: self._last_epoch += 1 else: self._last_epoch = epoch # We call either step or step_update, depending on whether we're using the scheduler every # epoch or every step. # Otherwise, lightning will always call step (i.e., meant for each epoch), and if we set # scheduler interval to "step", then the learning rate update will be wrong. if self.t_in_epochs: super().step(epoch=self._last_epoch) else: super().step_update(num_updates=self._last_epoch)