vit-COVID-19-severity / LightningViTRegressor.py
ludolara's picture
Create LightningViTRegressor.py
5cd6bff
raw
history blame
2.24 kB
import lightning.pytorch as pl
import torchmetrics
from torch.optim import AdamW
from transformers import ViTForImageClassification
from torch import nn
from transformers.optimization import get_scheduler
class LightningViTRegressor(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224-in21k",
num_labels=1,
)
self.mse = torchmetrics.MeanSquaredError()
self.mae = torchmetrics.MeanAbsoluteError()
self.r2_score = torchmetrics.R2Score()
def common_step(self, step_type, batch, batch_idx):
x,y = batch
x = self.model(x)
x = x.logits
loss = nn.functional.mse_loss(x,y)
mean_squared_error = self.mse(x,y)
mean_absolute_error = self.mae(x,y)
r2_score = self.r2_score(x,y)
to_log = {step_type + "_loss": loss,
step_type + "_mse": mean_squared_error,
step_type + "_mae": mean_absolute_error,
step_type + '_r2_score': r2_score} # add more items if needed
self.log_dict(to_log)
return loss
def training_step(self, batch, batch_idx):
loss = self.common_step("train", batch, batch_idx)
return loss
def validation_step(self, batch, batch_idx):
loss = self.common_step("val", batch, batch_idx)
return loss
def test_step(self, batch, batch_idx):
loss = self.common_step("test", batch, batch_idx)
return loss
# def configure_optimizers(self):
# optimizer = optim.Adam(self.parameters(), lr = 1e-5)
# return optimizer
def configure_optimizers(self):
# optimizer = AdamW(optimizer_grouped_params, lr=self.hparams.lr, betas=(0.9, 0.999), eps=1e-7)
optimizer = AdamW(self.parameters(), lr = 1e-5)
# Configure learning rate scheduler.
scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=self.trainer.estimated_stepping_batches)
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
return [optimizer], [scheduler]