File size: 1,539 Bytes
a48216a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel,AutoTokenizer
class Elect(nn.Module):
def __init__(self,args,device):
super(Elect, self).__init__()
self.device = device
self.plm = AutoModel.from_pretrained(args.ckpt_dir)
self.hidden_size = self.plm.config.hidden_size
self.tokenizer = AutoTokenizer.from_pretrained(args.ckpt_dir)
self.clf = nn.Linear(self.hidden_size, len(args.labels))
self.dropout = nn.Dropout(0.3)
self.p2l = nn.Linear(self.hidden_size,256)
self.proj = nn.Linear(self.hidden_size*2,self.hidden_size)
self.l2a = nn.Linear(11,256)
self.la = nn.Parameter(torch.zeros(len(args.labels),self.hidden_size))
def forward(self, batch):
ids = batch['ids'].to(self.device, dtype=torch.long)
mask = batch['mask'].to(self.device, dtype=torch.long)
token_type_ids = batch['token_type_ids'].to(self.device, dtype=torch.long)
hidden_state = self.plm(input_ids=ids, attention_mask=mask)[0]
pooler = hidden_state[:, 0] # [batch_size, hidden_size]
pooler = self.dropout(pooler) # [batch_size, hidden_size]
attn = torch.softmax(pooler@(self.la.transpose(0,1)),dim=-1) # [batch_size, hidden_size]
art = [email protected] # [batch_size, hidden_size]
oa = F.relu(self.proj(torch.cat([art, pooler],dim=-1))) # [batch_size, hidden_size]
output = self.clf(oa) # [batch_size, len(labels)]
return output |