Spaces:
Sleeping
Sleeping
File size: 3,735 Bytes
e45d058 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
# Adapted from https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/gpu_stats_monitor.html#GPUStatsMonitor
# We only need the speed monitoring, not the GPU monitoring
import time
from typing import Any
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT
class SpeedMonitor(Callback):
"""Monitor the speed of each step and each epoch.
"""
def __init__(self, intra_step_time: bool = True, inter_step_time: bool = True,
epoch_time: bool = True, verbose=False):
super().__init__()
self._log_stats = AttributeDict(
{
'intra_step_time': intra_step_time,
'inter_step_time': inter_step_time,
'epoch_time': epoch_time,
}
)
self.verbose = verbose
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._snap_epoch_time = None
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._snap_intra_step_time = None
self._snap_inter_step_time = None
self._snap_epoch_time = time.time()
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._snap_inter_step_time = None
def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._snap_inter_step_time = None
@rank_zero_only
def on_train_batch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
) -> None:
if self._log_stats.intra_step_time:
self._snap_intra_step_time = time.time()
if not trainer._logger_connector.should_update_logs:
return
logs = {}
if self._log_stats.inter_step_time and self._snap_inter_step_time:
# First log at beginning of second step
logs["time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000
if trainer.logger is not None:
trainer.logger.log_metrics(logs, step=trainer.global_step)
@rank_zero_only
def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
if self._log_stats.inter_step_time:
self._snap_inter_step_time = time.time()
if self.verbose and self._log_stats.intra_step_time and self._snap_intra_step_time:
pl_module.print(f"time/intra_step (ms): {(time.time() - self._snap_intra_step_time) * 1000}")
if not trainer._logger_connector.should_update_logs:
return
logs = {}
if self._log_stats.intra_step_time and self._snap_intra_step_time:
logs["time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000
if trainer.logger is not None:
trainer.logger.log_metrics(logs, step=trainer.global_step)
@rank_zero_only
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule",) -> None:
logs = {}
if self._log_stats.epoch_time and self._snap_epoch_time:
logs["time/epoch (s)"] = time.time() - self._snap_epoch_time
if trainer.logger is not None:
trainer.logger.log_metrics(logs, step=trainer.global_step)
|