zhihan1996 commited on
Commit
3183ac4
1 Parent(s): a79a8fd

Upload 2 files

Browse files

Enable input longer than 512 by truncating it into multiple pieces of 512-length sequences and taking the average embedding as the input embedding.

Files changed (2) hide show
  1. configuration_bert.py +23 -0
  2. dnabert_layer.py +110 -0
configuration_bert.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from transformers import BertConfig as TransformersBertConfig
5
+
6
+
7
+ class BertConfig(TransformersBertConfig):
8
+
9
+ def __init__(
10
+ self,
11
+ **kwargs,
12
+ ):
13
+ """Configuration class for MosaicBert.
14
+
15
+ Args:
16
+ alibi_starting_size (int): Use `alibi_starting_size` to determine how large of an alibi tensor to
17
+ create when initializing the model. You should be able to ignore this parameter in most cases.
18
+ Defaults to 512.
19
+ attention_probs_dropout_prob (float): By default, turn off attention dropout in Mosaic BERT
20
+ (otherwise, Flash Attention will be off by default). Defaults to 0.0.
21
+ """
22
+ super().__init__(**kwargs)
23
+
dnabert_layer.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
6
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel
7
+ from transformers.modeling_outputs import SequenceClassifierOutput
8
+
9
+ class DNABertForSequenceClassification(BertPreTrainedModel):
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+ self.num_labels = config.num_labels
13
+ self.config = config
14
+
15
+ self.bert = BertModel(config)
16
+ classifier_dropout = (
17
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
18
+ )
19
+ self.dropout = nn.Dropout(classifier_dropout)
20
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
21
+
22
+ # Initialize weights and apply final processing
23
+ self.post_init()
24
+
25
+ def forward(
26
+ self,
27
+ input_ids: Optional[torch.Tensor] = None,
28
+ attention_mask: Optional[torch.Tensor] = None,
29
+ token_type_ids: Optional[torch.Tensor] = None,
30
+ position_ids: Optional[torch.Tensor] = None,
31
+ head_mask: Optional[torch.Tensor] = None,
32
+ inputs_embeds: Optional[torch.Tensor] = None,
33
+ labels: Optional[torch.Tensor] = None,
34
+ output_attentions: Optional[bool] = None,
35
+ output_hidden_states: Optional[bool] = None,
36
+ return_dict: Optional[bool] = None,
37
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
38
+ r"""
39
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
40
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
41
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
42
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
43
+ """
44
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
45
+
46
+ # get the size of input_ids
47
+ batch_size, seq_len = input_ids.shape
48
+ if seq_len > 512:
49
+ assert seq_len % 512 == 0, "seq_len should be a multiple of 512"
50
+ # split the input_ids into multiple chunks
51
+ input_ids = input_ids.view(-1, 512)
52
+ attention_mask = attention_mask.view(-1, 512) if attention_mask is not None else None
53
+ token_type_ids = token_type_ids.view(-1, 512) if token_type_ids is not None else None
54
+ position_ids = None
55
+
56
+ outputs = self.bert(
57
+ input_ids,
58
+ attention_mask=attention_mask,
59
+ token_type_ids=token_type_ids,
60
+ position_ids=position_ids,
61
+ head_mask=head_mask,
62
+ inputs_embeds=inputs_embeds,
63
+ output_attentions=output_attentions,
64
+ output_hidden_states=output_hidden_states,
65
+ return_dict=return_dict,
66
+ )
67
+
68
+ pooled_output = outputs[1]
69
+
70
+ if seq_len > 512:
71
+ # reshape the pooled_output
72
+ pooled_output = pooled_output.view(batch_size, -1, pooled_output.shape[-1])
73
+ # take the mean of the pooled_output
74
+ pooled_output = torch.mean(pooled_output, dim=1)
75
+
76
+ pooled_output = self.dropout(pooled_output)
77
+ logits = self.classifier(pooled_output)
78
+
79
+ loss = None
80
+ if labels is not None:
81
+ if self.config.problem_type is None:
82
+ if self.num_labels == 1:
83
+ self.config.problem_type = "regression"
84
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
85
+ self.config.problem_type = "single_label_classification"
86
+ else:
87
+ self.config.problem_type = "multi_label_classification"
88
+
89
+ if self.config.problem_type == "regression":
90
+ loss_fct = MSELoss()
91
+ if self.num_labels == 1:
92
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
93
+ else:
94
+ loss = loss_fct(logits, labels)
95
+ elif self.config.problem_type == "single_label_classification":
96
+ loss_fct = CrossEntropyLoss()
97
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
98
+ elif self.config.problem_type == "multi_label_classification":
99
+ loss_fct = BCEWithLogitsLoss()
100
+ loss = loss_fct(logits, labels)
101
+ if not return_dict:
102
+ output = (logits,) + outputs[2:]
103
+ return ((loss,) + output) if loss is not None else output
104
+
105
+ return SequenceClassifierOutput(
106
+ loss=loss,
107
+ logits=logits,
108
+ hidden_states=outputs.hidden_states,
109
+ attentions=outputs.attentions,
110
+ )