Do0rMaMu's picture
Upload folder using huggingface_hub
e45d058 verified
import subprocess
from pathlib import Path
from typing import List
import matplotlib.pyplot as plt
import seaborn as sn
import torch
import wandb
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import LoggerCollection, WandbLogger
from pytorch_lightning.utilities import rank_zero_only
from sklearn import metrics
from sklearn.metrics import f1_score, precision_score, recall_score
def get_wandb_logger(trainer: Trainer) -> WandbLogger:
"""Safely get Weights&Biases logger from Trainer."""
if trainer.fast_dev_run:
raise Exception(
"Cannot use wandb callbacks since pytorch lightning disables loggers in `fast_dev_run=true` mode."
)
if isinstance(trainer.logger, WandbLogger):
return trainer.logger
if isinstance(trainer.logger, LoggerCollection):
for logger in trainer.logger:
if isinstance(logger, WandbLogger):
return logger
raise Exception(
"You are using wandb related callback, but WandbLogger was not found for some reason..."
)
class WatchModel(Callback):
"""Make wandb watch model at the beginning of the run."""
def __init__(self, log: str = "gradients", log_freq: int = 100):
self.log = log
self.log_freq = log_freq
@rank_zero_only
def on_train_start(self, trainer, pl_module):
logger = get_wandb_logger(trainer=trainer)
logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq)
class UploadCodeAsArtifact(Callback):
"""Upload all code files to wandb as an artifact, at the beginning of the run."""
def __init__(self, code_dir: str, use_git: bool = True):
"""
Args:
code_dir: the code directory
use_git: if using git, then upload all files that are not ignored by git.
if not using git, then upload all '*.py' file
"""
self.code_dir = code_dir
self.use_git = use_git
@rank_zero_only
def on_train_start(self, trainer, pl_module):
logger = get_wandb_logger(trainer=trainer)
experiment = logger.experiment
code = wandb.Artifact("project-source", type="code")
if self.use_git:
# get .git folder
# https://alexwlchan.net/2020/11/a-python-function-to-ignore-a-path-with-git-info-exclude/
git_dir_path = Path(
subprocess.check_output(["git", "rev-parse", "--git-dir"]).strip().decode("utf8")
).resolve()
for path in Path(self.code_dir).resolve().rglob("*"):
if (
path.is_file()
# ignore files in .git
and not str(path).startswith(str(git_dir_path)) # noqa: W503
# ignore files ignored by git
and ( # noqa: W503
subprocess.run(["git", "check-ignore", "-q", str(path)]).returncode == 1
)
):
code.add_file(str(path), name=str(path.relative_to(self.code_dir)))
else:
for path in Path(self.code_dir).resolve().rglob("*.py"):
code.add_file(str(path), name=str(path.relative_to(self.code_dir)))
experiment.log_artifact(code)
class UploadCheckpointsAsArtifact(Callback):
"""Upload checkpoints to wandb as an artifact, at the end of run."""
def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False):
self.ckpt_dir = ckpt_dir
self.upload_best_only = upload_best_only
@rank_zero_only
def on_keyboard_interrupt(self, trainer, pl_module):
self.on_train_end(trainer, pl_module)
@rank_zero_only
def on_train_end(self, trainer, pl_module):
logger = get_wandb_logger(trainer=trainer)
experiment = logger.experiment
ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints")
if self.upload_best_only:
ckpts.add_file(trainer.checkpoint_callback.best_model_path)
else:
for path in Path(self.ckpt_dir).rglob("*.ckpt"):
ckpts.add_file(str(path))
experiment.log_artifact(ckpts)
class LogConfusionMatrix(Callback):
"""Generate confusion matrix every epoch and send it to wandb.
Expects validation step to return predictions and targets.
"""
def __init__(self):
self.preds = []
self.targets = []
self.ready = True
def on_sanity_check_start(self, trainer, pl_module) -> None:
self.ready = False
def on_sanity_check_end(self, trainer, pl_module):
"""Start executing this callback only after all validation sanity checks end."""
self.ready = True
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
):
"""Gather data from single batch."""
if self.ready:
self.preds.append(outputs["preds"])
self.targets.append(outputs["targets"])
def on_validation_epoch_end(self, trainer, pl_module):
"""Generate confusion matrix."""
if self.ready:
logger = get_wandb_logger(trainer)
experiment = logger.experiment
preds = torch.cat(self.preds).cpu().numpy()
targets = torch.cat(self.targets).cpu().numpy()
confusion_matrix = metrics.confusion_matrix(y_true=targets, y_pred=preds)
# set figure size
plt.figure(figsize=(14, 8))
# set labels size
sn.set(font_scale=1.4)
# set font size
sn.heatmap(confusion_matrix, annot=True, annot_kws={"size": 8}, fmt="g")
# names should be uniqe or else charts from different experiments in wandb will overlap
experiment.log({f"confusion_matrix/{experiment.name}": wandb.Image(plt)}, commit=False)
# according to wandb docs this should also work but it crashes
# experiment.log(f{"confusion_matrix/{experiment.name}": plt})
# reset plot
plt.clf()
self.preds.clear()
self.targets.clear()
class LogF1PrecRecHeatmap(Callback):
"""Generate f1, precision, recall heatmap every epoch and send it to wandb.
Expects validation step to return predictions and targets.
"""
def __init__(self, class_names: List[str] = None):
self.preds = []
self.targets = []
self.ready = True
def on_sanity_check_start(self, trainer, pl_module):
self.ready = False
def on_sanity_check_end(self, trainer, pl_module):
"""Start executing this callback only after all validation sanity checks end."""
self.ready = True
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
):
"""Gather data from single batch."""
if self.ready:
self.preds.append(outputs["preds"])
self.targets.append(outputs["targets"])
def on_validation_epoch_end(self, trainer, pl_module):
"""Generate f1, precision and recall heatmap."""
if self.ready:
logger = get_wandb_logger(trainer=trainer)
experiment = logger.experiment
preds = torch.cat(self.preds).cpu().numpy()
targets = torch.cat(self.targets).cpu().numpy()
f1 = f1_score(targets, preds, average=None)
r = recall_score(targets, preds, average=None)
p = precision_score(targets, preds, average=None)
data = [f1, p, r]
# set figure size
plt.figure(figsize=(14, 3))
# set labels size
sn.set(font_scale=1.2)
# set font size
sn.heatmap(
data,
annot=True,
annot_kws={"size": 10},
fmt=".3f",
yticklabels=["F1", "Precision", "Recall"],
)
# names should be uniqe or else charts from different experiments in wandb will overlap
experiment.log({f"f1_p_r_heatmap/{experiment.name}": wandb.Image(plt)}, commit=False)
# reset plot
plt.clf()
self.preds.clear()
self.targets.clear()
class LogImagePredictions(Callback):
"""Logs a validation batch and their predictions to wandb.
Example adapted from:
https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY
"""
def __init__(self, num_samples: int = 8):
super().__init__()
self.num_samples = num_samples
self.ready = True
def on_sanity_check_start(self, trainer, pl_module):
self.ready = False
def on_sanity_check_end(self, trainer, pl_module):
"""Start executing this callback only after all validation sanity checks end."""
self.ready = True
def on_validation_epoch_end(self, trainer, pl_module):
if self.ready:
logger = get_wandb_logger(trainer=trainer)
experiment = logger.experiment
# get a validation batch from the validation dat loader
val_samples = next(iter(trainer.datamodule.val_dataloader()))
val_imgs, val_labels = val_samples
# run the batch through the network
val_imgs = val_imgs.to(device=pl_module.device)
logits = pl_module(val_imgs)
preds = torch.argmax(logits, dim=-1)
# log the images as wandb Image
experiment.log(
{
f"Images/{experiment.name}": [
wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")
for x, pred, y in zip(
val_imgs[: self.num_samples],
preds[: self.num_samples],
val_labels[: self.num_samples],
)
]
}
)