|
import torch |
|
import torch.distributed.nn |
|
from torch import distributed as dist, nn as nn |
|
from torch.nn import functional as F |
|
import numpy as np |
|
from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score |
|
|
|
try: |
|
import horovod.torch as hvd |
|
except ImportError: |
|
hvd = None |
|
|
|
|
|
def gather_features( |
|
audio_features, |
|
text_features, |
|
audio_features_mlp=None, |
|
text_features_mlp=None, |
|
local_loss=False, |
|
gather_with_grad=False, |
|
rank=0, |
|
world_size=1, |
|
use_horovod=False, |
|
mlp_loss=False, |
|
): |
|
if use_horovod: |
|
assert hvd is not None, "Please install horovod" |
|
if gather_with_grad: |
|
all_audio_features = hvd.allgather(audio_features) |
|
all_text_features = hvd.allgather(text_features) |
|
if mlp_loss: |
|
all_audio_features_mlp = hvd.allgather(audio_features_mlp) |
|
all_text_features_mlp = hvd.allgather(text_features_mlp) |
|
else: |
|
with torch.no_grad(): |
|
all_audio_features = hvd.allgather(audio_features) |
|
all_text_features = hvd.allgather(text_features) |
|
if mlp_loss: |
|
all_audio_features_mlp = hvd.allgather(audio_features_mlp) |
|
all_text_features_mlp = hvd.allgather(text_features_mlp) |
|
if not local_loss: |
|
|
|
gathered_audio_features = list( |
|
all_audio_features.chunk(world_size, dim=0) |
|
) |
|
gathered_text_features = list( |
|
all_text_features.chunk(world_size, dim=0) |
|
) |
|
gathered_audio_features[rank] = audio_features |
|
gathered_text_features[rank] = text_features |
|
all_audio_features = torch.cat(gathered_audio_features, dim=0) |
|
all_text_features = torch.cat(gathered_text_features, dim=0) |
|
if mlp_loss: |
|
gathered_audio_features_mlp = list( |
|
all_audio_features_mlp.chunk(world_size, dim=0) |
|
) |
|
gathered_text_features_mlp = list( |
|
all_text_features_mlp.chunk(world_size, dim=0) |
|
) |
|
gathered_audio_features_mlp[rank] = audio_features_mlp |
|
gathered_text_features_mlp[rank] = text_features_mlp |
|
all_audio_features_mlp = torch.cat( |
|
gathered_audio_features_mlp, dim=0 |
|
) |
|
all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) |
|
else: |
|
|
|
if gather_with_grad: |
|
all_audio_features = torch.cat( |
|
torch.distributed.nn.all_gather(audio_features), dim=0 |
|
) |
|
all_text_features = torch.cat( |
|
torch.distributed.nn.all_gather(text_features), dim=0 |
|
) |
|
if mlp_loss: |
|
all_audio_features_mlp = torch.cat( |
|
torch.distributed.nn.all_gather(audio_features_mlp), dim=0 |
|
) |
|
all_text_features_mlp = torch.cat( |
|
torch.distributed.nn.all_gather(text_features_mlp), dim=0 |
|
) |
|
else: |
|
gathered_audio_features = [ |
|
torch.zeros_like(audio_features) for _ in range(world_size) |
|
] |
|
gathered_text_features = [ |
|
torch.zeros_like(text_features) for _ in range(world_size) |
|
] |
|
dist.all_gather(gathered_audio_features, audio_features) |
|
dist.all_gather(gathered_text_features, text_features) |
|
if mlp_loss: |
|
gathered_audio_features_mlp = [ |
|
torch.zeros_like(audio_features_mlp) for _ in range(world_size) |
|
] |
|
gathered_text_features_mlp = [ |
|
torch.zeros_like(text_features_mlp) for _ in range(world_size) |
|
] |
|
dist.all_gather(gathered_audio_features_mlp, audio_features_mlp) |
|
dist.all_gather(gathered_text_features_mlp, text_features_mlp) |
|
if not local_loss: |
|
|
|
gathered_audio_features[rank] = audio_features |
|
gathered_text_features[rank] = text_features |
|
if mlp_loss: |
|
gathered_audio_features_mlp[rank] = audio_features_mlp |
|
gathered_text_features_mlp[rank] = text_features_mlp |
|
|
|
all_audio_features = torch.cat(gathered_audio_features, dim=0) |
|
all_text_features = torch.cat(gathered_text_features, dim=0) |
|
if mlp_loss: |
|
all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0) |
|
all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) |
|
if mlp_loss: |
|
return ( |
|
all_audio_features, |
|
all_text_features, |
|
all_audio_features_mlp, |
|
all_text_features_mlp, |
|
) |
|
else: |
|
return all_audio_features, all_text_features |
|
|
|
|
|
class ClipLoss(nn.Module): |
|
def __init__( |
|
self, |
|
local_loss=False, |
|
gather_with_grad=False, |
|
cache_labels=False, |
|
rank=0, |
|
world_size=1, |
|
use_horovod=False, |
|
mlp_loss=False, |
|
weight_loss_kappa=0, |
|
): |
|
super().__init__() |
|
self.local_loss = local_loss |
|
self.gather_with_grad = gather_with_grad |
|
self.cache_labels = cache_labels |
|
self.rank = rank |
|
self.world_size = world_size |
|
self.use_horovod = use_horovod |
|
self.mlp_loss = mlp_loss |
|
self.weighted_loss = bool(weight_loss_kappa != 0) |
|
self.weight_loss_kappa = weight_loss_kappa |
|
|
|
self.prev_num_logits = 0 |
|
self.labels = {} |
|
|
|
def forward( |
|
self, |
|
audio_features, |
|
text_features, |
|
logit_scale_a, |
|
logit_scale_t=None, |
|
audio_features_mlp=None, |
|
text_features_mlp=None, |
|
): |
|
device = audio_features.device |
|
if self.mlp_loss: |
|
if self.world_size > 1: |
|
( |
|
all_audio_features, |
|
all_text_features, |
|
all_audio_features_mlp, |
|
all_text_features_mlp, |
|
) = gather_features( |
|
audio_features=audio_features, |
|
text_features=text_features, |
|
audio_features_mlp=audio_features_mlp, |
|
text_features_mlp=text_features_mlp, |
|
local_loss=self.local_loss, |
|
gather_with_grad=self.gather_with_grad, |
|
rank=self.rank, |
|
world_size=self.world_size, |
|
use_horovod=self.use_horovod, |
|
mlp_loss=self.mlp_loss, |
|
) |
|
if self.local_loss: |
|
a_logits_per_audio = ( |
|
logit_scale_a * audio_features @ all_text_features_mlp.T |
|
) |
|
a_logits_per_text = ( |
|
logit_scale_a * text_features_mlp @ all_audio_features.T |
|
) |
|
t_logits_per_audio = ( |
|
logit_scale_t * audio_features_mlp @ all_text_features.T |
|
) |
|
t_logits_per_text = ( |
|
logit_scale_t * text_features @ all_audio_features_mlp.T |
|
) |
|
else: |
|
a_logits_per_audio = ( |
|
logit_scale_a * all_audio_features @ all_text_features_mlp.T |
|
) |
|
a_logits_per_text = a_logits_per_audio.T |
|
t_logits_per_audio = ( |
|
logit_scale_t * all_audio_features_mlp @ all_text_features.T |
|
) |
|
t_logits_per_text = t_logits_per_audio.T |
|
else: |
|
a_logits_per_audio = ( |
|
logit_scale_a * audio_features @ text_features_mlp.T |
|
) |
|
a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T |
|
t_logits_per_audio = ( |
|
logit_scale_t * audio_features_mlp @ text_features.T |
|
) |
|
t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T |
|
|
|
|
|
num_logits = a_logits_per_audio.shape[0] |
|
if self.prev_num_logits != num_logits or device not in self.labels: |
|
labels = torch.arange(num_logits, device=device, dtype=torch.long) |
|
if self.world_size > 1 and self.local_loss: |
|
labels = labels + num_logits * self.rank |
|
if self.cache_labels: |
|
self.labels[device] = labels |
|
self.prev_num_logits = num_logits |
|
else: |
|
labels = self.labels[device] |
|
|
|
if not self.weighted_loss: |
|
total_loss = ( |
|
F.cross_entropy(a_logits_per_audio, labels) |
|
+ F.cross_entropy(a_logits_per_text, labels) |
|
+ F.cross_entropy(t_logits_per_audio, labels) |
|
+ F.cross_entropy(t_logits_per_text, labels) |
|
) / 4 |
|
else: |
|
audio_weight = (audio_features @ audio_features.T).detach() |
|
audio_weight = ( |
|
torch.exp( |
|
torch.sum(audio_weight, axis=1) |
|
/ (self.weight_loss_kappa * len(audio_weight)) |
|
) |
|
).detach() |
|
text_weight = (text_features @ text_features.T).detach() |
|
text_weight = ( |
|
torch.exp( |
|
torch.sum(text_weight, axis=1) |
|
/ (self.weight_loss_kappa * len(text_features)) |
|
) |
|
).detach() |
|
total_loss = ( |
|
F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight) |
|
+ F.cross_entropy(a_logits_per_text, labels, weight=audio_weight) |
|
+ F.cross_entropy(t_logits_per_audio, labels, weight=text_weight) |
|
+ F.cross_entropy(t_logits_per_text, labels, weight=text_weight) |
|
) / 4 |
|
else: |
|
if self.world_size > 1: |
|
all_audio_features, all_text_features = gather_features( |
|
audio_features=audio_features, |
|
text_features=text_features, |
|
local_loss=self.local_loss, |
|
gather_with_grad=self.gather_with_grad, |
|
rank=self.rank, |
|
world_size=self.world_size, |
|
use_horovod=self.use_horovod, |
|
mlp_loss=self.mlp_loss, |
|
) |
|
|
|
if self.local_loss: |
|
logits_per_audio = ( |
|
logit_scale_a * audio_features @ all_text_features.T |
|
) |
|
logits_per_text = ( |
|
logit_scale_a * text_features @ all_audio_features.T |
|
) |
|
else: |
|
logits_per_audio = ( |
|
logit_scale_a * all_audio_features @ all_text_features.T |
|
) |
|
logits_per_text = logits_per_audio.T |
|
else: |
|
logits_per_audio = logit_scale_a * audio_features @ text_features.T |
|
logits_per_text = logit_scale_a * text_features @ audio_features.T |
|
|
|
|
|
num_logits = logits_per_audio.shape[0] |
|
if self.prev_num_logits != num_logits or device not in self.labels: |
|
labels = torch.arange(num_logits, device=device, dtype=torch.long) |
|
if self.world_size > 1 and self.local_loss: |
|
labels = labels + num_logits * self.rank |
|
if self.cache_labels: |
|
self.labels[device] = labels |
|
self.prev_num_logits = num_logits |
|
else: |
|
labels = self.labels[device] |
|
if not self.weighted_loss: |
|
total_loss = ( |
|
F.cross_entropy(logits_per_audio, labels) |
|
+ F.cross_entropy(logits_per_text, labels) |
|
) / 2 |
|
else: |
|
audio_weight = (all_audio_features @ all_audio_features.T).detach() |
|
audio_weight = ( |
|
torch.exp( |
|
torch.sum(audio_weight, axis=1) |
|
/ (self.weight_loss_kappa * len(all_audio_features)) |
|
) |
|
).detach() |
|
text_weight = (all_text_features @ all_text_features.T).detach() |
|
text_weight = ( |
|
torch.exp( |
|
torch.sum(text_weight, axis=1) |
|
/ (self.weight_loss_kappa * len(all_text_features)) |
|
) |
|
).detach() |
|
total_loss = ( |
|
F.cross_entropy(logits_per_audio, labels, weight=text_weight) |
|
+ F.cross_entropy(logits_per_text, labels, weight=audio_weight) |
|
) / 2 |
|
return total_loss |
|
|
|
|
|
def lp_gather_features(pred, target, world_size=1, use_horovod=False): |
|
if use_horovod: |
|
assert hvd is not None, "Please install horovod" |
|
with torch.no_grad(): |
|
all_preds = hvd.allgather(pred) |
|
all_targets = hvd.allgath(target) |
|
else: |
|
gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)] |
|
gathered_targets = [torch.zeros_like(target) for _ in range(world_size)] |
|
|
|
dist.all_gather(gathered_preds, pred) |
|
dist.all_gather(gathered_targets, target) |
|
all_preds = torch.cat(gathered_preds, dim=0) |
|
all_targets = torch.cat(gathered_targets, dim=0) |
|
|
|
return all_preds, all_targets |
|
|
|
|
|
def get_map(pred, target): |
|
pred = torch.sigmoid(pred).numpy() |
|
target = target.numpy() |
|
return np.mean(average_precision_score(target, pred, average=None)) |
|
|
|
|
|
def get_acc(pred, target): |
|
pred = torch.argmax(pred, 1).numpy() |
|
target = torch.argmax(target, 1).numpy() |
|
return accuracy_score(target, pred) |
|
|
|
|
|
def get_mauc(pred, target): |
|
pred = torch.sigmoid(pred).numpy() |
|
target = target.numpy() |
|
return np.mean(roc_auc_score(target, pred, average=None)) |
|
|
|
|
|
class LPMetrics(object): |
|
def __init__(self, metric_names=["map", "acc", "mauc"]): |
|
self.metrics = [] |
|
for name in metric_names: |
|
self.metrics.append(self.get_metric(name)) |
|
self.metric_names = metric_names |
|
|
|
def get_metric(self, name): |
|
if name == "map": |
|
return get_map |
|
elif name == "acc": |
|
return get_acc |
|
elif name == "mauc": |
|
return get_mauc |
|
else: |
|
raise ValueError(f"the metric should be at least one of [map, acc, mauc]") |
|
|
|
def evaluate_mertics(self, pred, target): |
|
metric_dict = {} |
|
for i in range(len(self.metric_names)): |
|
metric_dict[self.metric_names[i]] = self.metrics[i](pred, target) |
|
return metric_dict |
|
|
|
|
|
def calc_celoss(pred, target): |
|
target = torch.argmax(target, 1).long() |
|
return nn.CrossEntropyLoss()(pred, target) |
|
|
|
|
|
class LPLoss(nn.Module): |
|
def __init__(self, loss_name): |
|
super().__init__() |
|
if loss_name == "bce": |
|
self.loss_func = nn.BCEWithLogitsLoss() |
|
elif loss_name == "ce": |
|
self.loss_func = calc_celoss |
|
elif loss_name == "mse": |
|
self.loss_func = nn.MSELoss() |
|
else: |
|
raise ValueError(f"the loss func should be at least one of [bce, ce, mse]") |
|
|
|
def forward(self, pred, target): |
|
loss = self.loss_func(pred, target) |
|
return loss |
|
|