File size: 1,906 Bytes
e45d058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# Adapted from https://pytorch.org/docs/stable/_modules/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.html
# We divide by world_size first before converting to fp16, so it's safer.
from typing import Any, Callable

import torch
import torch.distributed as dist


def fp16_compress_hook(

    process_group: dist.ProcessGroup, bucket: dist.GradBucket

) -> torch.futures.Future[torch.Tensor]:
    """

    This DDP communication hook implements a simple gradient compression

    approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``)

    and then divides it by the process group size.

    It allreduces those ``float16`` gradient tensors. Once compressed gradient

    tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``).



    Example::

        >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)

    """
    group_to_use = process_group if process_group is not None else dist.group.WORLD
    world_size = group_to_use.size()

    # Divide first before converting to fp16
    # Use out argument to fuse the division and the conversion.
    compressed_tensor = torch.div(bucket.buffer(), world_size,
                                  out=torch.empty_like(bucket.buffer(), dtype=torch.float16))

    fut = dist.all_reduce(
        compressed_tensor, group=group_to_use, async_op=True
    ).get_future()

    def decompress(fut):
        decompressed_tensor = bucket.buffer()
        # Decompress in place to reduce the peak memory.
        # See: https://github.com/pytorch/pytorch/issues/45968
        decompressed_tensor.copy_(fut.value()[0])
        return decompressed_tensor

    # TODO: maybe have a backoff strategy: check if the buffer has inf / NaN, in that case
    # resend with fp32?
    return fut.then(decompress)