efficient_audio_captioning / models /transformer_decoder.py
wsntxxn
Add AudioCaps checkpoint
6065472
import math
import torch
import torch.nn as nn
from models import BaseDecoder
from utils.model_util import generate_length_mask, PositionalEncoding
from utils.train_util import merge_load_state_dict
class TransformerDecoder(BaseDecoder):
def __init__(self,
emb_dim,
vocab_size,
fc_emb_dim,
attn_emb_dim,
dropout,
freeze=False,
tie_weights=False,
**kwargs):
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout=dropout, tie_weights=tie_weights)
self.d_model = emb_dim
self.nhead = kwargs.get("nhead", self.d_model // 64)
self.nlayers = kwargs.get("nlayers", 2)
self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
self.pos_encoder = PositionalEncoding(self.d_model, dropout)
layer = nn.TransformerDecoderLayer(d_model=self.d_model,
nhead=self.nhead,
dim_feedforward=self.dim_feedforward,
dropout=dropout)
self.model = nn.TransformerDecoder(layer, self.nlayers)
self.classifier = nn.Linear(self.d_model, vocab_size, bias=False)
if tie_weights:
self.classifier.weight = self.word_embedding.weight
self.attn_proj = nn.Sequential(
nn.Linear(self.attn_emb_dim, self.d_model),
nn.ReLU(),
nn.Dropout(dropout),
nn.LayerNorm(self.d_model)
)
self.init_params()
self.freeze = freeze
if freeze:
for p in self.parameters():
p.requires_grad = False
def init_params(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def load_pretrained(self, pretrained, output_fn):
checkpoint = torch.load(pretrained, map_location="cpu")
if "model" in checkpoint:
checkpoint = checkpoint["model"]
if next(iter(checkpoint)).startswith("decoder."):
state_dict = {}
for k, v in checkpoint.items():
state_dict[k[8:]] = v
loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
if self.freeze:
for name, param in self.named_parameters():
if name in loaded_keys:
param.requires_grad = False
else:
param.requires_grad = True
def generate_square_subsequent_mask(self, max_length):
mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def forward(self, input_dict):
word = input_dict["word"]
attn_emb = input_dict["attn_emb"]
attn_emb_len = input_dict["attn_emb_len"]
cap_padding_mask = input_dict["cap_padding_mask"]
p_attn_emb = self.attn_proj(attn_emb)
p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
word = word.to(attn_emb.device)
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
embed = embed.transpose(0, 1) # [T, N, emb_dim]
embed = self.pos_encoder(embed)
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
tgt_key_padding_mask=cap_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
output = output.transpose(0, 1)
output = {
"embed": output,
"logit": self.classifier(output),
}
return output
class M2TransformerDecoder(BaseDecoder):
def __init__(self, vocab_size, fc_emb_dim, attn_emb_dim, dropout=0.1, **kwargs):
super().__init__(attn_emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout=dropout,)
try:
from m2transformer.models.transformer import MeshedDecoder
except:
raise ImportError("meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`")
del self.word_embedding
del self.in_dropout
self.d_model = attn_emb_dim
self.nhead = kwargs.get("nhead", self.d_model // 64)
self.nlayers = kwargs.get("nlayers", 2)
self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
self.model = MeshedDecoder(vocab_size, 100, self.nlayers, 0,
d_model=self.d_model,
h=self.nhead,
d_ff=self.dim_feedforward,
dropout=dropout)
self.init_params()
def init_params(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, input_dict):
word = input_dict["word"]
attn_emb = input_dict["attn_emb"]
attn_emb_mask = input_dict["attn_emb_mask"]
word = word.to(attn_emb.device)
embed, logit = self.model(word, attn_emb, attn_emb_mask)
output = {
"embed": embed,
"logit": logit,
}
return output
class EventTransformerDecoder(TransformerDecoder):
def forward(self, input_dict):
word = input_dict["word"] # index of word embeddings
attn_emb = input_dict["attn_emb"]
attn_emb_len = input_dict["attn_emb_len"]
cap_padding_mask = input_dict["cap_padding_mask"]
event_emb = input_dict["event"] # [N, emb_dim]
p_attn_emb = self.attn_proj(attn_emb)
p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
word = word.to(attn_emb.device)
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
embed = embed.transpose(0, 1) # [T, N, emb_dim]
embed += event_emb
embed = self.pos_encoder(embed)
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
tgt_key_padding_mask=cap_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
output = output.transpose(0, 1)
output = {
"embed": output,
"logit": self.classifier(output),
}
return output
class KeywordProbTransformerDecoder(TransformerDecoder):
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, keyword_classes_num, **kwargs):
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
dropout, **kwargs)
self.keyword_proj = nn.Linear(keyword_classes_num, self.d_model)
self.word_keyword_norm = nn.LayerNorm(self.d_model)
def forward(self, input_dict):
word = input_dict["word"] # index of word embeddings
attn_emb = input_dict["attn_emb"]
attn_emb_len = input_dict["attn_emb_len"]
cap_padding_mask = input_dict["cap_padding_mask"]
keyword = input_dict["keyword"] # [N, keyword_classes_num]
p_attn_emb = self.attn_proj(attn_emb)
p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
word = word.to(attn_emb.device)
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
embed = embed.transpose(0, 1) # [T, N, emb_dim]
embed += self.keyword_proj(keyword)
embed = self.word_keyword_norm(embed)
embed = self.pos_encoder(embed)
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
tgt_key_padding_mask=cap_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
output = output.transpose(0, 1)
output = {
"embed": output,
"logit": self.classifier(output),
}
return output