File size: 933 Bytes
4753b37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
# Copyright 2024 **AUTHORS_TODO**
# License: Apache-2.0
import inspect
import torch.nn as nn
from .configuration_bert import FlexBertConfig
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
except ImportError:
CrossEntropyLoss = None
LOSS2CLS = {
"cross_entropy": nn.CrossEntropyLoss,
"binary_cross_entropy": nn.BCEWithLogitsLoss,
"mean_squared_error": nn.MSELoss,
}
if CrossEntropyLoss is not None:
LOSS2CLS["fa_cross_entropy"] = CrossEntropyLoss
def get_loss_fn(config: FlexBertConfig) -> nn.Module:
try:
loss_class = LOSS2CLS[config.loss_function]
signature = inspect.signature(loss_class)
loss_kwargs = {k: v for k, v in config.loss_kwargs.items() if k in signature.parameters}
return loss_class(**loss_kwargs)
except KeyError:
raise ValueError(f"Invalid loss function type: {config.loss_function}, must be one of {LOSS2CLS.keys()}.")
|