import logging import os from pathlib import Path import numpy as np import pytorch_lightning as pl import torch from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.utilities import rank_zero_only from utils.utils_verbalisation_module import save_json from pytorch_lightning.utilities import rank_zero_info def count_trainable_parameters(model): model_parameters = filter(lambda p: p.requires_grad, model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) return params logger = logging.getLogger(__name__) class Seq2SeqLoggingCallback(pl.Callback): def on_batch_end(self, trainer, pl_module): lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)} pl_module.logger.log_metrics(lrs) @rank_zero_only def _write_logs( self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True ) -> None: logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****") metrics = trainer.callback_metrics #print(metrics.keys()) new_metrics = {} ms = ["log", "progress_bar", "preds"] for k, v in metrics.items(): ver = True for m in ms: if m in k: ver = False break if ver: new_metrics[k] = v print(new_metrics) trainer.logger.log_metrics(new_metrics) # Log results od = Path(pl_module.hparams.output_dir) if type_path == "test": results_file = od / "test_results.txt" generations_file = od / "test_generations.txt" else: # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json # If people want this it will be easy enough to add back. results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt" results_file.parent.mkdir(exist_ok=True) generations_file.parent.mkdir(exist_ok=True) with open(results_file, "a+") as writer: for key in sorted(metrics): if key in ["log", "progress_bar", "preds"]: continue try: val = metrics[key] if isinstance(val, torch.Tensor): val = val.item() msg = f"{key}: {val:.6f}\n" writer.write(msg) except: pass if not save_generations: return if "preds" in metrics: content = "\n".join(metrics["preds"]) generations_file.open("w+").write(content) @rank_zero_only def on_train_start(self, trainer, pl_module): try: npars = pl_module.model.model.num_parameters() except AttributeError: npars = pl_module.model.num_parameters() n_trainable_pars = count_trainable_parameters(pl_module) # mp stands for million parameters trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}) @rank_zero_only def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): save_json(pl_module.metrics, pl_module.metrics_save_path) return self._write_logs(trainer, pl_module, "test") @rank_zero_only def on_validation_end(self, trainer: pl.Trainer, pl_module): save_json(pl_module.metrics, pl_module.metrics_save_path) rank_zero_info("***** Validation results *****") metrics = trainer.callback_metrics # Log results for key in sorted(metrics): if key not in ["log", "progress_bar", "preds"]: rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) # Uncommenting this will save val generations # return self._write_logs(trainer, pl_module, "valid") def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False): """Saves the best model by validation ROUGE2 score.""" if metric == "rouge2": exp = "{val_avg_rouge2:.4f}-{step_count}" elif metric == "bleu": exp = "{val_avg_bleu:.4f}-{step_count}" elif metric == "loss": exp = "{val_avg_loss:.4f}-{step_count}" else: raise NotImplementedError( f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to this function." ) checkpoint_callback = ModelCheckpoint( filepath=os.path.join(output_dir, exp), monitor=f"val_{metric}", mode="min" if "loss" in metric else "max", save_top_k=save_top_k, period=0, # maybe save a checkpoint every time val is run, not just end of epoch. ) return checkpoint_callback def get_early_stopping_callback(metric, patience): return EarlyStopping( monitor=f"val_{metric}", # does this need avg? mode="min" if "loss" in metric else "max", patience=patience, verbose=True, )