riship-nv commited on
Commit
1a1dde9
1 Parent(s): 9068837

basic model definition

Browse files
Files changed (1) hide show
  1. model_definition.py +14 -0
model_definition.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Net(torch.nn.Module):
2
+ def __init__(self, num_relations, num_classes, num_nodes=None, input_dim=None, hidden_dim=16, num_bases=30):
3
+ super().__init__()
4
+ assert num_nodes is not None or input_dim is not None, "Please provide input feature dimensionality or number of nodes"
5
+ self.conv1 = RGCNConv(num_nodes if input_dim is None else input_dim, hidden_dim, num_relations,
6
+ num_bases)
7
+ self.conv2 = RGCNConv(hidden_dim, num_classes, dataset.num_relations,
8
+ num_bases)
9
+
10
+ def forward(self, x, edge_index, edge_type):
11
+ # if x is None, uses an embedding based on num_nodes
12
+ x = F.relu(self.conv1(x, edge_index, edge_type))
13
+ x = self.conv2(x, edge_index, edge_type)
14
+ return F.log_softmax(x, dim=1)