Spaces:
Sleeping
Sleeping
from lib import * | |
import contextlib | |
import io | |
import laion_clap | |
import torch | |
class AudioCaptioner(torch.nn.Module): | |
def get_dummy_token(self, batch_size: int) -> torch.Tensor: | |
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64) | |
def embed_waveform(self, waveform): | |
# compute the prefix | |
input_dict = { | |
'waveform': waveform # you can add more key-values | |
} | |
audio_embeds = self.clap_model.model.encode_audio( | |
input_dict, | |
device=waveform.device | |
) | |
# get BxD-dim embedding (last layer) D = 1024 -> 512 after audio projection | |
audio_embedding = torch.nn.functional.normalize(self.clap_model.model.audio_projection(audio_embeds['embedding']), dim=-1) | |
return audio_embedding | |
def create_prefix(self, waveform, batch_size): | |
if waveform is not None: | |
audio_embedding = self.embed_waveform(waveform) | |
else: | |
audio_embedding = torch.zeros(batch_size, self.prefix_size).cuda() | |
# project the prefix through map net and append it | |
prefix_projections = self.clip_project(audio_embedding).view(-1, self.prefix_length, self.gpt_embedding_size) | |
return prefix_projections | |
def forward(self, tokens: torch.Tensor, waveform: torch.Tensor, mask: Optional[torch.Tensor] = None, | |
labels: Optional[torch.Tensor] = None, freeze_gpt = False): | |
# embed the text | |
embedding_text = self.gpt.transformer.wte(tokens) | |
prefix_projections = self.create_prefix(waveform, tokens.shape[0]) | |
embedding_text = torch.cat((prefix_projections, embedding_text), dim=1) | |
# offset labels | |
if labels is not None: | |
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device) | |
labels = torch.cat((dummy_token, tokens), dim=1) | |
# push through GPT | |
if freeze_gpt: | |
with torch.no_grad(): | |
out = self.gpt(inputs_embeds=embedding_text, labels=labels, attention_mask=mask) | |
else: | |
out = self.gpt(inputs_embeds=embedding_text, labels=labels, attention_mask=mask) | |
return out | |
def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512, | |
num_layers: int = 8): | |
super(AudioCaptioner, self).__init__() | |
self.prefix_size = prefix_size | |
self.prefix_length = prefix_length | |
self.gpt = GPT2LMHeadModel.from_pretrained('gpt2') | |
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1] | |
self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, | |
self.gpt_embedding_size * prefix_length)) | |
self.clap_model = laion_clap.CLAP_Module( | |
enable_fusion=False, | |
amodel = 'HTSAT-base' | |
) | |
with contextlib.redirect_stdout(io.StringIO()): | |
self.clap_model.load_ckpt(ckpt = 'checkpoints/music_audioset_epoch_15_esc_90.14.pt') | |