|
class Net(torch.nn.Module): |
|
def __init__(self, num_relations, num_classes, num_nodes=None, input_dim=None, hidden_dim=16, num_bases=30): |
|
super().__init__() |
|
assert num_nodes is not None or input_dim is not None, "Please provide input feature dimensionality or number of nodes" |
|
self.conv1 = RGCNConv(num_nodes if input_dim is None else input_dim, hidden_dim, num_relations, |
|
num_bases) |
|
self.conv2 = RGCNConv(hidden_dim, num_classes, dataset.num_relations, |
|
num_bases) |
|
|
|
def forward(self, x, edge_index, edge_type): |
|
|
|
x = F.relu(self.conv1(x, edge_index, edge_type)) |
|
x = self.conv2(x, edge_index, edge_type) |
|
return F.log_softmax(x, dim=1) |