Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,342 Bytes
e73da9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import torch
import torch.nn as nn
from audioldm.clap.open_clip import create_model
from audioldm.clap.training.data import get_audio_features
import torchaudio
from transformers import RobertaTokenizer
import torch.nn.functional as F
class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
def __init__(
self,
pretrained_path="",
key="class",
sampling_rate=16000,
embed_mode="audio",
amodel = "HTSAT-tiny",
unconditional_prob=0.1,
random_mute=False,
max_random_mute_portion=0.5,
training_mode=True,
):
super().__init__()
self.key = key
self.device = "cpu"
self.precision = "fp32"
self.amodel = amodel # or 'PANN-14'
self.tmodel = "roberta" # the best text encoder in our training
self.enable_fusion = False # False if you do not want to use the fusion model
self.fusion_type = "aff_2d"
self.pretrained = pretrained_path
self.embed_mode = embed_mode
self.embed_mode_orig = embed_mode
self.sampling_rate = sampling_rate
self.unconditional_prob = unconditional_prob
self.random_mute = random_mute
self.tokenize = RobertaTokenizer.from_pretrained("roberta-base")
self.max_random_mute_portion = max_random_mute_portion
self.training_mode = training_mode
self.model, self.model_cfg = create_model(
self.amodel,
self.tmodel,
self.pretrained,
precision=self.precision,
device=self.device,
enable_fusion=self.enable_fusion,
fusion_type=self.fusion_type,
)
for p in self.model.parameters():
p.requires_grad = False
self.model.eval()
def get_unconditional_condition(self, batchsize):
self.unconditional_token = self.model.get_text_embedding(
self.tokenizer(["", ""])
)[0:1]
return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0)
def batch_to_list(self, batch):
ret = []
for i in range(batch.size(0)):
ret.append(batch[i])
return ret
def make_decision(self, probability):
if float(torch.rand(1)) < probability:
return True
else:
return False
def random_uniform(self, start, end):
val = torch.rand(1).item()
return start + (end - start) * val
def _random_mute(self, waveform):
# waveform: [bs, t-steps]
t_steps = waveform.size(-1)
for i in range(waveform.size(0)):
mute_size = int(
self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion))
)
mute_start = int(self.random_uniform(0, t_steps - mute_size))
waveform[i, mute_start : mute_start + mute_size] = 0
return waveform
def cos_similarity(self, waveform, text):
# waveform: [bs, t_steps]
with torch.no_grad():
self.embed_mode = "audio"
print(text)
audio_emb = self(waveform.cuda())
self.embed_mode = "text"
text_emb = self(text)
similarity = F.cosine_similarity(audio_emb, text_emb, dim=2)
return similarity.squeeze()
def forward(self, batch, key=None):
# If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0
# If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0
if self.model.training == True and not self.training_mode:
print(
"The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters."
)
self.model, self.model_cfg = create_model(
self.amodel,
self.tmodel,
self.pretrained,
precision=self.precision,
device="cuda",
enable_fusion=self.enable_fusion,
fusion_type=self.fusion_type,
)
for p in self.model.parameters():
p.requires_grad = False
self.model.eval()
# the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
if self.embed_mode == "audio":
with torch.no_grad():
audio_dict_list = []
assert (
self.sampling_rate == 16000
), "We only support 16000 sampling rate"
if self.random_mute:
batch = self._random_mute(batch)
# batch: [bs, 1, t-samples]
batch = torchaudio.functional.resample(
batch, orig_freq=self.sampling_rate, new_freq=48000
)
for waveform in self.batch_to_list(batch):
audio_dict = {}
audio_dict = get_audio_features(
audio_dict,
waveform,
480000,
data_truncating="fusion",
data_filling="repeatpad",
audio_cfg=self.model_cfg["audio_cfg"],
)
audio_dict_list.append(audio_dict)
# [bs, 512]
embed = self.model.get_audio_embedding(audio_dict_list)
elif self.embed_mode == "text":
with torch.no_grad():
# the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
text_data = self.tokenizer(batch)
embed = self.model.get_text_embedding(text_data)
embed = embed.unsqueeze(1)
self.unconditional_token = self.model.get_text_embedding(
self.tokenizer(["", ""])
)[0:1]
for i in range(embed.size(0)):
if self.make_decision(self.unconditional_prob):
embed[i] = self.unconditional_token
# [bs, 1, 512]
return embed.detach()
def tokenizer(self, text):
result = self.tokenize(
text,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt",
)
return {k: v.squeeze(0) for k, v in result.items()}
|