cpi-connect
commited on
Commit
•
9df8979
1
Parent(s):
008fd4d
Upload model
Browse files
model.py
CHANGED
@@ -88,43 +88,43 @@ class CybersecurityKnowledgeGraphModel(PreTrainedModel):
|
|
88 |
structured_output.extend(batch_output)
|
89 |
|
90 |
|
91 |
-
args = [(idx, item["argument"], item["token"]) for idx, item in enumerate(structured_output) if item["argument"]!= "O"]
|
92 |
|
93 |
-
entities = []
|
94 |
-
current_entity = None
|
95 |
-
for position, label, token in args:
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
for entity in entities:
|
106 |
-
|
107 |
-
|
108 |
|
109 |
-
for entity in entities:
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
|
123 |
-
for item in structured_output:
|
124 |
-
|
125 |
-
for entity in entities:
|
126 |
-
|
127 |
-
|
128 |
return structured_output
|
129 |
|
130 |
def forward_model(self, model, dataloader):
|
|
|
88 |
structured_output.extend(batch_output)
|
89 |
|
90 |
|
91 |
+
# args = [(idx, item["argument"], item["token"]) for idx, item in enumerate(structured_output) if item["argument"]!= "O"]
|
92 |
|
93 |
+
# entities = []
|
94 |
+
# current_entity = None
|
95 |
+
# for position, label, token in args:
|
96 |
+
# if label.startswith('B-'):
|
97 |
+
# if current_entity is not None:
|
98 |
+
# entities.append(current_entity)
|
99 |
+
# current_entity = {'label': label[2:], 'text': token.replace(" ", ""), 'start': position, 'end': position}
|
100 |
+
# elif label.startswith('I-'):
|
101 |
+
# if current_entity is not None:
|
102 |
+
# current_entity['text'] += ' ' + token.replace(" ", "")
|
103 |
+
# current_entity['end'] = position
|
104 |
+
|
105 |
+
# for entity in entities:
|
106 |
+
# context = self.tokenizer.decode([item["id"] for item in structured_output[max(0, entity["start"] - 15) : min(len(structured_output), entity["end"] + 15)]])
|
107 |
+
# entity["context"] = context
|
108 |
|
109 |
+
# for entity in entities:
|
110 |
+
# if len(self.arg_2_role[entity["label"]]) > 1:
|
111 |
+
# sent_embed = self.embed_model.encode(entity["context"])
|
112 |
+
# arg_embed = self.embed_model.encode(entity["text"])
|
113 |
+
# embed = np.concatenate((sent_embed, arg_embed))
|
114 |
+
|
115 |
+
# arg_clf = self.role_classifiers[entity["label"]]
|
116 |
+
# role_id = arg_clf.predict(embed.reshape(1, -1))
|
117 |
+
# role = self.arg_2_role[entity["label"]][role_id[0]]
|
118 |
+
|
119 |
+
# entity["role"] = role
|
120 |
+
# else:
|
121 |
+
# entity["role"] = self.arg_2_role[entity["label"]][0]
|
122 |
|
123 |
+
# for item in structured_output:
|
124 |
+
# item["role"] = "O"
|
125 |
+
# for entity in entities:
|
126 |
+
# for i in range(entity["start"], entity["end"] + 1):
|
127 |
+
# structured_output[i]["role"] = entity["role"]
|
128 |
return structured_output
|
129 |
|
130 |
def forward_model(self, model, dataloader):
|