alaeddine-13
commited on
Commit
•
b4f2b16
1
Parent(s):
e36c994
rename to jina bert
Browse files- modeling_bert.py +75 -129
modeling_bert.py
CHANGED
@@ -54,7 +54,7 @@ from transformers.utils import (
|
|
54 |
logging,
|
55 |
replace_return_docstrings,
|
56 |
)
|
57 |
-
from .configuration_bert import
|
58 |
|
59 |
try:
|
60 |
from tqdm.autonotebook import trange
|
@@ -66,7 +66,7 @@ except ImportError:
|
|
66 |
logger = logging.get_logger(__name__)
|
67 |
|
68 |
_CHECKPOINT_FOR_DOC = "bert-base-uncased"
|
69 |
-
_CONFIG_FOR_DOC = "
|
70 |
|
71 |
# TokenClassification docstring
|
72 |
_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = (
|
@@ -197,10 +197,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|
197 |
return model
|
198 |
|
199 |
|
200 |
-
class
|
201 |
"""Construct the embeddings from word, position and token_type embeddings."""
|
202 |
|
203 |
-
def __init__(self, config:
|
204 |
super().__init__()
|
205 |
self.word_embeddings = nn.Embedding(
|
206 |
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
|
@@ -280,7 +280,7 @@ class MyBertEmbeddings(nn.Module):
|
|
280 |
return embeddings
|
281 |
|
282 |
|
283 |
-
class
|
284 |
def __init__(self, config, position_embedding_type=None):
|
285 |
super().__init__()
|
286 |
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
@@ -448,7 +448,7 @@ class MyBertSelfAttention(nn.Module):
|
|
448 |
return outputs
|
449 |
|
450 |
|
451 |
-
class
|
452 |
def __init__(self, config):
|
453 |
super().__init__()
|
454 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
@@ -464,13 +464,13 @@ class MyBertSelfOutput(nn.Module):
|
|
464 |
return hidden_states
|
465 |
|
466 |
|
467 |
-
class
|
468 |
def __init__(self, config, position_embedding_type=None):
|
469 |
super().__init__()
|
470 |
-
self.self =
|
471 |
config, position_embedding_type=position_embedding_type
|
472 |
)
|
473 |
-
self.output =
|
474 |
self.pruned_heads = set()
|
475 |
|
476 |
def prune_heads(self, heads):
|
@@ -524,7 +524,7 @@ class MyBertAttention(nn.Module):
|
|
524 |
return outputs
|
525 |
|
526 |
|
527 |
-
class
|
528 |
def __init__(self, config):
|
529 |
super().__init__()
|
530 |
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
@@ -539,8 +539,8 @@ class MyBertIntermediate(nn.Module):
|
|
539 |
return hidden_states
|
540 |
|
541 |
|
542 |
-
class
|
543 |
-
def __init__(self, config:
|
544 |
super().__init__()
|
545 |
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
546 |
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
@@ -555,8 +555,8 @@ class MyBertOutput(nn.Module):
|
|
555 |
return hidden_states
|
556 |
|
557 |
|
558 |
-
class
|
559 |
-
def __init__(self, config:
|
560 |
super().__init__()
|
561 |
self.config = config
|
562 |
self.gated_layers = nn.Linear(
|
@@ -589,12 +589,12 @@ class MyBertGLUMLP(nn.Module):
|
|
589 |
return hidden_states
|
590 |
|
591 |
|
592 |
-
class
|
593 |
-
def __init__(self, config:
|
594 |
super().__init__()
|
595 |
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
596 |
self.seq_len_dim = 1
|
597 |
-
self.attention =
|
598 |
self.is_decoder = config.is_decoder
|
599 |
self.add_cross_attention = config.add_cross_attention
|
600 |
self.feed_forward_type = config.feed_forward_type
|
@@ -603,14 +603,14 @@ class MyBertLayer(nn.Module):
|
|
603 |
raise ValueError(
|
604 |
f"{self} should be used as a decoder model if cross attention is added"
|
605 |
)
|
606 |
-
self.crossattention =
|
607 |
config, position_embedding_type="absolute"
|
608 |
)
|
609 |
if self.feed_forward_type.endswith('glu'):
|
610 |
-
self.mlp =
|
611 |
else:
|
612 |
-
self.intermediate =
|
613 |
-
self.output =
|
614 |
|
615 |
def forward(
|
616 |
self,
|
@@ -699,12 +699,12 @@ class MyBertLayer(nn.Module):
|
|
699 |
return layer_output
|
700 |
|
701 |
|
702 |
-
class
|
703 |
-
def __init__(self, config:
|
704 |
super().__init__()
|
705 |
self.config = config
|
706 |
self.layer = nn.ModuleList(
|
707 |
-
[
|
708 |
)
|
709 |
self.gradient_checkpointing = False
|
710 |
self.num_attention_heads = config.num_attention_heads
|
@@ -724,26 +724,6 @@ class MyBertEncoder(nn.Module):
|
|
724 |
# will be applied, it is necessary to construct the diagonal mask.
|
725 |
n_heads = self.num_attention_heads
|
726 |
|
727 |
-
# Mosaics one
|
728 |
-
# def _get_alibi_head_slopes(n_heads: int) -> List[float]:
|
729 |
-
# def get_slopes_power_of_2(n_heads: int) -> List[float]:
|
730 |
-
# start = 2 ** (-(2 ** -(math.log2(n_heads) - 3)))
|
731 |
-
# ratio = start
|
732 |
-
# return [start * ratio**i for i in range(n_heads)]
|
733 |
-
|
734 |
-
# # In the paper, they only train models that have 2^a heads for some a. This function
|
735 |
-
# # has some good properties that only occur when the input is a power of 2. To
|
736 |
-
# # maintain that even when the number of heads is not a power of 2, we use a
|
737 |
-
# # workaround.
|
738 |
-
# if math.log2(n_heads).is_integer():
|
739 |
-
# return get_slopes_power_of_2(n_heads)
|
740 |
-
|
741 |
-
# closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
|
742 |
-
# slopes_a = get_slopes_power_of_2(closest_power_of_2)
|
743 |
-
# slopes_b = _get_alibi_head_slopes(2 * closest_power_of_2)
|
744 |
-
# slopes_b = slopes_b[0::2][: n_heads - closest_power_of_2]
|
745 |
-
# return slopes_a + slopes_b
|
746 |
-
|
747 |
def _get_alibi_head_slopes(n_heads: int) -> List[float]:
|
748 |
def get_slopes_power_of_2(n):
|
749 |
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
@@ -893,7 +873,7 @@ class MyBertEncoder(nn.Module):
|
|
893 |
)
|
894 |
|
895 |
|
896 |
-
class
|
897 |
def __init__(self, config):
|
898 |
super().__init__()
|
899 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
@@ -908,7 +888,7 @@ class MyBertPooler(nn.Module):
|
|
908 |
return pooled_output
|
909 |
|
910 |
|
911 |
-
class
|
912 |
def __init__(self, config):
|
913 |
super().__init__()
|
914 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
@@ -925,10 +905,10 @@ class MyBertPredictionHeadTransform(nn.Module):
|
|
925 |
return hidden_states
|
926 |
|
927 |
|
928 |
-
class
|
929 |
def __init__(self, config):
|
930 |
super().__init__()
|
931 |
-
self.transform =
|
932 |
|
933 |
# The output weights are the same as the input embeddings, but there is
|
934 |
# an output-only bias for each token.
|
@@ -945,17 +925,17 @@ class MyBertLMPredictionHead(nn.Module):
|
|
945 |
return hidden_states
|
946 |
|
947 |
|
948 |
-
class
|
949 |
def __init__(self, config):
|
950 |
super().__init__()
|
951 |
-
self.predictions =
|
952 |
|
953 |
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
954 |
prediction_scores = self.predictions(sequence_output)
|
955 |
return prediction_scores
|
956 |
|
957 |
|
958 |
-
class
|
959 |
def __init__(self, config):
|
960 |
super().__init__()
|
961 |
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
@@ -965,10 +945,10 @@ class MyBertOnlyNSPHead(nn.Module):
|
|
965 |
return seq_relationship_score
|
966 |
|
967 |
|
968 |
-
class
|
969 |
def __init__(self, config):
|
970 |
super().__init__()
|
971 |
-
self.predictions =
|
972 |
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
973 |
|
974 |
def forward(self, sequence_output, pooled_output):
|
@@ -977,13 +957,13 @@ class MyBertPreTrainingHeads(nn.Module):
|
|
977 |
return prediction_scores, seq_relationship_score
|
978 |
|
979 |
|
980 |
-
class
|
981 |
"""
|
982 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
983 |
models.
|
984 |
"""
|
985 |
|
986 |
-
config_class =
|
987 |
load_tf_weights = load_tf_weights_in_bert
|
988 |
base_model_prefix = "bert"
|
989 |
supports_gradient_checkpointing = True
|
@@ -1005,12 +985,12 @@ class MyBertPreTrainedModel(PreTrainedModel):
|
|
1005 |
module.weight.data.fill_(1.0)
|
1006 |
|
1007 |
def _set_gradient_checkpointing(self, module, value=False):
|
1008 |
-
if isinstance(module,
|
1009 |
module.gradient_checkpointing = value
|
1010 |
|
1011 |
|
1012 |
@dataclass
|
1013 |
-
class
|
1014 |
"""
|
1015 |
Output type of [`BertForPreTraining`].
|
1016 |
|
@@ -1113,7 +1093,7 @@ BERT_INPUTS_DOCSTRING = r"""
|
|
1113 |
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
1114 |
BERT_START_DOCSTRING,
|
1115 |
)
|
1116 |
-
class
|
1117 |
"""
|
1118 |
|
1119 |
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
@@ -1126,7 +1106,7 @@ class MyBertModel(MyBertPreTrainedModel):
|
|
1126 |
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
1127 |
"""
|
1128 |
|
1129 |
-
def __init__(self, config:
|
1130 |
super().__init__(config)
|
1131 |
self.config = config
|
1132 |
|
@@ -1137,17 +1117,17 @@ class MyBertModel(MyBertPreTrainedModel):
|
|
1137 |
|
1138 |
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
|
1139 |
|
1140 |
-
self.embeddings =
|
1141 |
-
self.encoder =
|
1142 |
|
1143 |
-
self.pooler =
|
1144 |
|
1145 |
# Initialize weights and apply final processing
|
1146 |
self.post_init()
|
1147 |
|
1148 |
@torch.inference_mode()
|
1149 |
def encode(
|
1150 |
-
self: '
|
1151 |
sentences: Union[str, List[str]],
|
1152 |
batch_size: int = 32,
|
1153 |
show_progress_bar: Optional[bool] = None,
|
@@ -1479,14 +1459,14 @@ class MyBertModel(MyBertPreTrainedModel):
|
|
1479 |
""",
|
1480 |
BERT_START_DOCSTRING,
|
1481 |
)
|
1482 |
-
class
|
1483 |
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
1484 |
|
1485 |
def __init__(self, config):
|
1486 |
super().__init__(config)
|
1487 |
|
1488 |
-
self.bert =
|
1489 |
-
self.cls =
|
1490 |
|
1491 |
# Initialize weights and apply final processing
|
1492 |
self.post_init()
|
@@ -1501,7 +1481,7 @@ class MyBertForPreTraining(MyBertPreTrainedModel):
|
|
1501 |
BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
|
1502 |
)
|
1503 |
@replace_return_docstrings(
|
1504 |
-
output_type=
|
1505 |
)
|
1506 |
def forward(
|
1507 |
self,
|
@@ -1516,7 +1496,7 @@ class MyBertForPreTraining(MyBertPreTrainedModel):
|
|
1516 |
output_attentions: Optional[bool] = None,
|
1517 |
output_hidden_states: Optional[bool] = None,
|
1518 |
return_dict: Optional[bool] = None,
|
1519 |
-
) -> Union[Tuple[torch.Tensor],
|
1520 |
r"""
|
1521 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1522 |
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
@@ -1532,22 +1512,6 @@ class MyBertForPreTraining(MyBertPreTrainedModel):
|
|
1532 |
Used to hide legacy arguments that have been deprecated.
|
1533 |
|
1534 |
Returns:
|
1535 |
-
|
1536 |
-
Example:
|
1537 |
-
|
1538 |
-
```python
|
1539 |
-
>>> from transformers import AutoTokenizer, MyBertForPreTraining
|
1540 |
-
>>> import torch
|
1541 |
-
|
1542 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
1543 |
-
>>> model = MyBertForPreTraining.from_pretrained("bert-base-uncased")
|
1544 |
-
|
1545 |
-
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
1546 |
-
>>> outputs = model(**inputs)
|
1547 |
-
|
1548 |
-
>>> prediction_logits = outputs.prediction_logits
|
1549 |
-
>>> seq_relationship_logits = outputs.seq_relationship_logits
|
1550 |
-
```
|
1551 |
"""
|
1552 |
return_dict = (
|
1553 |
return_dict if return_dict is not None else self.config.use_return_dict
|
@@ -1585,7 +1549,7 @@ class MyBertForPreTraining(MyBertPreTrainedModel):
|
|
1585 |
output = (prediction_scores, seq_relationship_score) + outputs[2:]
|
1586 |
return ((total_loss,) + output) if total_loss is not None else output
|
1587 |
|
1588 |
-
return
|
1589 |
loss=total_loss,
|
1590 |
prediction_logits=prediction_scores,
|
1591 |
seq_relationship_logits=seq_relationship_score,
|
@@ -1595,10 +1559,10 @@ class MyBertForPreTraining(MyBertPreTrainedModel):
|
|
1595 |
|
1596 |
|
1597 |
@add_start_docstrings(
|
1598 |
-
"""
|
1599 |
BERT_START_DOCSTRING,
|
1600 |
)
|
1601 |
-
class
|
1602 |
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
1603 |
|
1604 |
def __init__(self, config):
|
@@ -1606,11 +1570,11 @@ class MyBertLMHeadModel(MyBertPreTrainedModel):
|
|
1606 |
|
1607 |
if not config.is_decoder:
|
1608 |
logger.warning(
|
1609 |
-
"If you want to use `
|
1610 |
)
|
1611 |
|
1612 |
-
self.bert =
|
1613 |
-
self.cls =
|
1614 |
|
1615 |
# Initialize weights and apply final processing
|
1616 |
self.post_init()
|
@@ -1755,9 +1719,9 @@ class MyBertLMHeadModel(MyBertPreTrainedModel):
|
|
1755 |
|
1756 |
|
1757 |
@add_start_docstrings(
|
1758 |
-
"""
|
1759 |
)
|
1760 |
-
class
|
1761 |
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
1762 |
|
1763 |
def __init__(self, config):
|
@@ -1765,12 +1729,12 @@ class MyBertForMaskedLM(MyBertPreTrainedModel):
|
|
1765 |
|
1766 |
if config.is_decoder:
|
1767 |
logger.warning(
|
1768 |
-
"If you want to use `
|
1769 |
"bi-directional self-attention."
|
1770 |
)
|
1771 |
|
1772 |
-
self.bert =
|
1773 |
-
self.cls =
|
1774 |
|
1775 |
# Initialize weights and apply final processing
|
1776 |
self.post_init()
|
@@ -1880,15 +1844,15 @@ class MyBertForMaskedLM(MyBertPreTrainedModel):
|
|
1880 |
|
1881 |
|
1882 |
@add_start_docstrings(
|
1883 |
-
"""
|
1884 |
BERT_START_DOCSTRING,
|
1885 |
)
|
1886 |
-
class
|
1887 |
def __init__(self, config):
|
1888 |
super().__init__(config)
|
1889 |
|
1890 |
-
self.bert =
|
1891 |
-
self.cls =
|
1892 |
|
1893 |
# Initialize weights and apply final processing
|
1894 |
self.post_init()
|
@@ -1922,24 +1886,6 @@ class MyBertForNextSentencePrediction(MyBertPreTrainedModel):
|
|
1922 |
- 1 indicates sequence B is a random sequence.
|
1923 |
|
1924 |
Returns:
|
1925 |
-
|
1926 |
-
Example:
|
1927 |
-
|
1928 |
-
```python
|
1929 |
-
>>> from transformers import AutoTokenizer, MyBertForNextSentencePrediction
|
1930 |
-
>>> import torch
|
1931 |
-
|
1932 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
1933 |
-
>>> model = MyBertForNextSentencePrediction.from_pretrained("bert-base-uncased")
|
1934 |
-
|
1935 |
-
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
1936 |
-
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
1937 |
-
>>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
|
1938 |
-
|
1939 |
-
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
|
1940 |
-
>>> logits = outputs.logits
|
1941 |
-
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
|
1942 |
-
```
|
1943 |
"""
|
1944 |
|
1945 |
if "next_sentence_label" in kwargs:
|
@@ -1995,18 +1941,18 @@ class MyBertForNextSentencePrediction(MyBertPreTrainedModel):
|
|
1995 |
|
1996 |
@add_start_docstrings(
|
1997 |
"""
|
1998 |
-
|
1999 |
output) e.g. for GLUE tasks.
|
2000 |
""",
|
2001 |
BERT_START_DOCSTRING,
|
2002 |
)
|
2003 |
-
class
|
2004 |
def __init__(self, config):
|
2005 |
super().__init__(config)
|
2006 |
self.num_labels = config.num_labels
|
2007 |
self.config = config
|
2008 |
|
2009 |
-
self.bert =
|
2010 |
classifier_dropout = (
|
2011 |
config.classifier_dropout
|
2012 |
if config.classifier_dropout is not None
|
@@ -2106,16 +2052,16 @@ class MyBertForSequenceClassification(MyBertPreTrainedModel):
|
|
2106 |
|
2107 |
@add_start_docstrings(
|
2108 |
"""
|
2109 |
-
|
2110 |
softmax) e.g. for RocStories/SWAG tasks.
|
2111 |
""",
|
2112 |
BERT_START_DOCSTRING,
|
2113 |
)
|
2114 |
-
class
|
2115 |
def __init__(self, config):
|
2116 |
super().__init__(config)
|
2117 |
|
2118 |
-
self.bert =
|
2119 |
classifier_dropout = (
|
2120 |
config.classifier_dropout
|
2121 |
if config.classifier_dropout is not None
|
@@ -2222,17 +2168,17 @@ class MyBertForMultipleChoice(MyBertPreTrainedModel):
|
|
2222 |
|
2223 |
@add_start_docstrings(
|
2224 |
"""
|
2225 |
-
|
2226 |
Named-Entity-Recognition (NER) tasks.
|
2227 |
""",
|
2228 |
BERT_START_DOCSTRING,
|
2229 |
)
|
2230 |
-
class
|
2231 |
def __init__(self, config):
|
2232 |
super().__init__(config)
|
2233 |
self.num_labels = config.num_labels
|
2234 |
|
2235 |
-
self.bert =
|
2236 |
classifier_dropout = (
|
2237 |
config.classifier_dropout
|
2238 |
if config.classifier_dropout is not None
|
@@ -2311,17 +2257,17 @@ class MyBertForTokenClassification(MyBertPreTrainedModel):
|
|
2311 |
|
2312 |
@add_start_docstrings(
|
2313 |
"""
|
2314 |
-
|
2315 |
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
2316 |
""",
|
2317 |
BERT_START_DOCSTRING,
|
2318 |
)
|
2319 |
-
class
|
2320 |
def __init__(self, config):
|
2321 |
super().__init__(config)
|
2322 |
self.num_labels = config.num_labels
|
2323 |
|
2324 |
-
self.bert =
|
2325 |
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
2326 |
|
2327 |
# Initialize weights and apply final processing
|
|
|
54 |
logging,
|
55 |
replace_return_docstrings,
|
56 |
)
|
57 |
+
from .configuration_bert import JinaBertConfig
|
58 |
|
59 |
try:
|
60 |
from tqdm.autonotebook import trange
|
|
|
66 |
logger = logging.get_logger(__name__)
|
67 |
|
68 |
_CHECKPOINT_FOR_DOC = "bert-base-uncased"
|
69 |
+
_CONFIG_FOR_DOC = "JinaBertConfig"
|
70 |
|
71 |
# TokenClassification docstring
|
72 |
_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = (
|
|
|
197 |
return model
|
198 |
|
199 |
|
200 |
+
class JinaBertEmbeddings(nn.Module):
|
201 |
"""Construct the embeddings from word, position and token_type embeddings."""
|
202 |
|
203 |
+
def __init__(self, config: JinaBertConfig):
|
204 |
super().__init__()
|
205 |
self.word_embeddings = nn.Embedding(
|
206 |
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
|
|
|
280 |
return embeddings
|
281 |
|
282 |
|
283 |
+
class JinaBertSelfAttention(nn.Module):
|
284 |
def __init__(self, config, position_embedding_type=None):
|
285 |
super().__init__()
|
286 |
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
|
|
448 |
return outputs
|
449 |
|
450 |
|
451 |
+
class JinaBertSelfOutput(nn.Module):
|
452 |
def __init__(self, config):
|
453 |
super().__init__()
|
454 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
|
464 |
return hidden_states
|
465 |
|
466 |
|
467 |
+
class JinaBertAttention(nn.Module):
|
468 |
def __init__(self, config, position_embedding_type=None):
|
469 |
super().__init__()
|
470 |
+
self.self = JinaBertSelfAttention(
|
471 |
config, position_embedding_type=position_embedding_type
|
472 |
)
|
473 |
+
self.output = JinaBertSelfOutput(config)
|
474 |
self.pruned_heads = set()
|
475 |
|
476 |
def prune_heads(self, heads):
|
|
|
524 |
return outputs
|
525 |
|
526 |
|
527 |
+
class JinaBertIntermediate(nn.Module):
|
528 |
def __init__(self, config):
|
529 |
super().__init__()
|
530 |
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
|
539 |
return hidden_states
|
540 |
|
541 |
|
542 |
+
class JinaBertOutput(nn.Module):
|
543 |
+
def __init__(self, config: JinaBertConfig):
|
544 |
super().__init__()
|
545 |
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
546 |
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
555 |
return hidden_states
|
556 |
|
557 |
|
558 |
+
class JinaBertGLUMLP(nn.Module):
|
559 |
+
def __init__(self, config: JinaBertConfig):
|
560 |
super().__init__()
|
561 |
self.config = config
|
562 |
self.gated_layers = nn.Linear(
|
|
|
589 |
return hidden_states
|
590 |
|
591 |
|
592 |
+
class JinaBertLayer(nn.Module):
|
593 |
+
def __init__(self, config: JinaBertConfig):
|
594 |
super().__init__()
|
595 |
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
596 |
self.seq_len_dim = 1
|
597 |
+
self.attention = JinaBertAttention(config)
|
598 |
self.is_decoder = config.is_decoder
|
599 |
self.add_cross_attention = config.add_cross_attention
|
600 |
self.feed_forward_type = config.feed_forward_type
|
|
|
603 |
raise ValueError(
|
604 |
f"{self} should be used as a decoder model if cross attention is added"
|
605 |
)
|
606 |
+
self.crossattention = JinaBertAttention(
|
607 |
config, position_embedding_type="absolute"
|
608 |
)
|
609 |
if self.feed_forward_type.endswith('glu'):
|
610 |
+
self.mlp = JinaBertGLUMLP(config)
|
611 |
else:
|
612 |
+
self.intermediate = JinaBertIntermediate(config)
|
613 |
+
self.output = JinaBertOutput(config)
|
614 |
|
615 |
def forward(
|
616 |
self,
|
|
|
699 |
return layer_output
|
700 |
|
701 |
|
702 |
+
class JinaBertEncoder(nn.Module):
|
703 |
+
def __init__(self, config: JinaBertConfig):
|
704 |
super().__init__()
|
705 |
self.config = config
|
706 |
self.layer = nn.ModuleList(
|
707 |
+
[JinaBertLayer(config) for _ in range(config.num_hidden_layers)]
|
708 |
)
|
709 |
self.gradient_checkpointing = False
|
710 |
self.num_attention_heads = config.num_attention_heads
|
|
|
724 |
# will be applied, it is necessary to construct the diagonal mask.
|
725 |
n_heads = self.num_attention_heads
|
726 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
727 |
def _get_alibi_head_slopes(n_heads: int) -> List[float]:
|
728 |
def get_slopes_power_of_2(n):
|
729 |
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
|
|
873 |
)
|
874 |
|
875 |
|
876 |
+
class JinaBertPooler(nn.Module):
|
877 |
def __init__(self, config):
|
878 |
super().__init__()
|
879 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
|
888 |
return pooled_output
|
889 |
|
890 |
|
891 |
+
class JinaBertPredictionHeadTransform(nn.Module):
|
892 |
def __init__(self, config):
|
893 |
super().__init__()
|
894 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
|
905 |
return hidden_states
|
906 |
|
907 |
|
908 |
+
class JinaBertLMPredictionHead(nn.Module):
|
909 |
def __init__(self, config):
|
910 |
super().__init__()
|
911 |
+
self.transform = JinaBertPredictionHeadTransform(config)
|
912 |
|
913 |
# The output weights are the same as the input embeddings, but there is
|
914 |
# an output-only bias for each token.
|
|
|
925 |
return hidden_states
|
926 |
|
927 |
|
928 |
+
class JinaBertOnlyMLMHead(nn.Module):
|
929 |
def __init__(self, config):
|
930 |
super().__init__()
|
931 |
+
self.predictions = JinaBertLMPredictionHead(config)
|
932 |
|
933 |
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
934 |
prediction_scores = self.predictions(sequence_output)
|
935 |
return prediction_scores
|
936 |
|
937 |
|
938 |
+
class JinaBertOnlyNSPHead(nn.Module):
|
939 |
def __init__(self, config):
|
940 |
super().__init__()
|
941 |
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
|
|
945 |
return seq_relationship_score
|
946 |
|
947 |
|
948 |
+
class JinaBertPreTrainingHeads(nn.Module):
|
949 |
def __init__(self, config):
|
950 |
super().__init__()
|
951 |
+
self.predictions = JinaBertLMPredictionHead(config)
|
952 |
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
953 |
|
954 |
def forward(self, sequence_output, pooled_output):
|
|
|
957 |
return prediction_scores, seq_relationship_score
|
958 |
|
959 |
|
960 |
+
class JinaBertPreTrainedModel(PreTrainedModel):
|
961 |
"""
|
962 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
963 |
models.
|
964 |
"""
|
965 |
|
966 |
+
config_class = JinaBertConfig
|
967 |
load_tf_weights = load_tf_weights_in_bert
|
968 |
base_model_prefix = "bert"
|
969 |
supports_gradient_checkpointing = True
|
|
|
985 |
module.weight.data.fill_(1.0)
|
986 |
|
987 |
def _set_gradient_checkpointing(self, module, value=False):
|
988 |
+
if isinstance(module, JinaBertEncoder):
|
989 |
module.gradient_checkpointing = value
|
990 |
|
991 |
|
992 |
@dataclass
|
993 |
+
class JinaBertForPreTrainingOutput(ModelOutput):
|
994 |
"""
|
995 |
Output type of [`BertForPreTraining`].
|
996 |
|
|
|
1093 |
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
1094 |
BERT_START_DOCSTRING,
|
1095 |
)
|
1096 |
+
class JinaBertModel(JinaBertPreTrainedModel):
|
1097 |
"""
|
1098 |
|
1099 |
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
|
|
1106 |
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
1107 |
"""
|
1108 |
|
1109 |
+
def __init__(self, config: JinaBertConfig, add_pooling_layer=True):
|
1110 |
super().__init__(config)
|
1111 |
self.config = config
|
1112 |
|
|
|
1117 |
|
1118 |
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
|
1119 |
|
1120 |
+
self.embeddings = JinaBertEmbeddings(config)
|
1121 |
+
self.encoder = JinaBertEncoder(config)
|
1122 |
|
1123 |
+
self.pooler = JinaBertPooler(config) if add_pooling_layer else None
|
1124 |
|
1125 |
# Initialize weights and apply final processing
|
1126 |
self.post_init()
|
1127 |
|
1128 |
@torch.inference_mode()
|
1129 |
def encode(
|
1130 |
+
self: 'JinaBertModel',
|
1131 |
sentences: Union[str, List[str]],
|
1132 |
batch_size: int = 32,
|
1133 |
show_progress_bar: Optional[bool] = None,
|
|
|
1459 |
""",
|
1460 |
BERT_START_DOCSTRING,
|
1461 |
)
|
1462 |
+
class JinaBertForPreTraining(JinaBertPreTrainedModel):
|
1463 |
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
1464 |
|
1465 |
def __init__(self, config):
|
1466 |
super().__init__(config)
|
1467 |
|
1468 |
+
self.bert = JinaBertModel(config)
|
1469 |
+
self.cls = JinaBertPreTrainingHeads(config)
|
1470 |
|
1471 |
# Initialize weights and apply final processing
|
1472 |
self.post_init()
|
|
|
1481 |
BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
|
1482 |
)
|
1483 |
@replace_return_docstrings(
|
1484 |
+
output_type=JinaBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
|
1485 |
)
|
1486 |
def forward(
|
1487 |
self,
|
|
|
1496 |
output_attentions: Optional[bool] = None,
|
1497 |
output_hidden_states: Optional[bool] = None,
|
1498 |
return_dict: Optional[bool] = None,
|
1499 |
+
) -> Union[Tuple[torch.Tensor], JinaBertForPreTrainingOutput]:
|
1500 |
r"""
|
1501 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1502 |
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
|
|
1512 |
Used to hide legacy arguments that have been deprecated.
|
1513 |
|
1514 |
Returns:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1515 |
"""
|
1516 |
return_dict = (
|
1517 |
return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
1549 |
output = (prediction_scores, seq_relationship_score) + outputs[2:]
|
1550 |
return ((total_loss,) + output) if total_loss is not None else output
|
1551 |
|
1552 |
+
return JinaBertForPreTrainingOutput(
|
1553 |
loss=total_loss,
|
1554 |
prediction_logits=prediction_scores,
|
1555 |
seq_relationship_logits=seq_relationship_score,
|
|
|
1559 |
|
1560 |
|
1561 |
@add_start_docstrings(
|
1562 |
+
"""JinaBert Model with a `language modeling` head on top for CLM fine-tuning.""",
|
1563 |
BERT_START_DOCSTRING,
|
1564 |
)
|
1565 |
+
class JinaBertLMHeadModel(JinaBertPreTrainedModel):
|
1566 |
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
1567 |
|
1568 |
def __init__(self, config):
|
|
|
1570 |
|
1571 |
if not config.is_decoder:
|
1572 |
logger.warning(
|
1573 |
+
"If you want to use `JinaBertLMHeadModel` as a standalone, add `is_decoder=True.`"
|
1574 |
)
|
1575 |
|
1576 |
+
self.bert = JinaBertModel(config, add_pooling_layer=False)
|
1577 |
+
self.cls = JinaBertOnlyMLMHead(config)
|
1578 |
|
1579 |
# Initialize weights and apply final processing
|
1580 |
self.post_init()
|
|
|
1719 |
|
1720 |
|
1721 |
@add_start_docstrings(
|
1722 |
+
"""JinaBert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING
|
1723 |
)
|
1724 |
+
class JinaBertForMaskedLM(JinaBertPreTrainedModel):
|
1725 |
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
1726 |
|
1727 |
def __init__(self, config):
|
|
|
1729 |
|
1730 |
if config.is_decoder:
|
1731 |
logger.warning(
|
1732 |
+
"If you want to use `JinaBertForMaskedLM` make sure `config.is_decoder=False` for "
|
1733 |
"bi-directional self-attention."
|
1734 |
)
|
1735 |
|
1736 |
+
self.bert = JinaBertModel(config, add_pooling_layer=False)
|
1737 |
+
self.cls = JinaBertOnlyMLMHead(config)
|
1738 |
|
1739 |
# Initialize weights and apply final processing
|
1740 |
self.post_init()
|
|
|
1844 |
|
1845 |
|
1846 |
@add_start_docstrings(
|
1847 |
+
"""JinaBert Model with a `next sentence prediction (classification)` head on top.""",
|
1848 |
BERT_START_DOCSTRING,
|
1849 |
)
|
1850 |
+
class JinaBertForNextSentencePrediction(JinaBertPreTrainedModel):
|
1851 |
def __init__(self, config):
|
1852 |
super().__init__(config)
|
1853 |
|
1854 |
+
self.bert = JinaBertModel(config)
|
1855 |
+
self.cls = JinaBertOnlyNSPHead(config)
|
1856 |
|
1857 |
# Initialize weights and apply final processing
|
1858 |
self.post_init()
|
|
|
1886 |
- 1 indicates sequence B is a random sequence.
|
1887 |
|
1888 |
Returns:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1889 |
"""
|
1890 |
|
1891 |
if "next_sentence_label" in kwargs:
|
|
|
1941 |
|
1942 |
@add_start_docstrings(
|
1943 |
"""
|
1944 |
+
JinaBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
1945 |
output) e.g. for GLUE tasks.
|
1946 |
""",
|
1947 |
BERT_START_DOCSTRING,
|
1948 |
)
|
1949 |
+
class JinaBertForSequenceClassification(JinaBertPreTrainedModel):
|
1950 |
def __init__(self, config):
|
1951 |
super().__init__(config)
|
1952 |
self.num_labels = config.num_labels
|
1953 |
self.config = config
|
1954 |
|
1955 |
+
self.bert = JinaBertModel(config)
|
1956 |
classifier_dropout = (
|
1957 |
config.classifier_dropout
|
1958 |
if config.classifier_dropout is not None
|
|
|
2052 |
|
2053 |
@add_start_docstrings(
|
2054 |
"""
|
2055 |
+
JinaBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
2056 |
softmax) e.g. for RocStories/SWAG tasks.
|
2057 |
""",
|
2058 |
BERT_START_DOCSTRING,
|
2059 |
)
|
2060 |
+
class JinaBertForMultipleChoice(JinaBertPreTrainedModel):
|
2061 |
def __init__(self, config):
|
2062 |
super().__init__(config)
|
2063 |
|
2064 |
+
self.bert = JinaBertModel(config)
|
2065 |
classifier_dropout = (
|
2066 |
config.classifier_dropout
|
2067 |
if config.classifier_dropout is not None
|
|
|
2168 |
|
2169 |
@add_start_docstrings(
|
2170 |
"""
|
2171 |
+
JinaBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
2172 |
Named-Entity-Recognition (NER) tasks.
|
2173 |
""",
|
2174 |
BERT_START_DOCSTRING,
|
2175 |
)
|
2176 |
+
class JinaBertForTokenClassification(JinaBertPreTrainedModel):
|
2177 |
def __init__(self, config):
|
2178 |
super().__init__(config)
|
2179 |
self.num_labels = config.num_labels
|
2180 |
|
2181 |
+
self.bert = JinaBertModel(config, add_pooling_layer=False)
|
2182 |
classifier_dropout = (
|
2183 |
config.classifier_dropout
|
2184 |
if config.classifier_dropout is not None
|
|
|
2257 |
|
2258 |
@add_start_docstrings(
|
2259 |
"""
|
2260 |
+
JinaBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
2261 |
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
2262 |
""",
|
2263 |
BERT_START_DOCSTRING,
|
2264 |
)
|
2265 |
+
class JinaBertForQuestionAnswering(JinaBertPreTrainedModel):
|
2266 |
def __init__(self, config):
|
2267 |
super().__init__(config)
|
2268 |
self.num_labels = config.num_labels
|
2269 |
|
2270 |
+
self.bert = JinaBertModel(config, add_pooling_layer=False)
|
2271 |
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
2272 |
|
2273 |
# Initialize weights and apply final processing
|