import lightning.pytorch as pl import torchmetrics from torch.optim import AdamW from transformers import ViTForImageClassification from torch import nn from transformers.optimization import get_scheduler class LightningViTRegressor(pl.LightningModule): def __init__(self): super().__init__() self.model = ViTForImageClassification.from_pretrained( "google/vit-base-patch16-224-in21k", num_labels=1, ) self.mse = torchmetrics.MeanSquaredError() self.mae = torchmetrics.MeanAbsoluteError() self.r2_score = torchmetrics.R2Score() def common_step(self, step_type, batch, batch_idx): x,y = batch x = self.model(x) x = x.logits loss = nn.functional.mse_loss(x,y) mean_squared_error = self.mse(x,y) mean_absolute_error = self.mae(x,y) r2_score = self.r2_score(x,y) to_log = {step_type + "_loss": loss, step_type + "_mse": mean_squared_error, step_type + "_mae": mean_absolute_error, step_type + '_r2_score': r2_score} # add more items if needed self.log_dict(to_log) return loss def training_step(self, batch, batch_idx): loss = self.common_step("train", batch, batch_idx) return loss def validation_step(self, batch, batch_idx): loss = self.common_step("val", batch, batch_idx) return loss def test_step(self, batch, batch_idx): loss = self.common_step("test", batch, batch_idx) return loss # def configure_optimizers(self): # optimizer = optim.Adam(self.parameters(), lr = 1e-5) # return optimizer def configure_optimizers(self): # optimizer = AdamW(optimizer_grouped_params, lr=self.hparams.lr, betas=(0.9, 0.999), eps=1e-7) optimizer = AdamW(self.parameters(), lr = 1e-5) # Configure learning rate scheduler. scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=self.trainer.estimated_stepping_batches) scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} return [optimizer], [scheduler]