Spaces:
Sleeping
Sleeping
from typing import Any | |
from pytorch_lightning import Callback, Trainer, LightningModule | |
from pytorch_lightning.utilities import rank_zero_only | |
from pytorch_lightning.utilities.parsing import AttributeDict | |
class ParamsLog(Callback): | |
"""Log the number of parameters of the model | |
""" | |
def __init__(self, total_params_log: bool = True, trainable_params_log: bool = True, | |
non_trainable_params_log: bool = True): | |
super().__init__() | |
self._log_stats = AttributeDict( | |
{ | |
'total_params_log': total_params_log, | |
'trainable_params_log': trainable_params_log, | |
'non_trainable_params_log': non_trainable_params_log, | |
} | |
) | |
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: | |
logs = {} | |
if self._log_stats.total_params_log: | |
logs["model/params_total"] = sum(p.numel() for p in pl_module.parameters()) | |
if self._log_stats.trainable_params_log: | |
logs["model/params_trainable"] = sum(p.numel() for p in pl_module.parameters() | |
if p.requires_grad) | |
if self._log_stats.non_trainable_params_log: | |
logs["model/params_not_trainable"] = sum(p.numel() for p in pl_module.parameters() | |
if not p.requires_grad) | |
if trainer.logger is not None: | |
trainer.logger.log_hyperparams(logs) | |