perin / utility /log.py
Larisa Kolesnichenko
Add the original perin code
1d5604f
raw
history blame contribute delete
No virus
6.32 kB
#!/usr/bin/env python3
# coding=utf-8
from utility.loading_bar import LoadingBar
import time
import torch
class Log:
def __init__(self, dataset, model, optimizer, args, directory, log_each: int, initial_epoch=-1, log_wandb=True):
self.dataset = dataset
self.model = model
self.args = args
self.optimizer = optimizer
self.loading_bar = LoadingBar(length=27)
self.best_f1_score = 0.0
self.log_each = log_each
self.epoch = initial_epoch
self.log_wandb = log_wandb
if self.log_wandb:
globals()["wandb"] = __import__("wandb") # ugly way to not require wandb if not needed
self.directory = directory
self.evaluation_results = f"{directory}/results_{{0}}_{{1}}.json"
self.full_evaluation_results = f"{directory}/full_results_{{0}}_{{1}}.json"
self.best_full_evaluation_results = f"{directory}/best_full_results_{{0}}_{{1}}.json"
self.result_history = {epoch: {} for epoch in range(args.epochs)}
self.best_checkpoint_filename = f"{self.directory}/best_checkpoint.h5"
self.last_checkpoint_filename = f"{self.directory}/last_checkpoint.h5"
self.step = 0
self.total_batch_size = 0
self.flushed = True
def train(self, len_dataset: int) -> None:
self.flush()
self.epoch += 1
if self.epoch == 0:
self._print_header()
self.is_train = True
self._reset(len_dataset)
def eval(self, len_dataset: int) -> None:
self.flush()
self.is_train = False
self._reset(len_dataset)
def __call__(self, batch_size, losses, grad_norm: float = None, learning_rates: float = None,) -> None:
if self.is_train:
self._train_step(batch_size, losses, grad_norm, learning_rates)
else:
self._eval_step(batch_size, losses)
self.flushed = False
def flush(self) -> None:
if self.flushed:
return
self.flushed = True
if self.is_train:
print(f"\r┃{self.epoch:12d} ┃{self._time():>12} β”‚", end="", flush=True)
else:
if self.losses is not None and self.log_wandb:
dictionary = {f"validation/{key}": value / self.step for key, value in self.losses.items()}
dictionary["epoch"] = self.epoch
wandb.log(dictionary)
self.losses = None
# self._save_model(save_as_best=False, performance=None)
def log_evaluation(self, scores, mode, epoch):
f1_score = scores["sentiment_tuple/f1"]
if self.log_wandb:
scores = {f"{mode}/{k}": v for k, v in scores.items()}
wandb.log({
"epoch": epoch,
**scores
})
if mode == "validation" and f1_score > self.best_f1_score:
if self.log_wandb:
wandb.run.summary["best sentiment tuple f1 score"] = f1_score
self.best_f1_score = f1_score
self._save_model(save_as_best=True, f1_score=f1_score)
def _save_model(self, save_as_best: bool, f1_score: float):
if not self.args.save_checkpoints:
return
state = {
"epoch": self.epoch,
"dataset": self.dataset.state_dict(),
"f1_score": f1_score,
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"args": self.args.state_dict(),
}
filename = self.best_checkpoint_filename if save_as_best else self.last_checkpoint_filename
torch.save(state, filename)
if self.log_wandb:
wandb.save(filename)
def _train_step(self, batch_size, losses, grad_norm: float, learning_rates) -> None:
self.total_batch_size += batch_size
self.step += 1
if self.losses is None:
self.losses = losses
else:
for key, values in losses.items():
if key not in self.losses:
self.losses[key] = losses[key]
continue
self.losses[key] += losses[key]
if self.step % self.log_each == 0:
progress = self.total_batch_size / self.len_dataset
print(f"\r┃{self.epoch:12d} β”‚{self._time():>12} {self.loading_bar(progress)}", end="", flush=True)
if self.log_wandb:
dictionary = {f"train/{key}" if not key.startswith("weight/") else key: value / self.log_each for key, value in self.losses.items()}
dictionary["epoch"] = self.epoch
dictionary["learning_rate/encoder"] = learning_rates[0]
dictionary["learning_rate/decoder"] = learning_rates[-2]
dictionary["learning_rate/grad_norm"] = learning_rates[-1]
dictionary["gradient norm"] = grad_norm
wandb.log(dictionary)
self.losses = None
def _eval_step(self, batch_size, losses) -> None:
self.step += 1
if self.losses is None:
self.losses = losses
else:
for key, values in losses.items():
if key not in self.losses:
self.losses[key] = losses[key]
continue
self.losses[key] += losses[key]
def _reset(self, len_dataset: int) -> None:
self.start_time = time.time()
self.step = 0
self.total_batch_size = 0
self.len_dataset = len_dataset
self.losses = None
def _time(self) -> str:
time_seconds = int(time.time() - self.start_time)
return f"{time_seconds // 60:02d}:{time_seconds % 60:02d} min"
def _print_header(self) -> None:
print(f"┏━━━━━━━━━━━━━━┳━━━╸Sβ•Ίβ•ΈEβ•Ίβ•ΈMβ•Ίβ•ΈAβ•Ίβ•ΈNβ•Ίβ•ΈTβ•Ίβ•ΈIβ•Ίβ•ΈSβ•Ίβ•ΈK╺━━━━━━━━━━━━━━┓")
print(f"┃ ┃ β•· ┃")
print(f"┃ epoch ┃ elapsed β”‚ progress bar ┃")
print(f"┠──────────────╂──────────────┼─────────────────────────────┨")