TomatoCocotree
上传
6a62ffb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from fairseq.dataclass import FairseqDataclass
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
@dataclass
class PassThroughScheduleConfig(FairseqDataclass):
pass
@register_lr_scheduler("pass_through", dataclass=PassThroughScheduleConfig)
class PassThroughScheduleSchedule(FairseqLRScheduler):
"""Delegate lr scheduling to the optimizer."""
def __init__(self, cfg: PassThroughScheduleConfig, optimizer):
super().__init__(cfg, optimizer)
assert (
hasattr(optimizer, "lr_scheduler") and optimizer.lr_scheduler is not None
), "Pass-through schedule can only be used with optimizers with their own schedulers"
def state_dict(self):
return self.optimizer.lr_scheduler.state_dict()
def load_state_dict(self, state_dict):
self.optimizer.lr_scheduler.load_state_dict(state_dict)
def step_begin_epoch(self, epoch):
"""Update the learning rate at the beginning of the given epoch."""
return self.optimizer.lr_scheduler.step_begin_epoch(epoch)
def step_update(self, num_updates):
"""Update the learning rate after each update."""
return self.optimizer.lr_scheduler.step_update(num_updates)