Spaces:
Runtime error
Runtime error
File size: 2,432 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
class MeanAggregator(nn.Module):
def forward(self, features, A):
x = torch.bmm(A, features)
return x
class GraphConv(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.weight = nn.Parameter(torch.FloatTensor(in_dim * 2, out_dim))
self.bias = nn.Parameter(torch.FloatTensor(out_dim))
init.xavier_uniform_(self.weight)
init.constant_(self.bias, 0)
self.aggregator = MeanAggregator()
def forward(self, features, A):
b, n, d = features.shape
assert d == self.in_dim
agg_feats = self.aggregator(features, A)
cat_feats = torch.cat([features, agg_feats], dim=2)
out = torch.einsum('bnd,df->bnf', cat_feats, self.weight)
out = F.relu(out + self.bias)
return out
class GCN(nn.Module):
"""Graph convolutional network for clustering. This was from repo
https://github.com/Zhongdao/gcn_clustering licensed under the MIT license.
Args:
feat_len(int): The input node feature length.
"""
def __init__(self, feat_len):
super(GCN, self).__init__()
self.bn0 = nn.BatchNorm1d(feat_len, affine=False).float()
self.conv1 = GraphConv(feat_len, 512)
self.conv2 = GraphConv(512, 256)
self.conv3 = GraphConv(256, 128)
self.conv4 = GraphConv(128, 64)
self.classifier = nn.Sequential(
nn.Linear(64, 32), nn.PReLU(32), nn.Linear(32, 2))
def forward(self, x, A, knn_inds):
num_local_graphs, num_max_nodes, feat_len = x.shape
x = x.view(-1, feat_len)
x = self.bn0(x)
x = x.view(num_local_graphs, num_max_nodes, feat_len)
x = self.conv1(x, A)
x = self.conv2(x, A)
x = self.conv3(x, A)
x = self.conv4(x, A)
k = knn_inds.size(-1)
mid_feat_len = x.size(-1)
edge_feat = torch.zeros((num_local_graphs, k, mid_feat_len),
device=x.device)
for graph_ind in range(num_local_graphs):
edge_feat[graph_ind, :, :] = x[graph_ind, knn_inds[graph_ind]]
edge_feat = edge_feat.view(-1, mid_feat_len)
pred = self.classifier(edge_feat)
return pred
|