|
|
|
|
|
|
|
|
|
|
|
import math |
|
import re |
|
from dataclasses import dataclass, field |
|
from typing import List, Optional |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from fairseq import metrics, utils |
|
from fairseq.criterions import FairseqCriterion, register_criterion |
|
from fairseq.dataclass import FairseqDataclass |
|
|
|
|
|
@dataclass |
|
class HubertCriterionConfig(FairseqDataclass): |
|
pred_masked_weight: float = field( |
|
default=1.0, |
|
metadata={"help": "weight for predictive loss for masked frames"}, |
|
) |
|
pred_nomask_weight: float = field( |
|
default=0.0, |
|
metadata={"help": "weight for predictive loss for unmasked frames"}, |
|
) |
|
loss_weights: Optional[List[float]] = field( |
|
default=None, |
|
metadata={"help": "weights for additional loss terms (not first one)"}, |
|
) |
|
log_keys: List[str] = field( |
|
default_factory=lambda: [], |
|
metadata={"help": "output keys to log"}, |
|
) |
|
|
|
|
|
@register_criterion("hubert", dataclass=HubertCriterionConfig) |
|
class HubertCriterion(FairseqCriterion): |
|
def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None): |
|
super().__init__(task) |
|
self.pred_masked_weight = pred_masked_weight |
|
self.pred_nomask_weight = pred_nomask_weight |
|
self.loss_weights = loss_weights |
|
self.log_keys = [] if log_keys is None else log_keys |
|
|
|
def forward(self, model, sample, reduce=True, log_pred=False): |
|
"""Compute the loss for the given sample. |
|
Returns a tuple with three elements: |
|
1) the loss |
|
2) the sample size, which is used as the denominator for the gradient |
|
3) logging outputs to display while training |
|
""" |
|
net_output = model(target_list=sample["target_list"], **sample["net_input"]) |
|
loss = 0. |
|
sample_size = 0 |
|
logging_output = {} |
|
reduction = "sum" if reduce else "none" |
|
|
|
loss_m_list = [] |
|
logp_m_list = model.get_logits(net_output, True) |
|
targ_m_list = model.get_targets(net_output, True) |
|
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0 |
|
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)): |
|
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction) |
|
loss_m_list.append(loss_m) |
|
logging_output[f"loss_m_{i}"] = loss_m.detach().item() |
|
if self.pred_masked_weight > 0: |
|
loss += self.pred_masked_weight * sum(loss_m_list) |
|
sample_size += targ_m_list[0].numel() |
|
|
|
loss_u_list = [] |
|
logp_u_list = model.get_logits(net_output, False) |
|
targ_u_list = model.get_targets(net_output, False) |
|
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0 |
|
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)): |
|
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction) |
|
loss_u_list.append(loss_u) |
|
logging_output[f"loss_u_{i}"] = loss_u.detach().item() |
|
if self.pred_nomask_weight > 0: |
|
loss += self.pred_nomask_weight * sum(loss_u_list) |
|
sample_size += targ_u_list[0].numel() |
|
|
|
if self.loss_weights is not None: |
|
assert hasattr(model, "get_extra_losses") |
|
extra_losses, names = model.get_extra_losses(net_output) |
|
if torch.is_tensor(extra_losses): |
|
extra_losses = [extra_losses] |
|
names = [names] |
|
if len(self.loss_weights) == 1 and len(extra_losses) != 1: |
|
self.loss_weights = [self.loss_weights[0]] * len(extra_losses) |
|
assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}" |
|
for p, n, coef in zip(extra_losses, names, self.loss_weights): |
|
if coef != 0 and p is not None: |
|
p = coef * p.float() * sample_size |
|
loss += p |
|
logging_output[f"loss_{n}"] = p.item() |
|
|
|
logging_output = { |
|
"loss": loss.item() if reduce else loss, |
|
"ntokens": sample_size, |
|
"nsentences": sample["id"].numel(), |
|
"sample_size": sample_size, |
|
**logging_output, |
|
} |
|
|
|
for lk in self.log_keys: |
|
if lk in net_output: |
|
logging_output[lk] = float((net_output[lk])) |
|
|
|
def compute_correct(logits): |
|
if logits.numel() == 0: |
|
return 0, 0 |
|
else: |
|
assert logits.dim() > 1, logits.shape |
|
max = logits.argmax(-1) == 0 |
|
min = logits.argmin(-1) == 0 |
|
both = max & min |
|
corr = max.long().sum().item() - both.long().sum().item() |
|
count = max.numel() |
|
return corr, count |
|
|
|
with torch.no_grad(): |
|
for i, logp_m in enumerate(logp_m_list): |
|
corr_m, count_m = compute_correct(logp_m) |
|
logging_output[f"correct_m_{i}"] = corr_m |
|
logging_output[f"count_m_{i}"] = count_m |
|
|
|
for i, logp_u in enumerate(logp_u_list): |
|
corr_u, count_u = compute_correct(logp_u) |
|
logging_output[f"correct_u_{i}"] = corr_u |
|
logging_output[f"count_u_{i}"] = count_u |
|
|
|
return loss, sample_size, logging_output |
|
|
|
@staticmethod |
|
def reduce_metrics(logging_outputs) -> None: |
|
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy).""" |
|
loss_sum = sum(log.get("loss", 0) for log in logging_outputs) |
|
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) |
|
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) |
|
|
|
metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3) |
|
if sample_size != ntokens: |
|
metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3) |
|
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)) |
|
else: |
|
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)) |
|
|
|
counts = {} |
|
for lk in logging_outputs[0].keys(): |
|
if lk.startswith("count_"): |
|
val = sum(log[lk] for log in logging_outputs) |
|
metrics.log_scalar(lk, val) |
|
counts[lk] = val |
|
|
|
for lk in logging_outputs[0].keys(): |
|
if lk.startswith("loss_"): |
|
val = sum(log[lk] for log in logging_outputs) |
|
metrics.log_scalar(lk, val / sample_size / math.log(2), round=3) |
|
elif lk.startswith("correct_"): |
|
val = sum(log[lk] for log in logging_outputs) |
|
metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)]) |
|
|
|
@staticmethod |
|
def aggregate_logging_outputs(logging_outputs): |
|
"""Aggregate logging outputs from data parallel training.""" |
|
raise NotImplementedError() |
|
|
|
@staticmethod |
|
def logging_outputs_can_be_summed() -> bool: |
|
""" |
|
Whether the logging outputs returned by `forward` can be summed |
|
across workers prior to calling `reduce_metrics`. Setting this |
|
to True will improves distributed training speed. |
|
""" |
|
return False |
|
|