元学习系列(七):MAML与Meta-SGD

时间:2024-03-15 07:48:00

meta learning的一个重点在于如何在模型中引入先验知识,在孪生网络、原型网络等模型中,他们利用训练集作为先验知识,通过对比测试样本和训练样本进行分析,在神经图灵机等记忆增强神经网络中,模型引入了外部记忆,在训练过程中通过优化外部记忆,从而在训练新任务时,能通过外部记忆更快更准确地学习,而接下来要介绍的MAML,则从模型的参数初始化入手。

重新再说一次,小样本学习或者元学习的目的就是对于一个新任务,比如关于新类别的分类任务,模型在原来的基础上,只需要学习几个样本,经过几次迭代,就能达到很好的效果。

MAML算是给予了元学习一个新方向,像这种开拓性的模型,理解起来比其他改进性的模型要困难,所以我打算先详细讲一下如何通俗地理解MAML的思想。

首先我们来看一个例子,现在有四个洞打地鼠,之前地鼠每个洞都出现过,而下一次地鼠出现的地点是随机的,那么你要站在哪里才能在一分钟内打到地鼠。

如果我们用传统的机器学习或者深度学习模型去分析,模型可能会根据上几次地鼠出现的地点去预测下一次它出现的地点,然后结合不同地点的距离分析出你应该站的位置,比如说上几次地鼠都出现在1号洞,那么模型就很有可能建议你站在1号洞附近等待。

可是这样做可能有一个问题,就是你站在一号洞附近,就有可能离4号洞比较远,万一地鼠下一次出现在4号洞,你就赶不及了,所以,在这种要求下,最好的方法还是选一个距离四个洞都比较近的地点等待。

MAML的思想就是这样,找到一个模型的初始化参数,使得模型能够在短时间内通过少量样本进行学习,就能达到满意的效果,或者说这些初始化的参数,经过几次梯度下降,就能使模型达到较好的效果,MAML要找的就是这些参数。

这就是MAML和其他深度学习模型的不同之处,它要找的不是最优的参数,而是能够最快使模型达到最优的参数,为什么要这样呢,个人理解,还是和过拟合这个问题有关,对于小样本问题,深度学习模型因为层数较深,很容易就会过拟合,我觉得这可以表现在模型对于数量小的样本效果较差,特别在样本中同时存在数量较大和数量较小的不同类别的样本的时候,如果要做一个分类问题,模型肯定就会偏向于数量较大的类别。

所以MAML就提出,不如找一个平衡点,使得模型既不会偏向数量较大的类别,也不会偏向数量较小的类别,而是处于刚刚好的位置,使得在测试的时候,见到样本能够恰当地作出决策。

元学习系列(七):MAML与Meta-SGD

上图就挺经典的,假设这是一个参数关于模型的损失的图像,可以看到MAML找到的初始化参数会使得模型的损失非常大,效果非常差,但是要记住,MAML的目的不是找到最优参数使得模型效果最优,而是找到最合适的参数使得模型最快达到最优,而从这个角度来说,这个位置距离附近的最低点都是比较近的一个选择。

继续以小样本分类作为例子,以前的话,在样本充足的情况下,我们一直都希望模型能够对每个类别都能够有很好的分类效果。但是在部分类别样本不足的情况下,如果我们还是强行希望模型对他们能达到较好的分类效果,就会导致过拟合的问题,也就是,当我们在测试的时候,如果遇到了这些类别的样本,模型就很可能作出错误的分类,所以其实小样本学习,或者说元学习,虽然一开始说我们有一个目的是希望模型能够学会学习,听着很高深,但是实质上我们也只是在想办法解决小样本学习中的过拟合问题,如果模型能够在小样本的情况下也能达到较好的分类效果,那到底它学习能力强不强,根本就不重要,毕竟我们说到底只是希望利用模型做一些应用罢了。

啰嗦了那么久应该能够说清楚MAML的目的和思想了,接下来就讨论一下MAML的具体算法过程。

再说一次,MAML的目的是寻找一组最好的初始化参数,使得它能够在任意一个新任务上更新少数几次梯度就能达到比较好的效果,之后,再针对具体的任务进行微调,使得模型在小样本上也能达到较好的效果。

目标提到任意一个新任务,实际意思就是,在MAML的训练过程中,会针对多个任务进行训练,还是以分类作为例子,训练集包含多个task,但不同的task包含的类别不同,比如第一次训练模型需要分类猫狗人钢琴文具,第二次训练模型需要分类车人灯楼猫,反正每次训练的内容都不同:

元学习系列(七):MAML与Meta-SGD

具体来说,模型的训练数据可以分成两部分,一部分是为了找出最好的初始化参数,另一部分是关于具体任务用于模型的fine tuning,前者包含meta-train class和meta-test class,后者包含support set(有标签,用于调优)和query set(没有标签,用于预测)。其实也可以看成一个meta class和一个目标class,class里面再分成两部分是模型训练的需要。

接下来全部以分类任务作为例子介绍,假设现在我们一共有十个类别,我们从中随机抽5个样本构成一个task的train class,重复这个过程得到多个task,在这个过程中,可能train class或者test class中包含了重复的样本,但是影响不大,因为我们的目的是希望模型在面对大量的task中学习到足够强的泛化能力,所以重点在于不同task之间有一定区分度,而模型依然能准确做判断。

然后我们就可以开始训练我们的元模型,每一次的训练,模型会针对每个task做分类,然后计算梯度,更新参数,但是要注意,这个更新是相互独立的,比如现在一个batch有四个task,那么我们分别计算出四个不同的新参数:

θi=θaθL(fθ)\theta ' _i = \theta - a \nabla_\theta L (f_\theta)

把新的参数代进模型,就能得到四个新的模型,然后就可以用新模型分别计算每个task的meta-test class,得到四个损失函数,再加起来,作为一个batch的总损失:

Loss=L(fθi)Loss = \sum L (f_{\theta '_i})

为什么我们需要这个loss,还记得一开始,我们希望能够求出一个初始化参数,使得模型经过少数几次梯度下降就能得到较好的性能,而这里其实就是把梯度下降后的模型性能作为训练的损失函数,既然希望梯度下降之后模型的性能较好,那当然就是最小化梯度下降之后的模型的损失函数:

θnew=θβθL(fθi)\theta_{new} = \theta - \beta \nabla_{\theta} \sum L (f_{\theta '_i})

以上过程就展示了元模型的学习过程,经过学习,模型就能找到最优的初始点,接下来的任务就是针对具体的情况进行fine tuning。

首先当然就是把之前求出来的最有初始参数代进模型,然后就可以利用test class的support set进行训练,这个训练过程就是常规的训练过程了,不需要两次梯度更新,只是常规的利用support set进行梯度下降优化参数,最后再用query set测试模型。

可以看到,上述过程调整的主要是参数,能不能进一步通过模型学习优化的方向和学习率呢,这就是Meta-SGD针对MAML作出的一点改进:

θi=θaθL(fθ)\theta ' _i = \theta - a \cdot \nabla_\theta L (f_\theta)

和原式相比,主要改进是学习率不再是常数,而是需要学习的变量,同时a变成了一组和theta同样尺寸的向量,再和微分的损失函数按位相乘,为什么这里要这么做,我说一下我自己的理解。

比如现在模型一共有10个参数要学习,那么theta就是一个长度为10的向量,a也是一个长度为10的向量,现在模型训练了一个task,得到了损失函数,损失函数对这10个不同的参数求导,求出不同的参数在梯度下降时要减去的量,它们也构成了一个长度为10的向量。之前,我们用一个固定的常数a去乘它,现在,我们的a也不是固定的了,是可以通过模型学习的了,在这种情况下,我觉得模型就更加灵活,可以通过学习调整不同参数的学习率,比如对第一个参数,可以乘一个较大的学习率,对第二个参数,乘一个较小的学习率。个人觉得,这样求出来的theta,会比MAML更优。

至于a怎么学习,其实还是和普通的参数一样通过梯度下降学习:

anew=aβθL(fθi)a_{new} = a - \beta \nabla_{\theta} \sum L (f_{\theta '_i})

还记得MAML的损失函数是通过两次梯度下降求出来的,而a其实包含在第一次梯度下降后的公式中,所以a虽然作为学习率但其实也是模型损失函数的一部分,并不像平时我们进行一次梯度下降,对损失函数求导后再乘学习率。

在github写的自然语言处理入门教程,持续更新:NLPBeginner

在github写的机器学习入门教程,持续更新:MachineLearningModels

想浏览更多关于数学、机器学习、深度学习的内容,可浏览本人博客