oweller2 commited on
Commit
4753b37
1 Parent(s): c87aa93
Files changed (2) hide show
  1. loss.py +30 -0
  2. modeling_flexbert.py +1 -1
loss.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ import inspect
5
+ import torch.nn as nn
6
+ from .configuration_bert import FlexBertConfig
7
+
8
+ try:
9
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
10
+ except ImportError:
11
+ CrossEntropyLoss = None
12
+
13
+ LOSS2CLS = {
14
+ "cross_entropy": nn.CrossEntropyLoss,
15
+ "binary_cross_entropy": nn.BCEWithLogitsLoss,
16
+ "mean_squared_error": nn.MSELoss,
17
+ }
18
+
19
+ if CrossEntropyLoss is not None:
20
+ LOSS2CLS["fa_cross_entropy"] = CrossEntropyLoss
21
+
22
+
23
+ def get_loss_fn(config: FlexBertConfig) -> nn.Module:
24
+ try:
25
+ loss_class = LOSS2CLS[config.loss_function]
26
+ signature = inspect.signature(loss_class)
27
+ loss_kwargs = {k: v for k, v in config.loss_kwargs.items() if k in signature.parameters}
28
+ return loss_class(**loss_kwargs)
29
+ except KeyError:
30
+ raise ValueError(f"Invalid loss function type: {config.loss_function}, must be one of {LOSS2CLS.keys()}.")
modeling_flexbert.py CHANGED
@@ -116,7 +116,7 @@ from .layers import (
116
  from .mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
117
  from .normalization import get_norm_layer
118
  from .padding import pad_input, unpad_input
119
- from .bert_layers.loss import get_loss_fn
120
 
121
  # TODO: This is not used here, but this is so these files are copied when saving the model in ST/PyLate
122
  from .utils import StrEnum
 
116
  from .mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
117
  from .normalization import get_norm_layer
118
  from .padding import pad_input, unpad_input
119
+ from .loss import get_loss_fn
120
 
121
  # TODO: This is not used here, but this is so these files are copied when saving the model in ST/PyLate
122
  from .utils import StrEnum