图神经网络DGL框架,graph classification,多个且不同维度的node feature 训练-多个node features

时间:2024-02-17 22:21:12

我的节点有n_weight和n_community两个特征,都要加入训练
forward 程序中,cat n_weight和n_community两个特征,如果有很多个特征,写循环

class GraphClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(GraphClassifier, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, hidden_dim)
        # flatten into linear so we can crossentropy/ softmax it.
        self.classify = nn.Linear(hidden_dim, n_classes)

        
    def forward(self, g):

        # run the weight feature through the net
        w = g.ndata['n_weight']
        w = F.relu(self.conv1(g, w))
        w = F.relu(self.conv2(g, w))
        g.ndata['n_weight'] = w
        
        # run the community feature through the net
        c = g.ndata['n_community']
        c = F.relu(self.conv1(g, c))
        c = F.relu(self.conv1(g, c))
        g.ndata['n_community'] = c
        
        # combine both features into one tensor
        wc = torch.cat((w, c), 1)
        return self.classify(wc)