我的节点有n_weight和n_community两个特征,都要加入训练
forward 程序中,cat n_weight和n_community两个特征,如果有很多个特征,写循环
class GraphClassifier(nn.Module):
def __init__(self, in_dim, hidden_dim, n_classes):
super(GraphClassifier, self).__init__()
self.conv1 = GraphConv(in_dim, hidden_dim)
self.conv2 = GraphConv(hidden_dim, hidden_dim)
# flatten into linear so we can crossentropy/ softmax it.
self.classify = nn.Linear(hidden_dim, n_classes)
def forward(self, g):
# run the weight feature through the net
w = g.ndata['n_weight']
w = F.relu(self.conv1(g, w))
w = F.relu(self.conv2(g, w))
g.ndata['n_weight'] = w
# run the community feature through the net
c = g.ndata['n_community']
c = F.relu(self.conv1(g, c))
c = F.relu(self.conv1(g, c))
g.ndata['n_community'] = c
# combine both features into one tensor
wc = torch.cat((w, c), 1)
return self.classify(wc)