oweller2
commited on
Commit
•
4753b37
1
Parent(s):
c87aa93
loss
Browse files- loss.py +30 -0
- modeling_flexbert.py +1 -1
loss.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 **AUTHORS_TODO**
|
2 |
+
# License: Apache-2.0
|
3 |
+
|
4 |
+
import inspect
|
5 |
+
import torch.nn as nn
|
6 |
+
from .configuration_bert import FlexBertConfig
|
7 |
+
|
8 |
+
try:
|
9 |
+
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
10 |
+
except ImportError:
|
11 |
+
CrossEntropyLoss = None
|
12 |
+
|
13 |
+
LOSS2CLS = {
|
14 |
+
"cross_entropy": nn.CrossEntropyLoss,
|
15 |
+
"binary_cross_entropy": nn.BCEWithLogitsLoss,
|
16 |
+
"mean_squared_error": nn.MSELoss,
|
17 |
+
}
|
18 |
+
|
19 |
+
if CrossEntropyLoss is not None:
|
20 |
+
LOSS2CLS["fa_cross_entropy"] = CrossEntropyLoss
|
21 |
+
|
22 |
+
|
23 |
+
def get_loss_fn(config: FlexBertConfig) -> nn.Module:
|
24 |
+
try:
|
25 |
+
loss_class = LOSS2CLS[config.loss_function]
|
26 |
+
signature = inspect.signature(loss_class)
|
27 |
+
loss_kwargs = {k: v for k, v in config.loss_kwargs.items() if k in signature.parameters}
|
28 |
+
return loss_class(**loss_kwargs)
|
29 |
+
except KeyError:
|
30 |
+
raise ValueError(f"Invalid loss function type: {config.loss_function}, must be one of {LOSS2CLS.keys()}.")
|
modeling_flexbert.py
CHANGED
@@ -116,7 +116,7 @@ from .layers import (
|
|
116 |
from .mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
|
117 |
from .normalization import get_norm_layer
|
118 |
from .padding import pad_input, unpad_input
|
119 |
-
from .
|
120 |
|
121 |
# TODO: This is not used here, but this is so these files are copied when saving the model in ST/PyLate
|
122 |
from .utils import StrEnum
|
|
|
116 |
from .mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
|
117 |
from .normalization import get_norm_layer
|
118 |
from .padding import pad_input, unpad_input
|
119 |
+
from .loss import get_loss_fn
|
120 |
|
121 |
# TODO: This is not used here, but this is so these files are copied when saving the model in ST/PyLate
|
122 |
from .utils import StrEnum
|