DveloperY0115's picture
init repo
801501a
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
class LSTM(nn.Module):
def __init__(self, text_dim, embedding_dim, vocab_size, padding_idx=0):
super().__init__()
self.padding_idx = padding_idx
self.word_embedding = nn.Embedding(
vocab_size, embedding_dim, padding_idx=padding_idx
)
self.rnn = nn.LSTM(embedding_dim, text_dim, batch_first=True)
self.w_attn = nn.Parameter(torch.Tensor(1, text_dim))
nn.init.xavier_uniform_(self.w_attn)
def forward(self, padded_tokens, dropout=0.5):
w_emb = self.word_embedding(padded_tokens)
w_emb = F.dropout(w_emb, dropout, self.training)
len_seq = (padded_tokens != self.padding_idx).sum(dim=1).cpu()
x_packed = pack_padded_sequence(
w_emb, len_seq, enforce_sorted=False, batch_first=True
)
B = padded_tokens.shape[0]
rnn_out, _ = self.rnn(x_packed)
rnn_out, dummy = pad_packed_sequence(rnn_out, batch_first=True)
h = rnn_out[torch.arange(B), len_seq - 1]
final_feat, attn = self.word_attention(rnn_out, h, len_seq)
return final_feat, attn
def word_attention(self, R, h, len_seq):
"""
Input:
R: hidden states of the entire words
h: the final hidden state after processing the entire words
len_seq: the length of the sequence
Output:
final_feat: the final feature after the bilinear attention
attn: word attention weights
"""
B, N, D = R.shape
device = R.device
len_seq = len_seq.to(device)
W_attn = (self.w_attn * torch.eye(D).to(device))[None].repeat(B, 1, 1)
score = torch.bmm(torch.bmm(R, W_attn), h.unsqueeze(-1))
mask = torch.arange(N).reshape(1, N, 1).repeat(B, 1, 1).to(device)
mask = mask < len_seq.reshape(B, 1, 1)
score = score.masked_fill(mask == 0, -1e9)
attn = F.softmax(score, 1)
final_feat = torch.bmm(R.transpose(1, 2), attn).squeeze(-1)
return final_feat, attn.squeeze(-1)