Spaces:
Runtime error
Runtime error
CorvaeOboro
commited on
Commit
β’
b0636ca
1
Parent(s):
6b7f20a
Upload training_stats.py
Browse files- torch_utils/training_stats.py +268 -0
torch_utils/training_stats.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Facilities for reporting and collecting training statistics across
|
10 |
+
multiple processes and devices. The interface is designed to minimize
|
11 |
+
synchronization overhead as well as the amount of boilerplate in user
|
12 |
+
code."""
|
13 |
+
|
14 |
+
import re
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import dnnlib
|
18 |
+
|
19 |
+
from . import misc
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
|
24 |
+
_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
|
25 |
+
_counter_dtype = torch.float64 # Data type to use for the internal counters.
|
26 |
+
_rank = 0 # Rank of the current process.
|
27 |
+
_sync_device = None # Device to use for multiprocess communication. None = single-process.
|
28 |
+
_sync_called = False # Has _sync() been called yet?
|
29 |
+
_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
|
30 |
+
_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
|
31 |
+
|
32 |
+
#----------------------------------------------------------------------------
|
33 |
+
|
34 |
+
def init_multiprocessing(rank, sync_device):
|
35 |
+
r"""Initializes `torch_utils.training_stats` for collecting statistics
|
36 |
+
across multiple processes.
|
37 |
+
|
38 |
+
This function must be called after
|
39 |
+
`torch.distributed.init_process_group()` and before `Collector.update()`.
|
40 |
+
The call is not necessary if multi-process collection is not needed.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
rank: Rank of the current process.
|
44 |
+
sync_device: PyTorch device to use for inter-process
|
45 |
+
communication, or None to disable multi-process
|
46 |
+
collection. Typically `torch.device('cuda', rank)`.
|
47 |
+
"""
|
48 |
+
global _rank, _sync_device
|
49 |
+
assert not _sync_called
|
50 |
+
_rank = rank
|
51 |
+
_sync_device = sync_device
|
52 |
+
|
53 |
+
#----------------------------------------------------------------------------
|
54 |
+
|
55 |
+
@misc.profiled_function
|
56 |
+
def report(name, value):
|
57 |
+
r"""Broadcasts the given set of scalars to all interested instances of
|
58 |
+
`Collector`, across device and process boundaries.
|
59 |
+
|
60 |
+
This function is expected to be extremely cheap and can be safely
|
61 |
+
called from anywhere in the training loop, loss function, or inside a
|
62 |
+
`torch.nn.Module`.
|
63 |
+
|
64 |
+
Warning: The current implementation expects the set of unique names to
|
65 |
+
be consistent across processes. Please make sure that `report()` is
|
66 |
+
called at least once for each unique name by each process, and in the
|
67 |
+
same order. If a given process has no scalars to broadcast, it can do
|
68 |
+
`report(name, [])` (empty list).
|
69 |
+
|
70 |
+
Args:
|
71 |
+
name: Arbitrary string specifying the name of the statistic.
|
72 |
+
Averages are accumulated separately for each unique name.
|
73 |
+
value: Arbitrary set of scalars. Can be a list, tuple,
|
74 |
+
NumPy array, PyTorch tensor, or Python scalar.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
The same `value` that was passed in.
|
78 |
+
"""
|
79 |
+
if name not in _counters:
|
80 |
+
_counters[name] = dict()
|
81 |
+
|
82 |
+
elems = torch.as_tensor(value)
|
83 |
+
if elems.numel() == 0:
|
84 |
+
return value
|
85 |
+
|
86 |
+
elems = elems.detach().flatten().to(_reduce_dtype)
|
87 |
+
moments = torch.stack([
|
88 |
+
torch.ones_like(elems).sum(),
|
89 |
+
elems.sum(),
|
90 |
+
elems.square().sum(),
|
91 |
+
])
|
92 |
+
assert moments.ndim == 1 and moments.shape[0] == _num_moments
|
93 |
+
moments = moments.to(_counter_dtype)
|
94 |
+
|
95 |
+
device = moments.device
|
96 |
+
if device not in _counters[name]:
|
97 |
+
_counters[name][device] = torch.zeros_like(moments)
|
98 |
+
_counters[name][device].add_(moments)
|
99 |
+
return value
|
100 |
+
|
101 |
+
#----------------------------------------------------------------------------
|
102 |
+
|
103 |
+
def report0(name, value):
|
104 |
+
r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
|
105 |
+
but ignores any scalars provided by the other processes.
|
106 |
+
See `report()` for further details.
|
107 |
+
"""
|
108 |
+
report(name, value if _rank == 0 else [])
|
109 |
+
return value
|
110 |
+
|
111 |
+
#----------------------------------------------------------------------------
|
112 |
+
|
113 |
+
class Collector:
|
114 |
+
r"""Collects the scalars broadcasted by `report()` and `report0()` and
|
115 |
+
computes their long-term averages (mean and standard deviation) over
|
116 |
+
user-defined periods of time.
|
117 |
+
|
118 |
+
The averages are first collected into internal counters that are not
|
119 |
+
directly visible to the user. They are then copied to the user-visible
|
120 |
+
state as a result of calling `update()` and can then be queried using
|
121 |
+
`mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
|
122 |
+
internal counters for the next round, so that the user-visible state
|
123 |
+
effectively reflects averages collected between the last two calls to
|
124 |
+
`update()`.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
regex: Regular expression defining which statistics to
|
128 |
+
collect. The default is to collect everything.
|
129 |
+
keep_previous: Whether to retain the previous averages if no
|
130 |
+
scalars were collected on a given round
|
131 |
+
(default: True).
|
132 |
+
"""
|
133 |
+
def __init__(self, regex='.*', keep_previous=True):
|
134 |
+
self._regex = re.compile(regex)
|
135 |
+
self._keep_previous = keep_previous
|
136 |
+
self._cumulative = dict()
|
137 |
+
self._moments = dict()
|
138 |
+
self.update()
|
139 |
+
self._moments.clear()
|
140 |
+
|
141 |
+
def names(self):
|
142 |
+
r"""Returns the names of all statistics broadcasted so far that
|
143 |
+
match the regular expression specified at construction time.
|
144 |
+
"""
|
145 |
+
return [name for name in _counters if self._regex.fullmatch(name)]
|
146 |
+
|
147 |
+
def update(self):
|
148 |
+
r"""Copies current values of the internal counters to the
|
149 |
+
user-visible state and resets them for the next round.
|
150 |
+
|
151 |
+
If `keep_previous=True` was specified at construction time, the
|
152 |
+
operation is skipped for statistics that have received no scalars
|
153 |
+
since the last update, retaining their previous averages.
|
154 |
+
|
155 |
+
This method performs a number of GPU-to-CPU transfers and one
|
156 |
+
`torch.distributed.all_reduce()`. It is intended to be called
|
157 |
+
periodically in the main training loop, typically once every
|
158 |
+
N training steps.
|
159 |
+
"""
|
160 |
+
if not self._keep_previous:
|
161 |
+
self._moments.clear()
|
162 |
+
for name, cumulative in _sync(self.names()):
|
163 |
+
if name not in self._cumulative:
|
164 |
+
self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
165 |
+
delta = cumulative - self._cumulative[name]
|
166 |
+
self._cumulative[name].copy_(cumulative)
|
167 |
+
if float(delta[0]) != 0:
|
168 |
+
self._moments[name] = delta
|
169 |
+
|
170 |
+
def _get_delta(self, name):
|
171 |
+
r"""Returns the raw moments that were accumulated for the given
|
172 |
+
statistic between the last two calls to `update()`, or zero if
|
173 |
+
no scalars were collected.
|
174 |
+
"""
|
175 |
+
assert self._regex.fullmatch(name)
|
176 |
+
if name not in self._moments:
|
177 |
+
self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
178 |
+
return self._moments[name]
|
179 |
+
|
180 |
+
def num(self, name):
|
181 |
+
r"""Returns the number of scalars that were accumulated for the given
|
182 |
+
statistic between the last two calls to `update()`, or zero if
|
183 |
+
no scalars were collected.
|
184 |
+
"""
|
185 |
+
delta = self._get_delta(name)
|
186 |
+
return int(delta[0])
|
187 |
+
|
188 |
+
def mean(self, name):
|
189 |
+
r"""Returns the mean of the scalars that were accumulated for the
|
190 |
+
given statistic between the last two calls to `update()`, or NaN if
|
191 |
+
no scalars were collected.
|
192 |
+
"""
|
193 |
+
delta = self._get_delta(name)
|
194 |
+
if int(delta[0]) == 0:
|
195 |
+
return float('nan')
|
196 |
+
return float(delta[1] / delta[0])
|
197 |
+
|
198 |
+
def std(self, name):
|
199 |
+
r"""Returns the standard deviation of the scalars that were
|
200 |
+
accumulated for the given statistic between the last two calls to
|
201 |
+
`update()`, or NaN if no scalars were collected.
|
202 |
+
"""
|
203 |
+
delta = self._get_delta(name)
|
204 |
+
if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
|
205 |
+
return float('nan')
|
206 |
+
if int(delta[0]) == 1:
|
207 |
+
return float(0)
|
208 |
+
mean = float(delta[1] / delta[0])
|
209 |
+
raw_var = float(delta[2] / delta[0])
|
210 |
+
return np.sqrt(max(raw_var - np.square(mean), 0))
|
211 |
+
|
212 |
+
def as_dict(self):
|
213 |
+
r"""Returns the averages accumulated between the last two calls to
|
214 |
+
`update()` as an `dnnlib.EasyDict`. The contents are as follows:
|
215 |
+
|
216 |
+
dnnlib.EasyDict(
|
217 |
+
NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
|
218 |
+
...
|
219 |
+
)
|
220 |
+
"""
|
221 |
+
stats = dnnlib.EasyDict()
|
222 |
+
for name in self.names():
|
223 |
+
stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
|
224 |
+
return stats
|
225 |
+
|
226 |
+
def __getitem__(self, name):
|
227 |
+
r"""Convenience getter.
|
228 |
+
`collector[name]` is a synonym for `collector.mean(name)`.
|
229 |
+
"""
|
230 |
+
return self.mean(name)
|
231 |
+
|
232 |
+
#----------------------------------------------------------------------------
|
233 |
+
|
234 |
+
def _sync(names):
|
235 |
+
r"""Synchronize the global cumulative counters across devices and
|
236 |
+
processes. Called internally by `Collector.update()`.
|
237 |
+
"""
|
238 |
+
if len(names) == 0:
|
239 |
+
return []
|
240 |
+
global _sync_called
|
241 |
+
_sync_called = True
|
242 |
+
|
243 |
+
# Collect deltas within current rank.
|
244 |
+
deltas = []
|
245 |
+
device = _sync_device if _sync_device is not None else torch.device('cpu')
|
246 |
+
for name in names:
|
247 |
+
delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
|
248 |
+
for counter in _counters[name].values():
|
249 |
+
delta.add_(counter.to(device))
|
250 |
+
counter.copy_(torch.zeros_like(counter))
|
251 |
+
deltas.append(delta)
|
252 |
+
deltas = torch.stack(deltas)
|
253 |
+
|
254 |
+
# Sum deltas across ranks.
|
255 |
+
if _sync_device is not None:
|
256 |
+
torch.distributed.all_reduce(deltas)
|
257 |
+
|
258 |
+
# Update cumulative values.
|
259 |
+
deltas = deltas.cpu()
|
260 |
+
for idx, name in enumerate(names):
|
261 |
+
if name not in _cumulative:
|
262 |
+
_cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
263 |
+
_cumulative[name].add_(deltas[idx])
|
264 |
+
|
265 |
+
# Return name-value pairs.
|
266 |
+
return [(name, _cumulative[name]) for name in names]
|
267 |
+
|
268 |
+
#----------------------------------------------------------------------------
|