Update modeling_tunbert.py
Browse files- modeling_tunbert.py +2 -2
modeling_tunbert.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import torch.nn as nn
|
2 |
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForSequenceClassification, PreTrainedModel,AutoConfig, BertModel
|
3 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
4 |
-
|
5 |
class classifier(nn.Module):
|
6 |
def __init__(self,config):
|
7 |
super().__init__()
|
@@ -14,7 +14,7 @@ class classifier(nn.Module):
|
|
14 |
|
15 |
|
16 |
class TunBERT(PreTrainedModel):
|
17 |
-
config_class =
|
18 |
def __init__(self, config):
|
19 |
super().__init__(config)
|
20 |
self.BertModel = BertModel(config)
|
|
|
1 |
import torch.nn as nn
|
2 |
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForSequenceClassification, PreTrainedModel,AutoConfig, BertModel
|
3 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
4 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
5 |
class classifier(nn.Module):
|
6 |
def __init__(self,config):
|
7 |
super().__init__()
|
|
|
14 |
|
15 |
|
16 |
class TunBERT(PreTrainedModel):
|
17 |
+
config_class = BertConfig
|
18 |
def __init__(self, config):
|
19 |
super().__init__(config)
|
20 |
self.BertModel = BertModel(config)
|