alfiannajih commited on
Commit
a1cb05b
1 Parent(s): 74c1449

Update g_retriever_pipeline.py

Browse files
Files changed (1) hide show
  1. g_retriever_pipeline.py +60 -51
g_retriever_pipeline.py CHANGED
@@ -1,51 +1,60 @@
1
- from transformers import Pipeline, AutoTokenizer
2
- from torch_geometric.data import Batch
3
- import torch
4
-
5
- class GRetrieverPipeline(Pipeline):
6
- def __init__(self, **kwargs):
7
- Pipeline.__init__(self, **kwargs)
8
-
9
- self.tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path)
10
- self.eos_user = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
11
- self.max_txt_len = self.model.config.max_txt_len
12
- self.bos_length = len(self.model.config.bos_id)
13
-
14
- def _sanitize_parameters(self, **kwargs):
15
- preprocess_kwargs = {}
16
- if "textualized_graph" in kwargs:
17
- preprocess_kwargs["textualized_graph"] = kwargs["textualized_graph"]
18
-
19
- if "graph" in kwargs:
20
- preprocess_kwargs["graph"] = kwargs["graph"]
21
-
22
- return preprocess_kwargs, {}, {}
23
-
24
- def preprocess(self, inputs, textualized_graph, graph):
25
- textualized_graph_ids = self.tokenizer(textualized_graph, add_special_tokens=False)["input_ids"][:self.max_txt_len]
26
- question_ids = self.tokenizer(inputs, add_special_tokens=False)["input_ids"]
27
- eos_user_ids = self.tokenizer(self.eos_user, add_special_tokens=False)["input_ids"]
28
-
29
- input_ids = torch.tensor([
30
- [-1]*(self.bos_length + 1)
31
- + textualized_graph_ids
32
- + question_ids
33
- + eos_user_ids
34
- ])
35
- model_inputs = {
36
- "input_ids": input_ids,
37
- "attention_mask": torch.ones_like(input_ids)
38
- }
39
- model_inputs.update({
40
- "graph": Batch.from_data_list([graph])
41
- })
42
-
43
- return model_inputs
44
-
45
- def _forward(self, model_inputs):
46
- model_outputs = self.model.generate(**model_inputs)
47
-
48
- return model_outputs
49
-
50
- def postprocess(self, model_outputs):
51
- return model_outputs
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Pipeline, AutoTokenizer
2
+ from torch_geometric.data import Batch
3
+ import torch
4
+
5
+ class GRetrieverPipeline(Pipeline):
6
+ def __init__(self, **kwargs):
7
+ Pipeline.__init__(self, **kwargs)
8
+
9
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path)
10
+ self.eos_user = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
11
+ self.max_txt_len = self.model.config.max_txt_len
12
+ self.bos_length = len(self.model.config.bos_id)
13
+ self.input_length = 0
14
+
15
+ def _sanitize_parameters(self, **kwargs):
16
+ preprocess_kwargs = {}
17
+ if "textualized_graph" in kwargs:
18
+ preprocess_kwargs["textualized_graph"] = kwargs["textualized_graph"]
19
+
20
+ if "graph" in kwargs:
21
+ preprocess_kwargs["graph"] = kwargs["graph"]
22
+
23
+ if "generate_kwargs" in kwargs:
24
+ preprocess_kwargs["generate_kwargs"] = kwargs["generate_kwargs"]
25
+
26
+ return preprocess_kwargs, {}, {}
27
+
28
+ def preprocess(self, inputs, textualized_graph, graph, generate_kwargs=None):
29
+ textualized_graph_ids = self.tokenizer(textualized_graph, add_special_tokens=False)["input_ids"][:self.max_txt_len]
30
+ question_ids = self.tokenizer(inputs, add_special_tokens=False)["input_ids"]
31
+ eos_user_ids = self.tokenizer(self.eos_user, add_special_tokens=False)["input_ids"]
32
+
33
+ input_ids = torch.tensor([
34
+ [-1]*(self.bos_length + 1)
35
+ + textualized_graph_ids
36
+ + question_ids
37
+ + eos_user_ids
38
+ ])
39
+ model_inputs = {
40
+ "input_ids": input_ids,
41
+ "attention_mask": torch.ones_like(input_ids)
42
+ }
43
+ model_inputs.update({
44
+ "graph": Batch.from_data_list([graph])
45
+ })
46
+
47
+ if generate_kwargs != None:
48
+ model_inputs.update(generate_kwargs)
49
+
50
+ self.input_length = input_ids.shape[1]
51
+
52
+ return model_inputs
53
+
54
+ def _forward(self, model_inputs):
55
+ model_outputs = self.model.generate(**model_inputs)
56
+
57
+ return model_outputs
58
+
59
+ def postprocess(self, model_outputs):
60
+ return self.tokenizer.decode(model_outputs[0, self.input_length:])