消息传递模型(Message Passing Model)基于拉普拉斯平滑假设(邻居是相似的),试图聚合图中的邻居的信息来获取足够的依据,以实现更鲁棒的半监督节点分类。
图神经网络(Graph Neural Networks, GNN)和标签传播算法(Label Propagation, LPA)均为消息传递算法,其中GNN主要基于传播特征来提升预测效果,而LPA基于迭代式的标签传播来作预测。
一些工作要么用LPA对GNN预测结果做后处理,要么用LPA对GNN进行正则化。但是,它们仍不能直接将GNN和LPA有效地整合到消息传递模型中。
为解决这个问题,本文提出了统一消息传递模型(UNIMP)[1],它可以在训练和推理时结合特征和标签传播。UniMP基于两个简单而有效的想法:
- 将特征嵌入和标签嵌入同时作为输入信息进行传播
- 随机掩码部分标签信息,并在训练时对其进行预测
UniMP在概念上统一了特征传播和标签传播,具有强大的经验能力。
二、实现
2.1 关键部分
- 将标签进行嵌入(原有的C类One-hot标签,通过线性变换成与原始节点特征相同的维度)。
- 然后,将标签嵌入和节点特征相加作为GNN输入。 为避免训练时使用标签导致标签泄露,这里使用了掩码标签训练的策略。每个Epoch随机将训练集中部分节点的标签置(掩码)0(视为训练监督信号),然后利用节点特征 \(\mathbf{X}\) 和 \(\mathbf{A}\)以及剩余的标签去预测被掩码的标签)。
2.2 模型部分
UniMP中使用了GraphTransformer(Transformer中的Q、K、V注意力形式,加上边特征),同时引入了H-GCN的门控残差机制来缓解过平滑。
三、个人实验
将标签作为输入,在ArixV数据集节点分类任务上,能在小数点后第2位提升接近2个点。
论文BOT[2]中也对标签作为输入做了阐述,其作者还发表了相应的论文来论证标签作为输入的有效性的原因。
四、总结
标签有效的直觉就是,在图上的节点分类任务中,邻居标签也是预测目标节点标签的关键特征(这也和标签传播的思想一致)
标签嵌入和掩码标签预测是提升节点分类任务简单有效的方法。
这也印证了,特征层面的改进有时或许比模型结构提升来得更快。
参考文献
[1] Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification
[2] Bag of Tricks for Node Classification with Graph Neural Networks