hafidhsoekma's picture
First commit
49bceed
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
import time
import numpy as np
import pytorch_lightning as pl
import torch.nn as nn
import torchmetrics as tm
from torch import optim
from utils import configs
from .backbone_model import CLIPModel, TorchModel
class ImageClassificationLightningModule(pl.LightningModule):
def __init__(
self,
num_classes: int = len(configs.CLASS_CHARACTERS) - 1,
learning_rate: float = 3e-4,
weight_decay: float = 0.0,
name_model: str = "resnet50",
freeze_model: bool = True,
pretrained_model: bool = True,
):
super().__init__()
self.num_classes = num_classes
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.freeze_model = freeze_model
self.pretrained_model = pretrained_model
self.name_model = name_model
self.criterion = (
nn.BCEWithLogitsLoss()
if self.num_classes in (1, 2)
else nn.CrossEntropyLoss()
)
self.create_models()
self.create_metrics_models()
def create_models(self):
if self.name_model != "clip":
self.model = TorchModel(
self.name_model,
self.freeze_model,
self.pretrained_model,
self.num_classes,
)
else:
self.model = CLIPModel(
configs.CLIP_NAME_MODEL,
self.freeze_model,
self.pretrained_model,
self.num_classes,
)
def create_metrics_models(self):
self.metrics_accuracy = tm.Accuracy(
num_classes=1 if self.num_classes in (1, 2) else self.num_classes,
average="macro",
task="multiclass",
)
self.metrics_precision = tm.Precision(
num_classes=1 if self.num_classes in (1, 2) else self.num_classes,
average="macro",
task="multiclass",
)
self.metrics_recall = tm.Recall(
num_classes=1 if self.num_classes in (1, 2) else self.num_classes,
average="macro",
task="multiclass",
)
self.metrics_f1 = tm.F1Score(
num_classes=1 if self.num_classes in (1, 2) else self.num_classes,
average="macro",
task="multiclass",
)
def configure_optimizers(self):
optimizer = optim.Adam(
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
)
lr_scheduler = optim.lr_scheduler.LambdaLR(
optimizer,
lr_lambda=lambda x: (((1 + np.cos(x * np.pi / 20)) / 2) ** 1.0) * 0.9 + 0.1,
)
return {
"optimizer": optimizer,
"lr_scheduler": lr_scheduler,
"monitor": "metrics_f1_score",
}
def forward(self, x):
output = self.model(x)
return output
def training_step(self, batch, batch_idx):
x, y = batch
y = y.unsqueeze(1).float() if self.num_classes in (1, 2) else y
start_time = time.perf_counter()
preds_y = self(x)
inference_time = time.perf_counter() - start_time
loss = self.criterion(preds_y, y)
self.metrics_accuracy(preds_y, y)
self.metrics_precision(preds_y, y)
self.metrics_recall(preds_y, y)
self.metrics_f1(preds_y, y)
self.log(
"metrics_accuracy",
self.metrics_accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
)
self.log(
"metrics_precision",
self.metrics_precision,
on_step=False,
on_epoch=True,
prog_bar=True,
)
self.log(
"metrics_recall",
self.metrics_recall,
on_step=False,
on_epoch=True,
prog_bar=True,
)
self.log(
"metrics_f1_score",
self.metrics_f1,
on_step=False,
on_epoch=True,
prog_bar=True,
)
self.log(
"metrics_inference_time",
inference_time,
on_step=False,
on_epoch=True,
prog_bar=True,
)
return loss