File size: 4,523 Bytes
80914e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Ultralytics YOLO 🚀, GPL-3.0 license
import os

import pkg_resources as pkg

from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
from ultralytics.utils.torch_utils import model_info_for_loggers

try:
    from importlib.metadata import version

    import dvclive

    assert not TESTS_RUNNING  # do not log pytest
    assert SETTINGS['dvc'] is True  # verify integration is enabled

    ver = version('dvclive')
    if pkg.parse_version(ver) < pkg.parse_version('2.11.0'):
        LOGGER.debug(f'DVCLive is detected but version {ver} is incompatible (>=2.11 required).')
        dvclive = None  # noqa: F811
except (ImportError, AssertionError, TypeError):
    dvclive = None

# DVCLive logger instance
live = None
_processed_plots = {}

# `on_fit_epoch_end` is called on final validation (probably need to be fixed)
# for now this is the way we distinguish final evaluation of the best model vs
# last epoch validation
_training_epoch = False


def _logger_disabled():
    return os.getenv('ULTRALYTICS_DVC_DISABLED', 'false').lower() == 'true'


def _log_images(image_path, prefix=''):
    if live:
        live.log_image(os.path.join(prefix, image_path.name), image_path)


def _log_plots(plots, prefix=''):
    for name, params in plots.items():
        timestamp = params['timestamp']
        if _processed_plots.get(name) != timestamp:
            _log_images(name, prefix)
            _processed_plots[name] = timestamp


def _log_confusion_matrix(validator):
    targets = []
    preds = []
    matrix = validator.confusion_matrix.matrix
    names = list(validator.names.values())
    if validator.confusion_matrix.task == 'detect':
        names += ['background']

    for ti, pred in enumerate(matrix.T.astype(int)):
        for pi, num in enumerate(pred):
            targets.extend([names[ti]] * num)
            preds.extend([names[pi]] * num)

    live.log_sklearn_plot('confusion_matrix', targets, preds, name='cf.json', normalized=True)


def on_pretrain_routine_start(trainer):
    try:
        global live
        if not _logger_disabled():
            live = dvclive.Live(save_dvc_exp=True, cache_images=True)
            LOGGER.info(
                'DVCLive is detected and auto logging is enabled (can be disabled with `ULTRALYTICS_DVC_DISABLED=true`).'
            )
        else:
            LOGGER.debug('DVCLive is detected and auto logging is disabled via `ULTRALYTICS_DVC_DISABLED`.')
            live = None
    except Exception as e:
        LOGGER.warning(f'WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}')


def on_pretrain_routine_end(trainer):
    _log_plots(trainer.plots, 'train')


def on_train_start(trainer):
    if live:
        live.log_params(trainer.args)


def on_train_epoch_start(trainer):
    global _training_epoch
    _training_epoch = True


def on_fit_epoch_end(trainer):
    global _training_epoch
    if live and _training_epoch:
        all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
        for metric, value in all_metrics.items():
            live.log_metric(metric, value)

        if trainer.epoch == 0:
            for metric, value in model_info_for_loggers(trainer).items():
                live.log_metric(metric, value, plot=False)

        _log_plots(trainer.plots, 'train')
        _log_plots(trainer.validator.plots, 'val')

        live.next_step()
        _training_epoch = False


def on_train_end(trainer):
    if live:
        # At the end log the best metrics. It runs validator on the best model internally.
        all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
        for metric, value in all_metrics.items():
            live.log_metric(metric, value, plot=False)

        _log_plots(trainer.plots, 'val')
        _log_plots(trainer.validator.plots, 'val')
        _log_confusion_matrix(trainer.validator)

        if trainer.best.exists():
            live.log_artifact(trainer.best, copy=True)

        live.end()


callbacks = {
    'on_pretrain_routine_start': on_pretrain_routine_start,
    'on_pretrain_routine_end': on_pretrain_routine_end,
    'on_train_start': on_train_start,
    'on_train_epoch_start': on_train_epoch_start,
    'on_fit_epoch_end': on_fit_epoch_end,
    'on_train_end': on_train_end} if dvclive else {}