Spaces:
Sleeping
Sleeping
import subprocess | |
from pathlib import Path | |
from typing import List | |
import matplotlib.pyplot as plt | |
import seaborn as sn | |
import torch | |
import wandb | |
from pytorch_lightning import Callback, Trainer | |
from pytorch_lightning.loggers import LoggerCollection, WandbLogger | |
from pytorch_lightning.utilities import rank_zero_only | |
from sklearn import metrics | |
from sklearn.metrics import f1_score, precision_score, recall_score | |
def get_wandb_logger(trainer: Trainer) -> WandbLogger: | |
"""Safely get Weights&Biases logger from Trainer.""" | |
if trainer.fast_dev_run: | |
raise Exception( | |
"Cannot use wandb callbacks since pytorch lightning disables loggers in `fast_dev_run=true` mode." | |
) | |
if isinstance(trainer.logger, WandbLogger): | |
return trainer.logger | |
if isinstance(trainer.logger, LoggerCollection): | |
for logger in trainer.logger: | |
if isinstance(logger, WandbLogger): | |
return logger | |
raise Exception( | |
"You are using wandb related callback, but WandbLogger was not found for some reason..." | |
) | |
class WatchModel(Callback): | |
"""Make wandb watch model at the beginning of the run.""" | |
def __init__(self, log: str = "gradients", log_freq: int = 100): | |
self.log = log | |
self.log_freq = log_freq | |
def on_train_start(self, trainer, pl_module): | |
logger = get_wandb_logger(trainer=trainer) | |
logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) | |
class UploadCodeAsArtifact(Callback): | |
"""Upload all code files to wandb as an artifact, at the beginning of the run.""" | |
def __init__(self, code_dir: str, use_git: bool = True): | |
""" | |
Args: | |
code_dir: the code directory | |
use_git: if using git, then upload all files that are not ignored by git. | |
if not using git, then upload all '*.py' file | |
""" | |
self.code_dir = code_dir | |
self.use_git = use_git | |
def on_train_start(self, trainer, pl_module): | |
logger = get_wandb_logger(trainer=trainer) | |
experiment = logger.experiment | |
code = wandb.Artifact("project-source", type="code") | |
if self.use_git: | |
# get .git folder | |
# https://alexwlchan.net/2020/11/a-python-function-to-ignore-a-path-with-git-info-exclude/ | |
git_dir_path = Path( | |
subprocess.check_output(["git", "rev-parse", "--git-dir"]).strip().decode("utf8") | |
).resolve() | |
for path in Path(self.code_dir).resolve().rglob("*"): | |
if ( | |
path.is_file() | |
# ignore files in .git | |
and not str(path).startswith(str(git_dir_path)) # noqa: W503 | |
# ignore files ignored by git | |
and ( # noqa: W503 | |
subprocess.run(["git", "check-ignore", "-q", str(path)]).returncode == 1 | |
) | |
): | |
code.add_file(str(path), name=str(path.relative_to(self.code_dir))) | |
else: | |
for path in Path(self.code_dir).resolve().rglob("*.py"): | |
code.add_file(str(path), name=str(path.relative_to(self.code_dir))) | |
experiment.log_artifact(code) | |
class UploadCheckpointsAsArtifact(Callback): | |
"""Upload checkpoints to wandb as an artifact, at the end of run.""" | |
def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False): | |
self.ckpt_dir = ckpt_dir | |
self.upload_best_only = upload_best_only | |
def on_keyboard_interrupt(self, trainer, pl_module): | |
self.on_train_end(trainer, pl_module) | |
def on_train_end(self, trainer, pl_module): | |
logger = get_wandb_logger(trainer=trainer) | |
experiment = logger.experiment | |
ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") | |
if self.upload_best_only: | |
ckpts.add_file(trainer.checkpoint_callback.best_model_path) | |
else: | |
for path in Path(self.ckpt_dir).rglob("*.ckpt"): | |
ckpts.add_file(str(path)) | |
experiment.log_artifact(ckpts) | |
class LogConfusionMatrix(Callback): | |
"""Generate confusion matrix every epoch and send it to wandb. | |
Expects validation step to return predictions and targets. | |
""" | |
def __init__(self): | |
self.preds = [] | |
self.targets = [] | |
self.ready = True | |
def on_sanity_check_start(self, trainer, pl_module) -> None: | |
self.ready = False | |
def on_sanity_check_end(self, trainer, pl_module): | |
"""Start executing this callback only after all validation sanity checks end.""" | |
self.ready = True | |
def on_validation_batch_end( | |
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx | |
): | |
"""Gather data from single batch.""" | |
if self.ready: | |
self.preds.append(outputs["preds"]) | |
self.targets.append(outputs["targets"]) | |
def on_validation_epoch_end(self, trainer, pl_module): | |
"""Generate confusion matrix.""" | |
if self.ready: | |
logger = get_wandb_logger(trainer) | |
experiment = logger.experiment | |
preds = torch.cat(self.preds).cpu().numpy() | |
targets = torch.cat(self.targets).cpu().numpy() | |
confusion_matrix = metrics.confusion_matrix(y_true=targets, y_pred=preds) | |
# set figure size | |
plt.figure(figsize=(14, 8)) | |
# set labels size | |
sn.set(font_scale=1.4) | |
# set font size | |
sn.heatmap(confusion_matrix, annot=True, annot_kws={"size": 8}, fmt="g") | |
# names should be uniqe or else charts from different experiments in wandb will overlap | |
experiment.log({f"confusion_matrix/{experiment.name}": wandb.Image(plt)}, commit=False) | |
# according to wandb docs this should also work but it crashes | |
# experiment.log(f{"confusion_matrix/{experiment.name}": plt}) | |
# reset plot | |
plt.clf() | |
self.preds.clear() | |
self.targets.clear() | |
class LogF1PrecRecHeatmap(Callback): | |
"""Generate f1, precision, recall heatmap every epoch and send it to wandb. | |
Expects validation step to return predictions and targets. | |
""" | |
def __init__(self, class_names: List[str] = None): | |
self.preds = [] | |
self.targets = [] | |
self.ready = True | |
def on_sanity_check_start(self, trainer, pl_module): | |
self.ready = False | |
def on_sanity_check_end(self, trainer, pl_module): | |
"""Start executing this callback only after all validation sanity checks end.""" | |
self.ready = True | |
def on_validation_batch_end( | |
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx | |
): | |
"""Gather data from single batch.""" | |
if self.ready: | |
self.preds.append(outputs["preds"]) | |
self.targets.append(outputs["targets"]) | |
def on_validation_epoch_end(self, trainer, pl_module): | |
"""Generate f1, precision and recall heatmap.""" | |
if self.ready: | |
logger = get_wandb_logger(trainer=trainer) | |
experiment = logger.experiment | |
preds = torch.cat(self.preds).cpu().numpy() | |
targets = torch.cat(self.targets).cpu().numpy() | |
f1 = f1_score(targets, preds, average=None) | |
r = recall_score(targets, preds, average=None) | |
p = precision_score(targets, preds, average=None) | |
data = [f1, p, r] | |
# set figure size | |
plt.figure(figsize=(14, 3)) | |
# set labels size | |
sn.set(font_scale=1.2) | |
# set font size | |
sn.heatmap( | |
data, | |
annot=True, | |
annot_kws={"size": 10}, | |
fmt=".3f", | |
yticklabels=["F1", "Precision", "Recall"], | |
) | |
# names should be uniqe or else charts from different experiments in wandb will overlap | |
experiment.log({f"f1_p_r_heatmap/{experiment.name}": wandb.Image(plt)}, commit=False) | |
# reset plot | |
plt.clf() | |
self.preds.clear() | |
self.targets.clear() | |
class LogImagePredictions(Callback): | |
"""Logs a validation batch and their predictions to wandb. | |
Example adapted from: | |
https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY | |
""" | |
def __init__(self, num_samples: int = 8): | |
super().__init__() | |
self.num_samples = num_samples | |
self.ready = True | |
def on_sanity_check_start(self, trainer, pl_module): | |
self.ready = False | |
def on_sanity_check_end(self, trainer, pl_module): | |
"""Start executing this callback only after all validation sanity checks end.""" | |
self.ready = True | |
def on_validation_epoch_end(self, trainer, pl_module): | |
if self.ready: | |
logger = get_wandb_logger(trainer=trainer) | |
experiment = logger.experiment | |
# get a validation batch from the validation dat loader | |
val_samples = next(iter(trainer.datamodule.val_dataloader())) | |
val_imgs, val_labels = val_samples | |
# run the batch through the network | |
val_imgs = val_imgs.to(device=pl_module.device) | |
logits = pl_module(val_imgs) | |
preds = torch.argmax(logits, dim=-1) | |
# log the images as wandb Image | |
experiment.log( | |
{ | |
f"Images/{experiment.name}": [ | |
wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") | |
for x, pred, y in zip( | |
val_imgs[: self.num_samples], | |
preds[: self.num_samples], | |
val_labels[: self.num_samples], | |
) | |
] | |
} | |
) | |