seungheondoh
add model
e48ca55
raw
history blame
6.52 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .modules import AudioEncoder
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig
class BartCaptionModel(nn.Module):
def __init__(self, n_mels=128, num_of_conv=6, sr=16000, duration=10, max_length=128, label_smoothing=0.1, bart_type="facebook/bart-base", audio_dim=768):
super(BartCaptionModel, self).__init__()
# non-finetunning case
bart_config = BartConfig.from_pretrained(bart_type)
self.tokenizer = BartTokenizer.from_pretrained(bart_type)
self.bart = BartForConditionalGeneration(bart_config)
self.n_sample = sr * duration
self.hop_length = int(0.01 * sr) # hard coding hop_size
self.n_frames = int(self.n_sample // self.hop_length)
self.num_of_stride_conv = num_of_conv - 1
self.n_ctx = int(self.n_frames // 2**self.num_of_stride_conv) + 1
self.audio_encoder = AudioEncoder(
n_mels = n_mels, # hard coding n_mel
n_ctx = self.n_ctx,
audio_dim = audio_dim,
text_dim = self.bart.config.hidden_size,
num_of_stride_conv = self.num_of_stride_conv
)
self.max_length = max_length
self.loss_fct = nn.CrossEntropyLoss(label_smoothing= label_smoothing, ignore_index=-100)
@property
def device(self):
return list(self.parameters())[0].device
def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.ls
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
def forward_encoder(self, audio):
audio_embs = self.audio_encoder(audio)
encoder_outputs = self.bart.model.encoder(
input_ids=None,
inputs_embeds=audio_embs,
return_dict=True
)["last_hidden_state"]
return encoder_outputs, audio_embs
def forward_decoder(self, text, encoder_outputs):
text = self.tokenizer(text,
padding='longest',
truncation=True,
max_length=self.max_length,
return_tensors="pt")
input_ids = text["input_ids"].to(self.device)
attention_mask = text["attention_mask"].to(self.device)
decoder_targets = input_ids.masked_fill(
input_ids == self.tokenizer.pad_token_id, -100
)
decoder_input_ids = self.shift_tokens_right(
decoder_targets, self.bart.config.pad_token_id, self.bart.config.decoder_start_token_id
)
decoder_outputs = self.bart(
input_ids=None,
attention_mask=None,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=attention_mask,
inputs_embeds=None,
labels=None,
encoder_outputs=(encoder_outputs,),
return_dict=True
)
lm_logits = decoder_outputs["logits"]
loss = self.loss_fct(lm_logits.view(-1, self.tokenizer.vocab_size), decoder_targets.view(-1))
return loss
def forward(self, audio, text):
encoder_outputs, _ = self.forward_encoder(audio)
loss = self.forward_decoder(text, encoder_outputs)
return loss
def generate(self,
samples,
use_nucleus_sampling=False,
num_beams=5,
max_length=128,
min_length=2,
top_p=0.9,
repetition_penalty=1.0,
):
# self.bart.force_bos_token_to_be_generated = True
audio_embs = self.audio_encoder(samples)
encoder_outputs = self.bart.model.encoder(
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=audio_embs,
output_attentions=None,
output_hidden_states=None,
return_dict=True)
input_ids = torch.zeros((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
input_ids[:, 0] = self.bart.config.decoder_start_token_id
decoder_attention_mask = torch.ones((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
if use_nucleus_sampling:
outputs = self.bart.generate(
input_ids=None,
attention_mask=None,
decoder_input_ids=input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
max_length=max_length,
min_length=min_length,
do_sample=True,
top_p=top_p,
num_return_sequences=1,
repetition_penalty=1.1)
else:
outputs = self.bart.generate(input_ids=None,
attention_mask=None,
decoder_input_ids=input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
head_mask=None,
decoder_head_mask=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
repetition_penalty=repetition_penalty)
captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return captions