Spaces:
Running
Running
File size: 5,341 Bytes
8279c69 4e04e76 8279c69 4e04e76 8279c69 4e04e76 8279c69 4e04e76 8279c69 4e04e76 8279c69 4e04e76 8279c69 4e04e76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
import torch.nn as nn
from layers import TransformerEncoder
class Generator(nn.Module):
"""Generator network."""
def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio):
super(Generator, self).__init__()
self.vertexes = vertexes
self.edges = edges
self.nodes = nodes
self.depth = depth
self.dim = dim
self.heads = heads
self.mlp_ratio = mlp_ratio
self.dropout = dropout
if act == "relu":
act = nn.ReLU()
elif act == "leaky":
act = nn.LeakyReLU()
elif act == "sigmoid":
act = nn.Sigmoid()
elif act == "tanh":
act = nn.Tanh()
self.features = vertexes * vertexes * edges + vertexes * nodes
self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
self.pos_enc_dim = 5
self.node_layers = nn.Sequential(nn.Linear(nodes, 64), act, nn.Linear(64, dim), act, nn.Dropout(self.dropout))
self.edge_layers = nn.Sequential(nn.Linear(edges, 64), act, nn.Linear(64, dim), act, nn.Dropout(self.dropout))
self.TransformerEncoder = TransformerEncoder(dim=self.dim, depth=self.depth, heads=self.heads, act = act,
mlp_ratio=self.mlp_ratio, drop_rate=self.dropout)
self.readout_e = nn.Linear(self.dim, edges)
self.readout_n = nn.Linear(self.dim, nodes)
self.softmax = nn.Softmax(dim = -1)
def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def laplacian_positional_enc(self, adj):
A = adj
D = torch.diag(torch.count_nonzero(A, dim=-1))
L = torch.eye(A.shape[0], device=A.device) - D * A * D
EigVal, EigVec = torch.linalg.eig(L)
idx = torch.argsort(torch.real(EigVal))
EigVal, EigVec = EigVal[idx], torch.real(EigVec[:,idx])
pos_enc = EigVec[:,1:self.pos_enc_dim + 1]
return pos_enc
def forward(self, z_e, z_n):
b, n, c = z_n.shape
_, _, _ , d = z_e.shape
node = self.node_layers(z_n)
edge = self.edge_layers(z_e)
edge = (edge + edge.permute(0, 2, 1, 3)) / 2
node, edge = self.TransformerEncoder(node, edge)
node_sample = self.readout_n(node)
edge_sample = self.readout_e(edge)
return node, edge, node_sample, edge_sample
class Discriminator(nn.Module):
def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio):
super(Discriminator, self).__init__()
self.vertexes = vertexes
self.edges = edges
self.nodes = nodes
self.depth = depth
self.dim = dim
self.heads = heads
self.mlp_ratio = mlp_ratio
self.dropout = dropout
if act == "relu":
act = nn.ReLU()
elif act == "leaky":
act = nn.LeakyReLU()
elif act == "sigmoid":
act = nn.Sigmoid()
elif act == "tanh":
act = nn.Tanh()
self.features = vertexes * vertexes * edges + vertexes * nodes
self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
self.node_layers = nn.Sequential(nn.Linear(nodes, 64), act, nn.Linear(64, dim), act, nn.Dropout(self.dropout))
self.edge_layers = nn.Sequential(nn.Linear(edges, 64), act, nn.Linear(64, dim), act, nn.Dropout(self.dropout))
self.TransformerEncoder = TransformerEncoder(dim=self.dim, depth=self.depth, heads=self.heads, act = act,
mlp_ratio=self.mlp_ratio, drop_rate=self.dropout)
self.node_features = vertexes * dim
self.edge_features = vertexes * vertexes * dim
self.node_mlp = nn.Sequential(nn.Linear(self.node_features, 64), act, nn.Linear(64, 32), act, nn.Linear(32, 16), act, nn.Linear(16, 1))
def forward(self, z_e, z_n):
b, n, c = z_n.shape
_, _, _ , d = z_e.shape
node = self.node_layers(z_n)
edge = self.edge_layers(z_e)
edge = (edge + edge.permute(0, 2, 1, 3)) / 2
node, edge = self.TransformerEncoder(node, edge)
node = node.view(b, -1)
prediction = self.node_mlp(node)
return prediction
class simple_disc(nn.Module):
def __init__(self, act, m_dim, vertexes, b_dim):
super().__init__()
if act == "relu":
act = nn.ReLU()
elif act == "leaky":
act = nn.LeakyReLU()
elif act == "sigmoid":
act = nn.Sigmoid()
elif act == "tanh":
act = nn.Tanh()
else:
raise ValueError("Unsupported activation function: {}".format(act))
features = vertexes * m_dim + vertexes * vertexes * b_dim
self.predictor = nn.Sequential(nn.Linear(features,256), act, nn.Linear(256,128), act, nn.Linear(128,64), act,
nn.Linear(64,32), act, nn.Linear(32,16), act,
nn.Linear(16,1))
def forward(self, x):
prediction = self.predictor(x)
return prediction |