RGMC / audiocaptioner.py
NikitaSrivatsan
Fixed typo in checkpoint name
8971856
from lib import *
import contextlib
import io
import laion_clap
import torch
class AudioCaptioner(torch.nn.Module):
def get_dummy_token(self, batch_size: int) -> torch.Tensor:
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64)
def embed_waveform(self, waveform):
# compute the prefix
input_dict = {
'waveform': waveform # you can add more key-values
}
audio_embeds = self.clap_model.model.encode_audio(
input_dict,
device=waveform.device
)
# get BxD-dim embedding (last layer) D = 1024 -> 512 after audio projection
audio_embedding = torch.nn.functional.normalize(self.clap_model.model.audio_projection(audio_embeds['embedding']), dim=-1)
return audio_embedding
def create_prefix(self, waveform, batch_size):
if waveform is not None:
audio_embedding = self.embed_waveform(waveform)
else:
audio_embedding = torch.zeros(batch_size, self.prefix_size).cuda()
# project the prefix through map net and append it
prefix_projections = self.clip_project(audio_embedding).view(-1, self.prefix_length, self.gpt_embedding_size)
return prefix_projections
def forward(self, tokens: torch.Tensor, waveform: torch.Tensor, mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None, freeze_gpt = False):
# embed the text
embedding_text = self.gpt.transformer.wte(tokens)
prefix_projections = self.create_prefix(waveform, tokens.shape[0])
embedding_text = torch.cat((prefix_projections, embedding_text), dim=1)
# offset labels
if labels is not None:
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
labels = torch.cat((dummy_token, tokens), dim=1)
# push through GPT
if freeze_gpt:
with torch.no_grad():
out = self.gpt(inputs_embeds=embedding_text, labels=labels, attention_mask=mask)
else:
out = self.gpt(inputs_embeds=embedding_text, labels=labels, attention_mask=mask)
return out
def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
num_layers: int = 8):
super(AudioCaptioner, self).__init__()
self.prefix_size = prefix_size
self.prefix_length = prefix_length
self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
self.gpt_embedding_size * prefix_length))
self.clap_model = laion_clap.CLAP_Module(
enable_fusion=False,
amodel = 'HTSAT-base'
)
with contextlib.redirect_stdout(io.StringIO()):
self.clap_model.load_ckpt(ckpt = 'checkpoints/music_audioset_epoch_15_esc_90.14.pt')