vit-COVID-19-severity / LightningViTRegressor.py
ludolara's picture
Create LightningViTRegressor.py
5cd6bff
import lightning.pytorch as pl
import torchmetrics
from torch.optim import AdamW
from transformers import ViTForImageClassification
from torch import nn
from transformers.optimization import get_scheduler
class LightningViTRegressor(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224-in21k",
num_labels=1,
)
self.mse = torchmetrics.MeanSquaredError()
self.mae = torchmetrics.MeanAbsoluteError()
self.r2_score = torchmetrics.R2Score()
def common_step(self, step_type, batch, batch_idx):
x,y = batch
x = self.model(x)
x = x.logits
loss = nn.functional.mse_loss(x,y)
mean_squared_error = self.mse(x,y)
mean_absolute_error = self.mae(x,y)
r2_score = self.r2_score(x,y)
to_log = {step_type + "_loss": loss,
step_type + "_mse": mean_squared_error,
step_type + "_mae": mean_absolute_error,
step_type + '_r2_score': r2_score} # add more items if needed
self.log_dict(to_log)
return loss
def training_step(self, batch, batch_idx):
loss = self.common_step("train", batch, batch_idx)
return loss
def validation_step(self, batch, batch_idx):
loss = self.common_step("val", batch, batch_idx)
return loss
def test_step(self, batch, batch_idx):
loss = self.common_step("test", batch, batch_idx)
return loss
# def configure_optimizers(self):
# optimizer = optim.Adam(self.parameters(), lr = 1e-5)
# return optimizer
def configure_optimizers(self):
# optimizer = AdamW(optimizer_grouped_params, lr=self.hparams.lr, betas=(0.9, 0.999), eps=1e-7)
optimizer = AdamW(self.parameters(), lr = 1e-5)
# Configure learning rate scheduler.
scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=self.trainer.estimated_stepping_batches)
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
return [optimizer], [scheduler]