alfiannajih commited on
Commit
74c1449
1 Parent(s): a2c9add

Delete g_retriever

Browse files
g_retriever/.gitkeep DELETED
File without changes
g_retriever/__init__.py DELETED
File without changes
g_retriever/g_retriever_config.py DELETED
@@ -1,31 +0,0 @@
1
- from transformers import LlamaConfig
2
-
3
- class GRetrieverConfig(LlamaConfig):
4
- model_type = "llama"
5
-
6
- def __init__(
7
- self,
8
- max_txt_len: int = 1024,
9
- max_new_tokens: int = 256,
10
- gnn_num_layers: int = 4,
11
- gnn_in_dim: int = 768,
12
- gnn_hidden_dim: int = 1024,
13
- gnn_num_heads: int = 4,
14
- gnn_dropout: int = 0,
15
- bos_id: list = [128000, 128006, 882, 128007],
16
- **kwargs
17
- ):
18
- pretrained_config = LlamaConfig.from_pretrained("NousResearch/Hermes-3-Llama-3.1-8B")
19
- pretrained_config.update(kwargs)
20
-
21
- self.max_txt_len = max_txt_len
22
- self.max_new_tokens = max_new_tokens
23
- self.gnn_num_layers = gnn_num_layers
24
- self.gnn_in_dim = gnn_in_dim
25
- self.gnn_hidden_dim = gnn_hidden_dim
26
- self.gnn_num_heads = gnn_num_heads
27
- self.gnn_dropout = gnn_dropout
28
- self.bos_id = bos_id
29
-
30
- super().__init__(**pretrained_config.to_dict())
31
- self.pad_token_id = pretrained_config.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
g_retriever/g_retriever_model.py DELETED
@@ -1,215 +0,0 @@
1
- import torch
2
- from torch import nn
3
- import torch.nn.functional as F
4
-
5
- from transformers import LlamaForCausalLM
6
- from transformers.modeling_outputs import CausalLMOutputWithPast
7
- from transformers.cache_utils import StaticCache
8
- from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask_with_cache_position
9
- from .g_retriever_config import GRetrieverConfig
10
- from .gnn import GAT
11
-
12
- from functools import wraps
13
- from torch_geometric.nn.pool import global_mean_pool
14
-
15
- class GRetrieverModel(LlamaForCausalLM):
16
- config_class = GRetrieverConfig
17
-
18
- def __init__(self, config):
19
- super().__init__(config)
20
- self.graph_encoder = GAT(
21
- in_channels=config.gnn_in_dim,
22
- out_channels=config.gnn_hidden_dim,
23
- hidden_channels=config.gnn_hidden_dim,
24
- num_layers=config.gnn_num_layers,
25
- dropout=config.gnn_dropout,
26
- num_heads=config.gnn_num_heads,
27
- ).to(self.model.dtype)
28
-
29
- self.projector = nn.Sequential(
30
- nn.Linear(config.gnn_hidden_dim, 2048),
31
- nn.Sigmoid(),
32
- nn.Linear(2048, self.get_input_embeddings().embedding_dim),
33
- ).to(self.model.dtype)
34
-
35
- def encode_graphs(self, graph):
36
- n_embeds, _ = self.graph_encoder(
37
- graph.x.to(self.model.dtype),
38
- graph.edge_index.long(),
39
- graph.edge_attr.to(self.model.dtype)
40
- )
41
-
42
- # mean pooling
43
- g_embeds = global_mean_pool(n_embeds, graph.batch.to(n_embeds.device))
44
-
45
- return g_embeds
46
-
47
- @wraps(LlamaForCausalLM.forward)
48
- def forward(
49
- self,
50
- input_ids=None,
51
- graph=None,
52
- attention_mask=None,
53
- position_ids=None,
54
- past_key_values=None,
55
- inputs_embeds=None,
56
- labels=None,
57
- use_cache=None,
58
- output_attentions=None,
59
- output_hidden_states=None,
60
- return_dict=None,
61
- cache_position=None
62
- ):
63
- inputs = input_ids.clone()
64
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
65
- output_hidden_states = (
66
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
67
- )
68
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
69
-
70
- if (inputs==-1).any():
71
- # embed bos prompt
72
- bos_embeds = self.get_input_embeddings()(torch.tensor(
73
- self.config.bos_id,
74
- device=self.model.device
75
- ))
76
-
77
- # encode graph
78
- graph_embeds = self.encode_graphs(graph)
79
- graph_embeds = self.projector(graph_embeds).to(self.model.device)
80
-
81
- # prepare for reserved ids (bos+graph)
82
- non_tokenized_ids = (inputs == -1).nonzero()
83
- non_tokenized_shape = non_tokenized_ids[:, 0], non_tokenized_ids[:, 1]
84
-
85
- # embed inputs
86
- inputs[non_tokenized_shape] = self.config.pad_token_id
87
- temp_inputs_embeds = self.get_input_embeddings()(inputs)
88
- non_tokenized_embeds = torch.cat([bos_embeds.repeat(len(inputs), 1, 1), graph_embeds.unsqueeze(1)], dim=1)
89
-
90
- # replace reserved ids with bos+graph
91
- inputs_embeds = temp_inputs_embeds.clone()
92
- inputs_embeds[non_tokenized_shape] = non_tokenized_embeds.view(len(non_tokenized_ids), -1)
93
-
94
- else:
95
- inputs_embeds = self.get_input_embeddings()(inputs)
96
-
97
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
98
- outputs = self.model(
99
- attention_mask=attention_mask,
100
- position_ids=position_ids,
101
- past_key_values=past_key_values,
102
- inputs_embeds=inputs_embeds,
103
- use_cache=use_cache,
104
- output_attentions=output_attentions,
105
- output_hidden_states=output_hidden_states,
106
- return_dict=return_dict,
107
- cache_position=cache_position,
108
- )
109
-
110
- hidden_states = outputs[0]
111
-
112
- if self.config.pretraining_tp > 1:
113
- lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
114
- logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
115
- logits = torch.cat(logits, dim=-1)
116
- else:
117
- logits = self.lm_head(hidden_states)
118
- logits = logits.float()
119
-
120
- loss = None
121
- if labels is not None:
122
- # Shift so that tokens < n predict n
123
- shift_logits = logits[..., :-1, :].contiguous()
124
- shift_labels = labels[..., 1:].contiguous()
125
- # Flatten the tokens
126
- loss_fct = nn.CrossEntropyLoss()
127
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
128
- shift_labels = shift_labels.view(-1)
129
- # Enable model parallelism
130
- shift_labels = shift_labels.to(shift_logits.device)
131
- loss = loss_fct(shift_logits, shift_labels)
132
-
133
- if not return_dict:
134
- output = (logits,) + outputs[1:]
135
- return (loss,) + output if loss is not None else output
136
-
137
- return CausalLMOutputWithPast(
138
- loss=loss,
139
- logits=logits,
140
- past_key_values=outputs.past_key_values,
141
- hidden_states=outputs.hidden_states,
142
- attentions=outputs.attentions,
143
- )
144
-
145
- def prepare_inputs_for_generation(
146
- self,
147
- input_ids,
148
- graph=None,
149
- past_key_values=None,
150
- attention_mask=None,
151
- inputs_embeds=None,
152
- cache_position=None,
153
- position_ids=None,
154
- use_cache=True,
155
- **kwargs,
156
- ):
157
- # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
158
- # Exception 1: when passing input_embeds, input_ids may be missing entries
159
- # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
160
- if past_key_values is not None:
161
- if inputs_embeds is not None: # Exception 1
162
- input_ids = input_ids[:, -cache_position.shape[0] :]
163
- elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
164
- input_ids = input_ids[:, cache_position]
165
-
166
- if attention_mask is not None and position_ids is None:
167
- # create position_ids on the fly for batch generation
168
- position_ids = attention_mask.long().cumsum(-1) - 1
169
- position_ids.masked_fill_(attention_mask == 0, 1)
170
- if past_key_values:
171
- position_ids = position_ids[:, -input_ids.shape[1] :]
172
-
173
- # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
174
- position_ids = position_ids.clone(memory_format=torch.contiguous_format)
175
-
176
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
177
- if inputs_embeds is not None and cache_position[0] == 0:
178
- model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
179
- else:
180
- # The clone here is for the same reason as for `position_ids`.
181
- model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
182
-
183
- if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
184
- if model_inputs["inputs_embeds"] is not None:
185
- batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
186
- device = model_inputs["inputs_embeds"].device
187
- else:
188
- batch_size, sequence_length = model_inputs["input_ids"].shape
189
- device = model_inputs["input_ids"].device
190
-
191
- dtype = self.lm_head.weight.dtype
192
- min_dtype = torch.finfo(dtype).min
193
-
194
- attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
195
- attention_mask,
196
- sequence_length=sequence_length,
197
- target_length=past_key_values.get_max_length(),
198
- dtype=dtype,
199
- device=device,
200
- min_dtype=min_dtype,
201
- cache_position=cache_position,
202
- batch_size=batch_size,
203
- )
204
-
205
- model_inputs.update(
206
- {
207
- "graph": graph,
208
- "position_ids": position_ids,
209
- "cache_position": cache_position,
210
- "past_key_values": past_key_values,
211
- "use_cache": use_cache,
212
- "attention_mask": attention_mask,
213
- }
214
- )
215
- return model_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
g_retriever/g_retriever_pipeline.py DELETED
@@ -1,51 +0,0 @@
1
- from transformers import Pipeline, AutoTokenizer
2
- from torch_geometric.data import Batch
3
- import torch
4
-
5
- class GRetrieverPipeline(Pipeline):
6
- def __init__(self, **kwargs):
7
- Pipeline.__init__(self, **kwargs)
8
-
9
- self.tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path)
10
- self.eos_user = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
11
- self.max_txt_len = self.model.config.max_txt_len
12
- self.bos_length = len(self.model.config.bos_id)
13
-
14
- def _sanitize_parameters(self, **kwargs):
15
- preprocess_kwargs = {}
16
- if "textualized_graph" in kwargs:
17
- preprocess_kwargs["textualized_graph"] = kwargs["textualized_graph"]
18
-
19
- if "graph" in kwargs:
20
- preprocess_kwargs["graph"] = kwargs["graph"]
21
-
22
- return preprocess_kwargs, {}, {}
23
-
24
- def preprocess(self, inputs, textualized_graph, graph):
25
- textualized_graph_ids = self.tokenizer(textualized_graph, add_special_tokens=False)["input_ids"][:self.max_txt_len]
26
- question_ids = self.tokenizer(inputs, add_special_tokens=False)["input_ids"]
27
- eos_user_ids = self.tokenizer(self.eos_user, add_special_tokens=False)["input_ids"]
28
-
29
- input_ids = torch.tensor([
30
- [-1]*(self.bos_length + 1)
31
- + textualized_graph_ids
32
- + question_ids
33
- + eos_user_ids
34
- ])
35
- model_inputs = {
36
- "input_ids": input_ids,
37
- "attention_mask": torch.ones_like(input_ids)
38
- }
39
- model_inputs.update({
40
- "graph": Batch.from_data_list([graph])
41
- })
42
-
43
- return model_inputs
44
-
45
- def _forward(self, model_inputs):
46
- model_outputs = self.model.generate(**model_inputs)
47
-
48
- return model_outputs
49
-
50
- def postprocess(self, model_outputs):
51
- return model_outputs