Do0rMaMu's picture
Upload folder using huggingface_hub
e45d058 verified
import torch
from torch.optim import Optimizer
from timm.scheduler import CosineLRScheduler
# We need to subclass torch.optim.lr_scheduler._LRScheduler, or Pytorch-lightning will complain
class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler):
""" Wrap timm.scheduler.CosineLRScheduler so we can call scheduler.step() without passing in epoch.
It supports resuming as well.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._last_epoch = -1
self.step(epoch=0)
def step(self, epoch=None):
if epoch is None:
self._last_epoch += 1
else:
self._last_epoch = epoch
# We call either step or step_update, depending on whether we're using the scheduler every
# epoch or every step.
# Otherwise, lightning will always call step (i.e., meant for each epoch), and if we set
# scheduler interval to "step", then the learning rate update will be wrong.
if self.t_in_epochs:
super().step(epoch=self._last_epoch)
else:
super().step_update(num_updates=self._last_epoch)