Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | |
class LSTM(nn.Module): | |
def __init__(self, text_dim, embedding_dim, vocab_size, padding_idx=0): | |
super().__init__() | |
self.padding_idx = padding_idx | |
self.word_embedding = nn.Embedding( | |
vocab_size, embedding_dim, padding_idx=padding_idx | |
) | |
self.rnn = nn.LSTM(embedding_dim, text_dim, batch_first=True) | |
self.w_attn = nn.Parameter(torch.Tensor(1, text_dim)) | |
nn.init.xavier_uniform_(self.w_attn) | |
def forward(self, padded_tokens, dropout=0.5): | |
w_emb = self.word_embedding(padded_tokens) | |
w_emb = F.dropout(w_emb, dropout, self.training) | |
len_seq = (padded_tokens != self.padding_idx).sum(dim=1).cpu() | |
x_packed = pack_padded_sequence( | |
w_emb, len_seq, enforce_sorted=False, batch_first=True | |
) | |
B = padded_tokens.shape[0] | |
rnn_out, _ = self.rnn(x_packed) | |
rnn_out, dummy = pad_packed_sequence(rnn_out, batch_first=True) | |
h = rnn_out[torch.arange(B), len_seq - 1] | |
final_feat, attn = self.word_attention(rnn_out, h, len_seq) | |
return final_feat, attn | |
def word_attention(self, R, h, len_seq): | |
""" | |
Input: | |
R: hidden states of the entire words | |
h: the final hidden state after processing the entire words | |
len_seq: the length of the sequence | |
Output: | |
final_feat: the final feature after the bilinear attention | |
attn: word attention weights | |
""" | |
B, N, D = R.shape | |
device = R.device | |
len_seq = len_seq.to(device) | |
W_attn = (self.w_attn * torch.eye(D).to(device))[None].repeat(B, 1, 1) | |
score = torch.bmm(torch.bmm(R, W_attn), h.unsqueeze(-1)) | |
mask = torch.arange(N).reshape(1, N, 1).repeat(B, 1, 1).to(device) | |
mask = mask < len_seq.reshape(B, 1, 1) | |
score = score.masked_fill(mask == 0, -1e9) | |
attn = F.softmax(score, 1) | |
final_feat = torch.bmm(R.transpose(1, 2), attn).squeeze(-1) | |
return final_feat, attn.squeeze(-1) | |