欧卫
'add_app_files'
58627fa
raw
history blame
2.61 kB
import torch
import torch.nn as nn
from transformers import ElectraPreTrainedModel, ElectraModel
class ElectraReader(ElectraPreTrainedModel):
def __init__(self, config, learn_labels=False):
super(ElectraReader, self).__init__(config)
self.electra = ElectraModel(config)
self.relevance = nn.Linear(config.hidden_size, 1)
if learn_labels:
self.linear = nn.Linear(config.hidden_size, 2)
else:
self.linear = nn.Linear(config.hidden_size, 1)
self.init_weights()
self.learn_labels = learn_labels
def forward(self, encoding):
outputs = self.electra(encoding.input_ids,
attention_mask=encoding.attention_mask,
token_type_ids=encoding.token_type_ids)[0]
scores = self.linear(outputs)
if self.learn_labels:
scores = scores[:, 0].squeeze(1)
else:
scores = scores.squeeze(-1)
candidates = (encoding.input_ids == 103)
scores = self._mask_2d_index(scores, candidates)
return scores
def _mask_2d_index(self, scores, mask):
bsize, maxlen = scores.size()
bsize_, maxlen_ = mask.size()
assert bsize == bsize_, (scores.size(), mask.size())
assert maxlen == maxlen_, (scores.size(), mask.size())
# Get flat scores corresponding to the True mask positions, with -inf at the end
flat_scores = scores[mask]
flat_scores = torch.cat((flat_scores, torch.ones(1, device=self.device) * float('-inf')))
# Get 2D indexes
rowidxs, nnzs = torch.unique(torch.nonzero(mask, as_tuple=False)[:, 0], return_counts=True)
max_nnzs = nnzs.max().item()
rows = [[-1] * max_nnzs for _ in range(bsize)]
offset = 0
for rowidx, nnz in zip(rowidxs.tolist(), nnzs.tolist()):
rows[rowidx] = [offset + i for i in range(nnz)]
rows[rowidx] += [-1] * (max_nnzs - len(rows[rowidx]))
offset += nnz
indexes = torch.tensor(rows).to(self.device)
# Index with the 2D indexes
scores_2d = flat_scores[indexes]
return scores_2d
def _2d_index(self, embeddings, positions):
bsize, maxlen, hdim = embeddings.size()
bsize_, max_out = positions.size()
assert bsize == bsize_
assert positions.max() < maxlen
embeddings = embeddings.view(bsize * maxlen, hdim)
positions = positions + torch.arange(bsize, device=positions.device).unsqueeze(-1) * maxlen
return embeddings[positions]