Spaces:
Sleeping
Sleeping
import torch | |
from torch.optim import Optimizer | |
from timm.scheduler import CosineLRScheduler | |
# We need to subclass torch.optim.lr_scheduler._LRScheduler, or Pytorch-lightning will complain | |
class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler): | |
""" Wrap timm.scheduler.CosineLRScheduler so we can call scheduler.step() without passing in epoch. | |
It supports resuming as well. | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self._last_epoch = -1 | |
self.step(epoch=0) | |
def step(self, epoch=None): | |
if epoch is None: | |
self._last_epoch += 1 | |
else: | |
self._last_epoch = epoch | |
# We call either step or step_update, depending on whether we're using the scheduler every | |
# epoch or every step. | |
# Otherwise, lightning will always call step (i.e., meant for each epoch), and if we set | |
# scheduler interval to "step", then the learning rate update will be wrong. | |
if self.t_in_epochs: | |
super().step(epoch=self._last_epoch) | |
else: | |
super().step_update(num_updates=self._last_epoch) | |