Readme update
Browse files- config.json +8 -0
- configuration_vgcn.py +2 -0
- modeling_vcgn.py +40 -17
config.json
CHANGED
@@ -1,5 +1,12 @@
|
|
1 |
{
|
|
|
|
|
|
|
2 |
"attention_probs_dropout_prob": 0.1,
|
|
|
|
|
|
|
|
|
3 |
"bert_model": "readerbench/RoBERT-base",
|
4 |
"classifier_dropout": null,
|
5 |
"do_lower_case": 1,
|
@@ -34,6 +41,7 @@
|
|
34 |
"pad_token_id": 0,
|
35 |
"position_embedding_type": "absolute",
|
36 |
"tf_threshold": 0.0,
|
|
|
37 |
"transformers_version": "4.31.0",
|
38 |
"type_vocab_size": 2,
|
39 |
"use_cache": true,
|
|
|
1 |
{
|
2 |
+
"architectures": [
|
3 |
+
"VCGNModelForTextClassification"
|
4 |
+
],
|
5 |
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "configuration_vgcn.VGCNConfig",
|
8 |
+
"AutoModelForSequenceClassification": "modeling_vcgn.VCGNModelForTextClassification"
|
9 |
+
},
|
10 |
"bert_model": "readerbench/RoBERT-base",
|
11 |
"classifier_dropout": null,
|
12 |
"do_lower_case": 1,
|
|
|
41 |
"pad_token_id": 0,
|
42 |
"position_embedding_type": "absolute",
|
43 |
"tf_threshold": 0.0,
|
44 |
+
"torch_dtype": "float32",
|
45 |
"transformers_version": "4.31.0",
|
46 |
"type_vocab_size": 2,
|
47 |
"use_cache": true,
|
configuration_vgcn.py
CHANGED
@@ -6,6 +6,7 @@ class VGCNConfig(BertConfig):
|
|
6 |
|
7 |
def __init__(
|
8 |
self,
|
|
|
9 |
gcn_adj_matrix: str ='',
|
10 |
max_seq_len: int = 256,
|
11 |
npmi_threshold: float = 0.2,
|
@@ -29,5 +30,6 @@ class VGCNConfig(BertConfig):
|
|
29 |
self.tf_threshold = tf_threshold
|
30 |
self.vocab_type = vocab_type
|
31 |
self.gcn_embedding_dim = gcn_embedding_dim
|
|
|
32 |
|
33 |
super().__init__(**kwargs)
|
|
|
6 |
|
7 |
def __init__(
|
8 |
self,
|
9 |
+
bert_model='readerbench/RoBERT-base',
|
10 |
gcn_adj_matrix: str ='',
|
11 |
max_seq_len: int = 256,
|
12 |
npmi_threshold: float = 0.2,
|
|
|
30 |
self.tf_threshold = tf_threshold
|
31 |
self.vocab_type = vocab_type
|
32 |
self.gcn_embedding_dim = gcn_embedding_dim
|
33 |
+
self.bert_model = bert_model
|
34 |
|
35 |
super().__init__(**kwargs)
|
modeling_vcgn.py
CHANGED
@@ -64,27 +64,51 @@ def get_torch_gcn(gcn_vocab_adj_tf, gcn_vocab_adj,gcn_config:VGCNConfig):
|
|
64 |
class VCGNModelForTextClassification(PreTrainedModel):
|
65 |
config_class = VGCNConfig
|
66 |
|
67 |
-
def __init__(self, config):
|
68 |
super().__init__(config)
|
69 |
-
|
70 |
-
self.pre_trained_model_name = ''
|
71 |
-
self.remove_stop_words = False
|
72 |
-
self.tokenizer = None
|
73 |
-
self.norm_gcn_vocab_adj_list = None
|
74 |
-
self.gcn_vocab_size = config.vocab_size
|
75 |
|
|
|
76 |
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
self.model = VGCN_Bert(
|
80 |
config,
|
81 |
-
gcn_adj_matrix=
|
82 |
gcn_adj_dim=config.vocab_size,
|
83 |
-
gcn_adj_num=len(
|
84 |
gcn_embedding_dim=config.gcn_embedding_dim,
|
85 |
|
86 |
)
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
def load_adj_matrix(self, adj_matrix):
|
89 |
filename = None
|
90 |
if Path(adj_matrix).is_file():
|
@@ -98,11 +122,8 @@ class VCGNModelForTextClassification(PreTrainedModel):
|
|
98 |
|
99 |
gcn_vocab_adj_tf, gcn_vocab_adj, adj_config = pkl.load(open(filename, 'rb'))
|
100 |
|
101 |
-
|
102 |
-
self.
|
103 |
-
self.remove_stop_words = adj_config['remove_stop_words']
|
104 |
-
self.tokenizer = BertTokenizer.from_pretrained(self.pre_trained_model_name)
|
105 |
-
self.norm_gcn_vocab_adj_list = get_torch_gcn(gcn_vocab_adj_tf, gcn_vocab_adj, self.config)
|
106 |
|
107 |
def _prep_batch(self, batch: torch.Tensor):
|
108 |
|
@@ -207,12 +228,14 @@ class VocabGraphConvolution(nn.Module):
|
|
207 |
"""
|
208 |
def __init__(self,adj_matrix,voc_dim, num_adj, hid_dim, out_dim, dropout_rate=0.2):
|
209 |
super(VocabGraphConvolution, self).__init__()
|
210 |
-
if
|
211 |
self.adj_matrix=adj_matrix
|
212 |
-
|
213 |
self.adj_matrix=torch.nn.ParameterList([torch.nn.Parameter(x) for x in adj_matrix])
|
214 |
for p in self.adj_matrix:
|
215 |
p.requires_grad=False
|
|
|
|
|
216 |
|
217 |
self.voc_dim=voc_dim
|
218 |
self.num_adj=num_adj
|
|
|
64 |
class VCGNModelForTextClassification(PreTrainedModel):
|
65 |
config_class = VGCNConfig
|
66 |
|
67 |
+
def __init__(self, config, load_adjacency_matrix=True,):
|
68 |
super().__init__(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
+
self.tokenizer = BertTokenizer.from_pretrained(config.bert_model)
|
71 |
|
72 |
+
if load_adjacency_matrix:
|
73 |
+
norm_gcn_vocab_adj_list = self.load_adj_matrix(config.gcn_adj_matrix)
|
74 |
+
else:
|
75 |
+
norm_gcn_vocab_adj_list = []
|
76 |
+
for _ in range(2 if config.vocab_type=='all' else 1):
|
77 |
+
norm_gcn_vocab_adj_list.append(torch.sparse.FloatTensor(torch.LongTensor([[0],[0]]), torch.Tensor([0]), (config.vocab_size, config.vocab_size)))
|
78 |
|
79 |
self.model = VGCN_Bert(
|
80 |
config,
|
81 |
+
gcn_adj_matrix=norm_gcn_vocab_adj_list,
|
82 |
gcn_adj_dim=config.vocab_size,
|
83 |
+
gcn_adj_num=len(norm_gcn_vocab_adj_list),
|
84 |
gcn_embedding_dim=config.gcn_embedding_dim,
|
85 |
|
86 |
)
|
87 |
|
88 |
+
@classmethod
|
89 |
+
def from_pretrained(cls, *model_args, reload_adjacency_matrix=False, **kwargs):
|
90 |
+
model = super().from_pretrained( *model_args, **kwargs, load_adjacency_matrix=False)
|
91 |
+
|
92 |
+
if reload_adjacency_matrix:
|
93 |
+
norm_gcn_vocab_adj_list = model.load_adj_matrix(model.config.gcn_adj_matrix)
|
94 |
+
model.model.embeddings.vocab_gcn.adj_matrix=torch.nn.ParameterList([torch.nn.Parameter(x) for x in norm_gcn_vocab_adj_list])
|
95 |
+
for p in model.model.embeddings.vocab_gcn.adj_matrix:
|
96 |
+
p.requires_grad=False
|
97 |
+
|
98 |
+
return model
|
99 |
+
|
100 |
+
def set_adjacency_matrix(self, adj_matrix:Union[List, np.ndarray, sp.csr_matrix, torch.Tensor] ):
|
101 |
+
|
102 |
+
if isinstance(adj_matrix, np.ndarray):
|
103 |
+
adj_matrix = [torch.from_numpy(adj_matrix)]
|
104 |
+
else:
|
105 |
+
raise ValueError(f"adjacency matrix must be a list of torch.Tensor or torch.nn.Parameter, got {type(adj_matrix)}")
|
106 |
+
|
107 |
+
self.model.embeddings.vocab_gcn.adj_matrix=torch.nn.ParameterList([torch.nn.Parameter(x) for x in adj_matrix])
|
108 |
+
for p in self.model.embeddings.vocab_gcn.adj_matrix:
|
109 |
+
p.requires_grad=False
|
110 |
+
|
111 |
+
|
112 |
def load_adj_matrix(self, adj_matrix):
|
113 |
filename = None
|
114 |
if Path(adj_matrix).is_file():
|
|
|
122 |
|
123 |
gcn_vocab_adj_tf, gcn_vocab_adj, adj_config = pkl.load(open(filename, 'rb'))
|
124 |
|
125 |
+
self.tokenizer = BertTokenizer.from_pretrained(adj_config['bert_model'])
|
126 |
+
return get_torch_gcn(gcn_vocab_adj_tf, gcn_vocab_adj, self.config)
|
|
|
|
|
|
|
127 |
|
128 |
def _prep_batch(self, batch: torch.Tensor):
|
129 |
|
|
|
228 |
"""
|
229 |
def __init__(self,adj_matrix,voc_dim, num_adj, hid_dim, out_dim, dropout_rate=0.2):
|
230 |
super(VocabGraphConvolution, self).__init__()
|
231 |
+
if isinstance(adj_matrix, nn.Parameter) or isinstance(adj_matrix, nn.ParameterList):
|
232 |
self.adj_matrix=adj_matrix
|
233 |
+
elif isinstance(adj_matrix, list):
|
234 |
self.adj_matrix=torch.nn.ParameterList([torch.nn.Parameter(x) for x in adj_matrix])
|
235 |
for p in self.adj_matrix:
|
236 |
p.requires_grad=False
|
237 |
+
else:
|
238 |
+
raise ValueError(f"adjacency matrix must be a list of torch.Tensor or torch.nn.Parameter, got {type(adj_matrix)}")
|
239 |
|
240 |
self.voc_dim=voc_dim
|
241 |
self.num_adj=num_adj
|