File size: 4,553 Bytes
19c8b95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import atexit
import glob
import os
import re
import numpy as np

from tensorboardX import SummaryWriter

import dllogger as DLLogger
from dllogger import StdOutBackend, JSONStreamBackend, Verbosity


def unique_dllogger_fpath(log_fpath):

    if not os.path.isfile(log_fpath):
        return log_fpath

    # Avoid overwriting old logs
    saved = sorted([int(re.search('\.(\d+)', f).group(1))
                    for f in glob.glob(f'{log_fpath}.*')])

    log_num = (saved[-1] if saved else 0) + 1
    return f'{log_fpath}.{log_num}'


def stdout_step_format(step):
    if isinstance(step, str):
        return step
    fields = []
    if len(step) > 0:
        fields.append("epoch {:>4}".format(step[0]))
    if len(step) > 1:
        fields.append("iter {:>3}".format(step[1]))
    if len(step) > 2:
        fields[-1] += "/{}".format(step[2])
    return " | ".join(fields)


def stdout_metric_format(metric, metadata, value):
    name = metadata["name"] if "name" in metadata.keys() else metric + " : "
    unit = metadata["unit"] if "unit" in metadata.keys() else None
    format = "{" + metadata["format"] + "}" if "format" in metadata.keys() else "{}"
    fields = [name, format.format(value) if value is not None else value, unit]
    fields = filter(lambda f: f is not None, fields)
    return "| " + " ".join(fields)


def init_dllogger(log_fpath=None, dummy=False):
    if dummy:
        DLLogger.init(backends=[])
        return
    DLLogger.init(backends=[
        JSONStreamBackend(Verbosity.DEFAULT, log_fpath),
        StdOutBackend(Verbosity.VERBOSE, step_format=stdout_step_format,
                      metric_format=stdout_metric_format)
        ]
    )
    DLLogger.metadata("train_loss", {"name": "loss", "format": ":>5.2f"})
    DLLogger.metadata("train_mel_loss", {"name": "mel loss", "format": ":>5.2f"})
    DLLogger.metadata("avg_train_loss", {"name": "avg train loss", "format": ":>5.2f"})
    DLLogger.metadata("avg_train_mel_loss", {"name": "avg train mel loss", "format": ":>5.2f"})
    DLLogger.metadata("val_loss", {"name": "  avg val loss", "format": ":>5.2f"})
    DLLogger.metadata("val_mel_loss", {"name": "  avg val mel loss", "format": ":>5.2f"})
    DLLogger.metadata(
        "val_ema_loss",
        {"name": "  EMA val loss", "format": ":>5.2f"})
    DLLogger.metadata(
        "val_ema_mel_loss",
        {"name": "  EMA val mel loss", "format": ":>5.2f"})
    DLLogger.metadata(
        "train_frames/s", {"name": None, "unit": "frames/s", "format": ":>10.2f"})
    DLLogger.metadata(
        "avg_train_frames/s", {"name": None, "unit": "frames/s", "format": ":>10.2f"})
    DLLogger.metadata(
        "val_frames/s", {"name": None, "unit": "frames/s", "format": ":>10.2f"})
    DLLogger.metadata(
        "val_ema_frames/s", {"name": None, "unit": "frames/s", "format": ":>10.2f"})
    DLLogger.metadata(
        "took", {"name": "took", "unit": "s", "format": ":>3.2f"})
    DLLogger.metadata("lrate_change", {"name": "lrate"})


class TBLogger(object):
    """
    xyz_dummies: stretch the screen with empty plots so the legend would
                 always fit for other plots
    """
    def __init__(self, local_rank, log_dir, name, interval=1, dummies=False):
        self.enabled = (local_rank == 0)
        self.interval = interval
        self.cache = {}
        if local_rank == 0:
            self.summary_writer = SummaryWriter(
                log_dir=os.path.join(log_dir, name),
                flush_secs=120, max_queue=200)
            atexit.register(self.summary_writer.close)
            if dummies:
                for key in ('aaa', 'zzz'):
                    self.summary_writer.add_scalar(key, 0.0, 1)

    def log_value(self, step, key, val, stat='mean'):
        if self.enabled:
            if key not in self.cache:
                self.cache[key] = []
            self.cache[key].append(val)
            if len(self.cache[key]) == self.interval:
                agg_val = getattr(np, stat)(self.cache[key])
                self.summary_writer.add_scalar(key, agg_val, step)
                del self.cache[key]

    def log_meta(self, step, meta):
        for k, v in meta.items():
            self.log_value(step, k, v.item())

    def log_grads(self, step, model):
        if self.enabled:
            norms = [p.grad.norm().item() for p in model.parameters()
                     if p.grad is not None]
            for stat in ('max', 'min', 'mean'):
                self.log_value(step, f'grad_{stat}', getattr(np, stat)(norms),
                               stat=stat)