Spaces:
Build error
Build error
# -*- coding: utf-8 -*- | |
import random | |
import torch | |
import torch.nn as nn | |
from .base_model import CaptionModel | |
from .utils import repeat_tensor | |
import audio_to_text.captioning.models.decoder | |
class TransformerModel(CaptionModel): | |
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): | |
if not hasattr(self, "compatible_decoders"): | |
self.compatible_decoders = ( | |
audio_to_text.captioning.models.decoder.TransformerDecoder, | |
) | |
super().__init__(encoder, decoder, **kwargs) | |
def seq_forward(self, input_dict): | |
cap = input_dict["cap"] | |
cap_padding_mask = (cap == self.pad_idx).to(cap.device) | |
cap_padding_mask = cap_padding_mask[:, :-1] | |
output = self.decoder( | |
{ | |
"word": cap[:, :-1], | |
"attn_emb": input_dict["attn_emb"], | |
"attn_emb_len": input_dict["attn_emb_len"], | |
"cap_padding_mask": cap_padding_mask | |
} | |
) | |
return output | |
def prepare_decoder_input(self, input_dict, output): | |
decoder_input = { | |
"attn_emb": input_dict["attn_emb"], | |
"attn_emb_len": input_dict["attn_emb_len"] | |
} | |
t = input_dict["t"] | |
############### | |
# determine input word | |
################ | |
if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling | |
word = input_dict["cap"][:, :t+1] | |
else: | |
start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long() | |
if t == 0: | |
word = start_word | |
else: | |
word = torch.cat((start_word, output["seq"][:, :t]), dim=-1) | |
# word: [N, T] | |
decoder_input["word"] = word | |
cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device) | |
decoder_input["cap_padding_mask"] = cap_padding_mask | |
return decoder_input | |
def prepare_beamsearch_decoder_input(self, input_dict, output_i): | |
decoder_input = {} | |
t = input_dict["t"] | |
i = input_dict["sample_idx"] | |
beam_size = input_dict["beam_size"] | |
############### | |
# prepare attn embeds | |
################ | |
if t == 0: | |
attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size) | |
attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size) | |
output_i["attn_emb"] = attn_emb | |
output_i["attn_emb_len"] = attn_emb_len | |
decoder_input["attn_emb"] = output_i["attn_emb"] | |
decoder_input["attn_emb_len"] = output_i["attn_emb_len"] | |
############### | |
# determine input word | |
################ | |
start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long() | |
if t == 0: | |
word = start_word | |
else: | |
word = torch.cat((start_word, output_i["seq"]), dim=-1) | |
decoder_input["word"] = word | |
cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device) | |
decoder_input["cap_padding_mask"] = cap_padding_mask | |
return decoder_input | |
class M2TransformerModel(CaptionModel): | |
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): | |
if not hasattr(self, "compatible_decoders"): | |
self.compatible_decoders = ( | |
captioning.models.decoder.M2TransformerDecoder, | |
) | |
super().__init__(encoder, decoder, **kwargs) | |
self.check_encoder_compatibility() | |
def check_encoder_compatibility(self): | |
assert isinstance(self.encoder, captioning.models.encoder.M2TransformerEncoder), \ | |
f"only M2TransformerModel is compatible with {self.__class__.__name__}" | |
def seq_forward(self, input_dict): | |
cap = input_dict["cap"] | |
output = self.decoder( | |
{ | |
"word": cap[:, :-1], | |
"attn_emb": input_dict["attn_emb"], | |
"attn_emb_mask": input_dict["attn_emb_mask"], | |
} | |
) | |
return output | |
def prepare_decoder_input(self, input_dict, output): | |
decoder_input = { | |
"attn_emb": input_dict["attn_emb"], | |
"attn_emb_mask": input_dict["attn_emb_mask"] | |
} | |
t = input_dict["t"] | |
############### | |
# determine input word | |
################ | |
if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling | |
word = input_dict["cap"][:, :t+1] | |
else: | |
start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long() | |
if t == 0: | |
word = start_word | |
else: | |
word = torch.cat((start_word, output["seq"][:, :t]), dim=-1) | |
# word: [N, T] | |
decoder_input["word"] = word | |
return decoder_input | |
def prepare_beamsearch_decoder_input(self, input_dict, output_i): | |
decoder_input = {} | |
t = input_dict["t"] | |
i = input_dict["sample_idx"] | |
beam_size = input_dict["beam_size"] | |
############### | |
# prepare attn embeds | |
################ | |
if t == 0: | |
attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size) | |
attn_emb_mask = repeat_tensor(input_dict["attn_emb_mask"][i], beam_size) | |
output_i["attn_emb"] = attn_emb | |
output_i["attn_emb_mask"] = attn_emb_mask | |
decoder_input["attn_emb"] = output_i["attn_emb"] | |
decoder_input["attn_emb_mask"] = output_i["attn_emb_mask"] | |
############### | |
# determine input word | |
################ | |
start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long() | |
if t == 0: | |
word = start_word | |
else: | |
word = torch.cat((start_word, output_i["seq"]), dim=-1) | |
decoder_input["word"] = word | |
return decoder_input | |
class EventEncoder(nn.Module): | |
""" | |
Encode the Label information in AudioCaps and AudioSet | |
""" | |
def __init__(self, emb_dim, vocab_size=527): | |
super(EventEncoder, self).__init__() | |
self.label_embedding = nn.Parameter( | |
torch.randn((vocab_size, emb_dim)), requires_grad=True) | |
def forward(self, word_idxs): | |
indices = word_idxs / word_idxs.sum(dim=1, keepdim=True) | |
embeddings = indices @ self.label_embedding | |
return embeddings | |
class EventCondTransformerModel(TransformerModel): | |
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): | |
if not hasattr(self, "compatible_decoders"): | |
self.compatible_decoders = ( | |
captioning.models.decoder.EventTransformerDecoder, | |
) | |
super().__init__(encoder, decoder, **kwargs) | |
self.label_encoder = EventEncoder(decoder.emb_dim, 527) | |
self.train_forward_keys += ["events"] | |
self.inference_forward_keys += ["events"] | |
# def seq_forward(self, input_dict): | |
# cap = input_dict["cap"] | |
# cap_padding_mask = (cap == self.pad_idx).to(cap.device) | |
# cap_padding_mask = cap_padding_mask[:, :-1] | |
# output = self.decoder( | |
# { | |
# "word": cap[:, :-1], | |
# "attn_emb": input_dict["attn_emb"], | |
# "attn_emb_len": input_dict["attn_emb_len"], | |
# "cap_padding_mask": cap_padding_mask | |
# } | |
# ) | |
# return output | |
def prepare_decoder_input(self, input_dict, output): | |
decoder_input = super().prepare_decoder_input(input_dict, output) | |
decoder_input["events"] = self.label_encoder(input_dict["events"]) | |
return decoder_input | |
def prepare_beamsearch_decoder_input(self, input_dict, output_i): | |
decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i) | |
t = input_dict["t"] | |
i = input_dict["sample_idx"] | |
beam_size = input_dict["beam_size"] | |
if t == 0: | |
output_i["events"] = repeat_tensor(self.label_encoder(input_dict["events"])[i], beam_size) | |
decoder_input["events"] = output_i["events"] | |
return decoder_input | |
class KeywordCondTransformerModel(TransformerModel): | |
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs): | |
if not hasattr(self, "compatible_decoders"): | |
self.compatible_decoders = ( | |
captioning.models.decoder.KeywordProbTransformerDecoder, | |
) | |
super().__init__(encoder, decoder, **kwargs) | |
self.train_forward_keys += ["keyword"] | |
self.inference_forward_keys += ["keyword"] | |
def seq_forward(self, input_dict): | |
cap = input_dict["cap"] | |
cap_padding_mask = (cap == self.pad_idx).to(cap.device) | |
cap_padding_mask = cap_padding_mask[:, :-1] | |
keyword = input_dict["keyword"] | |
output = self.decoder( | |
{ | |
"word": cap[:, :-1], | |
"attn_emb": input_dict["attn_emb"], | |
"attn_emb_len": input_dict["attn_emb_len"], | |
"keyword": keyword, | |
"cap_padding_mask": cap_padding_mask | |
} | |
) | |
return output | |
def prepare_decoder_input(self, input_dict, output): | |
decoder_input = super().prepare_decoder_input(input_dict, output) | |
decoder_input["keyword"] = input_dict["keyword"] | |
return decoder_input | |
def prepare_beamsearch_decoder_input(self, input_dict, output_i): | |
decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i) | |
t = input_dict["t"] | |
i = input_dict["sample_idx"] | |
beam_size = input_dict["beam_size"] | |
if t == 0: | |
output_i["keyword"] = repeat_tensor(input_dict["keyword"][i], | |
beam_size) | |
decoder_input["keyword"] = output_i["keyword"] | |
return decoder_input | |