alfiannajih
commited on
Commit
•
a1cb05b
1
Parent(s):
74c1449
Update g_retriever_pipeline.py
Browse files- g_retriever_pipeline.py +60 -51
g_retriever_pipeline.py
CHANGED
@@ -1,51 +1,60 @@
|
|
1 |
-
from transformers import Pipeline, AutoTokenizer
|
2 |
-
from torch_geometric.data import Batch
|
3 |
-
import torch
|
4 |
-
|
5 |
-
class GRetrieverPipeline(Pipeline):
|
6 |
-
def __init__(self, **kwargs):
|
7 |
-
Pipeline.__init__(self, **kwargs)
|
8 |
-
|
9 |
-
self.tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path)
|
10 |
-
self.eos_user = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
|
11 |
-
self.max_txt_len = self.model.config.max_txt_len
|
12 |
-
self.bos_length = len(self.model.config.bos_id)
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
model_inputs
|
40 |
-
"
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Pipeline, AutoTokenizer
|
2 |
+
from torch_geometric.data import Batch
|
3 |
+
import torch
|
4 |
+
|
5 |
+
class GRetrieverPipeline(Pipeline):
|
6 |
+
def __init__(self, **kwargs):
|
7 |
+
Pipeline.__init__(self, **kwargs)
|
8 |
+
|
9 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path)
|
10 |
+
self.eos_user = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
|
11 |
+
self.max_txt_len = self.model.config.max_txt_len
|
12 |
+
self.bos_length = len(self.model.config.bos_id)
|
13 |
+
self.input_length = 0
|
14 |
+
|
15 |
+
def _sanitize_parameters(self, **kwargs):
|
16 |
+
preprocess_kwargs = {}
|
17 |
+
if "textualized_graph" in kwargs:
|
18 |
+
preprocess_kwargs["textualized_graph"] = kwargs["textualized_graph"]
|
19 |
+
|
20 |
+
if "graph" in kwargs:
|
21 |
+
preprocess_kwargs["graph"] = kwargs["graph"]
|
22 |
+
|
23 |
+
if "generate_kwargs" in kwargs:
|
24 |
+
preprocess_kwargs["generate_kwargs"] = kwargs["generate_kwargs"]
|
25 |
+
|
26 |
+
return preprocess_kwargs, {}, {}
|
27 |
+
|
28 |
+
def preprocess(self, inputs, textualized_graph, graph, generate_kwargs=None):
|
29 |
+
textualized_graph_ids = self.tokenizer(textualized_graph, add_special_tokens=False)["input_ids"][:self.max_txt_len]
|
30 |
+
question_ids = self.tokenizer(inputs, add_special_tokens=False)["input_ids"]
|
31 |
+
eos_user_ids = self.tokenizer(self.eos_user, add_special_tokens=False)["input_ids"]
|
32 |
+
|
33 |
+
input_ids = torch.tensor([
|
34 |
+
[-1]*(self.bos_length + 1)
|
35 |
+
+ textualized_graph_ids
|
36 |
+
+ question_ids
|
37 |
+
+ eos_user_ids
|
38 |
+
])
|
39 |
+
model_inputs = {
|
40 |
+
"input_ids": input_ids,
|
41 |
+
"attention_mask": torch.ones_like(input_ids)
|
42 |
+
}
|
43 |
+
model_inputs.update({
|
44 |
+
"graph": Batch.from_data_list([graph])
|
45 |
+
})
|
46 |
+
|
47 |
+
if generate_kwargs != None:
|
48 |
+
model_inputs.update(generate_kwargs)
|
49 |
+
|
50 |
+
self.input_length = input_ids.shape[1]
|
51 |
+
|
52 |
+
return model_inputs
|
53 |
+
|
54 |
+
def _forward(self, model_inputs):
|
55 |
+
model_outputs = self.model.generate(**model_inputs)
|
56 |
+
|
57 |
+
return model_outputs
|
58 |
+
|
59 |
+
def postprocess(self, model_outputs):
|
60 |
+
return self.tokenizer.decode(model_outputs[0, self.input_length:])
|