Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from dataclasses import dataclass | |
from fairseq.dataclass import FairseqDataclass | |
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler | |
class PassThroughScheduleConfig(FairseqDataclass): | |
pass | |
class PassThroughScheduleSchedule(FairseqLRScheduler): | |
"""Delegate lr scheduling to the optimizer.""" | |
def __init__(self, cfg: PassThroughScheduleConfig, optimizer): | |
super().__init__(cfg, optimizer) | |
assert ( | |
hasattr(optimizer, "lr_scheduler") and optimizer.lr_scheduler is not None | |
), "Pass-through schedule can only be used with optimizers with their own schedulers" | |
def state_dict(self): | |
return self.optimizer.lr_scheduler.state_dict() | |
def load_state_dict(self, state_dict): | |
self.optimizer.lr_scheduler.load_state_dict(state_dict) | |
def step_begin_epoch(self, epoch): | |
"""Update the learning rate at the beginning of the given epoch.""" | |
return self.optimizer.lr_scheduler.step_begin_epoch(epoch) | |
def step_update(self, num_updates): | |
"""Update the learning rate after each update.""" | |
return self.optimizer.lr_scheduler.step_update(num_updates) | |