File size: 4,474 Bytes
2b4f5ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import pytorch_lightning as pl
import torch
from transformers.optimization import AdamW
import torchmetrics
class DualEncoderModule(pl.LightningModule):
def __init__(self, tokenizer, model, learning_rate=1e-3):
super().__init__()
self.tokenizer = tokenizer
self.model = model
self.learning_rate = learning_rate
self.train_acc = torchmetrics.Accuracy(
task="multiclass", num_classes=model.num_labels
)
self.val_acc = torchmetrics.Accuracy(
task="multiclass", num_classes=model.num_labels
)
self.test_acc = torchmetrics.Accuracy(
task="multiclass", num_classes=model.num_labels
)
def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs)
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr=self.learning_rate)
return optimizer
def training_step(self, batch, batch_idx):
pos_ids, pos_mask, neg_ids, neg_mask = batch
neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
neg_mask = neg_mask.view(-1, neg_mask.shape[-1])
pos_outputs = self(
pos_ids,
attention_mask=pos_mask,
labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
pos_ids.get_device()
),
)
neg_outputs = self(
neg_ids,
attention_mask=neg_mask,
labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
neg_ids.get_device()
),
)
loss_scale = 1.0
loss = pos_outputs.loss + loss_scale * neg_outputs.loss
pos_logits = pos_outputs.logits
pos_preds = torch.argmax(pos_logits, axis=1)
self.train_acc(
pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
)
neg_logits = neg_outputs.logits
neg_preds = torch.argmax(neg_logits, axis=1)
self.train_acc(
neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
pos_ids, pos_mask, neg_ids, neg_mask = batch
neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
neg_mask = neg_mask.view(-1, neg_mask.shape[-1])
pos_outputs = self(
pos_ids,
attention_mask=pos_mask,
labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
pos_ids.get_device()
),
)
neg_outputs = self(
neg_ids,
attention_mask=neg_mask,
labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
neg_ids.get_device()
),
)
loss_scale = 1.0
loss = pos_outputs.loss + loss_scale * neg_outputs.loss
pos_logits = pos_outputs.logits
pos_preds = torch.argmax(pos_logits, axis=1)
self.val_acc(
pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
)
neg_logits = neg_outputs.logits
neg_preds = torch.argmax(neg_logits, axis=1)
self.val_acc(
neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
)
self.log("val_acc", self.val_acc)
return {"loss": loss}
def test_step(self, batch, batch_idx):
pos_ids, pos_mask, neg_ids, neg_mask = batch
neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
neg_mask = neg_mask.view(-1, neg_mask.shape[-1])
pos_outputs = self(
pos_ids,
attention_mask=pos_mask,
labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
pos_ids.get_device()
),
)
neg_outputs = self(
neg_ids,
attention_mask=neg_mask,
labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
neg_ids.get_device()
),
)
pos_logits = pos_outputs.logits
pos_preds = torch.argmax(pos_logits, axis=1)
self.test_acc(
pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
)
neg_logits = neg_outputs.logits
neg_preds = torch.argmax(neg_logits, axis=1)
self.test_acc(
neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
)
self.log("test_acc", self.test_acc)
|