File size: 7,449 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field

import torch
import torch.distributed as dist
from fairseq.dataclass.configs import FairseqBMUFConfig
from fairseq.dataclass.utils import gen_parser_from_dataclass
from fairseq.optim.fairseq_optimizer import FairseqOptimizer


class FairseqBMUF(FairseqOptimizer):
    """
    Implements incremental block distributed data parallelism similar to
    https://ieeexplore.ieee.org/document/7472805

    Paper title: Scalable training of deep learning machines by incremental
    block training with intra-block parallel optimization and blockwise
    model-update filtering
    """

    def __init__(self, cfg: FairseqBMUFConfig, optimizer):
        super().__init__(cfg)
        self._optimizer = optimizer
        self._num_updates = 0
        self.sync_iter = cfg.global_sync_iter
        self.block_momentum = cfg.block_momentum
        self.block_lr = cfg.block_lr
        self._reset_local_data()
        self.warmup_iteration = cfg.warmup_iterations
        self.use_nbm = cfg.use_nbm
        self.initial_state = self._optimizer.state_dict()
        self.average_sync = self.cfg.average_sync
        self.world_size = self.cfg.distributed_world_size

    @staticmethod
    def add_args(parser):
        """Add optimizer-specific arguments to the parser."""
        gen_parser_from_dataclass(parser, FairseqBMUFConfig())

    @property
    def optimizer(self):
        return self._optimizer.optimizer

    @property
    def optimizer_config(self):
        return self._optimizer.optimizer_config

    def get_lr(self):
        return self._optimizer.get_lr()

    def set_lr(self, lr):
        self._optimizer.set_lr(lr)

    def state_dict(self):
        return self._optimizer.state_dict()

    def load_state_dict(self, state_dict, optimizer_overrides=None):
        self._optimizer.load_state_dict(state_dict, optimizer_overrides)
        self.initial_state = self._optimizer.state_dict()

    def multiply_grads(self, c):
        """Multiplies grads by a constant *c*."""
        self._optimizer.multiply_grads(c)

    def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
        """Clips gradient norm."""
        return self._optimizer.clip_grad_norm(max_norm, aggregate_norm_fn)

    def average_params(self):
        self._optimizer.average_params()

    def _block_sync(self):
        if self.world_size <= 1:
            return
        # Update the global model using local models from all GPUs
        # (Step-1) Calculate grad between previously synced model and
        # currrent local model
        if self.block_momentum != 0:
            self._calc_grad()

        # (Step-2) Average gradient from all GPUs
        self._avg_grad_from_all_gpus()

        # (Step-3) Calculate global momentum and update the global model
        if self.block_momentum != 0:
            self._update_global_model()

        # (Step-4) Average local optimizer params
        if self.average_sync:
            self.average_params()

    def _is_warmup_end(self):
        # Check whether train iterations is equal to warmup iter
        if self.get_num_updates() == self.warmup_iteration:
            return True
        return False

    def _is_bmuf_iter(self):
        # Check whether train iterations is equal to bmuf sync iter
        if (self.get_num_updates() > self.warmup_iteration) and (
            self.get_num_updates() % self.sync_iter == 0
        ):
            return True
        return False

    def _warmup_sync(self, root_rank=0):
        if self.world_size <= 1:
            return
        # Broadcast the local model to all gpus
        for param in self.params:
            dist.broadcast(param.data, src=root_rank)

        # Update local optimizer state
        if self.average_sync:
            self._optimizer.average_params()
        else:
            self._optimizer.load_state_dict(self.initial_state)

        self._reset_local_data()

    def step(self, closure=None):
        """Performs a single optimization step."""
        self._optimizer.step(closure)
        self.set_num_updates(self.get_num_updates() + 1)
        if self._is_warmup_end():
            self._warmup_sync()
        elif self._is_bmuf_iter():
            self._block_sync()

    def zero_grad(self):
        """Clears the gradients of all optimized parameters."""
        self._optimizer.zero_grad()

    def get_num_updates(self):
        """Get the number of parameters updates."""
        return self._num_updates

    def set_num_updates(self, num_updates):
        """Set the number of parameters updates."""
        self._num_updates = num_updates

    @torch.no_grad()
    def _reset_local_data(self):
        # (Step-0) Initialize global momentum parameters and store global copy on each gpu
        self.global_params = [torch.zeros_like(p.data) for p in self.params]
        self.smoothed_grads = [p.data.new_zeros(p.data.size()) for p in self.params]
        self.grads = [p.data.new_zeros(p.data.size()) for p in self.params]

        # saving the global model locally for calculating gradient during bmuf sync
        for param, global_param in zip(self.params, self.global_params):
            global_param.copy_(param.data)

    @torch.no_grad()
    def _calc_grad(self):
        # global_params is basically the global copy from the previously finished
        # synchronisation. param.data is local parameter after block_sync_freq
        # for the local gpu. so grad is difference between previously synced
        # model and currrent local model.
        for index, (param, global_param) in enumerate(
            zip(self.params, self.global_params)
        ):
            self.grads[index] = global_param - param.data

    def _avg_grad_from_all_gpus(self):
        for index, param in enumerate(self.params):
            sync_para = param.data if self.block_momentum == 0 else self.grads[index]
            sync_para /= float(dist.get_world_size())
            dist.all_reduce(sync_para, op=dist.ReduceOp.SUM)

    @torch.no_grad()
    def _update_global_model(self):
        for index, (param, global_param, smoothed_grad, grad) in enumerate(
            zip(
                self.params,
                self.global_params,
                self.smoothed_grads,
                # all gpus would share the same value of smoothed_grad, since it is
                # always computed on synchronized gradients.
                self.grads,
            )
        ):
            # global_param is basically last syncrhornized parameter. though
            # smoothed_grad is local, all processes will have same value of
            # smoothed_grad and hence param is globally synchronized copy.
            # smoothed_grad(t) = BM * smoothed_grad(t-1) + BM_lr * grad(t)
            smoothed_grad = self.block_momentum * smoothed_grad + self.block_lr * grad
            param.data.copy_(global_param - smoothed_grad)

            # A Nesterov momentum here is to do a partial weight update before
            # calculating the gradient
            if self.use_nbm:
                param.data.copy_(param.data - self.block_momentum * smoothed_grad)

            # backup for the next synchronization.
            self.smoothed_grads[index] = smoothed_grad
            global_param.copy_(param.data)