# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass from fairseq.dataclass import FairseqDataclass from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler @dataclass class PassThroughScheduleConfig(FairseqDataclass): pass @register_lr_scheduler("pass_through", dataclass=PassThroughScheduleConfig) class PassThroughScheduleSchedule(FairseqLRScheduler): """Delegate lr scheduling to the optimizer.""" def __init__(self, cfg: PassThroughScheduleConfig, optimizer): super().__init__(cfg, optimizer) assert ( hasattr(optimizer, "lr_scheduler") and optimizer.lr_scheduler is not None ), "Pass-through schedule can only be used with optimizers with their own schedulers" def state_dict(self): return self.optimizer.lr_scheduler.state_dict() def load_state_dict(self, state_dict): self.optimizer.lr_scheduler.load_state_dict(state_dict) def step_begin_epoch(self, epoch): """Update the learning rate at the beginning of the given epoch.""" return self.optimizer.lr_scheduler.step_begin_epoch(epoch) def step_update(self, num_updates): """Update the learning rate after each update.""" return self.optimizer.lr_scheduler.step_update(num_updates)