File size: 3,735 Bytes
e45d058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Adapted from https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/gpu_stats_monitor.html#GPUStatsMonitor
# We only need the speed monitoring, not the GPU monitoring
import time
from typing import Any

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT


class SpeedMonitor(Callback):
    """Monitor the speed of each step and each epoch.

    """
    def __init__(self, intra_step_time: bool = True, inter_step_time: bool = True,

                 epoch_time: bool = True, verbose=False):
        super().__init__()
        self._log_stats = AttributeDict(
            {
                'intra_step_time': intra_step_time,
                'inter_step_time': inter_step_time,
                'epoch_time': epoch_time,
            }
        )
        self.verbose = verbose

    def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        self._snap_epoch_time = None

    def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        self._snap_intra_step_time = None
        self._snap_inter_step_time = None
        self._snap_epoch_time = time.time()

    def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        self._snap_inter_step_time = None

    def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        self._snap_inter_step_time = None

    @rank_zero_only
    def on_train_batch_start(

        self,

        trainer: "pl.Trainer",

        pl_module: "pl.LightningModule",

        batch: Any,

        batch_idx: int,

    ) -> None:
        if self._log_stats.intra_step_time:
            self._snap_intra_step_time = time.time()

        if not trainer._logger_connector.should_update_logs:
            return

        logs = {}
        if self._log_stats.inter_step_time and self._snap_inter_step_time:
            # First log at beginning of second step
            logs["time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000

        if trainer.logger is not None:
            trainer.logger.log_metrics(logs, step=trainer.global_step)

    @rank_zero_only
    def on_train_batch_end(

        self,

        trainer: "pl.Trainer",

        pl_module: "pl.LightningModule",

        outputs: STEP_OUTPUT,

        batch: Any,

        batch_idx: int,

    ) -> None:
        if self._log_stats.inter_step_time:
            self._snap_inter_step_time = time.time()

        if self.verbose and self._log_stats.intra_step_time and self._snap_intra_step_time:
            pl_module.print(f"time/intra_step (ms): {(time.time() - self._snap_intra_step_time) * 1000}")

        if not trainer._logger_connector.should_update_logs:
            return

        logs = {}
        if self._log_stats.intra_step_time and self._snap_intra_step_time:
            logs["time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000

        if trainer.logger is not None:
            trainer.logger.log_metrics(logs, step=trainer.global_step)

    @rank_zero_only
    def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule",) -> None:
        logs = {}
        if self._log_stats.epoch_time and self._snap_epoch_time:
            logs["time/epoch (s)"] = time.time() - self._snap_epoch_time
        if trainer.logger is not None:
            trainer.logger.log_metrics(logs, step=trainer.global_step)