|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
from typing import Any, Dict, List |
|
|
|
from fairseq import metrics, utils |
|
from fairseq.dataclass import FairseqDataclass |
|
from fairseq.dataclass.utils import gen_parser_from_dataclass |
|
from torch.nn.modules.loss import _Loss |
|
|
|
|
|
class FairseqCriterion(_Loss): |
|
def __init__(self, task): |
|
super().__init__() |
|
self.task = task |
|
if hasattr(task, "target_dictionary"): |
|
tgt_dict = task.target_dictionary |
|
self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100 |
|
|
|
@classmethod |
|
def add_args(cls, parser): |
|
"""Add criterion-specific arguments to the parser.""" |
|
dc = getattr(cls, "__dataclass", None) |
|
if dc is not None: |
|
gen_parser_from_dataclass(parser, dc()) |
|
|
|
@classmethod |
|
def build_criterion(cls, cfg: FairseqDataclass, task): |
|
"""Construct a criterion from command-line args.""" |
|
|
|
init_args = {} |
|
for p in inspect.signature(cls).parameters.values(): |
|
if ( |
|
p.kind == p.POSITIONAL_ONLY |
|
or p.kind == p.VAR_POSITIONAL |
|
or p.kind == p.VAR_KEYWORD |
|
): |
|
|
|
|
|
raise NotImplementedError("{} not supported".format(p.kind)) |
|
|
|
assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY} |
|
|
|
if p.name == "task": |
|
init_args["task"] = task |
|
elif p.name == "cfg": |
|
init_args["cfg"] = cfg |
|
elif hasattr(cfg, p.name): |
|
init_args[p.name] = getattr(cfg, p.name) |
|
elif p.default != p.empty: |
|
pass |
|
else: |
|
raise NotImplementedError( |
|
"Unable to infer Criterion arguments, please implement " |
|
"{}.build_criterion".format(cls.__name__) |
|
) |
|
return cls(**init_args) |
|
|
|
def forward(self, model, sample, reduce=True): |
|
"""Compute the loss for the given sample. |
|
|
|
Returns a tuple with three elements: |
|
1) the loss |
|
2) the sample size, which is used as the denominator for the gradient |
|
3) logging outputs to display while training |
|
""" |
|
raise NotImplementedError |
|
|
|
@staticmethod |
|
def aggregate_logging_outputs( |
|
logging_outputs: List[Dict[str, Any]] |
|
) -> Dict[str, Any]: |
|
"""Aggregate logging outputs from data parallel training.""" |
|
utils.deprecation_warning( |
|
"The aggregate_logging_outputs API is deprecated. " |
|
"Please use the reduce_metrics API instead." |
|
) |
|
raise NotImplementedError |
|
|
|
@classmethod |
|
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None: |
|
"""Aggregate logging outputs from data parallel training.""" |
|
utils.deprecation_warning( |
|
"Criterions should implement the reduce_metrics API. " |
|
"Falling back to deprecated aggregate_logging_outputs API." |
|
) |
|
agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs) |
|
for k, v in agg_logging_outputs.items(): |
|
if k in {"nsentences", "ntokens", "sample_size"}: |
|
continue |
|
metrics.log_scalar(k, v) |
|
|
|
@staticmethod |
|
def logging_outputs_can_be_summed() -> bool: |
|
""" |
|
Whether the logging outputs returned by `forward` can be summed |
|
across workers prior to calling `reduce_metrics`. Setting this |
|
to True will improves distributed training speed. |
|
""" |
|
return False |
|
|
|
|
|
class LegacyFairseqCriterion(FairseqCriterion): |
|
def __init__(self, args, task): |
|
super().__init__(task=task) |
|
self.args = args |
|
|
|
utils.deprecation_warning( |
|
"Criterions should take explicit arguments instead of an " |
|
"argparse.Namespace object, please update your criterion by " |
|
"extending FairseqCriterion instead of LegacyFairseqCriterion." |
|
) |
|
|
|
@classmethod |
|
def build_criterion(cls, args, task): |
|
"""Construct a criterion from command-line args.""" |
|
return cls(args, task) |
|
|