Spaces:
Sleeping
Sleeping
import torch | |
import lightning as L | |
import torchmetrics | |
class LightningModel(L.LightningModule): | |
def __init__(self, model, learning_rate, cosine_t_max, mode): | |
super().__init__() | |
self.learning_rate = learning_rate | |
self.cosine_t_max = cosine_t_max | |
self.model = model | |
self.example_input_array = torch.Tensor(1, 3, 32, 32) | |
self.mode = mode | |
self.save_hyperparameters(ignore=["model"]) | |
self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10) | |
self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10) | |
self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10) | |
def forward(self, x): | |
return self.model(x) | |
def _shared_step(self, batch): | |
features, true_labels = batch | |
logits = self(features) | |
loss = F.cross_entropy(logits, true_labels) | |
predicted_labels = torch.argmax(logits, dim=1) | |
return loss, true_labels, predicted_labels | |
def training_step(self, batch, batch_idx): | |
loss, true_labels, predicted_labels = self._shared_step(batch) | |
self.log("train_loss", loss) | |
self.train_acc(predicted_labels, true_labels) | |
self.log( | |
"train_acc", self.train_acc, prog_bar=True, on_epoch=True, on_step=False | |
) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
loss, true_labels, predicted_labels = self._shared_step(batch) | |
self.log("val_loss", loss, prog_bar=True) | |
self.val_acc(predicted_labels, true_labels) | |
self.log("val_acc", self.val_acc, prog_bar=True) | |
def test_step(self, batch, batch_idx): | |
loss, true_labels, predicted_labels = self._shared_step(batch) | |
self.test_acc(predicted_labels, true_labels) | |
self.log("test_acc", self.test_acc) | |
def configure_optimizers(self): | |
opt = torch.optim.SGD(self.parameters(), lr=self.learning_rate) | |
if self.mode == 'lrfind': | |
return opt | |
else: | |
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=self.cosine_t_max) # New! | |
return { | |
"optimizer": opt, | |
"lr_scheduler": { | |
"scheduler": sch, | |
"monitor": "train_loss", | |
"interval": "step", # step means "batch" here, default: epoch | |
"frequency": 1, # default | |
}, | |
} |