OpenSound's picture
Upload 211 files
9d3cb0a verified
import os
import typing
import torch
import torch.distributed as dist
from torch.nn.parallel import DataParallel
from torch.nn.parallel import DistributedDataParallel
from ..data.datasets import ResumableDistributedSampler as DistributedSampler
from ..data.datasets import ResumableSequentialSampler as SequentialSampler
class Accelerator: # pragma: no cover
"""This class is used to prepare models and dataloaders for
usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to
prepare the respective objects. In the case of models, they are moved to
the appropriate GPU and SyncBatchNorm is applied to them. In the case of
dataloaders, a sampler is created and the dataloader is initialized with
that sampler.
If the world size is 1, prepare_model and prepare_dataloader are
no-ops. If the environment variable ``LOCAL_RANK`` is not set, then the
script was launched without ``torchrun``, and ``DataParallel``
will be used instead of ``DistributedDataParallel`` (not recommended), if
the world size (number of GPUs) is greater than 1.
Parameters
----------
amp : bool, optional
Whether or not to enable automatic mixed precision, by default False
"""
def __init__(self, amp: bool = False):
local_rank = os.getenv("LOCAL_RANK", None)
self.world_size = torch.cuda.device_count()
self.use_ddp = self.world_size > 1 and local_rank is not None
self.use_dp = self.world_size > 1 and local_rank is None
self.device = "cpu" if self.world_size == 0 else "cuda"
if self.use_ddp:
local_rank = int(local_rank)
dist.init_process_group(
"nccl",
init_method="env://",
world_size=self.world_size,
rank=local_rank,
)
self.local_rank = 0 if local_rank is None else local_rank
self.amp = amp
class DummyScaler:
def __init__(self):
pass
def step(self, optimizer):
optimizer.step()
def scale(self, loss):
return loss
def unscale_(self, optimizer):
return optimizer
def update(self):
pass
self.scaler = torch.cuda.amp.GradScaler() if amp else DummyScaler()
self.device_ctx = (
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
)
def __enter__(self):
if self.device_ctx is not None:
self.device_ctx.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.device_ctx is not None:
self.device_ctx.__exit__(exc_type, exc_value, traceback)
def prepare_model(self, model: torch.nn.Module, **kwargs):
"""Prepares model for DDP or DP. The model is moved to
the device of the correct rank.
Parameters
----------
model : torch.nn.Module
Model that is converted for DDP or DP.
Returns
-------
torch.nn.Module
Wrapped model, or original model if DDP and DP are turned off.
"""
model = model.to(self.device)
if self.use_ddp:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DistributedDataParallel(
model, device_ids=[self.local_rank], **kwargs
)
elif self.use_dp:
model = DataParallel(model, **kwargs)
return model
# Automatic mixed-precision utilities
def autocast(self, *args, **kwargs):
"""Context manager for autocasting. Arguments
go to ``torch.cuda.amp.autocast``.
"""
return torch.cuda.amp.autocast(self.amp, *args, **kwargs)
def backward(self, loss: torch.Tensor):
"""Backwards pass, after scaling the loss if ``amp`` is
enabled.
Parameters
----------
loss : torch.Tensor
Loss value.
"""
self.scaler.scale(loss).backward()
def step(self, optimizer: torch.optim.Optimizer):
"""Steps the optimizer, using a ``scaler`` if ``amp`` is
enabled.
Parameters
----------
optimizer : torch.optim.Optimizer
Optimizer to step forward.
"""
self.scaler.step(optimizer)
def update(self):
"""Updates the scale factor."""
self.scaler.update()
def prepare_dataloader(
self, dataset: typing.Iterable, start_idx: int = None, **kwargs
):
"""Wraps a dataset with a DataLoader, using the correct sampler if DDP is
enabled.
Parameters
----------
dataset : typing.Iterable
Dataset to build Dataloader around.
start_idx : int, optional
Start index of sampler, useful if resuming from some epoch,
by default None
Returns
-------
_type_
_description_
"""
if self.use_ddp:
sampler = DistributedSampler(
dataset,
start_idx,
num_replicas=self.world_size,
rank=self.local_rank,
)
if "num_workers" in kwargs:
kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1)
kwargs["batch_size"] = max(kwargs["batch_size"] // self.world_size, 1)
else:
sampler = SequentialSampler(dataset, start_idx)
dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs)
return dataloader
@staticmethod
def unwrap(model):
"""Unwraps the model if it was wrapped in DDP or DP, otherwise
just returns the model. Use this to unwrap the model returned by
:py:func:`audiotools.ml.accelerator.Accelerator.prepare_model`.
"""
if hasattr(model, "module"):
return model.module
return model