Aaditya1 commited on
Commit
c36b641
1 Parent(s): a599152

Create modeling_graphormer.pyx

Browse files
Files changed (1) hide show
  1. modeling_graphormer.pyx +921 -0
modeling_graphormer.pyx ADDED
@@ -0,0 +1,921 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Microsoft, clefourrier The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Graphormer model."""
16
+
17
+ import math
18
+ from typing import Iterable, Iterator, List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
+
24
+ from ...activations import ACT2FN
25
+ from ...modeling_outputs import (
26
+ BaseModelOutputWithNoAttention,
27
+ SequenceClassifierOutput,
28
+ )
29
+ from ...modeling_utils import PreTrainedModel
30
+ from ...utils import logging
31
+ from .configuration_graphormer import GraphormerConfig
32
+
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+ _CHECKPOINT_FOR_DOC = "graphormer-base-pcqm4mv1"
37
+ _CONFIG_FOR_DOC = "GraphormerConfig"
38
+
39
+
40
+ GRAPHORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
41
+ "clefourrier/graphormer-base-pcqm4mv1",
42
+ "clefourrier/graphormer-base-pcqm4mv2",
43
+ # See all Graphormer models at https://huggingface.co/models?filter=graphormer
44
+ ]
45
+
46
+
47
+ def quant_noise(module: nn.Module, p: float, block_size: int):
48
+ """
49
+ From:
50
+ https://github.com/facebookresearch/fairseq/blob/dd0079bde7f678b0cd0715cbd0ae68d661b7226d/fairseq/modules/quant_noise.py
51
+
52
+ Wraps modules and applies quantization noise to the weights for subsequent quantization with Iterative Product
53
+ Quantization as described in "Training with Quantization Noise for Extreme Model Compression"
54
+
55
+ Args:
56
+ - module: nn.Module
57
+ - p: amount of Quantization Noise
58
+ - block_size: size of the blocks for subsequent quantization with iPQ
59
+
60
+ Remarks:
61
+ - Module weights must have the right sizes wrt the block size
62
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
63
+ - For more detail on how to quantize by blocks with convolutional weights, see "And the Bit Goes Down:
64
+ Revisiting the Quantization of Neural Networks"
65
+ - We implement the simplest form of noise here as stated in the paper which consists in randomly dropping
66
+ blocks
67
+ """
68
+
69
+ # if no quantization noise, don't register hook
70
+ if p <= 0:
71
+ return module
72
+
73
+ # supported modules
74
+ if not isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)):
75
+ raise NotImplementedError("Module unsupported for quant_noise.")
76
+
77
+ # test whether module.weight has the right sizes wrt block_size
78
+ is_conv = module.weight.ndim == 4
79
+
80
+ # 2D matrix
81
+ if not is_conv:
82
+ if module.weight.size(1) % block_size != 0:
83
+ raise AssertionError("Input features must be a multiple of block sizes")
84
+
85
+ # 4D matrix
86
+ else:
87
+ # 1x1 convolutions
88
+ if module.kernel_size == (1, 1):
89
+ if module.in_channels % block_size != 0:
90
+ raise AssertionError("Input channels must be a multiple of block sizes")
91
+ # regular convolutions
92
+ else:
93
+ k = module.kernel_size[0] * module.kernel_size[1]
94
+ if k % block_size != 0:
95
+ raise AssertionError("Kernel size must be a multiple of block size")
96
+
97
+ def _forward_pre_hook(mod, input):
98
+ # no noise for evaluation
99
+ if mod.training:
100
+ if not is_conv:
101
+ # gather weight and sizes
102
+ weight = mod.weight
103
+ in_features = weight.size(1)
104
+ out_features = weight.size(0)
105
+
106
+ # split weight matrix into blocks and randomly drop selected blocks
107
+ mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
108
+ mask.bernoulli_(p)
109
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
110
+
111
+ else:
112
+ # gather weight and sizes
113
+ weight = mod.weight
114
+ in_channels = mod.in_channels
115
+ out_channels = mod.out_channels
116
+
117
+ # split weight matrix into blocks and randomly drop selected blocks
118
+ if mod.kernel_size == (1, 1):
119
+ mask = torch.zeros(
120
+ int(in_channels // block_size * out_channels),
121
+ device=weight.device,
122
+ )
123
+ mask.bernoulli_(p)
124
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
125
+ else:
126
+ mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
127
+ mask.bernoulli_(p)
128
+ mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
129
+
130
+ # scale weights and apply mask
131
+ mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript
132
+ s = 1 / (1 - p)
133
+ mod.weight.data = s * weight.masked_fill(mask, 0)
134
+
135
+ module.register_forward_pre_hook(_forward_pre_hook)
136
+ return module
137
+
138
+
139
+ class LayerDropModuleList(nn.ModuleList):
140
+ """
141
+ From:
142
+ https://github.com/facebookresearch/fairseq/blob/dd0079bde7f678b0cd0715cbd0ae68d661b7226d/fairseq/modules/layer_drop.py
143
+ A LayerDrop implementation based on [`torch.nn.ModuleList`]. LayerDrop as described in
144
+ https://arxiv.org/abs/1909.11556.
145
+
146
+ We refresh the choice of which layers to drop every time we iterate over the LayerDropModuleList instance. During
147
+ evaluation we always iterate over all layers.
148
+
149
+ Usage:
150
+
151
+ ```python
152
+ layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3])
153
+ for layer in layers: # this might iterate over layers 1 and 3
154
+ x = layer(x)
155
+ for layer in layers: # this might iterate over all layers
156
+ x = layer(x)
157
+ for layer in layers: # this might not iterate over any layers
158
+ x = layer(x)
159
+ ```
160
+
161
+ Args:
162
+ p (float): probability of dropping out each layer
163
+ modules (iterable, optional): an iterable of modules to add
164
+ """
165
+
166
+ def __init__(self, p: float, modules: Optional[Iterable[nn.Module]] = None):
167
+ super().__init__(modules)
168
+ self.p = p
169
+
170
+ def __iter__(self) -> Iterator[nn.Module]:
171
+ dropout_probs = torch.empty(len(self)).uniform_()
172
+ for i, m in enumerate(super().__iter__()):
173
+ if not self.training or (dropout_probs[i] > self.p):
174
+ yield m
175
+
176
+
177
+ class GraphormerGraphNodeFeature(nn.Module):
178
+ """
179
+ Compute node features for each node in the graph.
180
+ """
181
+
182
+ def __init__(self, config: GraphormerConfig):
183
+ super().__init__()
184
+ self.num_heads = config.num_attention_heads
185
+ self.num_atoms = config.num_atoms
186
+
187
+ self.atom_encoder = nn.Embedding(config.num_atoms + 1, config.hidden_size, padding_idx=config.pad_token_id)
188
+ self.in_degree_encoder = nn.Embedding(
189
+ config.num_in_degree, config.hidden_size, padding_idx=config.pad_token_id
190
+ )
191
+ self.out_degree_encoder = nn.Embedding(
192
+ config.num_out_degree, config.hidden_size, padding_idx=config.pad_token_id
193
+ )
194
+
195
+ self.graph_token = nn.Embedding(1, config.hidden_size)
196
+
197
+ def forward(
198
+ self,
199
+ input_nodes: torch.LongTensor,
200
+ in_degree: torch.LongTensor,
201
+ out_degree: torch.LongTensor,
202
+ ) -> torch.Tensor:
203
+ n_graph, n_node = input_nodes.size()[:2]
204
+
205
+ node_feature = ( # node feature + graph token
206
+ self.atom_encoder(input_nodes).sum(dim=-2) # [n_graph, n_node, n_hidden]
207
+ + self.in_degree_encoder(in_degree)
208
+ + self.out_degree_encoder(out_degree)
209
+ )
210
+
211
+ graph_token_feature = self.graph_token.weight.unsqueeze(0).repeat(n_graph, 1, 1)
212
+
213
+ graph_node_feature = torch.cat([graph_token_feature, node_feature], dim=1)
214
+
215
+ return graph_node_feature
216
+
217
+
218
+ class GraphormerGraphAttnBias(nn.Module):
219
+ """
220
+ Compute attention bias for each head.
221
+ """
222
+
223
+ def __init__(self, config: GraphormerConfig):
224
+ super().__init__()
225
+ self.num_heads = config.num_attention_heads
226
+ self.multi_hop_max_dist = config.multi_hop_max_dist
227
+
228
+ # We do not change edge feature embedding learning, as edge embeddings are represented as a combination of the original features
229
+ # + shortest path
230
+ self.edge_encoder = nn.Embedding(config.num_edges + 1, config.num_attention_heads, padding_idx=0)
231
+
232
+ self.edge_type = config.edge_type
233
+ if self.edge_type == "multi_hop":
234
+ self.edge_dis_encoder = nn.Embedding(
235
+ config.num_edge_dis * config.num_attention_heads * config.num_attention_heads,
236
+ 1,
237
+ )
238
+
239
+ self.spatial_pos_encoder = nn.Embedding(config.num_spatial, config.num_attention_heads, padding_idx=0)
240
+
241
+ self.graph_token_virtual_distance = nn.Embedding(1, config.num_attention_heads)
242
+
243
+ def forward(
244
+ self,
245
+ input_nodes: torch.LongTensor,
246
+ attn_bias: torch.Tensor,
247
+ spatial_pos: torch.LongTensor,
248
+ input_edges: torch.LongTensor,
249
+ attn_edge_type: torch.LongTensor,
250
+ ) -> torch.Tensor:
251
+ n_graph, n_node = input_nodes.size()[:2]
252
+ graph_attn_bias = attn_bias.clone()
253
+ graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat(
254
+ 1, self.num_heads, 1, 1
255
+ ) # [n_graph, n_head, n_node+1, n_node+1]
256
+
257
+ # spatial pos
258
+ # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
259
+ spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2)
260
+ graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + spatial_pos_bias
261
+
262
+ # reset spatial pos here
263
+ t = self.graph_token_virtual_distance.weight.view(1, self.num_heads, 1)
264
+ graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t
265
+ graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t
266
+
267
+ # edge feature
268
+ if self.edge_type == "multi_hop":
269
+ spatial_pos_ = spatial_pos.clone()
270
+
271
+ spatial_pos_[spatial_pos_ == 0] = 1 # set pad to 1
272
+ # set 1 to 1, input_nodes > 1 to input_nodes - 1
273
+ spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_)
274
+ if self.multi_hop_max_dist > 0:
275
+ spatial_pos_ = spatial_pos_.clamp(0, self.multi_hop_max_dist)
276
+ input_edges = input_edges[:, :, :, : self.multi_hop_max_dist, :]
277
+ # [n_graph, n_node, n_node, max_dist, n_head]
278
+
279
+ input_edges = self.edge_encoder(input_edges).mean(-2)
280
+ max_dist = input_edges.size(-2)
281
+ edge_input_flat = input_edges.permute(3, 0, 1, 2, 4).reshape(max_dist, -1, self.num_heads)
282
+ edge_input_flat = torch.bmm(
283
+ edge_input_flat,
284
+ self.edge_dis_encoder.weight.reshape(-1, self.num_heads, self.num_heads)[:max_dist, :, :],
285
+ )
286
+ input_edges = edge_input_flat.reshape(max_dist, n_graph, n_node, n_node, self.num_heads).permute(
287
+ 1, 2, 3, 0, 4
288
+ )
289
+ input_edges = (input_edges.sum(-2) / (spatial_pos_.float().unsqueeze(-1))).permute(0, 3, 1, 2)
290
+ else:
291
+ # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
292
+ input_edges = self.edge_encoder(attn_edge_type).mean(-2).permute(0, 3, 1, 2)
293
+
294
+ graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + input_edges
295
+ graph_attn_bias = graph_attn_bias + attn_bias.unsqueeze(1) # reset
296
+
297
+ return graph_attn_bias
298
+
299
+
300
+ class GraphormerMultiheadAttention(nn.Module):
301
+ """Multi-headed attention.
302
+
303
+ See "Attention Is All You Need" for more details.
304
+ """
305
+
306
+ def __init__(self, config: GraphormerConfig):
307
+ super().__init__()
308
+ self.embedding_dim = config.embedding_dim
309
+ self.kdim = config.kdim if config.kdim is not None else config.embedding_dim
310
+ self.vdim = config.vdim if config.vdim is not None else config.embedding_dim
311
+ self.qkv_same_dim = self.kdim == config.embedding_dim and self.vdim == config.embedding_dim
312
+
313
+ self.num_heads = config.num_attention_heads
314
+ self.dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False)
315
+
316
+ self.head_dim = config.embedding_dim // config.num_attention_heads
317
+ if not (self.head_dim * config.num_attention_heads == self.embedding_dim):
318
+ raise AssertionError("The embedding_dim must be divisible by num_heads.")
319
+ self.scaling = self.head_dim**-0.5
320
+
321
+ self.self_attention = True # config.self_attention
322
+ if not (self.self_attention):
323
+ raise NotImplementedError("The Graphormer model only supports self attention for now.")
324
+ if self.self_attention and not self.qkv_same_dim:
325
+ raise AssertionError("Self-attention requires query, key and value to be of the same size.")
326
+
327
+ self.k_proj = quant_noise(
328
+ nn.Linear(self.kdim, config.embedding_dim, bias=config.bias),
329
+ config.q_noise,
330
+ config.qn_block_size,
331
+ )
332
+ self.v_proj = quant_noise(
333
+ nn.Linear(self.vdim, config.embedding_dim, bias=config.bias),
334
+ config.q_noise,
335
+ config.qn_block_size,
336
+ )
337
+ self.q_proj = quant_noise(
338
+ nn.Linear(config.embedding_dim, config.embedding_dim, bias=config.bias),
339
+ config.q_noise,
340
+ config.qn_block_size,
341
+ )
342
+
343
+ self.out_proj = quant_noise(
344
+ nn.Linear(config.embedding_dim, config.embedding_dim, bias=config.bias),
345
+ config.q_noise,
346
+ config.qn_block_size,
347
+ )
348
+
349
+ self.onnx_trace = False
350
+
351
+ def reset_parameters(self):
352
+ if self.qkv_same_dim:
353
+ # Empirically observed the convergence to be much better with
354
+ # the scaled initialization
355
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
356
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
357
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
358
+ else:
359
+ nn.init.xavier_uniform_(self.k_proj.weight)
360
+ nn.init.xavier_uniform_(self.v_proj.weight)
361
+ nn.init.xavier_uniform_(self.q_proj.weight)
362
+
363
+ nn.init.xavier_uniform_(self.out_proj.weight)
364
+ if self.out_proj.bias is not None:
365
+ nn.init.constant_(self.out_proj.bias, 0.0)
366
+
367
+ def forward(
368
+ self,
369
+ query: torch.LongTensor,
370
+ key: Optional[torch.Tensor],
371
+ value: Optional[torch.Tensor],
372
+ attn_bias: Optional[torch.Tensor],
373
+ key_padding_mask: Optional[torch.Tensor] = None,
374
+ need_weights: bool = True,
375
+ attn_mask: Optional[torch.Tensor] = None,
376
+ before_softmax: bool = False,
377
+ need_head_weights: bool = False,
378
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
379
+ """
380
+ Args:
381
+ key_padding_mask (Bytetorch.Tensor, optional): mask to exclude
382
+ keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s.
383
+ need_weights (bool, optional): return the attention weights,
384
+ averaged over heads (default: False).
385
+ attn_mask (Bytetorch.Tensor, optional): typically used to
386
+ implement causal attention, where the mask prevents the attention from looking forward in time
387
+ (default: None).
388
+ before_softmax (bool, optional): return the raw attention
389
+ weights and values before the attention softmax.
390
+ need_head_weights (bool, optional): return the attention
391
+ weights for each head. Implies *need_weights*. Default: return the average attention weights over all
392
+ heads.
393
+ """
394
+ if need_head_weights:
395
+ need_weights = True
396
+
397
+ tgt_len, bsz, embedding_dim = query.size()
398
+ src_len = tgt_len
399
+ if not (embedding_dim == self.embedding_dim):
400
+ raise AssertionError(
401
+ f"The query embedding dimension {embedding_dim} is not equal to the expected embedding_dim"
402
+ f" {self.embedding_dim}."
403
+ )
404
+ if not (list(query.size()) == [tgt_len, bsz, embedding_dim]):
405
+ raise AssertionError("Query size incorrect in Graphormer, compared to model dimensions.")
406
+
407
+ if key is not None:
408
+ src_len, key_bsz, _ = key.size()
409
+ if not torch.jit.is_scripting():
410
+ if (key_bsz != bsz) or (value is None) or not (src_len, bsz == value.shape[:2]):
411
+ raise AssertionError(
412
+ "The batch shape does not match the key or value shapes provided to the attention."
413
+ )
414
+
415
+ q = self.q_proj(query)
416
+ k = self.k_proj(query)
417
+ v = self.v_proj(query)
418
+
419
+ q *= self.scaling
420
+
421
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
422
+ if k is not None:
423
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
424
+ if v is not None:
425
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
426
+
427
+ if (k is None) or not (k.size(1) == src_len):
428
+ raise AssertionError("The shape of the key generated in the attention is incorrect")
429
+
430
+ # This is part of a workaround to get around fork/join parallelism
431
+ # not supporting Optional types.
432
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
433
+ key_padding_mask = None
434
+
435
+ if key_padding_mask is not None:
436
+ if key_padding_mask.size(0) != bsz or key_padding_mask.size(1) != src_len:
437
+ raise AssertionError(
438
+ "The shape of the generated padding mask for the key does not match expected dimensions."
439
+ )
440
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
441
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
442
+
443
+ if list(attn_weights.size()) != [bsz * self.num_heads, tgt_len, src_len]:
444
+ raise AssertionError("The attention weights generated do not match the expected dimensions.")
445
+
446
+ if attn_bias is not None:
447
+ attn_weights += attn_bias.view(bsz * self.num_heads, tgt_len, src_len)
448
+
449
+ if attn_mask is not None:
450
+ attn_mask = attn_mask.unsqueeze(0)
451
+ attn_weights += attn_mask
452
+
453
+ if key_padding_mask is not None:
454
+ # don't attend to padding symbols
455
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
456
+ attn_weights = attn_weights.masked_fill(
457
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
458
+ )
459
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
460
+
461
+ if before_softmax:
462
+ return attn_weights, v
463
+
464
+ attn_weights_float = torch.nn.functional.softmax(attn_weights, dim=-1)
465
+ attn_weights = attn_weights_float.type_as(attn_weights)
466
+ attn_probs = self.dropout_module(attn_weights)
467
+
468
+ if v is None:
469
+ raise AssertionError("No value generated")
470
+ attn = torch.bmm(attn_probs, v)
471
+ if list(attn.size()) != [bsz * self.num_heads, tgt_len, self.head_dim]:
472
+ raise AssertionError("The attention generated do not match the expected dimensions.")
473
+
474
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embedding_dim)
475
+ attn: torch.Tensor = self.out_proj(attn)
476
+
477
+ attn_weights = None
478
+ if need_weights:
479
+ attn_weights = attn_weights_float.contiguous().view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
480
+ if not need_head_weights:
481
+ # average attention weights over heads
482
+ attn_weights = attn_weights.mean(dim=0)
483
+
484
+ return attn, attn_weights
485
+
486
+ def apply_sparse_mask(self, attn_weights: torch.Tensor, tgt_len: int, src_len: int, bsz: int) -> torch.Tensor:
487
+ return attn_weights
488
+
489
+
490
+ class GraphormerGraphEncoderLayer(nn.Module):
491
+ def __init__(self, config: GraphormerConfig) -> None:
492
+ super().__init__()
493
+
494
+ # Initialize parameters
495
+ self.embedding_dim = config.embedding_dim
496
+ self.num_attention_heads = config.num_attention_heads
497
+ self.attention_dropout = config.attention_dropout
498
+ self.q_noise = config.q_noise
499
+ self.qn_block_size = config.qn_block_size
500
+ self.pre_layernorm = config.pre_layernorm
501
+
502
+ self.dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False)
503
+
504
+ self.activation_dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False)
505
+
506
+ # Initialize blocks
507
+ self.activation_fn = ACT2FN[config.activation_fn]
508
+ self.self_attn = GraphormerMultiheadAttention(config)
509
+
510
+ # layer norm associated with the self attention layer
511
+ self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim)
512
+
513
+ self.fc1 = self.build_fc(
514
+ self.embedding_dim,
515
+ config.ffn_embedding_dim,
516
+ q_noise=config.q_noise,
517
+ qn_block_size=config.qn_block_size,
518
+ )
519
+ self.fc2 = self.build_fc(
520
+ config.ffn_embedding_dim,
521
+ self.embedding_dim,
522
+ q_noise=config.q_noise,
523
+ qn_block_size=config.qn_block_size,
524
+ )
525
+
526
+ # layer norm associated with the position wise feed-forward NN
527
+ self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
528
+
529
+ def build_fc(
530
+ self, input_dim: int, output_dim: int, q_noise: float, qn_block_size: int
531
+ ) -> Union[nn.Module, nn.Linear, nn.Embedding, nn.Conv2d]:
532
+ return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
533
+
534
+ def forward(
535
+ self,
536
+ input_nodes: torch.Tensor,
537
+ self_attn_bias: Optional[torch.Tensor] = None,
538
+ self_attn_mask: Optional[torch.Tensor] = None,
539
+ self_attn_padding_mask: Optional[torch.Tensor] = None,
540
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
541
+ """
542
+ nn.LayerNorm is applied either before or after the self-attention/ffn modules similar to the original
543
+ Transformer implementation.
544
+ """
545
+ residual = input_nodes
546
+ if self.pre_layernorm:
547
+ input_nodes = self.self_attn_layer_norm(input_nodes)
548
+
549
+ input_nodes, attn = self.self_attn(
550
+ query=input_nodes,
551
+ key=input_nodes,
552
+ value=input_nodes,
553
+ attn_bias=self_attn_bias,
554
+ key_padding_mask=self_attn_padding_mask,
555
+ need_weights=False,
556
+ attn_mask=self_attn_mask,
557
+ )
558
+ input_nodes = self.dropout_module(input_nodes)
559
+ input_nodes = residual + input_nodes
560
+ if not self.pre_layernorm:
561
+ input_nodes = self.self_attn_layer_norm(input_nodes)
562
+
563
+ residual = input_nodes
564
+ if self.pre_layernorm:
565
+ input_nodes = self.final_layer_norm(input_nodes)
566
+ input_nodes = self.activation_fn(self.fc1(input_nodes))
567
+ input_nodes = self.activation_dropout_module(input_nodes)
568
+ input_nodes = self.fc2(input_nodes)
569
+ input_nodes = self.dropout_module(input_nodes)
570
+ input_nodes = residual + input_nodes
571
+ if not self.pre_layernorm:
572
+ input_nodes = self.final_layer_norm(input_nodes)
573
+
574
+ return input_nodes, attn
575
+
576
+
577
+ class GraphormerGraphEncoder(nn.Module):
578
+ def __init__(self, config: GraphormerConfig):
579
+ super().__init__()
580
+
581
+ self.dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False)
582
+ self.layerdrop = config.layerdrop
583
+ self.embedding_dim = config.embedding_dim
584
+ self.apply_graphormer_init = config.apply_graphormer_init
585
+ self.traceable = config.traceable
586
+
587
+ self.graph_node_feature = GraphormerGraphNodeFeature(config)
588
+ self.graph_attn_bias = GraphormerGraphAttnBias(config)
589
+
590
+ self.embed_scale = config.embed_scale
591
+
592
+ if config.q_noise > 0:
593
+ self.quant_noise = quant_noise(
594
+ nn.Linear(self.embedding_dim, self.embedding_dim, bias=False),
595
+ config.q_noise,
596
+ config.qn_block_size,
597
+ )
598
+ else:
599
+ self.quant_noise = None
600
+
601
+ if config.encoder_normalize_before:
602
+ self.emb_layer_norm = nn.LayerNorm(self.embedding_dim)
603
+ else:
604
+ self.emb_layer_norm = None
605
+
606
+ if config.pre_layernorm:
607
+ self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
608
+
609
+ if self.layerdrop > 0.0:
610
+ self.layers = LayerDropModuleList(p=self.layerdrop)
611
+ else:
612
+ self.layers = nn.ModuleList([])
613
+ self.layers.extend([GraphormerGraphEncoderLayer(config) for _ in range(config.num_hidden_layers)])
614
+
615
+ # Apply initialization of model params after building the model
616
+ if config.freeze_embeddings:
617
+ raise NotImplementedError("Freezing embeddings is not implemented yet.")
618
+
619
+ for layer in range(config.num_trans_layers_to_freeze):
620
+ m = self.layers[layer]
621
+ if m is not None:
622
+ for p in m.parameters():
623
+ p.requires_grad = False
624
+
625
+ def forward(
626
+ self,
627
+ input_nodes: torch.LongTensor,
628
+ input_edges: torch.LongTensor,
629
+ attn_bias: torch.Tensor,
630
+ in_degree: torch.LongTensor,
631
+ out_degree: torch.LongTensor,
632
+ spatial_pos: torch.LongTensor,
633
+ attn_edge_type: torch.LongTensor,
634
+ perturb=None,
635
+ last_state_only: bool = False,
636
+ token_embeddings: Optional[torch.Tensor] = None,
637
+ attn_mask: Optional[torch.Tensor] = None,
638
+ ) -> Tuple[Union[torch.Tensor, List[torch.LongTensor]], torch.Tensor]:
639
+ # compute padding mask. This is needed for multi-head attention
640
+ data_x = input_nodes
641
+ n_graph, n_node = data_x.size()[:2]
642
+ padding_mask = (data_x[:, :, 0]).eq(0)
643
+ padding_mask_cls = torch.zeros(n_graph, 1, device=padding_mask.device, dtype=padding_mask.dtype)
644
+ padding_mask = torch.cat((padding_mask_cls, padding_mask), dim=1)
645
+
646
+ attn_bias = self.graph_attn_bias(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type)
647
+
648
+ if token_embeddings is not None:
649
+ input_nodes = token_embeddings
650
+ else:
651
+ input_nodes = self.graph_node_feature(input_nodes, in_degree, out_degree)
652
+
653
+ if perturb is not None:
654
+ input_nodes[:, 1:, :] += perturb
655
+
656
+ if self.embed_scale is not None:
657
+ input_nodes = input_nodes * self.embed_scale
658
+
659
+ if self.quant_noise is not None:
660
+ input_nodes = self.quant_noise(input_nodes)
661
+
662
+ if self.emb_layer_norm is not None:
663
+ input_nodes = self.emb_layer_norm(input_nodes)
664
+
665
+ input_nodes = self.dropout_module(input_nodes)
666
+
667
+ input_nodes = input_nodes.transpose(0, 1)
668
+
669
+ inner_states = []
670
+ if not last_state_only:
671
+ inner_states.append(input_nodes)
672
+
673
+ for layer in self.layers:
674
+ input_nodes, _ = layer(
675
+ input_nodes,
676
+ self_attn_padding_mask=padding_mask,
677
+ self_attn_mask=attn_mask,
678
+ self_attn_bias=attn_bias,
679
+ )
680
+ if not last_state_only:
681
+ inner_states.append(input_nodes)
682
+
683
+ graph_rep = input_nodes[0, :, :]
684
+
685
+ if last_state_only:
686
+ inner_states = [input_nodes]
687
+
688
+ if self.traceable:
689
+ return torch.stack(inner_states), graph_rep
690
+ else:
691
+ return inner_states, graph_rep
692
+
693
+
694
+ class GraphormerDecoderHead(nn.Module):
695
+ def __init__(self, embedding_dim: int, num_classes: int):
696
+ super().__init__()
697
+ """num_classes should be 1 for regression, or the number of classes for classification"""
698
+ self.lm_output_learned_bias = nn.Parameter(torch.zeros(1))
699
+ self.classifier = nn.Linear(embedding_dim, num_classes, bias=False)
700
+ self.num_classes = num_classes
701
+
702
+ def forward(self, input_nodes: torch.Tensor, **unused) -> torch.Tensor:
703
+ input_nodes = self.classifier(input_nodes)
704
+ input_nodes = input_nodes + self.lm_output_learned_bias
705
+ return input_nodes
706
+
707
+
708
+ class GraphormerPreTrainedModel(PreTrainedModel):
709
+ """
710
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
711
+ models.
712
+ """
713
+
714
+ config_class = GraphormerConfig
715
+ base_model_prefix = "graphormer"
716
+ supports_gradient_checkpointing = True
717
+ main_input_name_nodes = "input_nodes"
718
+ main_input_name_edges = "input_edges"
719
+
720
+ def normal_(self, data: torch.Tensor):
721
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
722
+ # so that the RNG is consistent with and without FSDP
723
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
724
+
725
+ def init_graphormer_params(self, module: Union[nn.Linear, nn.Embedding, GraphormerMultiheadAttention]):
726
+ """
727
+ Initialize the weights specific to the Graphormer Model.
728
+ """
729
+ if isinstance(module, nn.Linear):
730
+ self.normal_(module.weight.data)
731
+ if module.bias is not None:
732
+ module.bias.data.zero_()
733
+ if isinstance(module, nn.Embedding):
734
+ self.normal_(module.weight.data)
735
+ if module.padding_idx is not None:
736
+ module.weight.data[module.padding_idx].zero_()
737
+ if isinstance(module, GraphormerMultiheadAttention):
738
+ self.normal_(module.q_proj.weight.data)
739
+ self.normal_(module.k_proj.weight.data)
740
+ self.normal_(module.v_proj.weight.data)
741
+
742
+ def _init_weights(
743
+ self,
744
+ module: Union[
745
+ nn.Linear, nn.Conv2d, nn.Embedding, nn.LayerNorm, GraphormerMultiheadAttention, GraphormerGraphEncoder
746
+ ],
747
+ ):
748
+ """
749
+ Initialize the weights
750
+ """
751
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
752
+ # We might be missing part of the Linear init, dependant on the layer num
753
+ module.weight.data.normal_(mean=0.0, std=0.02)
754
+ if module.bias is not None:
755
+ module.bias.data.zero_()
756
+ elif isinstance(module, nn.Embedding):
757
+ module.weight.data.normal_(mean=0.0, std=0.02)
758
+ if module.padding_idx is not None:
759
+ module.weight.data[module.padding_idx].zero_()
760
+ elif isinstance(module, GraphormerMultiheadAttention):
761
+ module.q_proj.weight.data.normal_(mean=0.0, std=0.02)
762
+ module.k_proj.weight.data.normal_(mean=0.0, std=0.02)
763
+ module.v_proj.weight.data.normal_(mean=0.0, std=0.02)
764
+ module.reset_parameters()
765
+ elif isinstance(module, nn.LayerNorm):
766
+ module.bias.data.zero_()
767
+ module.weight.data.fill_(1.0)
768
+ elif isinstance(module, GraphormerGraphEncoder):
769
+ if module.apply_graphormer_init:
770
+ module.apply(self.init_graphormer_params)
771
+
772
+ elif isinstance(module, nn.LayerNorm):
773
+ module.bias.data.zero_()
774
+ module.weight.data.fill_(1.0)
775
+
776
+ def _set_gradient_checkpointing(self, module, value=False):
777
+ if isinstance(module, GraphormerModel):
778
+ module.gradient_checkpointing = value
779
+
780
+
781
+ class GraphormerModel(GraphormerPreTrainedModel):
782
+ """The Graphormer model is a graph-encoder model.
783
+
784
+ It goes from a graph to its representation. If you want to use the model for a downstream classification task, use
785
+ GraphormerForGraphClassification instead. For any other downstream task, feel free to add a new class, or combine
786
+ this model with a downstream model of your choice, following the example in GraphormerForGraphClassification.
787
+ """
788
+
789
+ def __init__(self, config: GraphormerConfig):
790
+ super().__init__(config)
791
+ self.max_nodes = config.max_nodes
792
+
793
+ self.graph_encoder = GraphormerGraphEncoder(config)
794
+
795
+ self.share_input_output_embed = config.share_input_output_embed
796
+ self.lm_output_learned_bias = None
797
+
798
+ # Remove head is set to true during fine-tuning
799
+ self.load_softmax = not getattr(config, "remove_head", False)
800
+
801
+ self.lm_head_transform_weight = nn.Linear(config.embedding_dim, config.embedding_dim)
802
+ self.activation_fn = ACT2FN[config.activation_fn]
803
+ self.layer_norm = nn.LayerNorm(config.embedding_dim)
804
+
805
+ self.post_init()
806
+
807
+ def reset_output_layer_parameters(self):
808
+ self.lm_output_learned_bias = nn.Parameter(torch.zeros(1))
809
+
810
+ def forward(
811
+ self,
812
+ input_nodes: torch.LongTensor,
813
+ input_edges: torch.LongTensor,
814
+ attn_bias: torch.Tensor,
815
+ in_degree: torch.LongTensor,
816
+ out_degree: torch.LongTensor,
817
+ spatial_pos: torch.LongTensor,
818
+ attn_edge_type: torch.LongTensor,
819
+ perturb=None,
820
+ masked_tokens=None,
821
+ return_dict: Optional[bool] = None,
822
+ **unused,
823
+ ) -> Union[Tuple[torch.LongTensor], BaseModelOutputWithNoAttention]:
824
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
825
+
826
+ inner_states, graph_rep = self.graph_encoder(
827
+ input_nodes, input_edges, attn_bias, in_degree, out_degree, spatial_pos, attn_edge_type, perturb=perturb
828
+ )
829
+
830
+ # last inner state, then revert Batch and Graph len
831
+ input_nodes = inner_states[-1].transpose(0, 1)
832
+
833
+ # project masked tokens only
834
+ if masked_tokens is not None:
835
+ raise NotImplementedError
836
+
837
+ input_nodes = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(input_nodes)))
838
+
839
+ # project back to size of vocabulary
840
+ if self.share_input_output_embed and hasattr(self.graph_encoder.embed_tokens, "weight"):
841
+ input_nodes = torch.nn.functional.linear(input_nodes, self.graph_encoder.embed_tokens.weight)
842
+
843
+ if not return_dict:
844
+ return tuple(x for x in [input_nodes, inner_states] if x is not None)
845
+ return BaseModelOutputWithNoAttention(last_hidden_state=input_nodes, hidden_states=inner_states)
846
+
847
+ def max_nodes(self):
848
+ """Maximum output length supported by the encoder."""
849
+ return self.max_nodes
850
+
851
+
852
+ class GraphormerForGraphClassification(GraphormerPreTrainedModel):
853
+ """
854
+ This model can be used for graph-level classification or regression tasks.
855
+
856
+ It can be trained on
857
+ - regression (by setting config.num_classes to 1); there should be one float-type label per graph
858
+ - one task classification (by setting config.num_classes to the number of classes); there should be one integer
859
+ label per graph
860
+ - binary multi-task classification (by setting config.num_classes to the number of labels); there should be a list
861
+ of integer labels for each graph.
862
+ """
863
+
864
+ def __init__(self, config: GraphormerConfig):
865
+ super().__init__(config)
866
+ self.encoder = GraphormerModel(config)
867
+ self.embedding_dim = config.embedding_dim
868
+ self.num_classes = config.num_classes
869
+ self.classifier = GraphormerDecoderHead(self.embedding_dim, self.num_classes)
870
+ self.is_encoder_decoder = True
871
+
872
+ # Initialize weights and apply final processing
873
+ self.post_init()
874
+
875
+ def forward(
876
+ self,
877
+ input_nodes: torch.LongTensor,
878
+ input_edges: torch.LongTensor,
879
+ attn_bias: torch.Tensor,
880
+ in_degree: torch.LongTensor,
881
+ out_degree: torch.LongTensor,
882
+ spatial_pos: torch.LongTensor,
883
+ attn_edge_type: torch.LongTensor,
884
+ labels: Optional[torch.LongTensor] = None,
885
+ return_dict: Optional[bool] = None,
886
+ **unused,
887
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
888
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
889
+
890
+ encoder_outputs = self.encoder(
891
+ input_nodes,
892
+ input_edges,
893
+ attn_bias,
894
+ in_degree,
895
+ out_degree,
896
+ spatial_pos,
897
+ attn_edge_type,
898
+ return_dict=True,
899
+ )
900
+ outputs, hidden_states = encoder_outputs["last_hidden_state"], encoder_outputs["hidden_states"]
901
+
902
+ head_outputs = self.classifier(outputs)
903
+ logits = head_outputs[:, 0, :].contiguous()
904
+
905
+ loss = None
906
+ if labels is not None:
907
+ mask = ~torch.isnan(labels)
908
+
909
+ if self.num_classes == 1: # regression
910
+ loss_fct = MSELoss()
911
+ loss = loss_fct(logits[mask].squeeze(), labels[mask].squeeze().float())
912
+ elif self.num_classes > 1 and len(labels.shape) == 1: # One task classification
913
+ loss_fct = CrossEntropyLoss()
914
+ loss = loss_fct(logits[mask].view(-1, self.num_classes), labels[mask].view(-1))
915
+ else: # Binary multi-task classification
916
+ loss_fct = BCEWithLogitsLoss(reduction="sum")
917
+ loss = loss_fct(logits[mask], labels[mask])
918
+
919
+ if not return_dict:
920
+ return tuple(x for x in [loss, logits, hidden_states] if x is not None)
921
+ return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=hidden_states, attentions=None)