Spaces:
Runtime error
Runtime error
import pytorch_lightning as pl | |
from . import config | |
from .utils import ( | |
check_class_accuracy, | |
get_evaluation_bboxes, | |
mean_average_precision, | |
plot_couple_examples, | |
) | |
class PlotTestExamplesCallback(pl.Callback): | |
def __init__(self, every_n_epochs: int = 1) -> None: | |
super().__init__() | |
self.every_n_epochs = every_n_epochs | |
def on_train_epoch_end(self, trainer:pl.Trainer, pl_module:pl.LightningModule) -> None: | |
if (trainer.current_epoch + 1) % self.every_n_epochs == 0: | |
plot_couple_examples( | |
model=pl_module, | |
loader=trainer.datamodule.test_dataloader(), | |
thresh=0.6, | |
iou_thresh=0.5, | |
anchors=pl_module.scaled_anchors | |
) | |
class CheckClassAccuracyCallback(pl.Callback): | |
def __init__( | |
self, train_every_n_epochs: int = 1, test_every_n_epochs: int = 3 | |
) -> None: | |
super().__init__() | |
self.train_every_n_epochs = train_every_n_epochs | |
self.test_every_n_epochs = test_every_n_epochs | |
def on_train_epoch_end( | |
self, trainer: pl.Trainer, pl_module: pl.LightningModule | |
) -> None: | |
if (trainer.current_epoch + 1) % self.train_every_n_epochs == 0: | |
print("+++ TRAIN ACCURACIES") | |
class_acc, no_obj_acc, obj_acc = check_class_accuracy( | |
model=pl_module, | |
loader=trainer.datamodule.train_dataloader(), | |
threshold=config.CONF_THRESHOLD, | |
) | |
pl_module.log_dict( | |
{ | |
"train_class_acc": class_acc, | |
"train_no_obj_acc": no_obj_acc, | |
"train_obj_acc": obj_acc, | |
}, | |
logger=True, | |
) | |
if (trainer.current_epoch + 1) % self.test_every_n_epochs == 0: | |
print("+++ TEST ACCURACIES") | |
class_acc, no_obj_acc, obj_acc = check_class_accuracy( | |
model=pl_module, | |
loader=trainer.datamodule.test_dataloader(), | |
threshold=config.CONF_THRESHOLD, | |
) | |
pl_module.log_dict( | |
{ | |
"test_class_acc": class_acc, | |
"test_no_obj_acc": no_obj_acc, | |
"test_obj_acc": obj_acc, | |
}, | |
logger=True, | |
) | |
class MAPCallback(pl.Callback): | |
def __init__(self, every_n_epochs: int = 3) -> None: | |
super().__init__() | |
self.every_n_epochs = every_n_epochs | |
def on_train_epoch_end( | |
self, trainer: pl.Trainer, pl_module: pl.LightningModule | |
) -> None: | |
if (trainer.current_epoch + 1) % self.every_n_epochs == 0: | |
pred_boxes, true_boxes = get_evaluation_bboxes( | |
loader=trainer.datamodule.test_dataloader(), | |
model=pl_module, | |
iou_threshold=config.NMS_IOU_THRESH, | |
anchors=config.ANCHORS, | |
threshold=config.CONF_THRESHOLD, | |
device=config.DEVICE, | |
) | |
map_val = mean_average_precision( | |
pred_boxes=pred_boxes, | |
true_boxes=true_boxes, | |
iou_threshold=config.MAP_IOU_THRESH, | |
box_format="midpoint", | |
num_classes=config.NUM_CLASSES, | |
) | |
print("+++ MAP: ", map_val.item()) | |
pl_module.log( | |
"MAP", | |
map_val.item(), | |
logger=True, | |
) | |
pl_module.train() |