元学习方法MAML应用于有监督学习
对论文Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks一部分难点的解读。
理解MAML中的元学习概念
MAML中,meta learning被体现为使用task进行训练、验证和测试的任务。这里的task是一个典型的机器学习任务,比如分类问题:
1.任务之间没有关系,而元学习的目的是通过学习一系列任务获得先验知识,从而在进行新的任务时能用更少的信息获得较好的学习效果。
2.任务的地位和普通的学习任务中的数据是一样的,虽然任务本身也是一个包含训练数据,测试数据和评价方法的机器学习任务。
MAML应用于有监督学习
原文的算法如下:
解读:
每一步,对第i个任务Ti执行如下过程:
- 1.初始状态:模型的全体参数暂时为,准备使用任务Ti(包括训练数据xi和测试数据yi)训练。
MAML中,模型的结构是完全确定的(提前由超参数确定)。这里说的模型参数不包括模型的超参数,而是确定模型forward函数的可变参数。- 2.训练:对训练数据xi和现有参数进行一步训练,即forward->计算loss->loss.backward的过程,得到模型新的参数
- 3.测试:对于参数为的模型,使用测试数据yi和现有参数进行一步测试,即forward->计算loss的过程,得到这个任务的loss
- 4.更新参数:把每个任务的loss加起来作为最终的loss,然后对进行backward,得到新的
经过多次迭代,的值得到优化,使得对于一个新的任务,给定的在一开始就使得模型具有较好的效果。
由此可见,MAML方法为新任务提供的先验知识是模型的初始化参数。这种方法要求所有的任务使用架构相同的模型,通过训练任务的训练,MAML能够为测试任务提供一个在一开始就表现较好的模型,从而在小样本训练中获得较好效果。
MAML/元学习和有监督学习的比较
将上述方法和传统的有监督学习比较:
MAML:
- 0.在这一步之前,模型的参数暂时为.
- 1.forward: H组输入,其中第i组是任务Ti.Ti输入模型得到输出
(相当于上一节过程中的第一步和第二步是forward)- 2.backward: 利用H个计算loss,并backward更新.
(相当于上一节过程中的第三步和第四步,根据forward结果计算loss并backward)
传统的有监督学习:
- 0.在这一步之前,模型的参数暂时为.
- 1.forward: H组输入,其中第i组是数据xi.输入模型得到输出
- 2.backward: 利用H个计算loss,并backward更新.
一个小trick
传统的有监督学习中,计算loss是对于计算的,但是对应的原论文的步骤是对计算的。
对于这一点,原论文的解释如下:
理论上确实应该用对于计算,但原论文为了计算方便,使用了一阶近似值,也就是用总loss对的梯度近似替代总loss对的梯度。
实验
原论文包括三部分实验,regression很dummy,classification就是对最典型的两个小样本学习数据集Omniglot和miniImagenet的测试,RL就是Benchmarking Deep Reinforcement Learning for Continuous Control里面的任务。在这方面没有什么可以学习的。论文在分类任务上获得了SOTA,在RL任务上大幅缩短了训练时间。