元学习系列(三):Relation Network(关系网络)

时间:2024-03-15 07:49:18

对小样本学习,一开始介绍了孪生网络,它主要输入数据的任意两幅图像,学习计算他们的匹配程度,从而在测试集中计算测试样本和训练样本的哪个样本匹配程度最高。

孪生网络需要计算任意两两样本的匹配程度,而原型网络则进一步改进,提出对样本进行适当的embedding,然后计算样本每一类的样本中心,称为原型prototype,通过模型学习出prototype的位置,对测试样本计算到每个原型的距离,从而进行分类。

不论是孪生网络还是原型网络,在分析两个样本的时候都是通过embedding后的特征向量距离(比如欧氏距离)来反映,而关系网络则是通过构建神经网络来计算两个样本之间的距离从而分析匹配程度,和孪生网络、原型网络相比,关系网络可以看成提供了一个可学习的非线性分类器用于判断关系,而孪生网络、原型网络的距离只是一种线性的关系分类器。

我们以few-shot learning中的分类问题作为例子讨论一下关系网络的应用,首先查询集和样本集随机抽取样本交给embedding层处理得到feature map,然后把两个feature map拼接在一起,再交给关系网络处理,并计算出关系得分,假如这是一个5-way 1-shot问题,那么我们就可以得到5个得分,每个得分对应查询集样本属于每个分类的得分(概率),用公式可表示为:

ri,j=gϕ(C(fϕ(xi),fϕ(xj))), i=1,2,...,5r_{i,j} = g_{\phi}(C(f_{\phi}(x_i), f_{\phi}(x_j))), \ i=1,2,...,5

其中f表示embedding网络,C表示拼接操作,g表示关系网络。所以实际上,就是把各个类别的样本和测试样本的特征向量拼接起来,输入到神经网络中,通过神经网络分析他们之间的匹配度。

元学习系列(三):Relation Network(关系网络)
以上是few-way 1-shot的情况,如果是k-shot呢。论文提出把k个样本embedding后的向量相(按位),形成单个feature map,然后再继续上述的操作。

关于模型的训练,主要目的还是希望模型的关系网络对于正确分类的得分能更高,错误分类得分更低,损失函数形式为:

Loss=i=1mj=1n(ri,j1(yi==yj))2Loss = \sum_{i=1}^m \sum_{j=1}^n (r_{i,j} - 1(y_i==y_j))^2

其中1表示当yi等于yj就输出1,负责输出0。

所以简单来说,关系网络的创新点就是提出用神经网络,而不是欧氏距离去计算两个特征变量之间的匹配程度。

在github写的自然语言处理入门教程,持续更新:NLPBeginner

在github写的机器学习入门教程,持续更新:MachineLearningModels

想浏览更多关于数学、机器学习、深度学习的内容,可浏览本人博客