Spaces:
Paused
Paused
import os | |
import typing | |
import torch | |
import torch.distributed as dist | |
from torch.nn.parallel import DataParallel | |
from torch.nn.parallel import DistributedDataParallel | |
from ..data.datasets import ResumableDistributedSampler as DistributedSampler | |
from ..data.datasets import ResumableSequentialSampler as SequentialSampler | |
class Accelerator: # pragma: no cover | |
"""This class is used to prepare models and dataloaders for | |
usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to | |
prepare the respective objects. In the case of models, they are moved to | |
the appropriate GPU and SyncBatchNorm is applied to them. In the case of | |
dataloaders, a sampler is created and the dataloader is initialized with | |
that sampler. | |
If the world size is 1, prepare_model and prepare_dataloader are | |
no-ops. If the environment variable ``LOCAL_RANK`` is not set, then the | |
script was launched without ``torchrun``, and ``DataParallel`` | |
will be used instead of ``DistributedDataParallel`` (not recommended), if | |
the world size (number of GPUs) is greater than 1. | |
Parameters | |
---------- | |
amp : bool, optional | |
Whether or not to enable automatic mixed precision, by default False | |
""" | |
def __init__(self, amp: bool = False): | |
local_rank = os.getenv("LOCAL_RANK", None) | |
self.world_size = torch.cuda.device_count() | |
self.use_ddp = self.world_size > 1 and local_rank is not None | |
self.use_dp = self.world_size > 1 and local_rank is None | |
self.device = "cpu" if self.world_size == 0 else "cuda" | |
if self.use_ddp: | |
local_rank = int(local_rank) | |
dist.init_process_group( | |
"nccl", | |
init_method="env://", | |
world_size=self.world_size, | |
rank=local_rank, | |
) | |
self.local_rank = 0 if local_rank is None else local_rank | |
self.amp = amp | |
class DummyScaler: | |
def __init__(self): | |
pass | |
def step(self, optimizer): | |
optimizer.step() | |
def scale(self, loss): | |
return loss | |
def unscale_(self, optimizer): | |
return optimizer | |
def update(self): | |
pass | |
self.scaler = torch.cuda.amp.GradScaler() if amp else DummyScaler() | |
self.device_ctx = ( | |
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None | |
) | |
def __enter__(self): | |
if self.device_ctx is not None: | |
self.device_ctx.__enter__() | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
if self.device_ctx is not None: | |
self.device_ctx.__exit__(exc_type, exc_value, traceback) | |
def prepare_model(self, model: torch.nn.Module, **kwargs): | |
"""Prepares model for DDP or DP. The model is moved to | |
the device of the correct rank. | |
Parameters | |
---------- | |
model : torch.nn.Module | |
Model that is converted for DDP or DP. | |
Returns | |
------- | |
torch.nn.Module | |
Wrapped model, or original model if DDP and DP are turned off. | |
""" | |
model = model.to(self.device) | |
if self.use_ddp: | |
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
model = DistributedDataParallel( | |
model, device_ids=[self.local_rank], **kwargs | |
) | |
elif self.use_dp: | |
model = DataParallel(model, **kwargs) | |
return model | |
# Automatic mixed-precision utilities | |
def autocast(self, *args, **kwargs): | |
"""Context manager for autocasting. Arguments | |
go to ``torch.cuda.amp.autocast``. | |
""" | |
return torch.cuda.amp.autocast(self.amp, *args, **kwargs) | |
def backward(self, loss: torch.Tensor): | |
"""Backwards pass, after scaling the loss if ``amp`` is | |
enabled. | |
Parameters | |
---------- | |
loss : torch.Tensor | |
Loss value. | |
""" | |
self.scaler.scale(loss).backward() | |
def step(self, optimizer: torch.optim.Optimizer): | |
"""Steps the optimizer, using a ``scaler`` if ``amp`` is | |
enabled. | |
Parameters | |
---------- | |
optimizer : torch.optim.Optimizer | |
Optimizer to step forward. | |
""" | |
self.scaler.step(optimizer) | |
def update(self): | |
"""Updates the scale factor.""" | |
self.scaler.update() | |
def prepare_dataloader( | |
self, dataset: typing.Iterable, start_idx: int = None, **kwargs | |
): | |
"""Wraps a dataset with a DataLoader, using the correct sampler if DDP is | |
enabled. | |
Parameters | |
---------- | |
dataset : typing.Iterable | |
Dataset to build Dataloader around. | |
start_idx : int, optional | |
Start index of sampler, useful if resuming from some epoch, | |
by default None | |
Returns | |
------- | |
_type_ | |
_description_ | |
""" | |
if self.use_ddp: | |
sampler = DistributedSampler( | |
dataset, | |
start_idx, | |
num_replicas=self.world_size, | |
rank=self.local_rank, | |
) | |
if "num_workers" in kwargs: | |
kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1) | |
kwargs["batch_size"] = max(kwargs["batch_size"] // self.world_size, 1) | |
else: | |
sampler = SequentialSampler(dataset, start_idx) | |
dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs) | |
return dataloader | |
def unwrap(model): | |
"""Unwraps the model if it was wrapped in DDP or DP, otherwise | |
just returns the model. Use this to unwrap the model returned by | |
:py:func:`audiotools.ml.accelerator.Accelerator.prepare_model`. | |
""" | |
if hasattr(model, "module"): | |
return model.module | |
return model | |