File size: 6,479 Bytes
ba5dcdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.

"""Common definitions for GAN metrics."""

import os
import time
import hashlib
import numpy as np
import tensorflow as tf
import dnnlib
import dnnlib.tflib as tflib

import config
from training import misc
from training import dataset

#----------------------------------------------------------------------------
# Standard metrics.

fid50k = dnnlib.EasyDict(func_name='metrics.frechet_inception_distance.FID', name='fid50k', num_images=50000, minibatch_per_gpu=8)
ppl_zfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zfull', num_samples=100000, epsilon=1e-4, space='z', sampling='full', minibatch_per_gpu=16)
ppl_wfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wfull', num_samples=100000, epsilon=1e-4, space='w', sampling='full', minibatch_per_gpu=16)
ppl_zend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zend', num_samples=100000, epsilon=1e-4, space='z', sampling='end', minibatch_per_gpu=16)
ppl_wend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wend', num_samples=100000, epsilon=1e-4, space='w', sampling='end', minibatch_per_gpu=16)
ls = dnnlib.EasyDict(func_name='metrics.linear_separability.LS', name='ls', num_samples=200000, num_keep=100000, attrib_indices=range(40), minibatch_per_gpu=4)
dummy = dnnlib.EasyDict(func_name='metrics.metric_base.DummyMetric', name='dummy') # for debugging

#----------------------------------------------------------------------------
# Base class for metrics.

class MetricBase:
    def __init__(self, name):
        self.name = name
        self._network_pkl = None
        self._dataset_args = None
        self._mirror_augment = None
        self._results = []
        self._eval_time = None

    def run(self, network_pkl, run_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True):
        self._network_pkl = network_pkl
        self._dataset_args = dataset_args
        self._mirror_augment = mirror_augment
        self._results = []

        if (dataset_args is None or mirror_augment is None) and run_dir is not None:
            run_config = misc.parse_config_for_previous_run(run_dir)
            self._dataset_args = dict(run_config['dataset'])
            self._dataset_args['shuffle_mb'] = 0
            self._mirror_augment = run_config['train'].get('mirror_augment', False)

        time_begin = time.time()
        with tf.Graph().as_default(), tflib.create_session(tf_config).as_default(): # pylint: disable=not-context-manager
            _G, _D, Gs = misc.load_pkl(self._network_pkl)
            self._evaluate(Gs, num_gpus=num_gpus)
        self._eval_time = time.time() - time_begin

        if log_results:
            result_str = self.get_result_str()
            if run_dir is not None:
                log = os.path.join(run_dir, 'metric-%s.txt' % self.name)
                with dnnlib.util.Logger(log, 'a'):
                    print(result_str)
            else:
                print(result_str)

    def get_result_str(self):
        network_name = os.path.splitext(os.path.basename(self._network_pkl))[0]
        if len(network_name) > 29:
            network_name = '...' + network_name[-26:]
        result_str = '%-30s' % network_name
        result_str += ' time %-12s' % dnnlib.util.format_time(self._eval_time)
        for res in self._results:
            result_str += ' ' + self.name + res.suffix + ' '
            result_str += res.fmt % res.value
        return result_str

    def update_autosummaries(self):
        for res in self._results:
            tflib.autosummary.autosummary('Metrics/' + self.name + res.suffix, res.value)

    def _evaluate(self, Gs, num_gpus):
        raise NotImplementedError # to be overridden by subclasses

    def _report_result(self, value, suffix='', fmt='%-10.4f'):
        self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)]

    def _get_cache_file_for_reals(self, extension='pkl', **kwargs):
        all_args = dnnlib.EasyDict(metric_name=self.name, mirror_augment=self._mirror_augment)
        all_args.update(self._dataset_args)
        all_args.update(kwargs)
        md5 = hashlib.md5(repr(sorted(all_args.items())).encode('utf-8'))
        dataset_name = self._dataset_args['tfrecord_dir'].replace('\\', '/').split('/')[-1]
        return os.path.join(config.cache_dir, '%s-%s-%s.%s' % (md5.hexdigest(), self.name, dataset_name, extension))

    def _iterate_reals(self, minibatch_size):
        dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **self._dataset_args)
        while True:
            images, _labels = dataset_obj.get_minibatch_np(minibatch_size)
            if self._mirror_augment:
                images = misc.apply_mirror_augment(images)
            yield images

    def _iterate_fakes(self, Gs, minibatch_size, num_gpus):
        while True:
            latents = np.random.randn(minibatch_size, *Gs.input_shape[1:])
            fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
            images = Gs.run(latents, None, output_transform=fmt, is_validation=True, num_gpus=num_gpus, assume_frozen=True)
            yield images

#----------------------------------------------------------------------------
# Group of multiple metrics.

class MetricGroup:
    def __init__(self, metric_kwarg_list):
        self.metrics = [dnnlib.util.call_func_by_name(**kwargs) for kwargs in metric_kwarg_list]

    def run(self, *args, **kwargs):
        for metric in self.metrics:
            metric.run(*args, **kwargs)

    def get_result_str(self):
        return ' '.join(metric.get_result_str() for metric in self.metrics)

    def update_autosummaries(self):
        for metric in self.metrics:
            metric.update_autosummaries()

#----------------------------------------------------------------------------
# Dummy metric for debugging purposes.

class DummyMetric(MetricBase):
    def _evaluate(self, Gs, num_gpus):
        _ = Gs, num_gpus
        self._report_result(0.0)

#----------------------------------------------------------------------------