|
import lightning.pytorch as pl |
|
import torchmetrics |
|
from torch.optim import AdamW |
|
from transformers import ViTForImageClassification |
|
from torch import nn |
|
from transformers.optimization import get_scheduler |
|
|
|
class LightningViTRegressor(pl.LightningModule): |
|
def __init__(self): |
|
super().__init__() |
|
self.model = ViTForImageClassification.from_pretrained( |
|
"google/vit-base-patch16-224-in21k", |
|
num_labels=1, |
|
) |
|
self.mse = torchmetrics.MeanSquaredError() |
|
self.mae = torchmetrics.MeanAbsoluteError() |
|
self.r2_score = torchmetrics.R2Score() |
|
|
|
def common_step(self, step_type, batch, batch_idx): |
|
x,y = batch |
|
x = self.model(x) |
|
x = x.logits |
|
loss = nn.functional.mse_loss(x,y) |
|
mean_squared_error = self.mse(x,y) |
|
mean_absolute_error = self.mae(x,y) |
|
r2_score = self.r2_score(x,y) |
|
to_log = {step_type + "_loss": loss, |
|
step_type + "_mse": mean_squared_error, |
|
step_type + "_mae": mean_absolute_error, |
|
step_type + '_r2_score': r2_score} |
|
self.log_dict(to_log) |
|
return loss |
|
|
|
def training_step(self, batch, batch_idx): |
|
loss = self.common_step("train", batch, batch_idx) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
loss = self.common_step("val", batch, batch_idx) |
|
return loss |
|
|
|
def test_step(self, batch, batch_idx): |
|
loss = self.common_step("test", batch, batch_idx) |
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
def configure_optimizers(self): |
|
|
|
optimizer = AdamW(self.parameters(), lr = 1e-5) |
|
|
|
scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=self.trainer.estimated_stepping_batches) |
|
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} |
|
return [optimizer], [scheduler] |