# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from accelerate import Accelerator, DistributedType class LocalSGD: """ A helper class to support local SGD on top of Accelerator. It simply runs a given number of updates independently on each device, and averages model weights every K synchronization step. It should be used only in the multi-GPU (or multi-CPU) setup without extensions such as DeepSpeed. In particular, this is a simple implementation that cannot support scenarios such as model parallelism. Although we are not aware of the true origins of this simple approach, the idea of local SGD is quite old and goes back to at least: Zhang, J., De Sa, C., Mitliagkas, I., & RĂ©, C. (2016). [Parallel SGD: When does averaging help?. arXiv preprint arXiv:1606.07365.](https://arxiv.org/abs/1606.07365) We credit the term Local SGD to the following paper (but there might be earlier references we are not aware of). Stich, Sebastian Urban. ["Local SGD Converges Fast and Communicates Little." ICLR 2019-International Conference on Learning Representations. No. CONF. 2019.](https://arxiv.org/abs/1805.09767) """ def __enter__(self): if self.enabled: self.model_sync_obj = self.model.no_sync() self.model_sync_obj.__enter__() return self def __exit__(self, type, value, tb): if self.enabled: # Average all models on exit self._sync_and_avg_model_params() self.model_sync_obj.__exit__(type, value, tb) def __init__(self, accelerator: Accelerator, model: torch.nn.Module, local_sgd_steps: int, enabled: bool = True): """ Constructor. Args: model (`torch.nn.Module): The model whose parameters we need to average. accelerator (`Accelerator`): Accelerator object. local_sgd_steps (`int`): A number of local SGD steps (before model parameters are synchronized). enabled (`bool): Local SGD is disabled if this parameter set to `False`. """ if accelerator.distributed_type not in [ DistributedType.NO, DistributedType.MULTI_CPU, DistributedType.MULTI_GPU, ]: raise NotImplementedError("LocalSGD is supported only for CPUs and GPUs (no DeepSpeed or MegatronLM)") self.enabled = enabled and accelerator.distributed_type != DistributedType.NO self.num_steps = 0 if self.enabled: self.accelerator = accelerator self.model = model self.local_sgd_steps = local_sgd_steps def step(self): """ This function makes a "step" and synchronizes model parameters if necessary. """ self.num_steps += 1 if not self.enabled: return if self.num_steps % self.local_sgd_steps == 0: self._sync_and_avg_model_params() def _sync_and_avg_model_params(self): """ Synchronize + Average model parameters across all GPUs """ self.accelerator.wait_for_everyone() with self.accelerator.autocast(): for param in self.model.parameters(): param.data = self.accelerator.reduce(param.data, reduction="mean")