# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # We ignore warnings about stepping the scheduler since we step it ourselves during gradient accumulation import warnings from .state import AcceleratorState, GradientState warnings.filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler") class AcceleratedScheduler: """ A wrapper around a learning rate scheduler that will only step when the optimizer(s) have a training step. Useful to avoid making a scheduler step too fast when gradients went overflow and there was no training step (in mixed precision training) When performing gradient accumulation scheduler lengths should not be changed accordingly, Accelerate will always step the scheduler to account for it. Args: scheduler (`torch.optim.lr_scheduler._LRScheduler`): The scheduler to wrap. optimizers (one or a list of `torch.optim.Optimizer`): The optimizers used. step_with_optimizer (`bool`, *optional*, defaults to `True`): Whether or not the scheduler should be stepped at each optimizer step. split_batches (`bool`, *optional*, defaults to `False`): Whether or not the dataloaders split one batch across the different processes (so batch size is the same regardless of the number of processes) or create batches on each process (so batch size is the original batch size multiplied by the number of processes). """ def __init__(self, scheduler, optimizers, step_with_optimizer: bool = True, split_batches: bool = False): self.scheduler = scheduler self.optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers] self.split_batches = split_batches self.step_with_optimizer = step_with_optimizer self.gradient_state = GradientState() def step(self, *args, **kwargs): if not self.step_with_optimizer: # No link between scheduler and optimizer -> just step self.scheduler.step(*args, **kwargs) return # Otherwise, first make sure the optimizer was stepped. if not self.gradient_state.sync_gradients: if self.gradient_state.adjust_scheduler: self.scheduler._step_count += 1 return for opt in self.optimizers: if opt.step_was_skipped: return if self.split_batches: # Split batches -> the training dataloader batch size is not changed so one step per training step self.scheduler.step(*args, **kwargs) else: # Otherwise the training dataloader batch size was multiplied by `num_processes`, so we need to do # num_processes steps per training step num_processes = AcceleratorState().num_processes for _ in range(num_processes): # Special case when using OneCycle and `drop_last` was not used if hasattr(self.scheduler, "total_steps"): if self.scheduler._step_count <= self.scheduler.total_steps: self.scheduler.step(*args, **kwargs) else: self.scheduler.step(*args, **kwargs) # Passthroughs def get_last_lr(self): return self.scheduler.get_last_lr() def state_dict(self): return self.scheduler.state_dict() def load_state_dict(self, state_dict): self.scheduler.load_state_dict(state_dict) def get_lr(self): return self.scheduler.get_lr() def print_lr(self, *args, **kwargs): return self.scheduler.print_lr(*args, **kwargs)