Spaces:
Sleeping
Sleeping
File size: 1,205 Bytes
e45d058 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
from torch.optim import Optimizer
from timm.scheduler import CosineLRScheduler
# We need to subclass torch.optim.lr_scheduler._LRScheduler, or Pytorch-lightning will complain
class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler):
""" Wrap timm.scheduler.CosineLRScheduler so we can call scheduler.step() without passing in epoch.
It supports resuming as well.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._last_epoch = -1
self.step(epoch=0)
def step(self, epoch=None):
if epoch is None:
self._last_epoch += 1
else:
self._last_epoch = epoch
# We call either step or step_update, depending on whether we're using the scheduler every
# epoch or every step.
# Otherwise, lightning will always call step (i.e., meant for each epoch), and if we set
# scheduler interval to "step", then the learning rate update will be wrong.
if self.t_in_epochs:
super().step(epoch=self._last_epoch)
else:
super().step_update(num_updates=self._last_epoch)
|