CorvaeOboro commited on
Commit
e8f045f
β€’
1 Parent(s): 6ed4042

Delete dnnlib/tflib/network.py

Browse files
Files changed (1) hide show
  1. dnnlib/tflib/network.py +0 -781
dnnlib/tflib/network.py DELETED
@@ -1,781 +0,0 @@
1
- ο»Ώ# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Helper for managing networks."""
10
-
11
- import types
12
- import inspect
13
- import re
14
- import uuid
15
- import sys
16
- import copy
17
- import numpy as np
18
- import tensorflow as tf
19
-
20
- from collections import OrderedDict
21
- from typing import Any, List, Tuple, Union, Callable
22
-
23
- from . import tfutil
24
- from .. import util
25
-
26
- from .tfutil import TfExpression, TfExpressionEx
27
-
28
- # pylint: disable=protected-access
29
- # pylint: disable=attribute-defined-outside-init
30
- # pylint: disable=too-many-public-methods
31
-
32
- _import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import.
33
- _import_module_src = dict() # Source code for temporary modules created during pickle import.
34
-
35
-
36
- def import_handler(handler_func):
37
- """Function decorator for declaring custom import handlers."""
38
- _import_handlers.append(handler_func)
39
- return handler_func
40
-
41
-
42
- class Network:
43
- """Generic network abstraction.
44
-
45
- Acts as a convenience wrapper for a parameterized network construction
46
- function, providing several utility methods and convenient access to
47
- the inputs/outputs/weights.
48
-
49
- Network objects can be safely pickled and unpickled for long-term
50
- archival purposes. The pickling works reliably as long as the underlying
51
- network construction function is defined in a standalone Python module
52
- that has no side effects or application-specific imports.
53
-
54
- Args:
55
- name: Network name. Used to select TensorFlow name and variable scopes. Defaults to build func name if None.
56
- func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
57
- static_kwargs: Keyword arguments to be passed in to the network construction function.
58
- """
59
-
60
- def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
61
- # Locate the user-specified build function.
62
- assert isinstance(func_name, str) or util.is_top_level_function(func_name)
63
- if util.is_top_level_function(func_name):
64
- func_name = util.get_top_level_function_name(func_name)
65
- module, func_name = util.get_module_from_obj_name(func_name)
66
- func = util.get_obj_from_module(module, func_name)
67
-
68
- # Dig up source code for the module containing the build function.
69
- module_src = _import_module_src.get(module, None)
70
- if module_src is None:
71
- module_src = inspect.getsource(module)
72
-
73
- # Initialize fields.
74
- self._init_fields(name=(name or func_name), static_kwargs=static_kwargs, build_func=func, build_func_name=func_name, build_module_src=module_src)
75
-
76
- def _init_fields(self, name: str, static_kwargs: dict, build_func: Callable, build_func_name: str, build_module_src: str) -> None:
77
- tfutil.assert_tf_initialized()
78
- assert isinstance(name, str)
79
- assert len(name) >= 1
80
- assert re.fullmatch(r"[A-Za-z0-9_.\\-]*", name)
81
- assert isinstance(static_kwargs, dict)
82
- assert util.is_pickleable(static_kwargs)
83
- assert callable(build_func)
84
- assert isinstance(build_func_name, str)
85
- assert isinstance(build_module_src, str)
86
-
87
- # Choose TensorFlow name scope.
88
- with tf.name_scope(None):
89
- scope = tf.get_default_graph().unique_name(name, mark_as_used=True)
90
-
91
- # Query current TensorFlow device.
92
- with tfutil.absolute_name_scope(scope), tf.control_dependencies(None):
93
- device = tf.no_op(name="_QueryDevice").device
94
-
95
- # Immutable state.
96
- self._name = name
97
- self._scope = scope
98
- self._device = device
99
- self._static_kwargs = util.EasyDict(copy.deepcopy(static_kwargs))
100
- self._build_func = build_func
101
- self._build_func_name = build_func_name
102
- self._build_module_src = build_module_src
103
-
104
- # State before _init_graph().
105
- self._var_inits = dict() # var_name => initial_value, set to None by _init_graph()
106
- self._all_inits_known = False # Do we know for sure that _var_inits covers all the variables?
107
- self._components = None # subnet_name => Network, None if the components are not known yet
108
-
109
- # Initialized by _init_graph().
110
- self._input_templates = None
111
- self._output_templates = None
112
- self._own_vars = None
113
-
114
- # Cached values initialized the respective methods.
115
- self._input_shapes = None
116
- self._output_shapes = None
117
- self._input_names = None
118
- self._output_names = None
119
- self._vars = None
120
- self._trainables = None
121
- self._var_global_to_local = None
122
- self._run_cache = dict()
123
-
124
- def _init_graph(self) -> None:
125
- assert self._var_inits is not None
126
- assert self._input_templates is None
127
- assert self._output_templates is None
128
- assert self._own_vars is None
129
-
130
- # Initialize components.
131
- if self._components is None:
132
- self._components = util.EasyDict()
133
-
134
- # Choose build func kwargs.
135
- build_kwargs = dict(self.static_kwargs)
136
- build_kwargs["is_template_graph"] = True
137
- build_kwargs["components"] = self._components
138
-
139
- # Override scope and device, and ignore surrounding control dependencies.
140
- with tfutil.absolute_variable_scope(self.scope, reuse=False), tfutil.absolute_name_scope(self.scope), tf.device(self.device), tf.control_dependencies(None):
141
- assert tf.get_variable_scope().name == self.scope
142
- assert tf.get_default_graph().get_name_scope() == self.scope
143
-
144
- # Create input templates.
145
- self._input_templates = []
146
- for param in inspect.signature(self._build_func).parameters.values():
147
- if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
148
- self._input_templates.append(tf.placeholder(tf.float32, name=param.name))
149
-
150
- # Call build func.
151
- out_expr = self._build_func(*self._input_templates, **build_kwargs)
152
-
153
- # Collect output templates and variables.
154
- assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
155
- self._output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
156
- self._own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/"))
157
-
158
- # Check for errors.
159
- if len(self._input_templates) == 0:
160
- raise ValueError("Network build func did not list any inputs.")
161
- if len(self._output_templates) == 0:
162
- raise ValueError("Network build func did not return any outputs.")
163
- if any(not tfutil.is_tf_expression(t) for t in self._output_templates):
164
- raise ValueError("Network outputs must be TensorFlow expressions.")
165
- if any(t.shape.ndims is None for t in self._input_templates):
166
- raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
167
- if any(t.shape.ndims is None for t in self._output_templates):
168
- raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
169
- if any(not isinstance(comp, Network) for comp in self._components.values()):
170
- raise ValueError("Components of a Network must be Networks themselves.")
171
- if len(self._components) != len(set(comp.name for comp in self._components.values())):
172
- raise ValueError("Components of a Network must have unique names.")
173
-
174
- # Initialize variables.
175
- if len(self._var_inits):
176
- tfutil.set_vars({self._get_vars()[name]: value for name, value in self._var_inits.items() if name in self._get_vars()})
177
- remaining_inits = [var.initializer for name, var in self._own_vars.items() if name not in self._var_inits]
178
- if self._all_inits_known:
179
- assert len(remaining_inits) == 0
180
- else:
181
- tfutil.run(remaining_inits)
182
- self._var_inits = None
183
-
184
- @property
185
- def name(self):
186
- """User-specified name string."""
187
- return self._name
188
-
189
- @property
190
- def scope(self):
191
- """Unique TensorFlow scope containing template graph and variables, derived from the user-specified name."""
192
- return self._scope
193
-
194
- @property
195
- def device(self):
196
- """Name of the TensorFlow device that the weights of this network reside on. Determined by the current device at construction time."""
197
- return self._device
198
-
199
- @property
200
- def static_kwargs(self):
201
- """EasyDict of arguments passed to the user-supplied build func."""
202
- return copy.deepcopy(self._static_kwargs)
203
-
204
- @property
205
- def components(self):
206
- """EasyDict of sub-networks created by the build func."""
207
- return copy.copy(self._get_components())
208
-
209
- def _get_components(self):
210
- if self._components is None:
211
- self._init_graph()
212
- assert self._components is not None
213
- return self._components
214
-
215
- @property
216
- def input_shapes(self):
217
- """List of input tensor shapes, including minibatch dimension."""
218
- if self._input_shapes is None:
219
- self._input_shapes = [t.shape.as_list() for t in self.input_templates]
220
- return copy.deepcopy(self._input_shapes)
221
-
222
- @property
223
- def output_shapes(self):
224
- """List of output tensor shapes, including minibatch dimension."""
225
- if self._output_shapes is None:
226
- self._output_shapes = [t.shape.as_list() for t in self.output_templates]
227
- return copy.deepcopy(self._output_shapes)
228
-
229
- @property
230
- def input_shape(self):
231
- """Short-hand for input_shapes[0]."""
232
- return self.input_shapes[0]
233
-
234
- @property
235
- def output_shape(self):
236
- """Short-hand for output_shapes[0]."""
237
- return self.output_shapes[0]
238
-
239
- @property
240
- def num_inputs(self):
241
- """Number of input tensors."""
242
- return len(self.input_shapes)
243
-
244
- @property
245
- def num_outputs(self):
246
- """Number of output tensors."""
247
- return len(self.output_shapes)
248
-
249
- @property
250
- def input_names(self):
251
- """Name string for each input."""
252
- if self._input_names is None:
253
- self._input_names = [t.name.split("/")[-1].split(":")[0] for t in self.input_templates]
254
- return copy.copy(self._input_names)
255
-
256
- @property
257
- def output_names(self):
258
- """Name string for each output."""
259
- if self._output_names is None:
260
- self._output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
261
- return copy.copy(self._output_names)
262
-
263
- @property
264
- def input_templates(self):
265
- """Input placeholders in the template graph."""
266
- if self._input_templates is None:
267
- self._init_graph()
268
- assert self._input_templates is not None
269
- return copy.copy(self._input_templates)
270
-
271
- @property
272
- def output_templates(self):
273
- """Output tensors in the template graph."""
274
- if self._output_templates is None:
275
- self._init_graph()
276
- assert self._output_templates is not None
277
- return copy.copy(self._output_templates)
278
-
279
- @property
280
- def own_vars(self):
281
- """Variables defined by this network (local_name => var), excluding sub-networks."""
282
- return copy.copy(self._get_own_vars())
283
-
284
- def _get_own_vars(self):
285
- if self._own_vars is None:
286
- self._init_graph()
287
- assert self._own_vars is not None
288
- return self._own_vars
289
-
290
- @property
291
- def vars(self):
292
- """All variables (local_name => var)."""
293
- return copy.copy(self._get_vars())
294
-
295
- def _get_vars(self):
296
- if self._vars is None:
297
- self._vars = OrderedDict(self._get_own_vars())
298
- for comp in self._get_components().values():
299
- self._vars.update((comp.name + "/" + name, var) for name, var in comp._get_vars().items())
300
- return self._vars
301
-
302
- @property
303
- def trainables(self):
304
- """All trainable variables (local_name => var)."""
305
- return copy.copy(self._get_trainables())
306
-
307
- def _get_trainables(self):
308
- if self._trainables is None:
309
- self._trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
310
- return self._trainables
311
-
312
- @property
313
- def var_global_to_local(self):
314
- """Mapping from variable global names to local names."""
315
- return copy.copy(self._get_var_global_to_local())
316
-
317
- def _get_var_global_to_local(self):
318
- if self._var_global_to_local is None:
319
- self._var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
320
- return self._var_global_to_local
321
-
322
- def reset_own_vars(self) -> None:
323
- """Re-initialize all variables of this network, excluding sub-networks."""
324
- if self._var_inits is None or self._components is None:
325
- tfutil.run([var.initializer for var in self._get_own_vars().values()])
326
- else:
327
- self._var_inits.clear()
328
- self._all_inits_known = False
329
-
330
- def reset_vars(self) -> None:
331
- """Re-initialize all variables of this network, including sub-networks."""
332
- if self._var_inits is None:
333
- tfutil.run([var.initializer for var in self._get_vars().values()])
334
- else:
335
- self._var_inits.clear()
336
- self._all_inits_known = False
337
- if self._components is not None:
338
- for comp in self._components.values():
339
- comp.reset_vars()
340
-
341
- def reset_trainables(self) -> None:
342
- """Re-initialize all trainable variables of this network, including sub-networks."""
343
- tfutil.run([var.initializer for var in self._get_trainables().values()])
344
-
345
- def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
346
- """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s).
347
- The graph is placed on the current TensorFlow device."""
348
- assert len(in_expr) == self.num_inputs
349
- assert not all(expr is None for expr in in_expr)
350
- self._get_vars() # ensure that all variables have been created
351
-
352
- # Choose build func kwargs.
353
- build_kwargs = dict(self.static_kwargs)
354
- build_kwargs.update(dynamic_kwargs)
355
- build_kwargs["is_template_graph"] = False
356
- build_kwargs["components"] = self._components
357
-
358
- # Build TensorFlow graph to evaluate the network.
359
- with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
360
- assert tf.get_variable_scope().name == self.scope
361
- valid_inputs = [expr for expr in in_expr if expr is not None]
362
- final_inputs = []
363
- for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
364
- if expr is not None:
365
- expr = tf.identity(expr, name=name)
366
- else:
367
- expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
368
- final_inputs.append(expr)
369
- out_expr = self._build_func(*final_inputs, **build_kwargs)
370
-
371
- # Propagate input shapes back to the user-specified expressions.
372
- for expr, final in zip(in_expr, final_inputs):
373
- if isinstance(expr, tf.Tensor):
374
- expr.set_shape(final.shape)
375
-
376
- # Express outputs in the desired format.
377
- assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
378
- if return_as_list:
379
- out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
380
- return out_expr
381
-
382
- def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
383
- """Get the local name of a given variable, without any surrounding name scopes."""
384
- assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
385
- global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
386
- return self._get_var_global_to_local()[global_name]
387
-
388
- def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
389
- """Find variable by local or global name."""
390
- assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
391
- return self._get_vars()[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
392
-
393
- def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
394
- """Get the value of a given variable as NumPy array.
395
- Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
396
- return self.find_var(var_or_local_name).eval()
397
-
398
- def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
399
- """Set the value of a given variable based on the given NumPy array.
400
- Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
401
- tfutil.set_vars({self.find_var(var_or_local_name): new_value})
402
-
403
- def __getstate__(self) -> dict:
404
- """Pickle export."""
405
- state = dict()
406
- state["version"] = 5
407
- state["name"] = self.name
408
- state["static_kwargs"] = dict(self.static_kwargs)
409
- state["components"] = dict(self.components)
410
- state["build_module_src"] = self._build_module_src
411
- state["build_func_name"] = self._build_func_name
412
- state["variables"] = list(zip(self._get_own_vars().keys(), tfutil.run(list(self._get_own_vars().values()))))
413
- state["input_shapes"] = self.input_shapes
414
- state["output_shapes"] = self.output_shapes
415
- state["input_names"] = self.input_names
416
- state["output_names"] = self.output_names
417
- return state
418
-
419
- def __setstate__(self, state: dict) -> None:
420
- """Pickle import."""
421
-
422
- # Execute custom import handlers.
423
- for handler in _import_handlers:
424
- state = handler(state)
425
-
426
- # Get basic fields.
427
- assert state["version"] in [2, 3, 4, 5]
428
- name = state["name"]
429
- static_kwargs = state["static_kwargs"]
430
- build_module_src = state["build_module_src"]
431
- build_func_name = state["build_func_name"]
432
-
433
- # Create temporary module from the imported source code.
434
- module_name = "_tflib_network_import_" + uuid.uuid4().hex
435
- module = types.ModuleType(module_name)
436
- sys.modules[module_name] = module
437
- _import_module_src[module] = build_module_src
438
- exec(build_module_src, module.__dict__) # pylint: disable=exec-used
439
- build_func = util.get_obj_from_module(module, build_func_name)
440
-
441
- # Initialize fields.
442
- self._init_fields(name=name, static_kwargs=static_kwargs, build_func=build_func, build_func_name=build_func_name, build_module_src=build_module_src)
443
- self._var_inits.update(copy.deepcopy(state["variables"]))
444
- self._all_inits_known = True
445
- self._components = util.EasyDict(state.get("components", {}))
446
- self._input_shapes = copy.deepcopy(state.get("input_shapes", None))
447
- self._output_shapes = copy.deepcopy(state.get("output_shapes", None))
448
- self._input_names = copy.deepcopy(state.get("input_names", None))
449
- self._output_names = copy.deepcopy(state.get("output_names", None))
450
-
451
- def clone(self, name: str = None, **new_static_kwargs) -> "Network":
452
- """Create a clone of this network with its own copy of the variables."""
453
- static_kwargs = dict(self.static_kwargs)
454
- static_kwargs.update(new_static_kwargs)
455
- net = object.__new__(Network)
456
- net._init_fields(name=(name or self.name), static_kwargs=static_kwargs, build_func=self._build_func, build_func_name=self._build_func_name, build_module_src=self._build_module_src)
457
- net.copy_vars_from(self)
458
- return net
459
-
460
- def copy_own_vars_from(self, src_net: "Network") -> None:
461
- """Copy the values of all variables from the given network, excluding sub-networks."""
462
-
463
- # Source has unknown variables or unknown components => init now.
464
- if (src_net._var_inits is not None and not src_net._all_inits_known) or src_net._components is None:
465
- src_net._get_vars()
466
-
467
- # Both networks are inited => copy directly.
468
- if src_net._var_inits is None and self._var_inits is None:
469
- names = [name for name in self._get_own_vars().keys() if name in src_net._get_own_vars()]
470
- tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names}))
471
- return
472
-
473
- # Read from source.
474
- if src_net._var_inits is None:
475
- value_dict = tfutil.run(src_net._get_own_vars())
476
- else:
477
- value_dict = src_net._var_inits
478
-
479
- # Write to destination.
480
- if self._var_inits is None:
481
- tfutil.set_vars({self._get_vars()[name]: value for name, value in value_dict.items() if name in self._get_vars()})
482
- else:
483
- self._var_inits.update(value_dict)
484
-
485
- def copy_vars_from(self, src_net: "Network") -> None:
486
- """Copy the values of all variables from the given network, including sub-networks."""
487
-
488
- # Source has unknown variables or unknown components => init now.
489
- if (src_net._var_inits is not None and not src_net._all_inits_known) or src_net._components is None:
490
- src_net._get_vars()
491
-
492
- # Source is inited, but destination components have not been created yet => set as initial values.
493
- if src_net._var_inits is None and self._components is None:
494
- self._var_inits.update(tfutil.run(src_net._get_vars()))
495
- return
496
-
497
- # Destination has unknown components => init now.
498
- if self._components is None:
499
- self._get_vars()
500
-
501
- # Both networks are inited => copy directly.
502
- if src_net._var_inits is None and self._var_inits is None:
503
- names = [name for name in self._get_vars().keys() if name in src_net._get_vars()]
504
- tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names}))
505
- return
506
-
507
- # Copy recursively, component by component.
508
- self.copy_own_vars_from(src_net)
509
- for name, src_comp in src_net._components.items():
510
- if name in self._components:
511
- self._components[name].copy_vars_from(src_comp)
512
-
513
- def copy_trainables_from(self, src_net: "Network") -> None:
514
- """Copy the values of all trainable variables from the given network, including sub-networks."""
515
- names = [name for name in self._get_trainables().keys() if name in src_net._get_trainables()]
516
- tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names}))
517
-
518
- def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
519
- """Create new network with the given parameters, and copy all variables from this network."""
520
- if new_name is None:
521
- new_name = self.name
522
- static_kwargs = dict(self.static_kwargs)
523
- static_kwargs.update(new_static_kwargs)
524
- net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
525
- net.copy_vars_from(self)
526
- return net
527
-
528
- def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
529
- """Construct a TensorFlow op that updates the variables of this network
530
- to be slightly closer to those of the given network."""
531
- with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
532
- ops = []
533
- for name, var in self._get_vars().items():
534
- if name in src_net._get_vars():
535
- cur_beta = beta if var.trainable else beta_nontrainable
536
- new_value = tfutil.lerp(src_net._get_vars()[name], var, cur_beta)
537
- ops.append(var.assign(new_value))
538
- return tf.group(*ops)
539
-
540
- def run(self,
541
- *in_arrays: Tuple[Union[np.ndarray, None], ...],
542
- input_transform: dict = None,
543
- output_transform: dict = None,
544
- return_as_list: bool = False,
545
- print_progress: bool = False,
546
- minibatch_size: int = None,
547
- num_gpus: int = 1,
548
- assume_frozen: bool = False,
549
- **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
550
- """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
551
-
552
- Args:
553
- input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
554
- The dict must contain a 'func' field that points to a top-level function. The function is called with the input
555
- TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
556
- output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
557
- The dict must contain a 'func' field that points to a top-level function. The function is called with the output
558
- TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
559
- return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
560
- print_progress: Print progress to the console? Useful for very large input arrays.
561
- minibatch_size: Maximum minibatch size to use, None = disable batching.
562
- num_gpus: Number of GPUs to use.
563
- assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
564
- dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
565
- """
566
- assert len(in_arrays) == self.num_inputs
567
- assert not all(arr is None for arr in in_arrays)
568
- assert input_transform is None or util.is_top_level_function(input_transform["func"])
569
- assert output_transform is None or util.is_top_level_function(output_transform["func"])
570
- output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
571
- num_items = in_arrays[0].shape[0]
572
- if minibatch_size is None:
573
- minibatch_size = num_items
574
-
575
- # Construct unique hash key from all arguments that affect the TensorFlow graph.
576
- key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
577
- def unwind_key(obj):
578
- if isinstance(obj, dict):
579
- return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
580
- if callable(obj):
581
- return util.get_top_level_function_name(obj)
582
- return obj
583
- key = repr(unwind_key(key))
584
-
585
- # Build graph.
586
- if key not in self._run_cache:
587
- with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
588
- with tf.device("/cpu:0"):
589
- in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
590
- in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
591
-
592
- out_split = []
593
- for gpu in range(num_gpus):
594
- with tf.device(self.device if num_gpus == 1 else "/gpu:%d" % gpu):
595
- net_gpu = self.clone() if assume_frozen else self
596
- in_gpu = in_split[gpu]
597
-
598
- if input_transform is not None:
599
- in_kwargs = dict(input_transform)
600
- in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
601
- in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
602
-
603
- assert len(in_gpu) == self.num_inputs
604
- out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
605
-
606
- if output_transform is not None:
607
- out_kwargs = dict(output_transform)
608
- out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
609
- out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
610
-
611
- assert len(out_gpu) == self.num_outputs
612
- out_split.append(out_gpu)
613
-
614
- with tf.device("/cpu:0"):
615
- out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
616
- self._run_cache[key] = in_expr, out_expr
617
-
618
- # Run minibatches.
619
- in_expr, out_expr = self._run_cache[key]
620
- out_arrays = [np.empty([num_items] + expr.shape.as_list()[1:], expr.dtype.name) for expr in out_expr]
621
-
622
- for mb_begin in range(0, num_items, minibatch_size):
623
- if print_progress:
624
- print("\r%d / %d" % (mb_begin, num_items), end="")
625
-
626
- mb_end = min(mb_begin + minibatch_size, num_items)
627
- mb_num = mb_end - mb_begin
628
- mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
629
- mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
630
-
631
- for dst, src in zip(out_arrays, mb_out):
632
- dst[mb_begin: mb_end] = src
633
-
634
- # Done.
635
- if print_progress:
636
- print("\r%d / %d" % (num_items, num_items))
637
-
638
- if not return_as_list:
639
- out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
640
- return out_arrays
641
-
642
- def list_ops(self) -> List[TfExpression]:
643
- _ = self.output_templates # ensure that the template graph has been created
644
- include_prefix = self.scope + "/"
645
- exclude_prefix = include_prefix + "_"
646
- ops = tf.get_default_graph().get_operations()
647
- ops = [op for op in ops if op.name.startswith(include_prefix)]
648
- ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
649
- return ops
650
-
651
- def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
652
- """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
653
- individual layers of the network. Mainly intended to be used for reporting."""
654
- layers = []
655
-
656
- def recurse(scope, parent_ops, parent_vars, level):
657
- if len(parent_ops) == 0 and len(parent_vars) == 0:
658
- return
659
-
660
- # Ignore specific patterns.
661
- if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
662
- return
663
-
664
- # Filter ops and vars by scope.
665
- global_prefix = scope + "/"
666
- local_prefix = global_prefix[len(self.scope) + 1:]
667
- cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
668
- cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
669
- if not cur_ops and not cur_vars:
670
- return
671
-
672
- # Filter out all ops related to variables.
673
- for var in [op for op in cur_ops if op.type.startswith("Variable")]:
674
- var_prefix = var.name + "/"
675
- cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
676
-
677
- # Scope does not contain ops as immediate children => recurse deeper.
678
- contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type not in ["Identity", "Cast", "Transpose"] for op in cur_ops)
679
- if (level == 0 or not contains_direct_ops) and (len(cur_ops) != 0 or len(cur_vars) != 0):
680
- visited = set()
681
- for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
682
- token = rel_name.split("/")[0]
683
- if token not in visited:
684
- recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
685
- visited.add(token)
686
- return
687
-
688
- # Report layer.
689
- layer_name = scope[len(self.scope) + 1:]
690
- layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
691
- layer_trainables = [var for _name, var in cur_vars if var.trainable]
692
- layers.append((layer_name, layer_output, layer_trainables))
693
-
694
- recurse(self.scope, self.list_ops(), list(self._get_vars().items()), 0)
695
- return layers
696
-
697
- def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
698
- """Print a summary table of the network structure."""
699
- rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
700
- rows += [["---"] * 4]
701
- total_params = 0
702
-
703
- for layer_name, layer_output, layer_trainables in self.list_layers():
704
- num_params = sum(int(np.prod(var.shape.as_list())) for var in layer_trainables)
705
- weights = [var for var in layer_trainables if var.name.endswith("/weight:0")]
706
- weights.sort(key=lambda x: len(x.name))
707
- if len(weights) == 0 and len(layer_trainables) == 1:
708
- weights = layer_trainables
709
- total_params += num_params
710
-
711
- if not hide_layers_with_no_params or num_params != 0:
712
- num_params_str = str(num_params) if num_params > 0 else "-"
713
- output_shape_str = str(layer_output.shape)
714
- weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
715
- rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
716
-
717
- rows += [["---"] * 4]
718
- rows += [["Total", str(total_params), "", ""]]
719
-
720
- widths = [max(len(cell) for cell in column) for column in zip(*rows)]
721
- print()
722
- for row in rows:
723
- print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
724
- print()
725
-
726
- def setup_weight_histograms(self, title: str = None) -> None:
727
- """Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
728
- if title is None:
729
- title = self.name
730
-
731
- with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
732
- for local_name, var in self._get_trainables().items():
733
- if "/" in local_name:
734
- p = local_name.split("/")
735
- name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
736
- else:
737
- name = title + "_toplevel/" + local_name
738
-
739
- tf.summary.histogram(name, var)
740
-
741
- #----------------------------------------------------------------------------
742
- # Backwards-compatible emulation of legacy output transformation in Network.run().
743
-
744
- _print_legacy_warning = True
745
-
746
- def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
747
- global _print_legacy_warning
748
- legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
749
- if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
750
- return output_transform, dynamic_kwargs
751
-
752
- if _print_legacy_warning:
753
- _print_legacy_warning = False
754
- print()
755
- print("WARNING: Old-style output transformations in Network.run() are deprecated.")
756
- print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
757
- print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
758
- print()
759
- assert output_transform is None
760
-
761
- new_kwargs = dict(dynamic_kwargs)
762
- new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
763
- new_transform["func"] = _legacy_output_transform_func
764
- return new_transform, new_kwargs
765
-
766
- def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
767
- if out_mul != 1.0:
768
- expr = [x * out_mul for x in expr]
769
-
770
- if out_add != 0.0:
771
- expr = [x + out_add for x in expr]
772
-
773
- if out_shrink > 1:
774
- ksize = [1, 1, out_shrink, out_shrink]
775
- expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
776
-
777
- if out_dtype is not None:
778
- if tf.as_dtype(out_dtype).is_integer:
779
- expr = [tf.round(x) for x in expr]
780
- expr = [tf.saturate_cast(x, out_dtype) for x in expr]
781
- return expr